check NaN for hypernetwork tuning
This commit is contained in:
parent
5fd638f14d
commit
703e6d9e4e
1 changed files with 6 additions and 4 deletions
|
@ -272,15 +272,17 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
mean_loss = losses.mean()
|
||||||
pbar.set_description(f"loss: {losses.mean():.7f}")
|
if torch.isnan(mean_loss):
|
||||||
|
raise RuntimeError("Loss diverged.")
|
||||||
|
pbar.set_description(f"loss: {mean_loss:.7f}")
|
||||||
|
|
||||||
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
|
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
|
||||||
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
|
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
|
||||||
hypernetwork.save(last_saved_file)
|
hypernetwork.save(last_saved_file)
|
||||||
|
|
||||||
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
|
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
|
||||||
"loss": f"{losses.mean():.7f}",
|
"loss": f"{mean_loss:.7f}",
|
||||||
"learn_rate": scheduler.learn_rate
|
"learn_rate": scheduler.learn_rate
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -328,7 +330,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
|
|
||||||
shared.state.textinfo = f"""
|
shared.state.textinfo = f"""
|
||||||
<p>
|
<p>
|
||||||
Loss: {losses.mean():.7f}<br/>
|
Loss: {mean_loss:.7f}<br/>
|
||||||
Step: {hypernetwork.step}<br/>
|
Step: {hypernetwork.step}<br/>
|
||||||
Last prompt: {html.escape(entries[0].cond_text)}<br/>
|
Last prompt: {html.escape(entries[0].cond_text)}<br/>
|
||||||
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
||||||
|
|
Loading…
Reference in a new issue