Add cleanup after training
This commit is contained in:
parent
ab27c111d0
commit
3ce2bfdf95
2 changed files with 182 additions and 168 deletions
|
@ -398,6 +398,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
forced_filename = "<none>"
|
forced_filename = "<none>"
|
||||||
|
|
||||||
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
||||||
|
|
||||||
|
try:
|
||||||
for i, entries in pbar:
|
for i, entries in pbar:
|
||||||
hypernetwork.step = i + ititial_step
|
hypernetwork.step = i + ititial_step
|
||||||
if len(loss_dict) > 0:
|
if len(loss_dict) > 0:
|
||||||
|
@ -510,6 +512,13 @@ Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
||||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
</p>
|
</p>
|
||||||
"""
|
"""
|
||||||
|
finally:
|
||||||
|
if weights:
|
||||||
|
for weight in weights:
|
||||||
|
weight.requires_grad = False
|
||||||
|
if unload:
|
||||||
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
|
shared.sd_model.first_stage_model.to(devices.device)
|
||||||
|
|
||||||
report_statistics(loss_dict)
|
report_statistics(loss_dict)
|
||||||
checkpoint = sd_models.select_checkpoint()
|
checkpoint = sd_models.select_checkpoint()
|
||||||
|
|
|
@ -283,6 +283,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
||||||
embedding_yet_to_be_embedded = False
|
embedding_yet_to_be_embedded = False
|
||||||
|
|
||||||
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
||||||
|
|
||||||
|
try:
|
||||||
for i, entries in pbar:
|
for i, entries in pbar:
|
||||||
embedding.step = i + ititial_step
|
embedding.step = i + ititial_step
|
||||||
|
|
||||||
|
@ -396,6 +398,9 @@ Last saved embedding: {html.escape(last_saved_file)}<br/>
|
||||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
</p>
|
</p>
|
||||||
"""
|
"""
|
||||||
|
finally:
|
||||||
|
if embedding and embedding.vec is not None:
|
||||||
|
embedding.vec.requires_grad = False
|
||||||
|
|
||||||
checkpoint = sd_models.select_checkpoint()
|
checkpoint = sd_models.select_checkpoint()
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue