Allow tracking real-time loss

Someone had 6000 images in their dataset, and it was shown as 0, which was confusing.
This will allow tracking real time dataset-average loss for registered objects.
This commit is contained in:
AngelBottomless 2022-10-23 04:17:16 +09:00 committed by AUTOMATIC1111
parent ca5a9e79dc
commit 48dbf99e84

View file

@ -360,7 +360,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entries in pbar: for i, entries in pbar:
hypernetwork.step = i + ititial_step hypernetwork.step = i + ititial_step
if loss_dict and i % size == 0: if len(loss_dict) > 0:
previous_mean_loss = sum(i[-1] for i in loss_dict.values()) / len(loss_dict) previous_mean_loss = sum(i[-1] for i in loss_dict.values()) / len(loss_dict)
scheduler.apply(optimizer, hypernetwork.step) scheduler.apply(optimizer, hypernetwork.step)