Merge branch 'AUTOMATIC1111:master' into master
This commit is contained in:
commit
97749b7c7d
2 changed files with 14 additions and 0 deletions
|
@ -325,6 +325,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
|
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
|
||||||
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
|
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
|
||||||
|
|
||||||
|
steps_without_grad = 0
|
||||||
|
|
||||||
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
||||||
for i, entries in pbar:
|
for i, entries in pbar:
|
||||||
hypernetwork.step = i + ititial_step
|
hypernetwork.step = i + ititial_step
|
||||||
|
@ -347,8 +349,17 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
losses[hypernetwork.step % losses.shape[0]] = loss.item()
|
losses[hypernetwork.step % losses.shape[0]] = loss.item()
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
weights[0].grad = None
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
if weights[0].grad is None:
|
||||||
|
steps_without_grad += 1
|
||||||
|
else:
|
||||||
|
steps_without_grad = 0
|
||||||
|
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
mean_loss = losses.mean()
|
mean_loss = losses.mean()
|
||||||
if torch.isnan(mean_loss):
|
if torch.isnan(mean_loss):
|
||||||
raise RuntimeError("Loss diverged.")
|
raise RuntimeError("Loss diverged.")
|
||||||
|
|
|
@ -1648,6 +1648,9 @@ Requested path was: {f}
|
||||||
css = ""
|
css = ""
|
||||||
|
|
||||||
for cssfile in modules.scripts.list_files_with_name("style.css"):
|
for cssfile in modules.scripts.list_files_with_name("style.css"):
|
||||||
|
if not os.path.isfile(cssfile):
|
||||||
|
continue
|
||||||
|
|
||||||
with open(cssfile, "r", encoding="utf8") as file:
|
with open(cssfile, "r", encoding="utf8") as file:
|
||||||
css += file.read() + "\n"
|
css += file.read() + "\n"
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue