Save a csv containing the loss while training
This commit is contained in:
parent
698d303b04
commit
1cfc2a1898
3 changed files with 35 additions and 2 deletions
|
@ -5,6 +5,7 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import tqdm
|
import tqdm
|
||||||
|
import csv
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -174,7 +175,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt):
|
def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, write_csv_every, template_file, preview_image_prompt):
|
||||||
assert hypernetwork_name, 'hypernetwork not selected'
|
assert hypernetwork_name, 'hypernetwork not selected'
|
||||||
|
|
||||||
path = shared.hypernetworks.get(hypernetwork_name, None)
|
path = shared.hypernetworks.get(hypernetwork_name, None)
|
||||||
|
@ -256,6 +257,20 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||||
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
|
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
|
||||||
hypernetwork.save(last_saved_file)
|
hypernetwork.save(last_saved_file)
|
||||||
|
|
||||||
|
print(f"{write_csv_every} > {hypernetwork.step % write_csv_every == 0}, {write_csv_every}")
|
||||||
|
if write_csv_every > 0 and hypernetwork_dir is not None and hypernetwork.step % write_csv_every == 0:
|
||||||
|
write_csv_header = False if os.path.exists(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv")) else True
|
||||||
|
|
||||||
|
with open(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv"), "a+") as fout:
|
||||||
|
|
||||||
|
csv_writer = csv.DictWriter(fout, fieldnames=["step", "loss"])
|
||||||
|
|
||||||
|
if write_csv_header:
|
||||||
|
csv_writer.writeheader()
|
||||||
|
|
||||||
|
csv_writer.writerow({"step": hypernetwork.step,
|
||||||
|
"loss": f"{losses.mean():.7f}"})
|
||||||
|
|
||||||
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
|
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
|
||||||
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
|
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
import html
|
import html
|
||||||
import datetime
|
import datetime
|
||||||
|
import csv
|
||||||
|
|
||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
|
|
||||||
|
@ -172,7 +173,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
|
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, write_csv_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
|
||||||
assert embedding_name, 'embedding not selected'
|
assert embedding_name, 'embedding not selected'
|
||||||
|
|
||||||
shared.state.textinfo = "Initializing textual inversion training..."
|
shared.state.textinfo = "Initializing textual inversion training..."
|
||||||
|
@ -256,6 +257,20 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
|
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
|
||||||
embedding.save(last_saved_file)
|
embedding.save(last_saved_file)
|
||||||
|
|
||||||
|
if write_csv_every > 0 and log_directory is not None and embedding.step % write_csv_every == 0:
|
||||||
|
write_csv_header = False if os.path.exists(os.path.join(log_directory, "textual_inversion_loss.csv")) else True
|
||||||
|
|
||||||
|
with open(os.path.join(log_directory, "textual_inversion_loss.csv"), "a+") as fout:
|
||||||
|
|
||||||
|
csv_writer = csv.DictWriter(fout, fieldnames=["epoch", "epoch_step", "loss"])
|
||||||
|
|
||||||
|
if write_csv_header:
|
||||||
|
csv_writer.writeheader()
|
||||||
|
|
||||||
|
csv_writer.writerow({"epoch": epoch_num + 1,
|
||||||
|
"epoch_step": epoch_step - 1,
|
||||||
|
"loss": f"{losses.mean():.7f}"})
|
||||||
|
|
||||||
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
|
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
|
||||||
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
|
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
|
||||||
|
|
||||||
|
|
|
@ -1096,6 +1096,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
||||||
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
||||||
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
|
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||||
|
write_csv_every = gr.Number(label='Save an csv containing the loss to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||||
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
|
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||||
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
|
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
|
||||||
preview_image_prompt = gr.Textbox(label='Preview prompt', value="")
|
preview_image_prompt = gr.Textbox(label='Preview prompt', value="")
|
||||||
|
@ -1174,6 +1175,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
steps,
|
steps,
|
||||||
create_image_every,
|
create_image_every,
|
||||||
save_embedding_every,
|
save_embedding_every,
|
||||||
|
write_csv_every,
|
||||||
template_file,
|
template_file,
|
||||||
save_image_with_stored_embedding,
|
save_image_with_stored_embedding,
|
||||||
preview_image_prompt,
|
preview_image_prompt,
|
||||||
|
@ -1195,6 +1197,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
steps,
|
steps,
|
||||||
create_image_every,
|
create_image_every,
|
||||||
save_embedding_every,
|
save_embedding_every,
|
||||||
|
write_csv_every,
|
||||||
template_file,
|
template_file,
|
||||||
preview_image_prompt,
|
preview_image_prompt,
|
||||||
],
|
],
|
||||||
|
|
Loading…
Reference in a new issue