add options to custom hypernetwork layer structure

This commit is contained in:
discus0434 2022-10-19 00:51:36 +09:00
parent c1093b8051
commit 6021f7a75f
4 changed files with 76 additions and 25 deletions

1
.gitignore vendored
View file

@ -27,3 +27,4 @@ __pycache__
notification.mp3 notification.mp3
/SwinIR /SwinIR
/textual_inversion /textual_inversion
/hypernetwork

View file

@ -1,52 +1,98 @@
import csv
import datetime import datetime
import glob import glob
import html import html
import os import os
import sys import sys
import traceback import traceback
import tqdm
import csv
import torch
from ldm.util import default
from modules import devices, shared, processing, sd_models
import torch
from torch import einsum
from einops import rearrange, repeat
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
import torch
import tqdm
from einops import rearrange, repeat
from ldm.util import default
from modules import devices, processing, sd_models, shared
from modules.textual_inversion import textual_inversion from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
def parse_layer_structure(dim, state_dict):
i = 0
res = [1]
while (key := "linear.{}.weight".format(i)) in state_dict:
weight = state_dict[key]
res.append(len(weight) // dim)
i += 1
return res
class HypernetworkModule(torch.nn.Module): class HypernetworkModule(torch.nn.Module):
multiplier = 1.0 multiplier = 1.0
layer_structure = None
add_layer_norm = False
def __init__(self, dim, state_dict=None): def __init__(self, dim, state_dict=None):
super().__init__() super().__init__()
if (state_dict is None or 'linear.0.weight' not in state_dict) and self.layer_structure is None:
layer_structure = (1, 2, 1)
else:
if self.layer_structure is not None:
assert self.layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert self.layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
layer_structure = self.layer_structure
else:
layer_structure = parse_layer_structure(dim, state_dict)
self.linear1 = torch.nn.Linear(dim, dim * 2) linears = []
self.linear2 = torch.nn.Linear(dim * 2, dim) for i in range(len(layer_structure) - 1):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
if self.add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
self.linear = torch.nn.Sequential(*linears)
if state_dict is not None: if state_dict is not None:
self.load_state_dict(state_dict, strict=True) try:
self.load_state_dict(state_dict)
except RuntimeError:
self.try_load_previous(state_dict)
else: else:
for layer in self.linear:
self.linear1.weight.data.normal_(mean=0.0, std=0.01) layer.weight.data.normal_(mean = 0.0, std = 0.01)
self.linear1.bias.data.zero_() layer.bias.data.zero_()
self.linear2.weight.data.normal_(mean=0.0, std=0.01)
self.linear2.bias.data.zero_()
self.to(devices.device) self.to(devices.device)
def try_load_previous(self, state_dict):
states = self.state_dict()
states['linear.0.bias'].copy_(state_dict['linear1.bias'])
states['linear.0.weight'].copy_(state_dict['linear1.weight'])
states['linear.1.bias'].copy_(state_dict['linear2.bias'])
states['linear.1.weight'].copy_(state_dict['linear2.weight'])
def forward(self, x): def forward(self, x):
return x + (self.linear2(self.linear1(x))) * self.multiplier return x + self.linear(x) * self.multiplier
def trainables(self):
res = []
for layer in self.linear:
res += [layer.weight, layer.bias]
return res
def apply_strength(value=None): def apply_strength(value=None):
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
def apply_layer_structure(value=None):
HypernetworkModule.layer_structure = value if value is not None else shared.opts.sd_hypernetwork_layer_structure
def apply_layer_norm(value=None):
HypernetworkModule.add_layer_norm = value if value is not None else shared.opts.sd_hypernetwork_add_layer_norm
class Hypernetwork: class Hypernetwork:
filename = None filename = None
name = None name = None
@ -68,7 +114,7 @@ class Hypernetwork:
for k, layers in self.layers.items(): for k, layers in self.layers.items():
for layer in layers: for layer in layers:
layer.train() layer.train()
res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias] res += layer.trainables()
return res return res
@ -226,7 +272,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"): with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
assert ds.length > 1, "Dataset should contain more than 1 images"
if unload: if unload:
shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu)
@ -261,7 +307,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
with torch.autocast("cuda"): with torch.autocast("cuda"):
c = stack_conds([entry.cond for entry in entries]).to(devices.device) c = stack_conds([entry.cond for entry in entries]).to(devices.device)
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device) c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
x = torch.stack([entry.latent for entry in entries]).to(devices.device) x = torch.stack([entry.latent for entry in entries]).to(devices.device)
loss = shared.sd_model(x, c)[0] loss = shared.sd_model(x, c)[0]
del x del x
@ -283,7 +329,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
"loss": f"{mean_loss:.7f}", "loss": f"{mean_loss:.7f}",
"learn_rate": scheduler.learn_rate "learn_rate": f"{scheduler.learn_rate:.7f}"
}) })
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:

View file

@ -13,7 +13,7 @@ import modules.memmon
import modules.sd_models import modules.sd_models
import modules.styles import modules.styles
import modules.devices as devices import modules.devices as devices
from modules import sd_samplers, sd_models, localization from modules import sd_models, sd_samplers, localization
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from modules.paths import models_path, script_path, sd_path from modules.paths import models_path, script_path, sd_path
@ -258,6 +258,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_layer_structure": OptionInfo(None, "Hypernetwork layer structure Default: (1,2,1).", gr.Dropdown, lambda: {"choices": [(1, 2, 1), (1, 2, 2, 1), (1, 2, 4, 2, 1)]}),
"sd_hypernetwork_add_layer_norm": OptionInfo(False, "Add layer normalization to hypernetwork architecture."),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),

View file

@ -86,11 +86,13 @@ def initialize():
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
shared.opts.onchange("sd_hypernetwork_layer_structure", modules.hypernetworks.hypernetwork.apply_layer_structure)
shared.opts.onchange("sd_hypernetwork_add_layer_norm", modules.hypernetworks.hypernetwork.apply_layer_norm)
def webui(): def webui():
initialize() initialize()
# make the program just exit at ctrl+c without waiting for anything # make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame): def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}') print(f'Interrupted with signal {sig} in {frame}')
@ -101,7 +103,7 @@ def webui():
while 1: while 1:
demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
app, local_url, share_url = demo.launch( app, local_url, share_url = demo.launch(
share=cmd_opts.share, share=cmd_opts.share,
server_name="0.0.0.0" if cmd_opts.listen else None, server_name="0.0.0.0" if cmd_opts.listen else None,