statistics for pbar

This commit is contained in:
AngelBottomless 2022-10-23 21:29:53 +09:00 committed by AUTOMATIC1111
parent 40b56c9289
commit 348f89c8d4

View file

@ -335,6 +335,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
size = len(ds.indexes) size = len(ds.indexes)
loss_dict = defaultdict(lambda : deque(maxlen = 1024)) loss_dict = defaultdict(lambda : deque(maxlen = 1024))
losses = torch.zeros((size,)) losses = torch.zeros((size,))
previous_mean_losses = [0]
previous_mean_loss = 0 previous_mean_loss = 0
print("Mean loss of {} elements".format(size)) print("Mean loss of {} elements".format(size))
@ -356,7 +357,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
for i, entries in pbar: for i, entries in pbar:
hypernetwork.step = i + ititial_step hypernetwork.step = i + ititial_step
if len(loss_dict) > 0: if len(loss_dict) > 0:
previous_mean_loss = sum(i[-1] for i in loss_dict.values()) / len(loss_dict) previous_mean_losses = [i[-1] for i in loss_dict.values()]
previous_mean_loss = mean(previous_mean_losses)
scheduler.apply(optimizer, hypernetwork.step) scheduler.apply(optimizer, hypernetwork.step)
if scheduler.finished: if scheduler.finished:
@ -391,7 +393,13 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
raise RuntimeError("Loss diverged.") raise RuntimeError("Loss diverged.")
pbar.set_description(f"dataset loss: {previous_mean_loss:.7f}")
if len(previous_mean_losses) > 1:
std = stdev(previous_mean_losses)
else:
std = 0
dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
pbar.set_description(dataset_loss_info)
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint. # Before saving, change name to match current checkpoint.