Add learn_rate to csv and removed a left-over debug statement

This commit is contained in:
Melan 2022-10-13 12:37:58 +02:00
parent 1cfc2a1898
commit 8636b50aea
2 changed files with 6 additions and 5 deletions

View file

@ -257,19 +257,19 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt') last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
hypernetwork.save(last_saved_file) hypernetwork.save(last_saved_file)
print(f"{write_csv_every} > {hypernetwork.step % write_csv_every == 0}, {write_csv_every}")
if write_csv_every > 0 and hypernetwork_dir is not None and hypernetwork.step % write_csv_every == 0: if write_csv_every > 0 and hypernetwork_dir is not None and hypernetwork.step % write_csv_every == 0:
write_csv_header = False if os.path.exists(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv")) else True write_csv_header = False if os.path.exists(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv")) else True
with open(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv"), "a+") as fout: with open(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv"), "a+") as fout:
csv_writer = csv.DictWriter(fout, fieldnames=["step", "loss"]) csv_writer = csv.DictWriter(fout, fieldnames=["step", "loss", "learn_rate"])
if write_csv_header: if write_csv_header:
csv_writer.writeheader() csv_writer.writeheader()
csv_writer.writerow({"step": hypernetwork.step, csv_writer.writerow({"step": hypernetwork.step,
"loss": f"{losses.mean():.7f}"}) "loss": f"{losses.mean():.7f}",
"learn_rate": scheduler.learn_rate})
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')

View file

@ -262,14 +262,15 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
with open(os.path.join(log_directory, "textual_inversion_loss.csv"), "a+") as fout: with open(os.path.join(log_directory, "textual_inversion_loss.csv"), "a+") as fout:
csv_writer = csv.DictWriter(fout, fieldnames=["epoch", "epoch_step", "loss"]) csv_writer = csv.DictWriter(fout, fieldnames=["epoch", "epoch_step", "loss", "learn_rate"])
if write_csv_header: if write_csv_header:
csv_writer.writeheader() csv_writer.writeheader()
csv_writer.writerow({"epoch": epoch_num + 1, csv_writer.writerow({"epoch": epoch_num + 1,
"epoch_step": epoch_step - 1, "epoch_step": epoch_step - 1,
"loss": f"{losses.mean():.7f}"}) "loss": f"{losses.mean():.7f}",
"learn_rate": scheduler.learn_rate})
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')