Add cleanup after training
This commit is contained in:
parent
ab27c111d0
commit
3ce2bfdf95
2 changed files with 182 additions and 168 deletions
|
@ -398,110 +398,112 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
forced_filename = "<none>"
|
forced_filename = "<none>"
|
||||||
|
|
||||||
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
||||||
for i, entries in pbar:
|
|
||||||
hypernetwork.step = i + ititial_step
|
|
||||||
if len(loss_dict) > 0:
|
|
||||||
previous_mean_losses = [i[-1] for i in loss_dict.values()]
|
|
||||||
previous_mean_loss = mean(previous_mean_losses)
|
|
||||||
|
|
||||||
scheduler.apply(optimizer, hypernetwork.step)
|
|
||||||
if scheduler.finished:
|
|
||||||
break
|
|
||||||
|
|
||||||
if shared.state.interrupted:
|
try:
|
||||||
break
|
for i, entries in pbar:
|
||||||
|
hypernetwork.step = i + ititial_step
|
||||||
with torch.autocast("cuda"):
|
if len(loss_dict) > 0:
|
||||||
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
|
previous_mean_losses = [i[-1] for i in loss_dict.values()]
|
||||||
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
|
previous_mean_loss = mean(previous_mean_losses)
|
||||||
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
|
|
||||||
loss = shared.sd_model(x, c)[0]
|
|
||||||
del x
|
|
||||||
del c
|
|
||||||
|
|
||||||
losses[hypernetwork.step % losses.shape[0]] = loss.item()
|
|
||||||
for entry in entries:
|
|
||||||
loss_dict[entry.filename].append(loss.item())
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
scheduler.apply(optimizer, hypernetwork.step)
|
||||||
weights[0].grad = None
|
if scheduler.finished:
|
||||||
loss.backward()
|
break
|
||||||
|
|
||||||
if weights[0].grad is None:
|
if shared.state.interrupted:
|
||||||
steps_without_grad += 1
|
break
|
||||||
|
|
||||||
|
with torch.autocast("cuda"):
|
||||||
|
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
|
||||||
|
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
|
||||||
|
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
|
||||||
|
loss = shared.sd_model(x, c)[0]
|
||||||
|
del x
|
||||||
|
del c
|
||||||
|
|
||||||
|
losses[hypernetwork.step % losses.shape[0]] = loss.item()
|
||||||
|
for entry in entries:
|
||||||
|
loss_dict[entry.filename].append(loss.item())
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
weights[0].grad = None
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
if weights[0].grad is None:
|
||||||
|
steps_without_grad += 1
|
||||||
|
else:
|
||||||
|
steps_without_grad = 0
|
||||||
|
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
steps_done = hypernetwork.step + 1
|
||||||
|
|
||||||
|
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
|
||||||
|
raise RuntimeError("Loss diverged.")
|
||||||
|
|
||||||
|
if len(previous_mean_losses) > 1:
|
||||||
|
std = stdev(previous_mean_losses)
|
||||||
else:
|
else:
|
||||||
steps_without_grad = 0
|
std = 0
|
||||||
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
|
dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
|
||||||
|
pbar.set_description(dataset_loss_info)
|
||||||
|
|
||||||
optimizer.step()
|
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
|
||||||
|
# Before saving, change name to match current checkpoint.
|
||||||
|
hypernetwork.name = f'{hypernetwork_name}-{steps_done}'
|
||||||
|
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
|
||||||
|
hypernetwork.save(last_saved_file)
|
||||||
|
|
||||||
steps_done = hypernetwork.step + 1
|
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
|
||||||
|
"loss": f"{previous_mean_loss:.7f}",
|
||||||
|
"learn_rate": scheduler.learn_rate
|
||||||
|
})
|
||||||
|
|
||||||
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
|
if images_dir is not None and steps_done % create_image_every == 0:
|
||||||
raise RuntimeError("Loss diverged.")
|
forced_filename = f'{hypernetwork_name}-{steps_done}'
|
||||||
|
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||||
if len(previous_mean_losses) > 1:
|
|
||||||
std = stdev(previous_mean_losses)
|
|
||||||
else:
|
|
||||||
std = 0
|
|
||||||
dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
|
|
||||||
pbar.set_description(dataset_loss_info)
|
|
||||||
|
|
||||||
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
|
optimizer.zero_grad()
|
||||||
# Before saving, change name to match current checkpoint.
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
hypernetwork.name = f'{hypernetwork_name}-{steps_done}'
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
|
|
||||||
hypernetwork.save(last_saved_file)
|
|
||||||
|
|
||||||
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
|
p = processing.StableDiffusionProcessingTxt2Img(
|
||||||
"loss": f"{previous_mean_loss:.7f}",
|
sd_model=shared.sd_model,
|
||||||
"learn_rate": scheduler.learn_rate
|
do_not_save_grid=True,
|
||||||
})
|
do_not_save_samples=True,
|
||||||
|
)
|
||||||
|
|
||||||
if images_dir is not None and steps_done % create_image_every == 0:
|
if preview_from_txt2img:
|
||||||
forced_filename = f'{hypernetwork_name}-{steps_done}'
|
p.prompt = preview_prompt
|
||||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
p.negative_prompt = preview_negative_prompt
|
||||||
|
p.steps = preview_steps
|
||||||
|
p.sampler_index = preview_sampler_index
|
||||||
|
p.cfg_scale = preview_cfg_scale
|
||||||
|
p.seed = preview_seed
|
||||||
|
p.width = preview_width
|
||||||
|
p.height = preview_height
|
||||||
|
else:
|
||||||
|
p.prompt = entries[0].cond_text
|
||||||
|
p.steps = 20
|
||||||
|
|
||||||
optimizer.zero_grad()
|
preview_text = p.prompt
|
||||||
shared.sd_model.cond_stage_model.to(devices.device)
|
|
||||||
shared.sd_model.first_stage_model.to(devices.device)
|
|
||||||
|
|
||||||
p = processing.StableDiffusionProcessingTxt2Img(
|
processed = processing.process_images(p)
|
||||||
sd_model=shared.sd_model,
|
image = processed.images[0] if len(processed.images)>0 else None
|
||||||
do_not_save_grid=True,
|
|
||||||
do_not_save_samples=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if preview_from_txt2img:
|
if unload:
|
||||||
p.prompt = preview_prompt
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
p.negative_prompt = preview_negative_prompt
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
p.steps = preview_steps
|
|
||||||
p.sampler_index = preview_sampler_index
|
|
||||||
p.cfg_scale = preview_cfg_scale
|
|
||||||
p.seed = preview_seed
|
|
||||||
p.width = preview_width
|
|
||||||
p.height = preview_height
|
|
||||||
else:
|
|
||||||
p.prompt = entries[0].cond_text
|
|
||||||
p.steps = 20
|
|
||||||
|
|
||||||
preview_text = p.prompt
|
if image is not None:
|
||||||
|
shared.state.current_image = image
|
||||||
|
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||||
|
last_saved_image += f", prompt: {preview_text}"
|
||||||
|
|
||||||
processed = processing.process_images(p)
|
shared.state.job_no = hypernetwork.step
|
||||||
image = processed.images[0] if len(processed.images)>0 else None
|
|
||||||
|
|
||||||
if unload:
|
shared.state.textinfo = f"""
|
||||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
|
||||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
|
||||||
|
|
||||||
if image is not None:
|
|
||||||
shared.state.current_image = image
|
|
||||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
|
||||||
last_saved_image += f", prompt: {preview_text}"
|
|
||||||
|
|
||||||
shared.state.job_no = hypernetwork.step
|
|
||||||
|
|
||||||
shared.state.textinfo = f"""
|
|
||||||
<p>
|
<p>
|
||||||
Loss: {previous_mean_loss:.7f}<br/>
|
Loss: {previous_mean_loss:.7f}<br/>
|
||||||
Step: {hypernetwork.step}<br/>
|
Step: {hypernetwork.step}<br/>
|
||||||
|
@ -510,7 +512,14 @@ Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
||||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
</p>
|
</p>
|
||||||
"""
|
"""
|
||||||
|
finally:
|
||||||
|
if weights:
|
||||||
|
for weight in weights:
|
||||||
|
weight.requires_grad = False
|
||||||
|
if unload:
|
||||||
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
|
|
||||||
report_statistics(loss_dict)
|
report_statistics(loss_dict)
|
||||||
checkpoint = sd_models.select_checkpoint()
|
checkpoint = sd_models.select_checkpoint()
|
||||||
|
|
||||||
|
|
|
@ -283,111 +283,113 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
||||||
embedding_yet_to_be_embedded = False
|
embedding_yet_to_be_embedded = False
|
||||||
|
|
||||||
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
||||||
for i, entries in pbar:
|
|
||||||
embedding.step = i + ititial_step
|
|
||||||
|
|
||||||
scheduler.apply(optimizer, embedding.step)
|
try:
|
||||||
if scheduler.finished:
|
for i, entries in pbar:
|
||||||
break
|
embedding.step = i + ititial_step
|
||||||
|
|
||||||
if shared.state.interrupted:
|
scheduler.apply(optimizer, embedding.step)
|
||||||
break
|
if scheduler.finished:
|
||||||
|
break
|
||||||
|
|
||||||
with torch.autocast("cuda"):
|
if shared.state.interrupted:
|
||||||
c = cond_model([entry.cond_text for entry in entries])
|
break
|
||||||
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
|
|
||||||
loss = shared.sd_model(x, c)[0]
|
|
||||||
del x
|
|
||||||
|
|
||||||
losses[embedding.step % losses.shape[0]] = loss.item()
|
with torch.autocast("cuda"):
|
||||||
|
c = cond_model([entry.cond_text for entry in entries])
|
||||||
|
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
|
||||||
|
loss = shared.sd_model(x, c)[0]
|
||||||
|
del x
|
||||||
|
|
||||||
optimizer.zero_grad()
|
losses[embedding.step % losses.shape[0]] = loss.item()
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
steps_done = embedding.step + 1
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
epoch_num = embedding.step // len(ds)
|
steps_done = embedding.step + 1
|
||||||
epoch_step = embedding.step % len(ds)
|
|
||||||
|
|
||||||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
|
epoch_num = embedding.step // len(ds)
|
||||||
|
epoch_step = embedding.step % len(ds)
|
||||||
|
|
||||||
if embedding_dir is not None and steps_done % save_embedding_every == 0:
|
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
|
||||||
# Before saving, change name to match current checkpoint.
|
|
||||||
embedding.name = f'{embedding_name}-{steps_done}'
|
|
||||||
last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
|
|
||||||
embedding.save(last_saved_file)
|
|
||||||
embedding_yet_to_be_embedded = True
|
|
||||||
|
|
||||||
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
|
if embedding_dir is not None and steps_done % save_embedding_every == 0:
|
||||||
"loss": f"{losses.mean():.7f}",
|
# Before saving, change name to match current checkpoint.
|
||||||
"learn_rate": scheduler.learn_rate
|
embedding.name = f'{embedding_name}-{steps_done}'
|
||||||
})
|
last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
|
||||||
|
embedding.save(last_saved_file)
|
||||||
|
embedding_yet_to_be_embedded = True
|
||||||
|
|
||||||
if images_dir is not None and steps_done % create_image_every == 0:
|
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
|
||||||
forced_filename = f'{embedding_name}-{steps_done}'
|
"loss": f"{losses.mean():.7f}",
|
||||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
"learn_rate": scheduler.learn_rate
|
||||||
p = processing.StableDiffusionProcessingTxt2Img(
|
})
|
||||||
sd_model=shared.sd_model,
|
|
||||||
do_not_save_grid=True,
|
|
||||||
do_not_save_samples=True,
|
|
||||||
do_not_reload_embeddings=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if preview_from_txt2img:
|
if images_dir is not None and steps_done % create_image_every == 0:
|
||||||
p.prompt = preview_prompt
|
forced_filename = f'{embedding_name}-{steps_done}'
|
||||||
p.negative_prompt = preview_negative_prompt
|
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||||
p.steps = preview_steps
|
p = processing.StableDiffusionProcessingTxt2Img(
|
||||||
p.sampler_index = preview_sampler_index
|
sd_model=shared.sd_model,
|
||||||
p.cfg_scale = preview_cfg_scale
|
do_not_save_grid=True,
|
||||||
p.seed = preview_seed
|
do_not_save_samples=True,
|
||||||
p.width = preview_width
|
do_not_reload_embeddings=True,
|
||||||
p.height = preview_height
|
)
|
||||||
else:
|
|
||||||
p.prompt = entries[0].cond_text
|
|
||||||
p.steps = 20
|
|
||||||
p.width = training_width
|
|
||||||
p.height = training_height
|
|
||||||
|
|
||||||
preview_text = p.prompt
|
if preview_from_txt2img:
|
||||||
|
p.prompt = preview_prompt
|
||||||
|
p.negative_prompt = preview_negative_prompt
|
||||||
|
p.steps = preview_steps
|
||||||
|
p.sampler_index = preview_sampler_index
|
||||||
|
p.cfg_scale = preview_cfg_scale
|
||||||
|
p.seed = preview_seed
|
||||||
|
p.width = preview_width
|
||||||
|
p.height = preview_height
|
||||||
|
else:
|
||||||
|
p.prompt = entries[0].cond_text
|
||||||
|
p.steps = 20
|
||||||
|
p.width = training_width
|
||||||
|
p.height = training_height
|
||||||
|
|
||||||
processed = processing.process_images(p)
|
preview_text = p.prompt
|
||||||
image = processed.images[0]
|
|
||||||
|
|
||||||
shared.state.current_image = image
|
processed = processing.process_images(p)
|
||||||
|
image = processed.images[0]
|
||||||
|
|
||||||
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
shared.state.current_image = image
|
||||||
|
|
||||||
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
|
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
||||||
|
|
||||||
info = PngImagePlugin.PngInfo()
|
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
|
||||||
data = torch.load(last_saved_file)
|
|
||||||
info.add_text("sd-ti-embedding", embedding_to_b64(data))
|
|
||||||
|
|
||||||
title = "<{}>".format(data.get('name', '???'))
|
info = PngImagePlugin.PngInfo()
|
||||||
|
data = torch.load(last_saved_file)
|
||||||
|
info.add_text("sd-ti-embedding", embedding_to_b64(data))
|
||||||
|
|
||||||
try:
|
title = "<{}>".format(data.get('name', '???'))
|
||||||
vectorSize = list(data['string_to_param'].values())[0].shape[0]
|
|
||||||
except Exception as e:
|
|
||||||
vectorSize = '?'
|
|
||||||
|
|
||||||
checkpoint = sd_models.select_checkpoint()
|
try:
|
||||||
footer_left = checkpoint.model_name
|
vectorSize = list(data['string_to_param'].values())[0].shape[0]
|
||||||
footer_mid = '[{}]'.format(checkpoint.hash)
|
except Exception as e:
|
||||||
footer_right = '{}v {}s'.format(vectorSize, steps_done)
|
vectorSize = '?'
|
||||||
|
|
||||||
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
|
checkpoint = sd_models.select_checkpoint()
|
||||||
captioned_image = insert_image_data_embed(captioned_image, data)
|
footer_left = checkpoint.model_name
|
||||||
|
footer_mid = '[{}]'.format(checkpoint.hash)
|
||||||
|
footer_right = '{}v {}s'.format(vectorSize, steps_done)
|
||||||
|
|
||||||
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
|
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
|
||||||
embedding_yet_to_be_embedded = False
|
captioned_image = insert_image_data_embed(captioned_image, data)
|
||||||
|
|
||||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
|
||||||
last_saved_image += f", prompt: {preview_text}"
|
embedding_yet_to_be_embedded = False
|
||||||
|
|
||||||
shared.state.job_no = embedding.step
|
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||||
|
last_saved_image += f", prompt: {preview_text}"
|
||||||
|
|
||||||
shared.state.textinfo = f"""
|
shared.state.job_no = embedding.step
|
||||||
|
|
||||||
|
shared.state.textinfo = f"""
|
||||||
<p>
|
<p>
|
||||||
Loss: {losses.mean():.7f}<br/>
|
Loss: {losses.mean():.7f}<br/>
|
||||||
Step: {embedding.step}<br/>
|
Step: {embedding.step}<br/>
|
||||||
|
@ -396,6 +398,9 @@ Last saved embedding: {html.escape(last_saved_file)}<br/>
|
||||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
</p>
|
</p>
|
||||||
"""
|
"""
|
||||||
|
finally:
|
||||||
|
if embedding and embedding.vec is not None:
|
||||||
|
embedding.vec.requires_grad = False
|
||||||
|
|
||||||
checkpoint = sd_models.select_checkpoint()
|
checkpoint = sd_models.select_checkpoint()
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue