From 24694e5983d0944b901892cb101878e6dec89a20 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Sun, 23 Oct 2022 01:57:58 +0900 Subject: [PATCH] Update hypernetwork.py --- modules/hypernetworks/hypernetwork.py | 55 +++++++++++++++++++++------ 1 file changed, 44 insertions(+), 11 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 3bc71ee5..81132be4 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -16,6 +16,7 @@ from modules.textual_inversion import textual_inversion from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum +from statistics import stdev, mean class HypernetworkModule(torch.nn.Module): multiplier = 1.0 @@ -268,6 +269,32 @@ def stack_conds(conds): return torch.stack(conds) +def log_statistics(loss_info:dict, key, value): + if key not in loss_info: + loss_info[key] = [value] + else: + loss_info[key].append(value) + if len(loss_info) > 1024: + loss_info.pop(0) + + +def statistics(data): + total_information = f"loss:{mean(data):.3f}"+u"\u00B1"+f"({stdev(data)/ (len(data)**0.5):.3f})" + recent_data = data[-32:] + recent_information = f"recent 32 loss:{mean(recent_data):.3f}"+u"\u00B1"+f"({stdev(recent_data)/ (len(recent_data)**0.5):.3f})" + return total_information, recent_information + + +def report_statistics(loss_info:dict): + keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x])) + for key in keys: + info, recent = statistics(loss_info[key]) + print("Loss statistics for file " + key) + print(info) + print(recent) + + + def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images @@ -310,7 +337,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log for weight in weights: weight.requires_grad = True - losses = torch.zeros((32,)) + size = len(ds.indexes) + loss_dict = {} + losses = torch.zeros((size,)) + previous_mean_loss = 0 + print("Mean loss of {} elements".format(size)) last_saved_file = "" last_saved_image = "" @@ -329,7 +360,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) for i, entries in pbar: hypernetwork.step = i + ititial_step - + if loss_dict and i % size == 0: + previous_mean_loss = sum(i[-1] for i in loss_dict.values()) / len(loss_dict) + scheduler.apply(optimizer, hypernetwork.step) if scheduler.finished: break @@ -346,7 +379,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log del c losses[hypernetwork.step % losses.shape[0]] = loss.item() - + for entry in entries: + log_statistics(loss_dict, entry.filename, loss.item()) + optimizer.zero_grad() weights[0].grad = None loss.backward() @@ -359,10 +394,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log optimizer.step() - mean_loss = losses.mean() - if torch.isnan(mean_loss): + if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): raise RuntimeError("Loss diverged.") - pbar.set_description(f"loss: {mean_loss:.7f}") + pbar.set_description(f"dataset loss: {previous_mean_loss:.7f}") if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: # Before saving, change name to match current checkpoint. @@ -371,7 +405,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log hypernetwork.save(last_saved_file) textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { - "loss": f"{mean_loss:.7f}", + "loss": f"{previous_mean_loss:.7f}", "learn_rate": scheduler.learn_rate }) @@ -420,14 +454,15 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log shared.state.textinfo = f"""

-Loss: {mean_loss:.7f}
+Loss: {previous_mean_loss:.7f}
Step: {hypernetwork.step}
Last prompt: {html.escape(entries[0].cond_text)}
Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

""" - + + report_statistics(loss_dict) checkpoint = sd_models.select_checkpoint() hypernetwork.sd_checkpoint = checkpoint.hash @@ -438,5 +473,3 @@ Last saved image: {html.escape(last_saved_image)}
hypernetwork.save(filename) return hypernetwork, filename - -