Allow tracking real-time loss
Someone had 6000 images in their dataset, and it was shown as 0, which was confusing. This will allow tracking real time dataset-average loss for registered objects.
This commit is contained in:
parent
ca5a9e79dc
commit
48dbf99e84
1 changed files with 1 additions and 1 deletions
|
@ -360,7 +360,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
||||||
for i, entries in pbar:
|
for i, entries in pbar:
|
||||||
hypernetwork.step = i + ititial_step
|
hypernetwork.step = i + ititial_step
|
||||||
if loss_dict and i % size == 0:
|
if len(loss_dict) > 0:
|
||||||
previous_mean_loss = sum(i[-1] for i in loss_dict.values()) / len(loss_dict)
|
previous_mean_loss = sum(i[-1] for i in loss_dict.values()) / len(loss_dict)
|
||||||
|
|
||||||
scheduler.apply(optimizer, hypernetwork.step)
|
scheduler.apply(optimizer, hypernetwork.step)
|
||||||
|
|
Loading…
Reference in a new issue