initial support for training textual inversion

This commit is contained in:
AUTOMATIC 2022-10-02 15:03:39 +03:00
parent 84e97a98c5
commit 820f1dc96b
19 changed files with 828 additions and 315 deletions

1
.gitignore vendored
View file

@ -25,3 +25,4 @@ __pycache__
/.idea /.idea
notification.mp3 notification.mp3
/SwinIR /SwinIR
/textual_inversion

View file

@ -30,6 +30,7 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_inte
onUiUpdate(function(){ onUiUpdate(function(){
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery') check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery') check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', 'ti_interrupt', 'ti_preview', 'ti_gallery')
}) })
function requestMoreProgress(id_part, id_progressbar_span, id_interrupt){ function requestMoreProgress(id_part, id_progressbar_span, id_interrupt){

View file

@ -0,0 +1,8 @@
function start_training_textual_inversion(){
requestProgress('ti')
gradioApp().querySelector('#ti_error').innerHTML=''
return args_to_array(arguments)
}

View file

@ -32,10 +32,9 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32") errors.run(enable_tf32, "Enabling TF32")
device = get_optimal_device() device = get_optimal_device()
device_codeformer = cpu if has_mps else device device_codeformer = cpu if has_mps else device
dtype = torch.float16
def randn(seed, shape): def randn(seed, shape):
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.

View file

@ -56,7 +56,7 @@ class StableDiffusionProcessing:
self.prompt: str = prompt self.prompt: str = prompt
self.prompt_for_display: str = None self.prompt_for_display: str = None
self.negative_prompt: str = (negative_prompt or "") self.negative_prompt: str = (negative_prompt or "")
self.styles: str = styles self.styles: list = styles or []
self.seed: int = seed self.seed: int = seed
self.subseed: int = subseed self.subseed: int = subseed
self.subseed_strength: float = subseed_strength self.subseed_strength: float = subseed_strength
@ -271,7 +271,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None), "Denoising strength": getattr(p, 'denoising_strength', None),
"Eta": (None if p.sampler.eta == p.sampler.default_eta else p.sampler.eta), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
} }
generation_params.update(p.extra_generation_params) generation_params.update(p.extra_generation_params)
@ -295,7 +295,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
fix_seed(p) fix_seed(p)
if p.outpath_samples is not None:
os.makedirs(p.outpath_samples, exist_ok=True) os.makedirs(p.outpath_samples, exist_ok=True)
if p.outpath_grids is not None:
os.makedirs(p.outpath_grids, exist_ok=True) os.makedirs(p.outpath_grids, exist_ok=True)
modules.sd_hijack.model_hijack.apply_circular(p.tiling) modules.sd_hijack.model_hijack.apply_circular(p.tiling)
@ -323,7 +326,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch) return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
if os.path.exists(cmd_opts.embeddings_dir): if os.path.exists(cmd_opts.embeddings_dir):
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model) model_hijack.embedding_db.load_textual_inversion_embeddings()
infotexts = [] infotexts = []
output_images = [] output_images = []

View file

@ -6,244 +6,41 @@ import torch
import numpy as np import numpy as np
from torch import einsum from torch import einsum
from modules import prompt_parser import modules.textual_inversion.textual_inversion
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
from modules.shared import opts, device, cmd_opts from modules.shared import opts, device, cmd_opts
from ldm.util import default
from einops import rearrange
import ldm.modules.attention import ldm.modules.attention
import ldm.modules.diffusionmodules.model import ldm.modules.diffusionmodules.model
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
def split_cross_attention_forward_v1(self, x, context=None, mask=None): diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
for i in range(0, q.shape[0], 2):
end = i + 2
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
s1 *= self.scale
s2 = s1.softmax(dim=-1)
del s1
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
# taken from https://github.com/Doggettx/stable-diffusion def apply_optimizations():
def split_cross_attention_forward(self, x, context=None, mask=None): if cmd_opts.opt_split_attention_v1:
h = self.heads ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
q_in = self.to_q(x)
context = default(context, x)
k_in = self.to_k(context) * self.scale
v_in = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) def undo_optimizations():
del q_in, k_in, v_in ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
def nonlinearity_hijack(x):
# swish
t = torch.sigmoid(x)
x *= t
del t
return x
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
q1 = self.q(h_)
k1 = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q1.shape
q2 = q1.reshape(b, c, h*w)
del q1
q = q2.permute(0, 2, 1) # b,hw,c
del q2
k = k1.reshape(b, c, h*w) # b,c,hw
del k1
h_ = torch.zeros_like(k, device=q.device)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
steps = 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w2 = w1 * (int(c)**(-0.5))
del w1
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
del w2
# attend to values
v1 = v.reshape(b, c, h*w)
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
del w3
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
del v1, w4
h2 = h_.reshape(b, c, h, w)
del h_
h3 = self.proj_out(h2)
del h2
h3 += x
return h3
class StableDiffusionModelHijack: class StableDiffusionModelHijack:
ids_lookup = {}
word_embeddings = {}
word_embeddings_checksums = {}
fixes = None fixes = None
comments = [] comments = []
dir_mtime = None
layers = None layers = None
circular_enabled = False circular_enabled = False
clip = None clip = None
def load_textual_inversion_embeddings(self, dirname, model): embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
mt = os.path.getmtime(dirname)
if self.dir_mtime is not None and mt <= self.dir_mtime:
return
self.dir_mtime = mt
self.ids_lookup.clear()
self.word_embeddings.clear()
tokenizer = model.cond_stage_model.tokenizer
def const_hash(a):
r = 0
for v in a:
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
return r
def process_file(path, filename):
name = os.path.splitext(filename)[0]
data = torch.load(path, map_location="cpu")
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
if hasattr(param_dict, '_parameters'):
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
# diffuser concepts
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
self.word_embeddings[name] = emb.detach().to(device)
self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1)*100)&0xffff:04x}'
ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
self.ids_lookup[first_id].append((ids, name))
for fn in os.listdir(dirname):
try:
fullfn = os.path.join(dirname, fn)
if os.stat(fullfn).st_size == 0:
continue
process_file(fullfn, fn)
except Exception:
print(f"Error loading emedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
continue
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
def hijack(self, m): def hijack(self, m):
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
@ -253,12 +50,7 @@ class StableDiffusionModelHijack:
self.clip = m.cond_stage_model self.clip = m.cond_stage_model
if cmd_opts.opt_split_attention_v1: apply_optimizations()
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
def flatten(el): def flatten(el):
flattened = [flatten(children) for children in el.children()] flattened = [flatten(children) for children in el.children()]
@ -296,7 +88,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack): def __init__(self, wrapped, hijack):
super().__init__() super().__init__()
self.wrapped = wrapped self.wrapped = wrapped
self.hijack = hijack self.hijack: StableDiffusionModelHijack = hijack
self.tokenizer = wrapped.tokenizer self.tokenizer = wrapped.tokenizer
self.max_length = wrapped.max_length self.max_length = wrapped.max_length
self.token_mults = {} self.token_mults = {}
@ -317,7 +109,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if mult != 1.0: if mult != 1.0:
self.token_mults[ident] = mult self.token_mults[ident] = mult
def tokenize_line(self, line, used_custom_terms, hijack_comments): def tokenize_line(self, line, used_custom_terms, hijack_comments):
id_start = self.wrapped.tokenizer.bos_token_id id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id id_end = self.wrapped.tokenizer.eos_token_id
@ -339,28 +130,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens): while i < len(tokens):
token = tokens[i] token = tokens[i]
possible_matches = self.hijack.ids_lookup.get(token, None) embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
if possible_matches is None: if embedding is None:
remade_tokens.append(token)
multipliers.append(weight)
else:
found = False
for ids, word in possible_matches:
if tokens[i:i + len(ids)] == ids:
emb_len = int(self.hijack.word_embeddings[word].shape[0])
fixes.append((len(remade_tokens), word))
remade_tokens += [0] * emb_len
multipliers += [weight] * emb_len
i += len(ids) - 1
found = True
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
break
if not found:
remade_tokens.append(token) remade_tokens.append(token)
multipliers.append(weight) multipliers.append(weight)
i += 1 i += 1
else:
emb_len = int(embedding.vec.shape[0])
fixes.append((len(remade_tokens), embedding))
remade_tokens += [0] * emb_len
multipliers += [weight] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum()))
i += emb_len
if len(remade_tokens) > maxlen - 2: if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
@ -431,32 +213,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens): while i < len(tokens):
token = tokens[i] token = tokens[i]
possible_matches = self.hijack.ids_lookup.get(token, None) embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
if mult_change is not None: if mult_change is not None:
mult *= mult_change mult *= mult_change
elif possible_matches is None: i += 1
elif embedding is None:
remade_tokens.append(token) remade_tokens.append(token)
multipliers.append(mult) multipliers.append(mult)
i += 1
else: else:
found = False emb_len = int(embedding.vec.shape[0])
for ids, word in possible_matches: fixes.append((len(remade_tokens), embedding))
if tokens[i:i+len(ids)] == ids:
emb_len = int(self.hijack.word_embeddings[word].shape[0])
fixes.append((len(remade_tokens), word))
remade_tokens += [0] * emb_len remade_tokens += [0] * emb_len
multipliers += [mult] * emb_len multipliers += [mult] * emb_len
i += len(ids) - 1 used_custom_terms.append((embedding.name, embedding.checksum()))
found = True i += emb_len
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
break
if not found:
remade_tokens.append(token)
multipliers.append(mult)
i += 1
if len(remade_tokens) > maxlen - 2: if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
@ -464,6 +237,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
overflowing_words = [vocab.get(int(x), "") for x in ovf] overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
token_count = len(remade_tokens) token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
@ -484,7 +258,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
else: else:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
self.hijack.fixes = hijack_fixes self.hijack.fixes = hijack_fixes
self.hijack.comments = hijack_comments self.hijack.comments = hijack_comments
@ -517,15 +290,20 @@ class EmbeddingsWithFixes(torch.nn.Module):
inputs_embeds = self.wrapped(input_ids) inputs_embeds = self.wrapped(input_ids)
if batch_fixes is not None: if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, word in fixes:
emb = self.embeddings.word_embeddings[word]
emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
tensor[offset+1:offset+1+emb_len] = self.embeddings.word_embeddings[word][0:emb_len]
return inputs_embeds return inputs_embeds
vecs = []
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes:
emb = embedding.vec
emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
vecs.append(tensor)
return torch.stack(vecs)
def add_circular_option_to_conv_2d(): def add_circular_option_to_conv_2d():
conv2d_constructor = torch.nn.Conv2d.__init__ conv2d_constructor = torch.nn.Conv2d.__init__

View file

@ -0,0 +1,164 @@
import math
import torch
from torch import einsum
from ldm.util import default
from einops import rearrange
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
for i in range(0, q.shape[0], 2):
end = i + 2
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
s1 *= self.scale
s2 = s1.softmax(dim=-1)
del s1
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
# taken from https://github.com/Doggettx/stable-diffusion
def split_cross_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
k_in = self.to_k(context) * self.scale
v_in = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
def nonlinearity_hijack(x):
# swish
t = torch.sigmoid(x)
x *= t
del t
return x
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
q1 = self.q(h_)
k1 = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q1.shape
q2 = q1.reshape(b, c, h*w)
del q1
q = q2.permute(0, 2, 1) # b,hw,c
del q2
k = k1.reshape(b, c, h*w) # b,c,hw
del k1
h_ = torch.zeros_like(k, device=q.device)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
steps = 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w2 = w1 * (int(c)**(-0.5))
del w1
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
del w2
# attend to values
v1 = v.reshape(b, c, h*w)
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
del w3
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
del v1, w4
h2 = h_.reshape(b, c, h, w)
del h_
h3 = self.proj_out(h2)
del h2
h3 += x
return h3

View file

@ -8,7 +8,7 @@ from omegaconf import OmegaConf
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from modules import shared, modelloader from modules import shared, modelloader, devices
from modules.paths import models_path from modules.paths import models_path
model_dir = "Stable-diffusion" model_dir = "Stable-diffusion"
@ -134,6 +134,8 @@ def load_model_weights(model, checkpoint_file, sd_model_hash):
if not shared.cmd_opts.no_half: if not shared.cmd_opts.no_half:
model.half() model.half()
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
model.sd_model_hash = sd_model_hash model.sd_model_hash = sd_model_hash
model.sd_model_checkpint = checkpoint_file model.sd_model_checkpint = checkpoint_file

View file

@ -78,6 +78,7 @@ class State:
current_latent = None current_latent = None
current_image = None current_image = None
current_image_sampling_step = 0 current_image_sampling_step = 0
textinfo = None
def interrupt(self): def interrupt(self):
self.interrupted = True self.interrupted = True
@ -88,7 +89,7 @@ class State:
self.current_image_sampling_step = 0 self.current_image_sampling_step = 0
def get_job_timestamp(self): def get_job_timestamp(self):
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
state = State() state = State()

View file

@ -0,0 +1,76 @@
import os
import numpy as np
import PIL
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import random
import tqdm
class PersonalizedBase(Dataset):
def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None):
self.placeholder_token = placeholder_token
self.size = size
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.dataset = []
with open(template_file, "r") as file:
lines = [x.strip() for x in file.readlines()]
self.lines = lines
assert data_root, 'dataset directory not specified'
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
image = Image.open(path)
image = image.convert('RGB')
image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
filename = os.path.basename(path)
filename_tokens = os.path.splitext(filename)[0].replace('_', '-').replace(' ', '-').split('-')
filename_tokens = [token for token in filename_tokens if token.isalpha()]
npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
torchdata = torch.moveaxis(torchdata, 2, 0)
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
self.dataset.append((init_latent, filename_tokens))
self.length = len(self.dataset) * repeats
self.initial_indexes = np.arange(self.length) % len(self.dataset)
self.indexes = None
self.shuffle()
def shuffle(self):
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
def __len__(self):
return self.length
def __getitem__(self, i):
if i % len(self.dataset) == 0:
self.shuffle()
index = self.indexes[i % len(self.indexes)]
x, filename_tokens = self.dataset[index]
text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token)
text = text.replace("[filewords]", ' '.join(filename_tokens))
return x, text

View file

@ -0,0 +1,258 @@
import os
import sys
import traceback
import torch
import tqdm
import html
import datetime
from modules import shared, devices, sd_hijack, processing
import modules.textual_inversion.dataset
class Embedding:
def __init__(self, vec, name, step=None):
self.vec = vec
self.name = name
self.step = step
self.cached_checksum = None
def save(self, filename):
embedding_data = {
"string_to_token": {"*": 265},
"string_to_param": {"*": self.vec},
"name": self.name,
"step": self.step,
}
torch.save(embedding_data, filename)
def checksum(self):
if self.cached_checksum is not None:
return self.cached_checksum
def const_hash(a):
r = 0
for v in a:
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
return r
self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
return self.cached_checksum
class EmbeddingDatabase:
def __init__(self, embeddings_dir):
self.ids_lookup = {}
self.word_embeddings = {}
self.dir_mtime = None
self.embeddings_dir = embeddings_dir
def register_embedding(self, embedding, model):
self.word_embeddings[embedding.name] = embedding
ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
self.ids_lookup[first_id].append((ids, embedding))
return embedding
def load_textual_inversion_embeddings(self):
mt = os.path.getmtime(self.embeddings_dir)
if self.dir_mtime is not None and mt <= self.dir_mtime:
return
self.dir_mtime = mt
self.ids_lookup.clear()
self.word_embeddings.clear()
def process_file(path, filename):
name = os.path.splitext(filename)[0]
data = torch.load(path, map_location="cpu")
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
if hasattr(param_dict, '_parameters'):
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
# diffuser concepts
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
else:
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
vec = emb.detach().to(devices.device, dtype=torch.float32)
embedding = Embedding(vec, name)
embedding.step = data.get('step', None)
self.register_embedding(embedding, shared.sd_model)
for fn in os.listdir(self.embeddings_dir):
try:
fullfn = os.path.join(self.embeddings_dir, fn)
if os.stat(fullfn).st_size == 0:
continue
process_file(fullfn, fn)
except Exception:
print(f"Error loading emedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
continue
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]
possible_matches = self.ids_lookup.get(token, None)
if possible_matches is None:
return None
for ids, embedding in possible_matches:
if tokens[offset:offset + len(ids)] == ids:
return embedding
return None
def create_embedding(name, num_vectors_per_token):
init_text = '*'
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedded = embedding_layer(ids.to(devices.device)).squeeze(0)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
for i in range(num_vectors_per_token):
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"
embedding = Embedding(vec, name)
embedding.step = 0
embedding.save(fn)
return fn
def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%d-%m"), embedding_name)
if save_embedding_every > 0:
embedding_dir = os.path.join(log_directory, "embeddings")
os.makedirs(embedding_dir, exist_ok=True)
else:
embedding_dir = None
if create_image_every > 0:
images_dir = os.path.join(log_directory, "images")
os.makedirs(images_dir, exist_ok=True)
else:
images_dir = None
cond_model = shared.sd_model.cond_stage_model
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
hijack = sd_hijack.model_hijack
embedding = hijack.embedding_db.word_embeddings[embedding_name]
embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
losses = torch.zeros((32,))
last_saved_file = "<none>"
last_saved_image = "<none>"
ititial_step = embedding.step or 0
if ititial_step > steps:
return embedding, filename
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, (x, text) in pbar:
embedding.step = i + ititial_step
if embedding.step > steps:
break
if shared.state.interrupted:
break
with torch.autocast("cuda"):
c = cond_model([text])
loss = shared.sd_model(x.unsqueeze(0), c)[0]
losses[embedding.step % losses.shape[0]] = loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_description(f"loss: {losses.mean():.7f}")
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
embedding.save(last_saved_file)
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
prompt=text,
steps=20,
do_not_save_grid=True,
do_not_save_samples=True,
)
processed = processing.process_images(p)
image = processed.images[0]
shared.state.current_image = image
image.save(last_saved_image)
last_saved_image += f", prompt: {text}"
shared.state.job_no = embedding.step
shared.state.textinfo = f"""
<p>
Loss: {losses.mean():.7f}<br/>
Step: {embedding.step}<br/>
Last prompt: {html.escape(text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
embedding.cached_checksum = None
embedding.save(filename)
return embedding, filename

View file

@ -0,0 +1,32 @@
import html
import gradio as gr
import modules.textual_inversion.textual_inversion as ti
from modules import sd_hijack, shared
def create_embedding(name, nvpt):
filename = ti.create_embedding(name, nvpt)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
def train_embedding(*args):
try:
sd_hijack.undo_optimizations()
embedding, filename = ti.train_embedding(*args)
res = f"""
Training {'interrupted' if shared.state.interrupted else 'finished'} after {embedding.step} steps.
Embedding saved to {html.escape(filename)}
"""
return res, ""
except Exception:
raise
finally:
sd_hijack.apply_optimizations()

View file

@ -21,6 +21,7 @@ import gradio as gr
import gradio.utils import gradio.utils
import gradio.routes import gradio.routes
from modules import sd_hijack
from modules.paths import script_path from modules.paths import script_path
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
import modules.shared as shared import modules.shared as shared
@ -32,6 +33,7 @@ import modules.gfpgan_model
import modules.codeformer_model import modules.codeformer_model
import modules.styles import modules.styles
import modules.generation_parameters_copypaste import modules.generation_parameters_copypaste
import modules.textual_inversion.ui
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
mimetypes.init() mimetypes.init()
@ -142,8 +144,8 @@ def save_files(js_data, images, index):
return '', '', plaintext_to_html(f"Saved: {filenames[0]}") return '', '', plaintext_to_html(f"Saved: {filenames[0]}")
def wrap_gradio_call(func): def wrap_gradio_call(func, extra_outputs=None):
def f(*args, **kwargs): def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
if run_memmon: if run_memmon:
shared.mem_mon.monitor() shared.mem_mon.monitor()
@ -159,7 +161,10 @@ def wrap_gradio_call(func):
shared.state.job = "" shared.state.job = ""
shared.state.job_count = 0 shared.state.job_count = 0
res = [None, '', f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"] if extra_outputs_array is None:
extra_outputs_array = [None, '']
res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
elapsed = time.perf_counter() - t elapsed = time.perf_counter() - t
@ -179,6 +184,7 @@ def wrap_gradio_call(func):
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>{vram_html}</div>" res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>{vram_html}</div>"
shared.state.interrupted = False shared.state.interrupted = False
shared.state.job_count = 0
return tuple(res) return tuple(res)
@ -187,7 +193,7 @@ def wrap_gradio_call(func):
def check_progress_call(id_part): def check_progress_call(id_part):
if shared.state.job_count == 0: if shared.state.job_count == 0:
return "", gr_show(False), gr_show(False) return "", gr_show(False), gr_show(False), gr_show(False)
progress = 0 progress = 0
@ -219,13 +225,19 @@ def check_progress_call(id_part):
else: else:
preview_visibility = gr_show(True) preview_visibility = gr_show(True)
return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image if shared.state.textinfo is not None:
textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
else:
textinfo_result = gr_show(False)
return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image, textinfo_result
def check_progress_call_initial(id_part): def check_progress_call_initial(id_part):
shared.state.job_count = -1 shared.state.job_count = -1
shared.state.current_latent = None shared.state.current_latent = None
shared.state.current_image = None shared.state.current_image = None
shared.state.textinfo = None
return check_progress_call(id_part) return check_progress_call(id_part)
@ -399,13 +411,16 @@ def create_toprow(is_img2img):
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste
def setup_progressbar(progressbar, preview, id_part): def setup_progressbar(progressbar, preview, id_part, textinfo=None):
if textinfo is None:
textinfo = gr.HTML(visible=False)
check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
check_progress.click( check_progress.click(
fn=lambda: check_progress_call(id_part), fn=lambda: check_progress_call(id_part),
show_progress=False, show_progress=False,
inputs=[], inputs=[],
outputs=[progressbar, preview, preview], outputs=[progressbar, preview, preview, textinfo],
) )
check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
@ -413,11 +428,14 @@ def setup_progressbar(progressbar, preview, id_part):
fn=lambda: check_progress_call_initial(id_part), fn=lambda: check_progress_call_initial(id_part),
show_progress=False, show_progress=False,
inputs=[], inputs=[],
outputs=[progressbar, preview, preview], outputs=[progressbar, preview, preview, textinfo],
) )
def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): def create_ui(wrap_gradio_gpu_call):
import modules.img2img
import modules.txt2img
with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False) txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False) dummy_component = gr.Label(visible=False)
@ -483,7 +501,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
txt2img_args = dict( txt2img_args = dict(
fn=txt2img, fn=wrap_gradio_gpu_call(modules.txt2img.txt2img),
_js="submit", _js="submit",
inputs=[ inputs=[
txt2img_prompt, txt2img_prompt,
@ -675,7 +693,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
) )
img2img_args = dict( img2img_args = dict(
fn=img2img, fn=wrap_gradio_gpu_call(modules.img2img.img2img),
_js="submit_img2img", _js="submit_img2img",
inputs=[ inputs=[
dummy_component, dummy_component,
@ -828,7 +846,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
open_extras_folder = gr.Button('Open output directory', elem_id=button_id) open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
submit.click( submit.click(
fn=run_extras, fn=wrap_gradio_gpu_call(modules.extras.run_extras),
_js="get_extras_tab_index", _js="get_extras_tab_index",
inputs=[ inputs=[
dummy_component, dummy_component,
@ -878,7 +896,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
pnginfo_send_to_img2img = gr.Button('Send to img2img') pnginfo_send_to_img2img = gr.Button('Send to img2img')
image.change( image.change(
fn=wrap_gradio_call(run_pnginfo), fn=wrap_gradio_call(modules.extras.run_pnginfo),
inputs=[image], inputs=[image],
outputs=[html, generation_info, html2], outputs=[html, generation_info, html2],
) )
@ -900,6 +918,92 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
with gr.Blocks() as textual_inversion_interface:
with gr.Row().style(equal_height=False):
with gr.Column():
with gr.Group():
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new embedding</p>")
new_embedding_name = gr.Textbox(label="Name")
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
with gr.Column():
create_embedding = gr.Button(value="Create", variant='primary')
with gr.Group():
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 512x512 images</p>")
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
steps = gr.Number(label='Max steps', value=100000, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=1000, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=1000, precision=0)
with gr.Row():
with gr.Column(scale=2):
gr.HTML(value="")
with gr.Column():
with gr.Row():
interrupt_training = gr.Button(value="Interrupt")
train_embedding = gr.Button(value="Train", variant='primary')
with gr.Column():
progressbar = gr.HTML(elem_id="ti_progressbar")
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
ti_preview = gr.Image(elem_id='ti_preview', visible=False)
ti_progress = gr.HTML(elem_id="ti_progress", value="")
ti_outcome = gr.HTML(elem_id="ti_error", value="")
setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress)
create_embedding.click(
fn=modules.textual_inversion.ui.create_embedding,
inputs=[
new_embedding_name,
nvpt,
],
outputs=[
train_embedding_name,
ti_output,
ti_outcome,
]
)
train_embedding.click(
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
train_embedding_name,
learn_rate,
dataset_directory,
log_directory,
steps,
create_image_every,
save_embedding_every,
template_file,
],
outputs=[
ti_output,
ti_outcome,
]
)
interrupt_training.click(
fn=lambda: shared.state.interrupt(),
inputs=[],
outputs=[],
)
def create_setting_component(key): def create_setting_component(key):
def fun(): def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default return opts.data[key] if key in opts.data else opts.data_labels[key].default
@ -1011,6 +1115,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
(extras_interface, "Extras", "extras"), (extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"), (pnginfo_interface, "PNG Info", "pnginfo"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"), (modelmerger_interface, "Checkpoint Merger", "modelmerger"),
(textual_inversion_interface, "Textual inversion", "ti"),
(settings_interface, "Settings", "settings"), (settings_interface, "Settings", "settings"),
] ]
@ -1044,11 +1149,11 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
def modelmerger(*args): def modelmerger(*args):
try: try:
results = run_modelmerger(*args) results = modules.extras.run_modelmerger(*args)
except Exception as e: except Exception as e:
print("Error loading/saving model file:", file=sys.stderr) print("Error loading/saving model file:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
modules.sd_models.list_models() #To remove the potentially missing models from the list modules.sd_models.list_models() # to remove the potentially missing models from the list
return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)] return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
return results return results

View file

@ -157,7 +157,7 @@ button{
max-width: 10em; max-width: 10em;
} }
#txt2img_preview, #img2img_preview{ #txt2img_preview, #img2img_preview, #ti_preview{
position: absolute; position: absolute;
width: 320px; width: 320px;
left: 0; left: 0;
@ -172,18 +172,18 @@ button{
} }
@media screen and (min-width: 768px) { @media screen and (min-width: 768px) {
#txt2img_preview, #img2img_preview { #txt2img_preview, #img2img_preview, #ti_preview {
position: absolute; position: absolute;
} }
} }
@media screen and (max-width: 767px) { @media screen and (max-width: 767px) {
#txt2img_preview, #img2img_preview { #txt2img_preview, #img2img_preview, #ti_preview {
position: relative; position: relative;
} }
} }
#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0{ #txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0, #ti_preview div.left-0.top-0{
display: none; display: none;
} }
@ -247,7 +247,7 @@ input[type="range"]{
#txt2img_negative_prompt, #img2img_negative_prompt{ #txt2img_negative_prompt, #img2img_negative_prompt{
} }
#txt2img_progressbar, #img2img_progressbar{ #txt2img_progressbar, #img2img_progressbar, #ti_progressbar{
position: absolute; position: absolute;
z-index: 1000; z-index: 1000;
right: 0; right: 0;

View file

@ -0,0 +1,19 @@
a painting, art by [name]
a rendering, art by [name]
a cropped painting, art by [name]
the painting, art by [name]
a clean painting, art by [name]
a dirty painting, art by [name]
a dark painting, art by [name]
a picture, art by [name]
a cool painting, art by [name]
a close-up painting, art by [name]
a bright painting, art by [name]
a cropped painting, art by [name]
a good painting, art by [name]
a close-up painting, art by [name]
a rendition, art by [name]
a nice painting, art by [name]
a small painting, art by [name]
a weird painting, art by [name]
a large painting, art by [name]

View file

@ -0,0 +1,19 @@
a painting of [filewords], art by [name]
a rendering of [filewords], art by [name]
a cropped painting of [filewords], art by [name]
the painting of [filewords], art by [name]
a clean painting of [filewords], art by [name]
a dirty painting of [filewords], art by [name]
a dark painting of [filewords], art by [name]
a picture of [filewords], art by [name]
a cool painting of [filewords], art by [name]
a close-up painting of [filewords], art by [name]
a bright painting of [filewords], art by [name]
a cropped painting of [filewords], art by [name]
a good painting of [filewords], art by [name]
a close-up painting of [filewords], art by [name]
a rendition of [filewords], art by [name]
a nice painting of [filewords], art by [name]
a small painting of [filewords], art by [name]
a weird painting of [filewords], art by [name]
a large painting of [filewords], art by [name]

View file

@ -0,0 +1,27 @@
a photo of a [name]
a rendering of a [name]
a cropped photo of the [name]
the photo of a [name]
a photo of a clean [name]
a photo of a dirty [name]
a dark photo of the [name]
a photo of my [name]
a photo of the cool [name]
a close-up photo of a [name]
a bright photo of the [name]
a cropped photo of a [name]
a photo of the [name]
a good photo of the [name]
a photo of one [name]
a close-up photo of the [name]
a rendition of the [name]
a photo of the clean [name]
a rendition of a [name]
a photo of a nice [name]
a good photo of a [name]
a photo of the nice [name]
a photo of the small [name]
a photo of the weird [name]
a photo of the large [name]
a photo of a cool [name]
a photo of a small [name]

View file

@ -0,0 +1,27 @@
a photo of a [name], [filewords]
a rendering of a [name], [filewords]
a cropped photo of the [name], [filewords]
the photo of a [name], [filewords]
a photo of a clean [name], [filewords]
a photo of a dirty [name], [filewords]
a dark photo of the [name], [filewords]
a photo of my [name], [filewords]
a photo of the cool [name], [filewords]
a close-up photo of a [name], [filewords]
a bright photo of the [name], [filewords]
a cropped photo of a [name], [filewords]
a photo of the [name], [filewords]
a good photo of the [name], [filewords]
a photo of one [name], [filewords]
a close-up photo of the [name], [filewords]
a rendition of the [name], [filewords]
a photo of the clean [name], [filewords]
a rendition of a [name], [filewords]
a photo of a nice [name], [filewords]
a good photo of a [name], [filewords]
a photo of the nice [name], [filewords]
a photo of the small [name], [filewords]
a photo of the weird [name], [filewords]
a photo of the large [name], [filewords]
a photo of a cool [name], [filewords]
a photo of a small [name], [filewords]

View file

@ -12,7 +12,6 @@ import modules.bsrgan_model as bsrgan
import modules.extras import modules.extras
import modules.face_restoration import modules.face_restoration
import modules.gfpgan_model as gfpgan import modules.gfpgan_model as gfpgan
import modules.img2img
import modules.ldsr_model as ldsr import modules.ldsr_model as ldsr
import modules.lowvram import modules.lowvram
import modules.realesrgan_model as realesrgan import modules.realesrgan_model as realesrgan
@ -21,7 +20,6 @@ import modules.sd_hijack
import modules.sd_models import modules.sd_models
import modules.shared as shared import modules.shared as shared
import modules.swinir_model as swinir import modules.swinir_model as swinir
import modules.txt2img
import modules.ui import modules.ui
from modules import modelloader from modules import modelloader
from modules.paths import script_path from modules.paths import script_path
@ -46,7 +44,7 @@ def wrap_queued_call(func):
return f return f
def wrap_gradio_gpu_call(func): def wrap_gradio_gpu_call(func, extra_outputs=None):
def f(*args, **kwargs): def f(*args, **kwargs):
devices.torch_gc() devices.torch_gc()
@ -58,6 +56,7 @@ def wrap_gradio_gpu_call(func):
shared.state.current_image = None shared.state.current_image = None
shared.state.current_image_sampling_step = 0 shared.state.current_image_sampling_step = 0
shared.state.interrupted = False shared.state.interrupted = False
shared.state.textinfo = None
with queue_lock: with queue_lock:
res = func(*args, **kwargs) res = func(*args, **kwargs)
@ -69,7 +68,7 @@ def wrap_gradio_gpu_call(func):
return res return res
return modules.ui.wrap_gradio_call(f) return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
modules.scripts.load_scripts(os.path.join(script_path, "scripts")) modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
@ -86,13 +85,7 @@ def webui():
signal.signal(signal.SIGINT, sigint_handler) signal.signal(signal.SIGINT, sigint_handler)
demo = modules.ui.create_ui( demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
txt2img=wrap_gradio_gpu_call(modules.txt2img.txt2img),
img2img=wrap_gradio_gpu_call(modules.img2img.img2img),
run_extras=wrap_gradio_gpu_call(modules.extras.run_extras),
run_pnginfo=modules.extras.run_pnginfo,
run_modelmerger=modules.extras.run_modelmerger
)
demo.launch( demo.launch(
share=cmd_opts.share, share=cmd_opts.share,