Gradient clipping for textual embedding
This commit is contained in:
parent
a133042c66
commit
1618df41ba
2 changed files with 12 additions and 1 deletions
|
@ -206,7 +206,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
assert embedding_name, 'embedding not selected'
|
assert embedding_name, 'embedding not selected'
|
||||||
|
|
||||||
shared.state.textinfo = "Initializing textual inversion training..."
|
shared.state.textinfo = "Initializing textual inversion training..."
|
||||||
|
@ -256,6 +256,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
||||||
if ititial_step > steps:
|
if ititial_step > steps:
|
||||||
return embedding, filename
|
return embedding, filename
|
||||||
|
|
||||||
|
clip_grad_mode_value = clip_grad_mode == "value"
|
||||||
|
clip_grad_mode_norm = clip_grad_mode == "norm"
|
||||||
|
|
||||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
||||||
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
|
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
|
||||||
|
|
||||||
|
@ -280,6 +283,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
if clip_grad_mode_value:
|
||||||
|
torch.nn.utils.clip_grad_value_(embedding.vec, clip_value=clip_grad_value)
|
||||||
|
elif clip_grad_mode_norm:
|
||||||
|
torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=clip_grad_value)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1409,6 +1409,8 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
training_width,
|
training_width,
|
||||||
training_height,
|
training_height,
|
||||||
steps,
|
steps,
|
||||||
|
clip_grad_mode,
|
||||||
|
clip_grad_value,
|
||||||
create_image_every,
|
create_image_every,
|
||||||
save_embedding_every,
|
save_embedding_every,
|
||||||
template_file,
|
template_file,
|
||||||
|
|
Loading…
Reference in a new issue