add options to custom hypernetwork layer structure
This commit is contained in:
parent
c1093b8051
commit
6021f7a75f
4 changed files with 76 additions and 25 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -27,3 +27,4 @@ __pycache__
|
||||||
notification.mp3
|
notification.mp3
|
||||||
/SwinIR
|
/SwinIR
|
||||||
/textual_inversion
|
/textual_inversion
|
||||||
|
/hypernetwork
|
||||||
|
|
|
@ -1,52 +1,98 @@
|
||||||
|
import csv
|
||||||
import datetime
|
import datetime
|
||||||
import glob
|
import glob
|
||||||
import html
|
import html
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import tqdm
|
|
||||||
import csv
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from ldm.util import default
|
|
||||||
from modules import devices, shared, processing, sd_models
|
|
||||||
import torch
|
|
||||||
from torch import einsum
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
import modules.textual_inversion.dataset
|
import modules.textual_inversion.dataset
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from ldm.util import default
|
||||||
|
from modules import devices, processing, sd_models, shared
|
||||||
from modules.textual_inversion import textual_inversion
|
from modules.textual_inversion import textual_inversion
|
||||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
|
from torch import einsum
|
||||||
|
|
||||||
|
|
||||||
|
def parse_layer_structure(dim, state_dict):
|
||||||
|
i = 0
|
||||||
|
res = [1]
|
||||||
|
while (key := "linear.{}.weight".format(i)) in state_dict:
|
||||||
|
weight = state_dict[key]
|
||||||
|
res.append(len(weight) // dim)
|
||||||
|
i += 1
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
class HypernetworkModule(torch.nn.Module):
|
class HypernetworkModule(torch.nn.Module):
|
||||||
multiplier = 1.0
|
multiplier = 1.0
|
||||||
|
layer_structure = None
|
||||||
|
add_layer_norm = False
|
||||||
|
|
||||||
def __init__(self, dim, state_dict=None):
|
def __init__(self, dim, state_dict=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if (state_dict is None or 'linear.0.weight' not in state_dict) and self.layer_structure is None:
|
||||||
|
layer_structure = (1, 2, 1)
|
||||||
|
else:
|
||||||
|
if self.layer_structure is not None:
|
||||||
|
assert self.layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
||||||
|
assert self.layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
||||||
|
layer_structure = self.layer_structure
|
||||||
|
else:
|
||||||
|
layer_structure = parse_layer_structure(dim, state_dict)
|
||||||
|
|
||||||
self.linear1 = torch.nn.Linear(dim, dim * 2)
|
linears = []
|
||||||
self.linear2 = torch.nn.Linear(dim * 2, dim)
|
for i in range(len(layer_structure) - 1):
|
||||||
|
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
||||||
|
if self.add_layer_norm:
|
||||||
|
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
||||||
|
|
||||||
|
self.linear = torch.nn.Sequential(*linears)
|
||||||
|
|
||||||
if state_dict is not None:
|
if state_dict is not None:
|
||||||
self.load_state_dict(state_dict, strict=True)
|
try:
|
||||||
|
self.load_state_dict(state_dict)
|
||||||
|
except RuntimeError:
|
||||||
|
self.try_load_previous(state_dict)
|
||||||
else:
|
else:
|
||||||
|
for layer in self.linear:
|
||||||
self.linear1.weight.data.normal_(mean=0.0, std=0.01)
|
layer.weight.data.normal_(mean = 0.0, std = 0.01)
|
||||||
self.linear1.bias.data.zero_()
|
layer.bias.data.zero_()
|
||||||
self.linear2.weight.data.normal_(mean=0.0, std=0.01)
|
|
||||||
self.linear2.bias.data.zero_()
|
|
||||||
|
|
||||||
self.to(devices.device)
|
self.to(devices.device)
|
||||||
|
|
||||||
|
def try_load_previous(self, state_dict):
|
||||||
|
states = self.state_dict()
|
||||||
|
states['linear.0.bias'].copy_(state_dict['linear1.bias'])
|
||||||
|
states['linear.0.weight'].copy_(state_dict['linear1.weight'])
|
||||||
|
states['linear.1.bias'].copy_(state_dict['linear2.bias'])
|
||||||
|
states['linear.1.weight'].copy_(state_dict['linear2.weight'])
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x + (self.linear2(self.linear1(x))) * self.multiplier
|
return x + self.linear(x) * self.multiplier
|
||||||
|
|
||||||
|
def trainables(self):
|
||||||
|
res = []
|
||||||
|
for layer in self.linear:
|
||||||
|
res += [layer.weight, layer.bias]
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
def apply_strength(value=None):
|
def apply_strength(value=None):
|
||||||
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
|
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
|
||||||
|
|
||||||
|
|
||||||
|
def apply_layer_structure(value=None):
|
||||||
|
HypernetworkModule.layer_structure = value if value is not None else shared.opts.sd_hypernetwork_layer_structure
|
||||||
|
|
||||||
|
|
||||||
|
def apply_layer_norm(value=None):
|
||||||
|
HypernetworkModule.add_layer_norm = value if value is not None else shared.opts.sd_hypernetwork_add_layer_norm
|
||||||
|
|
||||||
|
|
||||||
class Hypernetwork:
|
class Hypernetwork:
|
||||||
filename = None
|
filename = None
|
||||||
name = None
|
name = None
|
||||||
|
@ -68,7 +114,7 @@ class Hypernetwork:
|
||||||
for k, layers in self.layers.items():
|
for k, layers in self.layers.items():
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
layer.train()
|
layer.train()
|
||||||
res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias]
|
res += layer.trainables()
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@ -226,7 +272,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
|
||||||
|
assert ds.length > 1, "Dataset should contain more than 1 images"
|
||||||
if unload:
|
if unload:
|
||||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
shared.sd_model.first_stage_model.to(devices.cpu)
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
||||||
|
@ -261,7 +307,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
|
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
|
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
|
||||||
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
|
c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
|
||||||
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
|
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
|
||||||
loss = shared.sd_model(x, c)[0]
|
loss = shared.sd_model(x, c)[0]
|
||||||
del x
|
del x
|
||||||
|
@ -283,7 +329,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
|
|
||||||
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
|
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
|
||||||
"loss": f"{mean_loss:.7f}",
|
"loss": f"{mean_loss:.7f}",
|
||||||
"learn_rate": scheduler.learn_rate
|
"learn_rate": f"{scheduler.learn_rate:.7f}"
|
||||||
})
|
})
|
||||||
|
|
||||||
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
|
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
|
||||||
|
|
|
@ -13,7 +13,7 @@ import modules.memmon
|
||||||
import modules.sd_models
|
import modules.sd_models
|
||||||
import modules.styles
|
import modules.styles
|
||||||
import modules.devices as devices
|
import modules.devices as devices
|
||||||
from modules import sd_samplers, sd_models, localization
|
from modules import sd_models, sd_samplers, localization
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
from modules.paths import models_path, script_path, sd_path
|
from modules.paths import models_path, script_path, sd_path
|
||||||
|
|
||||||
|
@ -258,6 +258,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
|
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
|
||||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||||
|
"sd_hypernetwork_layer_structure": OptionInfo(None, "Hypernetwork layer structure Default: (1,2,1).", gr.Dropdown, lambda: {"choices": [(1, 2, 1), (1, 2, 2, 1), (1, 2, 4, 2, 1)]}),
|
||||||
|
"sd_hypernetwork_add_layer_norm": OptionInfo(False, "Add layer normalization to hypernetwork architecture."),
|
||||||
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
|
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
|
||||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||||
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
||||||
|
|
6
webui.py
6
webui.py
|
@ -86,11 +86,13 @@ def initialize():
|
||||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
|
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
|
||||||
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
|
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
|
||||||
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
|
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
|
||||||
|
shared.opts.onchange("sd_hypernetwork_layer_structure", modules.hypernetworks.hypernetwork.apply_layer_structure)
|
||||||
|
shared.opts.onchange("sd_hypernetwork_add_layer_norm", modules.hypernetworks.hypernetwork.apply_layer_norm)
|
||||||
|
|
||||||
|
|
||||||
def webui():
|
def webui():
|
||||||
initialize()
|
initialize()
|
||||||
|
|
||||||
# make the program just exit at ctrl+c without waiting for anything
|
# make the program just exit at ctrl+c without waiting for anything
|
||||||
def sigint_handler(sig, frame):
|
def sigint_handler(sig, frame):
|
||||||
print(f'Interrupted with signal {sig} in {frame}')
|
print(f'Interrupted with signal {sig} in {frame}')
|
||||||
|
@ -101,7 +103,7 @@ def webui():
|
||||||
while 1:
|
while 1:
|
||||||
|
|
||||||
demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
|
demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
|
||||||
|
|
||||||
app, local_url, share_url = demo.launch(
|
app, local_url, share_url = demo.launch(
|
||||||
share=cmd_opts.share,
|
share=cmd_opts.share,
|
||||||
server_name="0.0.0.0" if cmd_opts.listen else None,
|
server_name="0.0.0.0" if cmd_opts.listen else None,
|
||||||
|
|
Loading…
Reference in a new issue