hypernetwork training mk1
This commit is contained in:
parent
f7c787eb7c
commit
12c4d5c6b5
12 changed files with 414 additions and 107 deletions
|
@ -1,88 +0,0 @@
|
||||||
import glob
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from ldm.util import default
|
|
||||||
from modules import devices, shared
|
|
||||||
import torch
|
|
||||||
from torch import einsum
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
|
|
||||||
|
|
||||||
class HypernetworkModule(torch.nn.Module):
|
|
||||||
def __init__(self, dim, state_dict):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.linear1 = torch.nn.Linear(dim, dim * 2)
|
|
||||||
self.linear2 = torch.nn.Linear(dim * 2, dim)
|
|
||||||
|
|
||||||
self.load_state_dict(state_dict, strict=True)
|
|
||||||
self.to(devices.device)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return x + (self.linear2(self.linear1(x)))
|
|
||||||
|
|
||||||
|
|
||||||
class Hypernetwork:
|
|
||||||
filename = None
|
|
||||||
name = None
|
|
||||||
|
|
||||||
def __init__(self, filename):
|
|
||||||
self.filename = filename
|
|
||||||
self.name = os.path.splitext(os.path.basename(filename))[0]
|
|
||||||
self.layers = {}
|
|
||||||
|
|
||||||
state_dict = torch.load(filename, map_location='cpu')
|
|
||||||
for size, sd in state_dict.items():
|
|
||||||
self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
|
|
||||||
|
|
||||||
|
|
||||||
def load_hypernetworks(path):
|
|
||||||
res = {}
|
|
||||||
|
|
||||||
for filename in glob.iglob(path + '**/*.pt', recursive=True):
|
|
||||||
try:
|
|
||||||
hn = Hypernetwork(filename)
|
|
||||||
res[hn.name] = hn
|
|
||||||
except Exception:
|
|
||||||
print(f"Error loading hypernetwork {filename}", file=sys.stderr)
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
|
||||||
h = self.heads
|
|
||||||
|
|
||||||
q = self.to_q(x)
|
|
||||||
context = default(context, x)
|
|
||||||
|
|
||||||
hypernetwork = shared.selected_hypernetwork()
|
|
||||||
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
|
|
||||||
|
|
||||||
if hypernetwork_layers is not None:
|
|
||||||
k = self.to_k(hypernetwork_layers[0](context))
|
|
||||||
v = self.to_v(hypernetwork_layers[1](context))
|
|
||||||
else:
|
|
||||||
k = self.to_k(context)
|
|
||||||
v = self.to_v(context)
|
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
|
||||||
|
|
||||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
|
||||||
|
|
||||||
if mask is not None:
|
|
||||||
mask = rearrange(mask, 'b ... -> b (...)')
|
|
||||||
max_neg_value = -torch.finfo(sim.dtype).max
|
|
||||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
|
||||||
sim.masked_fill_(~mask, max_neg_value)
|
|
||||||
|
|
||||||
# attention, what we cannot get enough of
|
|
||||||
attn = sim.softmax(dim=-1)
|
|
||||||
|
|
||||||
out = einsum('b i j, b j d -> b i d', attn, v)
|
|
||||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
|
||||||
return self.to_out(out)
|
|
267
modules/hypernetwork/hypernetwork.py
Normal file
267
modules/hypernetwork/hypernetwork.py
Normal file
|
@ -0,0 +1,267 @@
|
||||||
|
import datetime
|
||||||
|
import glob
|
||||||
|
import html
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ldm.util import default
|
||||||
|
from modules import devices, shared, processing, sd_models
|
||||||
|
import torch
|
||||||
|
from torch import einsum
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
import modules.textual_inversion.dataset
|
||||||
|
|
||||||
|
|
||||||
|
class HypernetworkModule(torch.nn.Module):
|
||||||
|
def __init__(self, dim, state_dict=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.linear1 = torch.nn.Linear(dim, dim * 2)
|
||||||
|
self.linear2 = torch.nn.Linear(dim * 2, dim)
|
||||||
|
|
||||||
|
if state_dict is not None:
|
||||||
|
self.load_state_dict(state_dict, strict=True)
|
||||||
|
else:
|
||||||
|
self.linear1.weight.data.fill_(0.0001)
|
||||||
|
self.linear1.bias.data.fill_(0.0001)
|
||||||
|
self.linear2.weight.data.fill_(0.0001)
|
||||||
|
self.linear2.bias.data.fill_(0.0001)
|
||||||
|
|
||||||
|
self.to(devices.device)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x + (self.linear2(self.linear1(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class Hypernetwork:
|
||||||
|
filename = None
|
||||||
|
name = None
|
||||||
|
|
||||||
|
def __init__(self, name=None):
|
||||||
|
self.filename = None
|
||||||
|
self.name = name
|
||||||
|
self.layers = {}
|
||||||
|
self.step = 0
|
||||||
|
self.sd_checkpoint = None
|
||||||
|
self.sd_checkpoint_name = None
|
||||||
|
|
||||||
|
for size in [320, 640, 768, 1280]:
|
||||||
|
self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))
|
||||||
|
|
||||||
|
def weights(self):
|
||||||
|
res = []
|
||||||
|
|
||||||
|
for k, layers in self.layers.items():
|
||||||
|
for layer in layers:
|
||||||
|
layer.train()
|
||||||
|
res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias]
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
def save(self, filename):
|
||||||
|
state_dict = {}
|
||||||
|
|
||||||
|
for k, v in self.layers.items():
|
||||||
|
state_dict[k] = (v[0].state_dict(), v[1].state_dict())
|
||||||
|
|
||||||
|
state_dict['step'] = self.step
|
||||||
|
state_dict['name'] = self.name
|
||||||
|
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
||||||
|
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
||||||
|
|
||||||
|
torch.save(state_dict, filename)
|
||||||
|
|
||||||
|
def load(self, filename):
|
||||||
|
self.filename = filename
|
||||||
|
if self.name is None:
|
||||||
|
self.name = os.path.splitext(os.path.basename(filename))[0]
|
||||||
|
|
||||||
|
state_dict = torch.load(filename, map_location='cpu')
|
||||||
|
|
||||||
|
for size, sd in state_dict.items():
|
||||||
|
if type(size) == int:
|
||||||
|
self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
|
||||||
|
|
||||||
|
self.name = state_dict.get('name', self.name)
|
||||||
|
self.step = state_dict.get('step', 0)
|
||||||
|
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
|
||||||
|
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
|
||||||
|
|
||||||
|
|
||||||
|
def load_hypernetworks(path):
|
||||||
|
res = {}
|
||||||
|
|
||||||
|
for filename in glob.iglob(path + '**/*.pt', recursive=True):
|
||||||
|
try:
|
||||||
|
hn = Hypernetwork()
|
||||||
|
hn.load(filename)
|
||||||
|
res[hn.name] = hn
|
||||||
|
except Exception:
|
||||||
|
print(f"Error loading hypernetwork {filename}", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
||||||
|
h = self.heads
|
||||||
|
|
||||||
|
q = self.to_q(x)
|
||||||
|
context = default(context, x)
|
||||||
|
|
||||||
|
hypernetwork_layers = (shared.hypernetwork.layers if shared.hypernetwork is not None else {}).get(context.shape[2], None)
|
||||||
|
|
||||||
|
if hypernetwork_layers is not None:
|
||||||
|
hypernetwork_k, hypernetwork_v = hypernetwork_layers
|
||||||
|
|
||||||
|
self.hypernetwork_k = hypernetwork_k
|
||||||
|
self.hypernetwork_v = hypernetwork_v
|
||||||
|
|
||||||
|
context_k = hypernetwork_k(context)
|
||||||
|
context_v = hypernetwork_v(context)
|
||||||
|
else:
|
||||||
|
context_k = context
|
||||||
|
context_v = context
|
||||||
|
|
||||||
|
k = self.to_k(context_k)
|
||||||
|
v = self.to_v(context_v)
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||||
|
|
||||||
|
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
mask = rearrange(mask, 'b ... -> b (...)')
|
||||||
|
max_neg_value = -torch.finfo(sim.dtype).max
|
||||||
|
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||||
|
sim.masked_fill_(~mask, max_neg_value)
|
||||||
|
|
||||||
|
# attention, what we cannot get enough of
|
||||||
|
attn = sim.softmax(dim=-1)
|
||||||
|
|
||||||
|
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||||
|
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
|
def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt):
|
||||||
|
assert hypernetwork_name, 'embedding not selected'
|
||||||
|
|
||||||
|
shared.hypernetwork = shared.hypernetworks[hypernetwork_name]
|
||||||
|
|
||||||
|
shared.state.textinfo = "Initializing hypernetwork training..."
|
||||||
|
shared.state.job_count = steps
|
||||||
|
|
||||||
|
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||||
|
|
||||||
|
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
|
||||||
|
|
||||||
|
if save_hypernetwork_every > 0:
|
||||||
|
hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
|
||||||
|
os.makedirs(hypernetwork_dir, exist_ok=True)
|
||||||
|
else:
|
||||||
|
hypernetwork_dir = None
|
||||||
|
|
||||||
|
if create_image_every > 0:
|
||||||
|
images_dir = os.path.join(log_directory, "images")
|
||||||
|
os.makedirs(images_dir, exist_ok=True)
|
||||||
|
else:
|
||||||
|
images_dir = None
|
||||||
|
|
||||||
|
cond_model = shared.sd_model.cond_stage_model
|
||||||
|
|
||||||
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
|
with torch.autocast("cuda"):
|
||||||
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file)
|
||||||
|
|
||||||
|
hypernetwork = shared.hypernetworks[hypernetwork_name]
|
||||||
|
weights = hypernetwork.weights()
|
||||||
|
for weight in weights:
|
||||||
|
weight.requires_grad = True
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(weights, lr=learn_rate)
|
||||||
|
|
||||||
|
losses = torch.zeros((32,))
|
||||||
|
|
||||||
|
last_saved_file = "<none>"
|
||||||
|
last_saved_image = "<none>"
|
||||||
|
|
||||||
|
ititial_step = hypernetwork.step or 0
|
||||||
|
if ititial_step > steps:
|
||||||
|
return hypernetwork, filename
|
||||||
|
|
||||||
|
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
||||||
|
for i, (x, text) in pbar:
|
||||||
|
hypernetwork.step = i + ititial_step
|
||||||
|
|
||||||
|
if hypernetwork.step > steps:
|
||||||
|
break
|
||||||
|
|
||||||
|
if shared.state.interrupted:
|
||||||
|
break
|
||||||
|
|
||||||
|
with torch.autocast("cuda"):
|
||||||
|
c = cond_model([text])
|
||||||
|
|
||||||
|
x = x.to(devices.device)
|
||||||
|
loss = shared.sd_model(x.unsqueeze(0), c)[0]
|
||||||
|
del x
|
||||||
|
|
||||||
|
losses[hypernetwork.step % losses.shape[0]] = loss.item()
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
pbar.set_description(f"loss: {losses.mean():.7f}")
|
||||||
|
|
||||||
|
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
|
||||||
|
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
|
||||||
|
hypernetwork.save(last_saved_file)
|
||||||
|
|
||||||
|
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
|
||||||
|
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
|
||||||
|
|
||||||
|
preview_text = text if preview_image_prompt == "" else preview_image_prompt
|
||||||
|
|
||||||
|
p = processing.StableDiffusionProcessingTxt2Img(
|
||||||
|
sd_model=shared.sd_model,
|
||||||
|
prompt=preview_text,
|
||||||
|
steps=20,
|
||||||
|
do_not_save_grid=True,
|
||||||
|
do_not_save_samples=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
processed = processing.process_images(p)
|
||||||
|
image = processed.images[0]
|
||||||
|
|
||||||
|
shared.state.current_image = image
|
||||||
|
image.save(last_saved_image)
|
||||||
|
|
||||||
|
last_saved_image += f", prompt: {preview_text}"
|
||||||
|
|
||||||
|
shared.state.job_no = hypernetwork.step
|
||||||
|
|
||||||
|
shared.state.textinfo = f"""
|
||||||
|
<p>
|
||||||
|
Loss: {losses.mean():.7f}<br/>
|
||||||
|
Step: {hypernetwork.step}<br/>
|
||||||
|
Last prompt: {html.escape(text)}<br/>
|
||||||
|
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
||||||
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
|
</p>
|
||||||
|
"""
|
||||||
|
|
||||||
|
checkpoint = sd_models.select_checkpoint()
|
||||||
|
|
||||||
|
hypernetwork.sd_checkpoint = checkpoint.hash
|
||||||
|
hypernetwork.sd_checkpoint_name = checkpoint.model_name
|
||||||
|
hypernetwork.save(filename)
|
||||||
|
|
||||||
|
return hypernetwork, filename
|
||||||
|
|
||||||
|
|
43
modules/hypernetwork/ui.py
Normal file
43
modules/hypernetwork/ui.py
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
import html
|
||||||
|
import os
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
import modules.textual_inversion.textual_inversion
|
||||||
|
import modules.textual_inversion.preprocess
|
||||||
|
from modules import sd_hijack, shared
|
||||||
|
|
||||||
|
|
||||||
|
def create_hypernetwork(name):
|
||||||
|
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
||||||
|
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||||
|
|
||||||
|
hypernetwork = modules.hypernetwork.hypernetwork.Hypernetwork(name=name)
|
||||||
|
hypernetwork.save(fn)
|
||||||
|
|
||||||
|
shared.reload_hypernetworks()
|
||||||
|
shared.hypernetwork = shared.hypernetworks.get(shared.opts.sd_hypernetwork, None)
|
||||||
|
|
||||||
|
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", ""
|
||||||
|
|
||||||
|
|
||||||
|
def train_hypernetwork(*args):
|
||||||
|
|
||||||
|
initial_hypernetwork = shared.hypernetwork
|
||||||
|
|
||||||
|
try:
|
||||||
|
sd_hijack.undo_optimizations()
|
||||||
|
|
||||||
|
hypernetwork, filename = modules.hypernetwork.hypernetwork.train_hypernetwork(*args)
|
||||||
|
|
||||||
|
res = f"""
|
||||||
|
Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
|
||||||
|
Hypernetwork saved to {html.escape(filename)}
|
||||||
|
"""
|
||||||
|
return res, ""
|
||||||
|
except Exception:
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
shared.hypernetwork = initial_hypernetwork
|
||||||
|
sd_hijack.apply_optimizations()
|
||||||
|
|
|
@ -8,7 +8,7 @@ from torch import einsum
|
||||||
from torch.nn.functional import silu
|
from torch.nn.functional import silu
|
||||||
|
|
||||||
import modules.textual_inversion.textual_inversion
|
import modules.textual_inversion.textual_inversion
|
||||||
from modules import prompt_parser, devices, sd_hijack_optimizations, shared, hypernetwork
|
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
|
||||||
from modules.shared import opts, device, cmd_opts
|
from modules.shared import opts, device, cmd_opts
|
||||||
|
|
||||||
import ldm.modules.attention
|
import ldm.modules.attention
|
||||||
|
@ -32,6 +32,8 @@ def apply_optimizations():
|
||||||
|
|
||||||
|
|
||||||
def undo_optimizations():
|
def undo_optimizations():
|
||||||
|
from modules.hypernetwork import hypernetwork
|
||||||
|
|
||||||
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||||
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
|
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||||
|
|
|
@ -45,8 +45,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||||
q_in = self.to_q(x)
|
q_in = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
|
|
||||||
hypernetwork = shared.selected_hypernetwork()
|
hypernetwork_layers = (shared.hypernetwork.layers if shared.hypernetwork is not None else {}).get(context.shape[2], None)
|
||||||
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
|
|
||||||
|
|
||||||
if hypernetwork_layers is not None:
|
if hypernetwork_layers is not None:
|
||||||
k_in = self.to_k(hypernetwork_layers[0](context))
|
k_in = self.to_k(hypernetwork_layers[0](context))
|
||||||
|
|
|
@ -13,7 +13,7 @@ import modules.memmon
|
||||||
import modules.sd_models
|
import modules.sd_models
|
||||||
import modules.styles
|
import modules.styles
|
||||||
import modules.devices as devices
|
import modules.devices as devices
|
||||||
from modules import sd_samplers, hypernetwork
|
from modules import sd_samplers
|
||||||
from modules.paths import models_path, script_path, sd_path
|
from modules.paths import models_path, script_path, sd_path
|
||||||
|
|
||||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||||
|
@ -28,6 +28,7 @@ parser.add_argument("--no-half", action='store_true', help="do not switch the mo
|
||||||
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
||||||
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
||||||
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
||||||
|
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
|
||||||
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
||||||
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
|
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
|
||||||
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
|
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
|
||||||
|
@ -76,11 +77,15 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
|
||||||
|
|
||||||
config_filename = cmd_opts.ui_settings_file
|
config_filename = cmd_opts.ui_settings_file
|
||||||
|
|
||||||
hypernetworks = hypernetwork.load_hypernetworks(os.path.join(models_path, 'hypernetworks'))
|
|
||||||
|
def reload_hypernetworks():
|
||||||
|
from modules.hypernetwork import hypernetwork
|
||||||
|
hypernetworks.clear()
|
||||||
|
hypernetworks.update(hypernetwork.load_hypernetworks(cmd_opts.hypernetwork_dir))
|
||||||
|
|
||||||
|
|
||||||
def selected_hypernetwork():
|
hypernetworks = {}
|
||||||
return hypernetworks.get(opts.sd_hypernetwork, None)
|
hypernetwork = None
|
||||||
|
|
||||||
|
|
||||||
class State:
|
class State:
|
||||||
|
|
|
@ -22,7 +22,6 @@ def preprocess(*args):
|
||||||
|
|
||||||
|
|
||||||
def train_embedding(*args):
|
def train_embedding(*args):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
sd_hijack.undo_optimizations()
|
sd_hijack.undo_optimizations()
|
||||||
|
|
||||||
|
|
|
@ -37,6 +37,7 @@ import modules.generation_parameters_copypaste
|
||||||
from modules import prompt_parser
|
from modules import prompt_parser
|
||||||
from modules.images import save_image
|
from modules.images import save_image
|
||||||
import modules.textual_inversion.ui
|
import modules.textual_inversion.ui
|
||||||
|
import modules.hypernetwork.ui
|
||||||
|
|
||||||
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
|
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
|
||||||
mimetypes.init()
|
mimetypes.init()
|
||||||
|
@ -965,6 +966,18 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
create_embedding = gr.Button(value="Create", variant='primary')
|
create_embedding = gr.Button(value="Create", variant='primary')
|
||||||
|
|
||||||
|
with gr.Group():
|
||||||
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new hypernetwork</p>")
|
||||||
|
|
||||||
|
new_hypernetwork_name = gr.Textbox(label="Name")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=3):
|
||||||
|
gr.HTML(value="")
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
create_hypernetwork = gr.Button(value="Create", variant='primary')
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Preprocess images</p>")
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>Preprocess images</p>")
|
||||||
|
|
||||||
|
@ -986,6 +999,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 512x512 images</p>")
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 512x512 images</p>")
|
||||||
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||||
|
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
|
||||||
learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
|
learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
|
||||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||||
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
||||||
|
@ -993,15 +1007,12 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
||||||
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
|
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||||
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
|
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||||
|
preview_image_prompt = gr.Textbox(label='Preview prompt', value="")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=2):
|
interrupt_training = gr.Button(value="Interrupt")
|
||||||
gr.HTML(value="")
|
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
|
||||||
|
train_embedding = gr.Button(value="Train Embedding", variant='primary')
|
||||||
with gr.Column():
|
|
||||||
with gr.Row():
|
|
||||||
interrupt_training = gr.Button(value="Interrupt")
|
|
||||||
train_embedding = gr.Button(value="Train", variant='primary')
|
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
progressbar = gr.HTML(elem_id="ti_progressbar")
|
progressbar = gr.HTML(elem_id="ti_progressbar")
|
||||||
|
@ -1027,6 +1038,18 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
create_hypernetwork.click(
|
||||||
|
fn=modules.hypernetwork.ui.create_hypernetwork,
|
||||||
|
inputs=[
|
||||||
|
new_hypernetwork_name,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
train_hypernetwork_name,
|
||||||
|
ti_output,
|
||||||
|
ti_outcome,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
run_preprocess.click(
|
run_preprocess.click(
|
||||||
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
|
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
|
||||||
_js="start_training_textual_inversion",
|
_js="start_training_textual_inversion",
|
||||||
|
@ -1062,12 +1085,33 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
train_hypernetwork.click(
|
||||||
|
fn=wrap_gradio_gpu_call(modules.hypernetwork.ui.train_hypernetwork, extra_outputs=[gr.update()]),
|
||||||
|
_js="start_training_textual_inversion",
|
||||||
|
inputs=[
|
||||||
|
train_hypernetwork_name,
|
||||||
|
learn_rate,
|
||||||
|
dataset_directory,
|
||||||
|
log_directory,
|
||||||
|
steps,
|
||||||
|
create_image_every,
|
||||||
|
save_embedding_every,
|
||||||
|
template_file,
|
||||||
|
preview_image_prompt,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
ti_output,
|
||||||
|
ti_outcome,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
interrupt_training.click(
|
interrupt_training.click(
|
||||||
fn=lambda: shared.state.interrupt(),
|
fn=lambda: shared.state.interrupt(),
|
||||||
inputs=[],
|
inputs=[],
|
||||||
outputs=[],
|
outputs=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_setting_component(key):
|
def create_setting_component(key):
|
||||||
def fun():
|
def fun():
|
||||||
return opts.data[key] if key in opts.data else opts.data_labels[key].default
|
return opts.data[key] if key in opts.data else opts.data_labels[key].default
|
||||||
|
|
|
@ -78,8 +78,7 @@ def apply_checkpoint(p, x, xs):
|
||||||
|
|
||||||
|
|
||||||
def apply_hypernetwork(p, x, xs):
|
def apply_hypernetwork(p, x, xs):
|
||||||
hn = shared.hypernetworks.get(x, None)
|
shared.hypernetwork = shared.hypernetworks.get(x, None)
|
||||||
opts.data["sd_hypernetwork"] = hn.name if hn is not None else 'None'
|
|
||||||
|
|
||||||
|
|
||||||
def format_value_add_label(p, opt, x):
|
def format_value_add_label(p, opt, x):
|
||||||
|
@ -199,7 +198,7 @@ class Script(scripts.Script):
|
||||||
modules.processing.fix_seed(p)
|
modules.processing.fix_seed(p)
|
||||||
p.batch_size = 1
|
p.batch_size = 1
|
||||||
|
|
||||||
initial_hn = opts.sd_hypernetwork
|
initial_hn = shared.hypernetwork
|
||||||
|
|
||||||
def process_axis(opt, vals):
|
def process_axis(opt, vals):
|
||||||
if opt.label == 'Nothing':
|
if opt.label == 'Nothing':
|
||||||
|
@ -308,6 +307,6 @@ class Script(scripts.Script):
|
||||||
# restore checkpoint in case it was changed by axes
|
# restore checkpoint in case it was changed by axes
|
||||||
modules.sd_models.reload_model_weights(shared.sd_model)
|
modules.sd_models.reload_model_weights(shared.sd_model)
|
||||||
|
|
||||||
opts.data["sd_hypernetwork"] = initial_hn
|
shared.hypernetwork = initial_hn
|
||||||
|
|
||||||
return processed
|
return processed
|
||||||
|
|
27
textual_inversion_templates/hypernetwork.txt
Normal file
27
textual_inversion_templates/hypernetwork.txt
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
a photo of a [filewords]
|
||||||
|
a rendering of a [filewords]
|
||||||
|
a cropped photo of the [filewords]
|
||||||
|
the photo of a [filewords]
|
||||||
|
a photo of a clean [filewords]
|
||||||
|
a photo of a dirty [filewords]
|
||||||
|
a dark photo of the [filewords]
|
||||||
|
a photo of my [filewords]
|
||||||
|
a photo of the cool [filewords]
|
||||||
|
a close-up photo of a [filewords]
|
||||||
|
a bright photo of the [filewords]
|
||||||
|
a cropped photo of a [filewords]
|
||||||
|
a photo of the [filewords]
|
||||||
|
a good photo of the [filewords]
|
||||||
|
a photo of one [filewords]
|
||||||
|
a close-up photo of the [filewords]
|
||||||
|
a rendition of the [filewords]
|
||||||
|
a photo of the clean [filewords]
|
||||||
|
a rendition of a [filewords]
|
||||||
|
a photo of a nice [filewords]
|
||||||
|
a good photo of a [filewords]
|
||||||
|
a photo of the nice [filewords]
|
||||||
|
a photo of the small [filewords]
|
||||||
|
a photo of the weird [filewords]
|
||||||
|
a photo of the large [filewords]
|
||||||
|
a photo of a cool [filewords]
|
||||||
|
a photo of a small [filewords]
|
1
textual_inversion_templates/none.txt
Normal file
1
textual_inversion_templates/none.txt
Normal file
|
@ -0,0 +1 @@
|
||||||
|
picture
|
9
webui.py
9
webui.py
|
@ -74,6 +74,15 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
|
||||||
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
|
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
|
||||||
|
|
||||||
|
|
||||||
|
def set_hypernetwork():
|
||||||
|
shared.hypernetwork = shared.hypernetworks.get(shared.opts.sd_hypernetwork, None)
|
||||||
|
|
||||||
|
|
||||||
|
shared.reload_hypernetworks()
|
||||||
|
shared.opts.onchange("sd_hypernetwork", set_hypernetwork)
|
||||||
|
set_hypernetwork()
|
||||||
|
|
||||||
|
|
||||||
modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
|
modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
|
||||||
|
|
||||||
shared.sd_model = modules.sd_models.load_model()
|
shared.sd_model = modules.sd_models.load_model()
|
||||||
|
|
Loading…
Reference in a new issue