Merge branch 'AUTOMATIC1111:master' into master

This commit is contained in:
discus0434 2022-10-22 22:00:59 +09:00 committed by GitHub
commit 97749b7c7d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 0 deletions

View file

@ -325,6 +325,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc... # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
steps_without_grad = 0
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entries in pbar: for i, entries in pbar:
hypernetwork.step = i + ititial_step hypernetwork.step = i + ititial_step
@ -347,8 +349,17 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
losses[hypernetwork.step % losses.shape[0]] = loss.item() losses[hypernetwork.step % losses.shape[0]] = loss.item()
optimizer.zero_grad() optimizer.zero_grad()
weights[0].grad = None
loss.backward() loss.backward()
if weights[0].grad is None:
steps_without_grad += 1
else:
steps_without_grad = 0
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
optimizer.step() optimizer.step()
mean_loss = losses.mean() mean_loss = losses.mean()
if torch.isnan(mean_loss): if torch.isnan(mean_loss):
raise RuntimeError("Loss diverged.") raise RuntimeError("Loss diverged.")

View file

@ -1648,6 +1648,9 @@ Requested path was: {f}
css = "" css = ""
for cssfile in modules.scripts.list_files_with_name("style.css"): for cssfile in modules.scripts.list_files_with_name("style.css"):
if not os.path.isfile(cssfile):
continue
with open(cssfile, "r", encoding="utf8") as file: with open(cssfile, "r", encoding="utf8") as file:
css += file.read() + "\n" css += file.read() + "\n"