fix missing 'mean loss' for tensorboard integration

This commit is contained in:
AngelBottomless 2023-01-16 02:08:47 +09:00 committed by GitHub
parent ce13ced5dc
commit 16f410893e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -644,7 +644,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
if shared.opts.training_enable_tensorboard:
epoch_num = hypernetwork.step // len(ds)
epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
mean_loss = sum(sum(x) for x in loss_dict.values()) / sum(len(x) for x in loss_dict.values())
textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {