add options to custom hypernetwork layer structure

This commit is contained in:
discus0434 2022-10-19 00:51:36 +09:00
parent c1093b8051
commit 6021f7a75f
4 changed files with 76 additions and 25 deletions

1
.gitignore vendored
View file

@ -27,3 +27,4 @@ __pycache__
notification.mp3
/SwinIR
/textual_inversion
/hypernetwork

View file

@ -1,52 +1,98 @@
import csv
import datetime
import glob
import html
import os
import sys
import traceback
import tqdm
import csv
import torch
from ldm.util import default
from modules import devices, shared, processing, sd_models
import torch
from torch import einsum
from einops import rearrange, repeat
import modules.textual_inversion.dataset
import torch
import tqdm
from einops import rearrange, repeat
from ldm.util import default
from modules import devices, processing, sd_models, shared
from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
def parse_layer_structure(dim, state_dict):
i = 0
res = [1]
while (key := "linear.{}.weight".format(i)) in state_dict:
weight = state_dict[key]
res.append(len(weight) // dim)
i += 1
return res
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
layer_structure = None
add_layer_norm = False
def __init__(self, dim, state_dict=None):
super().__init__()
if (state_dict is None or 'linear.0.weight' not in state_dict) and self.layer_structure is None:
layer_structure = (1, 2, 1)
else:
if self.layer_structure is not None:
assert self.layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert self.layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
layer_structure = self.layer_structure
else:
layer_structure = parse_layer_structure(dim, state_dict)
self.linear1 = torch.nn.Linear(dim, dim * 2)
self.linear2 = torch.nn.Linear(dim * 2, dim)
linears = []
for i in range(len(layer_structure) - 1):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
if self.add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
self.linear = torch.nn.Sequential(*linears)
if state_dict is not None:
self.load_state_dict(state_dict, strict=True)
try:
self.load_state_dict(state_dict)
except RuntimeError:
self.try_load_previous(state_dict)
else:
self.linear1.weight.data.normal_(mean=0.0, std=0.01)
self.linear1.bias.data.zero_()
self.linear2.weight.data.normal_(mean=0.0, std=0.01)
self.linear2.bias.data.zero_()
for layer in self.linear:
layer.weight.data.normal_(mean = 0.0, std = 0.01)
layer.bias.data.zero_()
self.to(devices.device)
def try_load_previous(self, state_dict):
states = self.state_dict()
states['linear.0.bias'].copy_(state_dict['linear1.bias'])
states['linear.0.weight'].copy_(state_dict['linear1.weight'])
states['linear.1.bias'].copy_(state_dict['linear2.bias'])
states['linear.1.weight'].copy_(state_dict['linear2.weight'])
def forward(self, x):
return x + (self.linear2(self.linear1(x))) * self.multiplier
return x + self.linear(x) * self.multiplier
def trainables(self):
res = []
for layer in self.linear:
res += [layer.weight, layer.bias]
return res
def apply_strength(value=None):
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
def apply_layer_structure(value=None):
HypernetworkModule.layer_structure = value if value is not None else shared.opts.sd_hypernetwork_layer_structure
def apply_layer_norm(value=None):
HypernetworkModule.add_layer_norm = value if value is not None else shared.opts.sd_hypernetwork_add_layer_norm
class Hypernetwork:
filename = None
name = None
@ -68,7 +114,7 @@ class Hypernetwork:
for k, layers in self.layers.items():
for layer in layers:
layer.train()
res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias]
res += layer.trainables()
return res
@ -226,7 +272,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
assert ds.length > 1, "Dataset should contain more than 1 images"
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
@ -261,7 +307,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
with torch.autocast("cuda"):
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
loss = shared.sd_model(x, c)[0]
del x
@ -283,7 +329,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
"loss": f"{mean_loss:.7f}",
"learn_rate": scheduler.learn_rate
"learn_rate": f"{scheduler.learn_rate:.7f}"
})
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:

View file

@ -13,7 +13,7 @@ import modules.memmon
import modules.sd_models
import modules.styles
import modules.devices as devices
from modules import sd_samplers, sd_models, localization
from modules import sd_models, sd_samplers, localization
from modules.hypernetworks import hypernetwork
from modules.paths import models_path, script_path, sd_path
@ -258,6 +258,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_layer_structure": OptionInfo(None, "Hypernetwork layer structure Default: (1,2,1).", gr.Dropdown, lambda: {"choices": [(1, 2, 1), (1, 2, 2, 1), (1, 2, 4, 2, 1)]}),
"sd_hypernetwork_add_layer_norm": OptionInfo(False, "Add layer normalization to hypernetwork architecture."),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),

View file

@ -86,11 +86,13 @@ def initialize():
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
shared.opts.onchange("sd_hypernetwork_layer_structure", modules.hypernetworks.hypernetwork.apply_layer_structure)
shared.opts.onchange("sd_hypernetwork_add_layer_norm", modules.hypernetworks.hypernetwork.apply_layer_norm)
def webui():
initialize()
# make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}')
@ -101,7 +103,7 @@ def webui():
while 1:
demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
app, local_url, share_url = demo.launch(
share=cmd_opts.share,
server_name="0.0.0.0" if cmd_opts.listen else None,