Merge pull request #4271 from MarkovInequality/racecond_fix
Fixes #4137 caused by race condition in training when VAE is unloaded
This commit is contained in:
commit
5267414319
2 changed files with 8 additions and 0 deletions
|
@ -433,7 +433,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
|||
|
||||
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
|
||||
|
||||
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
||||
|
||||
if unload:
|
||||
shared.parallel_processing_allowed = False
|
||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
|
@ -612,10 +615,12 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||
if shared.opts.save_optimizer_state:
|
||||
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
||||
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
|
||||
|
||||
del optimizer
|
||||
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
shared.parallel_processing_allowed = old_parallel_processing_allowed
|
||||
|
||||
return hypernetwork, filename
|
||||
|
||||
|
|
|
@ -269,6 +269,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
|||
|
||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
||||
|
||||
pin_memory = shared.opts.pin_memory
|
||||
|
||||
|
@ -279,6 +280,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
|||
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
|
||||
|
||||
if unload:
|
||||
shared.parallel_processing_allowed = False
|
||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||
|
||||
embedding.vec.requires_grad = True
|
||||
|
@ -450,6 +452,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||
pbar.leave = False
|
||||
pbar.close()
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
shared.parallel_processing_allowed = old_parallel_processing_allowed
|
||||
|
||||
return embedding, filename
|
||||
|
||||
|
|
Loading…
Reference in a new issue