Simplify grad clip

This commit is contained in:
Muhammad Rizqi Nur 2022-11-05 11:48:38 +07:00
parent 3277f90e93
commit bb832d7725
2 changed files with 14 additions and 18 deletions

View file

@ -385,10 +385,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
clip_grad_mode_value = clip_grad_mode == "value" clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
clip_grad_mode_norm = clip_grad_mode == "norm" torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm None
if clip_grad_enabled: if clip_grad:
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False) clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
# dataset loading may take a while, so input validations and early returns should be done before this # dataset loading may take a while, so input validations and early returns should be done before this
@ -433,7 +433,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if shared.state.interrupted: if shared.state.interrupted:
break break
if clip_grad_enabled: if clip_grad:
clip_grad_sched.step(hypernetwork.step) clip_grad_sched.step(hypernetwork.step)
with torch.autocast("cuda"): with torch.autocast("cuda"):
@ -458,10 +458,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
steps_without_grad = 0 steps_without_grad = 0
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue' assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
if clip_grad_mode_value: if clip_grad:
torch.nn.utils.clip_grad_value_(weights, clip_value=clip_grad_sched.learn_rate) clip_grad(weights, clip_grad_sched.learn_rate)
elif clip_grad_mode_norm:
torch.nn.utils.clip_grad_norm_(weights, max_norm=clip_grad_sched.learn_rate)
optimizer.step() optimizer.step()

View file

@ -269,10 +269,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
clip_grad_mode_value = clip_grad_mode == "value" clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
clip_grad_mode_norm = clip_grad_mode == "norm" torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm None
if clip_grad_enabled: if clip_grad:
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False) clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
# dataset loading may take a while, so input validations and early returns should be done before this # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
@ -302,7 +302,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if shared.state.interrupted: if shared.state.interrupted:
break break
if clip_grad_enabled: if clip_grad:
clip_grad_sched.step(embedding.step) clip_grad_sched.step(embedding.step)
with torch.autocast("cuda"): with torch.autocast("cuda"):
@ -316,10 +316,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
if clip_grad_mode_value: if clip_grad:
torch.nn.utils.clip_grad_value_(embedding.vec, clip_value=clip_grad_sched.learn_rate) clip_grad(embedding.vec, clip_grad_sched.learn_rate)
elif clip_grad_mode_norm:
torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=clip_grad_sched.learn_rate)
optimizer.step() optimizer.step()