Update hypernetwork.py
This commit is contained in:
parent
321bacc6a9
commit
24694e5983
1 changed files with 44 additions and 11 deletions
|
@ -16,6 +16,7 @@ from modules.textual_inversion import textual_inversion
|
||||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
|
|
||||||
|
from statistics import stdev, mean
|
||||||
|
|
||||||
class HypernetworkModule(torch.nn.Module):
|
class HypernetworkModule(torch.nn.Module):
|
||||||
multiplier = 1.0
|
multiplier = 1.0
|
||||||
|
@ -268,6 +269,32 @@ def stack_conds(conds):
|
||||||
return torch.stack(conds)
|
return torch.stack(conds)
|
||||||
|
|
||||||
|
|
||||||
|
def log_statistics(loss_info:dict, key, value):
|
||||||
|
if key not in loss_info:
|
||||||
|
loss_info[key] = [value]
|
||||||
|
else:
|
||||||
|
loss_info[key].append(value)
|
||||||
|
if len(loss_info) > 1024:
|
||||||
|
loss_info.pop(0)
|
||||||
|
|
||||||
|
|
||||||
|
def statistics(data):
|
||||||
|
total_information = f"loss:{mean(data):.3f}"+u"\u00B1"+f"({stdev(data)/ (len(data)**0.5):.3f})"
|
||||||
|
recent_data = data[-32:]
|
||||||
|
recent_information = f"recent 32 loss:{mean(recent_data):.3f}"+u"\u00B1"+f"({stdev(recent_data)/ (len(recent_data)**0.5):.3f})"
|
||||||
|
return total_information, recent_information
|
||||||
|
|
||||||
|
|
||||||
|
def report_statistics(loss_info:dict):
|
||||||
|
keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
|
||||||
|
for key in keys:
|
||||||
|
info, recent = statistics(loss_info[key])
|
||||||
|
print("Loss statistics for file " + key)
|
||||||
|
print(info)
|
||||||
|
print(recent)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||||
from modules import images
|
from modules import images
|
||||||
|
@ -310,7 +337,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
for weight in weights:
|
for weight in weights:
|
||||||
weight.requires_grad = True
|
weight.requires_grad = True
|
||||||
|
|
||||||
losses = torch.zeros((32,))
|
size = len(ds.indexes)
|
||||||
|
loss_dict = {}
|
||||||
|
losses = torch.zeros((size,))
|
||||||
|
previous_mean_loss = 0
|
||||||
|
print("Mean loss of {} elements".format(size))
|
||||||
|
|
||||||
last_saved_file = "<none>"
|
last_saved_file = "<none>"
|
||||||
last_saved_image = "<none>"
|
last_saved_image = "<none>"
|
||||||
|
@ -329,7 +360,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
||||||
for i, entries in pbar:
|
for i, entries in pbar:
|
||||||
hypernetwork.step = i + ititial_step
|
hypernetwork.step = i + ititial_step
|
||||||
|
if loss_dict and i % size == 0:
|
||||||
|
previous_mean_loss = sum(i[-1] for i in loss_dict.values()) / len(loss_dict)
|
||||||
|
|
||||||
scheduler.apply(optimizer, hypernetwork.step)
|
scheduler.apply(optimizer, hypernetwork.step)
|
||||||
if scheduler.finished:
|
if scheduler.finished:
|
||||||
break
|
break
|
||||||
|
@ -346,7 +379,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
del c
|
del c
|
||||||
|
|
||||||
losses[hypernetwork.step % losses.shape[0]] = loss.item()
|
losses[hypernetwork.step % losses.shape[0]] = loss.item()
|
||||||
|
for entry in entries:
|
||||||
|
log_statistics(loss_dict, entry.filename, loss.item())
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
weights[0].grad = None
|
weights[0].grad = None
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
@ -359,10 +394,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
mean_loss = losses.mean()
|
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
|
||||||
if torch.isnan(mean_loss):
|
|
||||||
raise RuntimeError("Loss diverged.")
|
raise RuntimeError("Loss diverged.")
|
||||||
pbar.set_description(f"loss: {mean_loss:.7f}")
|
pbar.set_description(f"dataset loss: {previous_mean_loss:.7f}")
|
||||||
|
|
||||||
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
|
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
|
||||||
# Before saving, change name to match current checkpoint.
|
# Before saving, change name to match current checkpoint.
|
||||||
|
@ -371,7 +405,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
hypernetwork.save(last_saved_file)
|
hypernetwork.save(last_saved_file)
|
||||||
|
|
||||||
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
|
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
|
||||||
"loss": f"{mean_loss:.7f}",
|
"loss": f"{previous_mean_loss:.7f}",
|
||||||
"learn_rate": scheduler.learn_rate
|
"learn_rate": scheduler.learn_rate
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -420,14 +454,15 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
|
|
||||||
shared.state.textinfo = f"""
|
shared.state.textinfo = f"""
|
||||||
<p>
|
<p>
|
||||||
Loss: {mean_loss:.7f}<br/>
|
Loss: {previous_mean_loss:.7f}<br/>
|
||||||
Step: {hypernetwork.step}<br/>
|
Step: {hypernetwork.step}<br/>
|
||||||
Last prompt: {html.escape(entries[0].cond_text)}<br/>
|
Last prompt: {html.escape(entries[0].cond_text)}<br/>
|
||||||
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
||||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
</p>
|
</p>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
report_statistics(loss_dict)
|
||||||
checkpoint = sd_models.select_checkpoint()
|
checkpoint = sd_models.select_checkpoint()
|
||||||
|
|
||||||
hypernetwork.sd_checkpoint = checkpoint.hash
|
hypernetwork.sd_checkpoint = checkpoint.hash
|
||||||
|
@ -438,5 +473,3 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
hypernetwork.save(filename)
|
hypernetwork.save(filename)
|
||||||
|
|
||||||
return hypernetwork, filename
|
return hypernetwork, filename
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue