initial support for training textual inversion
This commit is contained in:
parent
84e97a98c5
commit
820f1dc96b
19 changed files with 828 additions and 315 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -25,3 +25,4 @@ __pycache__
|
||||||
/.idea
|
/.idea
|
||||||
notification.mp3
|
notification.mp3
|
||||||
/SwinIR
|
/SwinIR
|
||||||
|
/textual_inversion
|
||||||
|
|
|
@ -30,6 +30,7 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_inte
|
||||||
onUiUpdate(function(){
|
onUiUpdate(function(){
|
||||||
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
|
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
|
||||||
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
|
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
|
||||||
|
check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', 'ti_interrupt', 'ti_preview', 'ti_gallery')
|
||||||
})
|
})
|
||||||
|
|
||||||
function requestMoreProgress(id_part, id_progressbar_span, id_interrupt){
|
function requestMoreProgress(id_part, id_progressbar_span, id_interrupt){
|
||||||
|
|
8
javascript/textualInversion.js
Normal file
8
javascript/textualInversion.js
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
|
||||||
|
|
||||||
|
function start_training_textual_inversion(){
|
||||||
|
requestProgress('ti')
|
||||||
|
gradioApp().querySelector('#ti_error').innerHTML=''
|
||||||
|
|
||||||
|
return args_to_array(arguments)
|
||||||
|
}
|
|
@ -32,10 +32,9 @@ def enable_tf32():
|
||||||
|
|
||||||
errors.run(enable_tf32, "Enabling TF32")
|
errors.run(enable_tf32, "Enabling TF32")
|
||||||
|
|
||||||
|
|
||||||
device = get_optimal_device()
|
device = get_optimal_device()
|
||||||
device_codeformer = cpu if has_mps else device
|
device_codeformer = cpu if has_mps else device
|
||||||
|
dtype = torch.float16
|
||||||
|
|
||||||
def randn(seed, shape):
|
def randn(seed, shape):
|
||||||
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
||||||
|
|
|
@ -56,7 +56,7 @@ class StableDiffusionProcessing:
|
||||||
self.prompt: str = prompt
|
self.prompt: str = prompt
|
||||||
self.prompt_for_display: str = None
|
self.prompt_for_display: str = None
|
||||||
self.negative_prompt: str = (negative_prompt or "")
|
self.negative_prompt: str = (negative_prompt or "")
|
||||||
self.styles: str = styles
|
self.styles: list = styles or []
|
||||||
self.seed: int = seed
|
self.seed: int = seed
|
||||||
self.subseed: int = subseed
|
self.subseed: int = subseed
|
||||||
self.subseed_strength: float = subseed_strength
|
self.subseed_strength: float = subseed_strength
|
||||||
|
@ -271,7 +271,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
||||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||||
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||||
"Denoising strength": getattr(p, 'denoising_strength', None),
|
"Denoising strength": getattr(p, 'denoising_strength', None),
|
||||||
"Eta": (None if p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
|
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
|
||||||
}
|
}
|
||||||
|
|
||||||
generation_params.update(p.extra_generation_params)
|
generation_params.update(p.extra_generation_params)
|
||||||
|
@ -295,8 +295,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
|
|
||||||
fix_seed(p)
|
fix_seed(p)
|
||||||
|
|
||||||
os.makedirs(p.outpath_samples, exist_ok=True)
|
if p.outpath_samples is not None:
|
||||||
os.makedirs(p.outpath_grids, exist_ok=True)
|
os.makedirs(p.outpath_samples, exist_ok=True)
|
||||||
|
|
||||||
|
if p.outpath_grids is not None:
|
||||||
|
os.makedirs(p.outpath_grids, exist_ok=True)
|
||||||
|
|
||||||
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
||||||
|
|
||||||
|
@ -323,7 +326,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
|
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
|
||||||
|
|
||||||
if os.path.exists(cmd_opts.embeddings_dir):
|
if os.path.exists(cmd_opts.embeddings_dir):
|
||||||
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)
|
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||||
|
|
||||||
infotexts = []
|
infotexts = []
|
||||||
output_images = []
|
output_images = []
|
||||||
|
|
|
@ -6,244 +6,41 @@ import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
|
|
||||||
from modules import prompt_parser
|
import modules.textual_inversion.textual_inversion
|
||||||
|
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
|
||||||
from modules.shared import opts, device, cmd_opts
|
from modules.shared import opts, device, cmd_opts
|
||||||
|
|
||||||
from ldm.util import default
|
|
||||||
from einops import rearrange
|
|
||||||
import ldm.modules.attention
|
import ldm.modules.attention
|
||||||
import ldm.modules.diffusionmodules.model
|
import ldm.modules.diffusionmodules.model
|
||||||
|
|
||||||
|
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
||||||
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
||||||
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||||
h = self.heads
|
|
||||||
|
|
||||||
q = self.to_q(x)
|
|
||||||
context = default(context, x)
|
|
||||||
k = self.to_k(context)
|
|
||||||
v = self.to_v(context)
|
|
||||||
del context, x
|
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
|
||||||
|
|
||||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
|
||||||
for i in range(0, q.shape[0], 2):
|
|
||||||
end = i + 2
|
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
|
||||||
s1 *= self.scale
|
|
||||||
|
|
||||||
s2 = s1.softmax(dim=-1)
|
|
||||||
del s1
|
|
||||||
|
|
||||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
|
||||||
del s2
|
|
||||||
|
|
||||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
|
||||||
del r1
|
|
||||||
|
|
||||||
return self.to_out(r2)
|
|
||||||
|
|
||||||
|
|
||||||
# taken from https://github.com/Doggettx/stable-diffusion
|
def apply_optimizations():
|
||||||
def split_cross_attention_forward(self, x, context=None, mask=None):
|
if cmd_opts.opt_split_attention_v1:
|
||||||
h = self.heads
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
||||||
|
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
||||||
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|
||||||
|
ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack
|
||||||
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
|
||||||
|
|
||||||
q_in = self.to_q(x)
|
|
||||||
context = default(context, x)
|
|
||||||
k_in = self.to_k(context) * self.scale
|
|
||||||
v_in = self.to_v(context)
|
|
||||||
del context, x
|
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
def undo_optimizations():
|
||||||
del q_in, k_in, v_in
|
ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward
|
||||||
|
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
|
||||||
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||||
|
|
||||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
|
||||||
|
|
||||||
stats = torch.cuda.memory_stats(q.device)
|
|
||||||
mem_active = stats['active_bytes.all.current']
|
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
|
||||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
|
||||||
mem_free_torch = mem_reserved - mem_active
|
|
||||||
mem_free_total = mem_free_cuda + mem_free_torch
|
|
||||||
|
|
||||||
gb = 1024 ** 3
|
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
|
||||||
modifier = 3 if q.element_size() == 2 else 2.5
|
|
||||||
mem_required = tensor_size * modifier
|
|
||||||
steps = 1
|
|
||||||
|
|
||||||
if mem_required > mem_free_total:
|
|
||||||
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
|
||||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
|
||||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
|
||||||
|
|
||||||
if steps > 64:
|
|
||||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
|
||||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
|
||||||
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
|
||||||
|
|
||||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
|
||||||
for i in range(0, q.shape[1], slice_size):
|
|
||||||
end = i + slice_size
|
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
|
||||||
|
|
||||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
|
||||||
del s1
|
|
||||||
|
|
||||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
|
||||||
del s2
|
|
||||||
|
|
||||||
del q, k, v
|
|
||||||
|
|
||||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
|
||||||
del r1
|
|
||||||
|
|
||||||
return self.to_out(r2)
|
|
||||||
|
|
||||||
def nonlinearity_hijack(x):
|
|
||||||
# swish
|
|
||||||
t = torch.sigmoid(x)
|
|
||||||
x *= t
|
|
||||||
del t
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def cross_attention_attnblock_forward(self, x):
|
|
||||||
h_ = x
|
|
||||||
h_ = self.norm(h_)
|
|
||||||
q1 = self.q(h_)
|
|
||||||
k1 = self.k(h_)
|
|
||||||
v = self.v(h_)
|
|
||||||
|
|
||||||
# compute attention
|
|
||||||
b, c, h, w = q1.shape
|
|
||||||
|
|
||||||
q2 = q1.reshape(b, c, h*w)
|
|
||||||
del q1
|
|
||||||
|
|
||||||
q = q2.permute(0, 2, 1) # b,hw,c
|
|
||||||
del q2
|
|
||||||
|
|
||||||
k = k1.reshape(b, c, h*w) # b,c,hw
|
|
||||||
del k1
|
|
||||||
|
|
||||||
h_ = torch.zeros_like(k, device=q.device)
|
|
||||||
|
|
||||||
stats = torch.cuda.memory_stats(q.device)
|
|
||||||
mem_active = stats['active_bytes.all.current']
|
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
|
||||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
|
||||||
mem_free_torch = mem_reserved - mem_active
|
|
||||||
mem_free_total = mem_free_cuda + mem_free_torch
|
|
||||||
|
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
|
||||||
mem_required = tensor_size * 2.5
|
|
||||||
steps = 1
|
|
||||||
|
|
||||||
if mem_required > mem_free_total:
|
|
||||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
|
||||||
|
|
||||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
|
||||||
for i in range(0, q.shape[1], slice_size):
|
|
||||||
end = i + slice_size
|
|
||||||
|
|
||||||
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
|
||||||
w2 = w1 * (int(c)**(-0.5))
|
|
||||||
del w1
|
|
||||||
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
|
|
||||||
del w2
|
|
||||||
|
|
||||||
# attend to values
|
|
||||||
v1 = v.reshape(b, c, h*w)
|
|
||||||
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
|
||||||
del w3
|
|
||||||
|
|
||||||
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
|
||||||
del v1, w4
|
|
||||||
|
|
||||||
h2 = h_.reshape(b, c, h, w)
|
|
||||||
del h_
|
|
||||||
|
|
||||||
h3 = self.proj_out(h2)
|
|
||||||
del h2
|
|
||||||
|
|
||||||
h3 += x
|
|
||||||
|
|
||||||
return h3
|
|
||||||
|
|
||||||
class StableDiffusionModelHijack:
|
class StableDiffusionModelHijack:
|
||||||
ids_lookup = {}
|
|
||||||
word_embeddings = {}
|
|
||||||
word_embeddings_checksums = {}
|
|
||||||
fixes = None
|
fixes = None
|
||||||
comments = []
|
comments = []
|
||||||
dir_mtime = None
|
|
||||||
layers = None
|
layers = None
|
||||||
circular_enabled = False
|
circular_enabled = False
|
||||||
clip = None
|
clip = None
|
||||||
|
|
||||||
def load_textual_inversion_embeddings(self, dirname, model):
|
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
|
||||||
mt = os.path.getmtime(dirname)
|
|
||||||
if self.dir_mtime is not None and mt <= self.dir_mtime:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.dir_mtime = mt
|
|
||||||
self.ids_lookup.clear()
|
|
||||||
self.word_embeddings.clear()
|
|
||||||
|
|
||||||
tokenizer = model.cond_stage_model.tokenizer
|
|
||||||
|
|
||||||
def const_hash(a):
|
|
||||||
r = 0
|
|
||||||
for v in a:
|
|
||||||
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
|
|
||||||
return r
|
|
||||||
|
|
||||||
def process_file(path, filename):
|
|
||||||
name = os.path.splitext(filename)[0]
|
|
||||||
|
|
||||||
data = torch.load(path, map_location="cpu")
|
|
||||||
|
|
||||||
# textual inversion embeddings
|
|
||||||
if 'string_to_param' in data:
|
|
||||||
param_dict = data['string_to_param']
|
|
||||||
if hasattr(param_dict, '_parameters'):
|
|
||||||
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
|
||||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
|
||||||
emb = next(iter(param_dict.items()))[1]
|
|
||||||
# diffuser concepts
|
|
||||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
|
||||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
|
||||||
|
|
||||||
emb = next(iter(data.values()))
|
|
||||||
if len(emb.shape) == 1:
|
|
||||||
emb = emb.unsqueeze(0)
|
|
||||||
|
|
||||||
self.word_embeddings[name] = emb.detach().to(device)
|
|
||||||
self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1)*100)&0xffff:04x}'
|
|
||||||
|
|
||||||
ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
|
|
||||||
|
|
||||||
first_id = ids[0]
|
|
||||||
if first_id not in self.ids_lookup:
|
|
||||||
self.ids_lookup[first_id] = []
|
|
||||||
self.ids_lookup[first_id].append((ids, name))
|
|
||||||
|
|
||||||
for fn in os.listdir(dirname):
|
|
||||||
try:
|
|
||||||
fullfn = os.path.join(dirname, fn)
|
|
||||||
|
|
||||||
if os.stat(fullfn).st_size == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
process_file(fullfn, fn)
|
|
||||||
except Exception:
|
|
||||||
print(f"Error loading emedding {fn}:", file=sys.stderr)
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
continue
|
|
||||||
|
|
||||||
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
|
|
||||||
|
|
||||||
def hijack(self, m):
|
def hijack(self, m):
|
||||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||||
|
@ -253,12 +50,7 @@ class StableDiffusionModelHijack:
|
||||||
|
|
||||||
self.clip = m.cond_stage_model
|
self.clip = m.cond_stage_model
|
||||||
|
|
||||||
if cmd_opts.opt_split_attention_v1:
|
apply_optimizations()
|
||||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
|
||||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
|
||||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
|
||||||
ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack
|
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
|
||||||
|
|
||||||
def flatten(el):
|
def flatten(el):
|
||||||
flattened = [flatten(children) for children in el.children()]
|
flattened = [flatten(children) for children in el.children()]
|
||||||
|
@ -296,7 +88,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
def __init__(self, wrapped, hijack):
|
def __init__(self, wrapped, hijack):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.wrapped = wrapped
|
self.wrapped = wrapped
|
||||||
self.hijack = hijack
|
self.hijack: StableDiffusionModelHijack = hijack
|
||||||
self.tokenizer = wrapped.tokenizer
|
self.tokenizer = wrapped.tokenizer
|
||||||
self.max_length = wrapped.max_length
|
self.max_length = wrapped.max_length
|
||||||
self.token_mults = {}
|
self.token_mults = {}
|
||||||
|
@ -317,7 +109,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
if mult != 1.0:
|
if mult != 1.0:
|
||||||
self.token_mults[ident] = mult
|
self.token_mults[ident] = mult
|
||||||
|
|
||||||
|
|
||||||
def tokenize_line(self, line, used_custom_terms, hijack_comments):
|
def tokenize_line(self, line, used_custom_terms, hijack_comments):
|
||||||
id_start = self.wrapped.tokenizer.bos_token_id
|
id_start = self.wrapped.tokenizer.bos_token_id
|
||||||
id_end = self.wrapped.tokenizer.eos_token_id
|
id_end = self.wrapped.tokenizer.eos_token_id
|
||||||
|
@ -339,28 +130,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
while i < len(tokens):
|
while i < len(tokens):
|
||||||
token = tokens[i]
|
token = tokens[i]
|
||||||
|
|
||||||
possible_matches = self.hijack.ids_lookup.get(token, None)
|
embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||||
|
|
||||||
if possible_matches is None:
|
if embedding is None:
|
||||||
remade_tokens.append(token)
|
remade_tokens.append(token)
|
||||||
multipliers.append(weight)
|
multipliers.append(weight)
|
||||||
|
i += 1
|
||||||
else:
|
else:
|
||||||
found = False
|
emb_len = int(embedding.vec.shape[0])
|
||||||
for ids, word in possible_matches:
|
fixes.append((len(remade_tokens), embedding))
|
||||||
if tokens[i:i + len(ids)] == ids:
|
remade_tokens += [0] * emb_len
|
||||||
emb_len = int(self.hijack.word_embeddings[word].shape[0])
|
multipliers += [weight] * emb_len
|
||||||
fixes.append((len(remade_tokens), word))
|
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||||
remade_tokens += [0] * emb_len
|
i += emb_len
|
||||||
multipliers += [weight] * emb_len
|
|
||||||
i += len(ids) - 1
|
|
||||||
found = True
|
|
||||||
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
|
|
||||||
break
|
|
||||||
|
|
||||||
if not found:
|
|
||||||
remade_tokens.append(token)
|
|
||||||
multipliers.append(weight)
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
if len(remade_tokens) > maxlen - 2:
|
if len(remade_tokens) > maxlen - 2:
|
||||||
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
||||||
|
@ -431,32 +213,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
while i < len(tokens):
|
while i < len(tokens):
|
||||||
token = tokens[i]
|
token = tokens[i]
|
||||||
|
|
||||||
possible_matches = self.hijack.ids_lookup.get(token, None)
|
embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||||
|
|
||||||
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
|
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
|
||||||
if mult_change is not None:
|
if mult_change is not None:
|
||||||
mult *= mult_change
|
mult *= mult_change
|
||||||
elif possible_matches is None:
|
i += 1
|
||||||
|
elif embedding is None:
|
||||||
remade_tokens.append(token)
|
remade_tokens.append(token)
|
||||||
multipliers.append(mult)
|
multipliers.append(mult)
|
||||||
|
i += 1
|
||||||
else:
|
else:
|
||||||
found = False
|
emb_len = int(embedding.vec.shape[0])
|
||||||
for ids, word in possible_matches:
|
fixes.append((len(remade_tokens), embedding))
|
||||||
if tokens[i:i+len(ids)] == ids:
|
remade_tokens += [0] * emb_len
|
||||||
emb_len = int(self.hijack.word_embeddings[word].shape[0])
|
multipliers += [mult] * emb_len
|
||||||
fixes.append((len(remade_tokens), word))
|
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||||
remade_tokens += [0] * emb_len
|
i += emb_len
|
||||||
multipliers += [mult] * emb_len
|
|
||||||
i += len(ids) - 1
|
|
||||||
found = True
|
|
||||||
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
|
|
||||||
break
|
|
||||||
|
|
||||||
if not found:
|
|
||||||
remade_tokens.append(token)
|
|
||||||
multipliers.append(mult)
|
|
||||||
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
if len(remade_tokens) > maxlen - 2:
|
if len(remade_tokens) > maxlen - 2:
|
||||||
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
||||||
|
@ -464,6 +237,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
||||||
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
||||||
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||||
|
|
||||||
token_count = len(remade_tokens)
|
token_count = len(remade_tokens)
|
||||||
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||||
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
|
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
|
||||||
|
@ -484,7 +258,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
|
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
|
||||||
|
|
||||||
|
|
||||||
self.hijack.fixes = hijack_fixes
|
self.hijack.fixes = hijack_fixes
|
||||||
self.hijack.comments = hijack_comments
|
self.hijack.comments = hijack_comments
|
||||||
|
|
||||||
|
@ -517,14 +290,19 @@ class EmbeddingsWithFixes(torch.nn.Module):
|
||||||
|
|
||||||
inputs_embeds = self.wrapped(input_ids)
|
inputs_embeds = self.wrapped(input_ids)
|
||||||
|
|
||||||
if batch_fixes is not None:
|
if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
|
||||||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
return inputs_embeds
|
||||||
for offset, word in fixes:
|
|
||||||
emb = self.embeddings.word_embeddings[word]
|
|
||||||
emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
|
|
||||||
tensor[offset+1:offset+1+emb_len] = self.embeddings.word_embeddings[word][0:emb_len]
|
|
||||||
|
|
||||||
return inputs_embeds
|
vecs = []
|
||||||
|
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||||
|
for offset, embedding in fixes:
|
||||||
|
emb = embedding.vec
|
||||||
|
emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
|
||||||
|
tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
|
||||||
|
|
||||||
|
vecs.append(tensor)
|
||||||
|
|
||||||
|
return torch.stack(vecs)
|
||||||
|
|
||||||
|
|
||||||
def add_circular_option_to_conv_2d():
|
def add_circular_option_to_conv_2d():
|
||||||
|
|
164
modules/sd_hijack_optimizations.py
Normal file
164
modules/sd_hijack_optimizations.py
Normal file
|
@ -0,0 +1,164 @@
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
from torch import einsum
|
||||||
|
|
||||||
|
from ldm.util import default
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
||||||
|
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
||||||
|
h = self.heads
|
||||||
|
|
||||||
|
q = self.to_q(x)
|
||||||
|
context = default(context, x)
|
||||||
|
k = self.to_k(context)
|
||||||
|
v = self.to_v(context)
|
||||||
|
del context, x
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||||
|
|
||||||
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
||||||
|
for i in range(0, q.shape[0], 2):
|
||||||
|
end = i + 2
|
||||||
|
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||||
|
s1 *= self.scale
|
||||||
|
|
||||||
|
s2 = s1.softmax(dim=-1)
|
||||||
|
del s1
|
||||||
|
|
||||||
|
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||||
|
del s2
|
||||||
|
|
||||||
|
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
del r1
|
||||||
|
|
||||||
|
return self.to_out(r2)
|
||||||
|
|
||||||
|
|
||||||
|
# taken from https://github.com/Doggettx/stable-diffusion
|
||||||
|
def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||||
|
h = self.heads
|
||||||
|
|
||||||
|
q_in = self.to_q(x)
|
||||||
|
context = default(context, x)
|
||||||
|
k_in = self.to_k(context) * self.scale
|
||||||
|
v_in = self.to_v(context)
|
||||||
|
del context, x
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||||
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
|
|
||||||
|
stats = torch.cuda.memory_stats(q.device)
|
||||||
|
mem_active = stats['active_bytes.all.current']
|
||||||
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
|
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||||
|
mem_free_torch = mem_reserved - mem_active
|
||||||
|
mem_free_total = mem_free_cuda + mem_free_torch
|
||||||
|
|
||||||
|
gb = 1024 ** 3
|
||||||
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||||
|
modifier = 3 if q.element_size() == 2 else 2.5
|
||||||
|
mem_required = tensor_size * modifier
|
||||||
|
steps = 1
|
||||||
|
|
||||||
|
if mem_required > mem_free_total:
|
||||||
|
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||||
|
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||||
|
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||||
|
|
||||||
|
if steps > 64:
|
||||||
|
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||||
|
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||||
|
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
||||||
|
|
||||||
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||||
|
for i in range(0, q.shape[1], slice_size):
|
||||||
|
end = i + slice_size
|
||||||
|
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||||
|
|
||||||
|
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||||
|
del s1
|
||||||
|
|
||||||
|
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||||
|
del s2
|
||||||
|
|
||||||
|
del q, k, v
|
||||||
|
|
||||||
|
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
del r1
|
||||||
|
|
||||||
|
return self.to_out(r2)
|
||||||
|
|
||||||
|
def nonlinearity_hijack(x):
|
||||||
|
# swish
|
||||||
|
t = torch.sigmoid(x)
|
||||||
|
x *= t
|
||||||
|
del t
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def cross_attention_attnblock_forward(self, x):
|
||||||
|
h_ = x
|
||||||
|
h_ = self.norm(h_)
|
||||||
|
q1 = self.q(h_)
|
||||||
|
k1 = self.k(h_)
|
||||||
|
v = self.v(h_)
|
||||||
|
|
||||||
|
# compute attention
|
||||||
|
b, c, h, w = q1.shape
|
||||||
|
|
||||||
|
q2 = q1.reshape(b, c, h*w)
|
||||||
|
del q1
|
||||||
|
|
||||||
|
q = q2.permute(0, 2, 1) # b,hw,c
|
||||||
|
del q2
|
||||||
|
|
||||||
|
k = k1.reshape(b, c, h*w) # b,c,hw
|
||||||
|
del k1
|
||||||
|
|
||||||
|
h_ = torch.zeros_like(k, device=q.device)
|
||||||
|
|
||||||
|
stats = torch.cuda.memory_stats(q.device)
|
||||||
|
mem_active = stats['active_bytes.all.current']
|
||||||
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
|
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||||
|
mem_free_torch = mem_reserved - mem_active
|
||||||
|
mem_free_total = mem_free_cuda + mem_free_torch
|
||||||
|
|
||||||
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||||
|
mem_required = tensor_size * 2.5
|
||||||
|
steps = 1
|
||||||
|
|
||||||
|
if mem_required > mem_free_total:
|
||||||
|
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||||
|
|
||||||
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||||
|
for i in range(0, q.shape[1], slice_size):
|
||||||
|
end = i + slice_size
|
||||||
|
|
||||||
|
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||||
|
w2 = w1 * (int(c)**(-0.5))
|
||||||
|
del w1
|
||||||
|
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
|
||||||
|
del w2
|
||||||
|
|
||||||
|
# attend to values
|
||||||
|
v1 = v.reshape(b, c, h*w)
|
||||||
|
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||||
|
del w3
|
||||||
|
|
||||||
|
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||||
|
del v1, w4
|
||||||
|
|
||||||
|
h2 = h_.reshape(b, c, h, w)
|
||||||
|
del h_
|
||||||
|
|
||||||
|
h3 = self.proj_out(h2)
|
||||||
|
del h2
|
||||||
|
|
||||||
|
h3 += x
|
||||||
|
|
||||||
|
return h3
|
|
@ -8,7 +8,7 @@ from omegaconf import OmegaConf
|
||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
from modules import shared, modelloader
|
from modules import shared, modelloader, devices
|
||||||
from modules.paths import models_path
|
from modules.paths import models_path
|
||||||
|
|
||||||
model_dir = "Stable-diffusion"
|
model_dir = "Stable-diffusion"
|
||||||
|
@ -134,6 +134,8 @@ def load_model_weights(model, checkpoint_file, sd_model_hash):
|
||||||
if not shared.cmd_opts.no_half:
|
if not shared.cmd_opts.no_half:
|
||||||
model.half()
|
model.half()
|
||||||
|
|
||||||
|
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
||||||
|
|
||||||
model.sd_model_hash = sd_model_hash
|
model.sd_model_hash = sd_model_hash
|
||||||
model.sd_model_checkpint = checkpoint_file
|
model.sd_model_checkpint = checkpoint_file
|
||||||
|
|
||||||
|
|
|
@ -78,6 +78,7 @@ class State:
|
||||||
current_latent = None
|
current_latent = None
|
||||||
current_image = None
|
current_image = None
|
||||||
current_image_sampling_step = 0
|
current_image_sampling_step = 0
|
||||||
|
textinfo = None
|
||||||
|
|
||||||
def interrupt(self):
|
def interrupt(self):
|
||||||
self.interrupted = True
|
self.interrupted = True
|
||||||
|
@ -88,7 +89,7 @@ class State:
|
||||||
self.current_image_sampling_step = 0
|
self.current_image_sampling_step = 0
|
||||||
|
|
||||||
def get_job_timestamp(self):
|
def get_job_timestamp(self):
|
||||||
return datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
|
||||||
|
|
||||||
|
|
||||||
state = State()
|
state = State()
|
||||||
|
|
76
modules/textual_inversion/dataset.py
Normal file
76
modules/textual_inversion/dataset.py
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
import random
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class PersonalizedBase(Dataset):
|
||||||
|
def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None):
|
||||||
|
|
||||||
|
self.placeholder_token = placeholder_token
|
||||||
|
|
||||||
|
self.size = size
|
||||||
|
self.width = width
|
||||||
|
self.height = height
|
||||||
|
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||||
|
|
||||||
|
self.dataset = []
|
||||||
|
|
||||||
|
with open(template_file, "r") as file:
|
||||||
|
lines = [x.strip() for x in file.readlines()]
|
||||||
|
|
||||||
|
self.lines = lines
|
||||||
|
|
||||||
|
assert data_root, 'dataset directory not specified'
|
||||||
|
|
||||||
|
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
||||||
|
print("Preparing dataset...")
|
||||||
|
for path in tqdm.tqdm(self.image_paths):
|
||||||
|
image = Image.open(path)
|
||||||
|
image = image.convert('RGB')
|
||||||
|
image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
|
||||||
|
|
||||||
|
filename = os.path.basename(path)
|
||||||
|
filename_tokens = os.path.splitext(filename)[0].replace('_', '-').replace(' ', '-').split('-')
|
||||||
|
filename_tokens = [token for token in filename_tokens if token.isalpha()]
|
||||||
|
|
||||||
|
npimage = np.array(image).astype(np.uint8)
|
||||||
|
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
|
||||||
|
|
||||||
|
torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
|
||||||
|
torchdata = torch.moveaxis(torchdata, 2, 0)
|
||||||
|
|
||||||
|
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
|
||||||
|
|
||||||
|
self.dataset.append((init_latent, filename_tokens))
|
||||||
|
|
||||||
|
self.length = len(self.dataset) * repeats
|
||||||
|
|
||||||
|
self.initial_indexes = np.arange(self.length) % len(self.dataset)
|
||||||
|
self.indexes = None
|
||||||
|
self.shuffle()
|
||||||
|
|
||||||
|
def shuffle(self):
|
||||||
|
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
if i % len(self.dataset) == 0:
|
||||||
|
self.shuffle()
|
||||||
|
|
||||||
|
index = self.indexes[i % len(self.indexes)]
|
||||||
|
x, filename_tokens = self.dataset[index]
|
||||||
|
|
||||||
|
text = random.choice(self.lines)
|
||||||
|
text = text.replace("[name]", self.placeholder_token)
|
||||||
|
text = text.replace("[filewords]", ' '.join(filename_tokens))
|
||||||
|
|
||||||
|
return x, text
|
258
modules/textual_inversion/textual_inversion.py
Normal file
258
modules/textual_inversion/textual_inversion.py
Normal file
|
@ -0,0 +1,258 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
import html
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
from modules import shared, devices, sd_hijack, processing
|
||||||
|
import modules.textual_inversion.dataset
|
||||||
|
|
||||||
|
|
||||||
|
class Embedding:
|
||||||
|
def __init__(self, vec, name, step=None):
|
||||||
|
self.vec = vec
|
||||||
|
self.name = name
|
||||||
|
self.step = step
|
||||||
|
self.cached_checksum = None
|
||||||
|
|
||||||
|
def save(self, filename):
|
||||||
|
embedding_data = {
|
||||||
|
"string_to_token": {"*": 265},
|
||||||
|
"string_to_param": {"*": self.vec},
|
||||||
|
"name": self.name,
|
||||||
|
"step": self.step,
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.save(embedding_data, filename)
|
||||||
|
|
||||||
|
def checksum(self):
|
||||||
|
if self.cached_checksum is not None:
|
||||||
|
return self.cached_checksum
|
||||||
|
|
||||||
|
def const_hash(a):
|
||||||
|
r = 0
|
||||||
|
for v in a:
|
||||||
|
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
|
||||||
|
return r
|
||||||
|
|
||||||
|
self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
|
||||||
|
return self.cached_checksum
|
||||||
|
|
||||||
|
class EmbeddingDatabase:
|
||||||
|
def __init__(self, embeddings_dir):
|
||||||
|
self.ids_lookup = {}
|
||||||
|
self.word_embeddings = {}
|
||||||
|
self.dir_mtime = None
|
||||||
|
self.embeddings_dir = embeddings_dir
|
||||||
|
|
||||||
|
def register_embedding(self, embedding, model):
|
||||||
|
|
||||||
|
self.word_embeddings[embedding.name] = embedding
|
||||||
|
|
||||||
|
ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]
|
||||||
|
|
||||||
|
first_id = ids[0]
|
||||||
|
if first_id not in self.ids_lookup:
|
||||||
|
self.ids_lookup[first_id] = []
|
||||||
|
self.ids_lookup[first_id].append((ids, embedding))
|
||||||
|
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def load_textual_inversion_embeddings(self):
|
||||||
|
mt = os.path.getmtime(self.embeddings_dir)
|
||||||
|
if self.dir_mtime is not None and mt <= self.dir_mtime:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.dir_mtime = mt
|
||||||
|
self.ids_lookup.clear()
|
||||||
|
self.word_embeddings.clear()
|
||||||
|
|
||||||
|
def process_file(path, filename):
|
||||||
|
name = os.path.splitext(filename)[0]
|
||||||
|
|
||||||
|
data = torch.load(path, map_location="cpu")
|
||||||
|
|
||||||
|
# textual inversion embeddings
|
||||||
|
if 'string_to_param' in data:
|
||||||
|
param_dict = data['string_to_param']
|
||||||
|
if hasattr(param_dict, '_parameters'):
|
||||||
|
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||||
|
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||||
|
emb = next(iter(param_dict.items()))[1]
|
||||||
|
# diffuser concepts
|
||||||
|
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
||||||
|
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||||
|
|
||||||
|
emb = next(iter(data.values()))
|
||||||
|
if len(emb.shape) == 1:
|
||||||
|
emb = emb.unsqueeze(0)
|
||||||
|
else:
|
||||||
|
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||||
|
|
||||||
|
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||||
|
embedding = Embedding(vec, name)
|
||||||
|
embedding.step = data.get('step', None)
|
||||||
|
self.register_embedding(embedding, shared.sd_model)
|
||||||
|
|
||||||
|
for fn in os.listdir(self.embeddings_dir):
|
||||||
|
try:
|
||||||
|
fullfn = os.path.join(self.embeddings_dir, fn)
|
||||||
|
|
||||||
|
if os.stat(fullfn).st_size == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
process_file(fullfn, fn)
|
||||||
|
except Exception:
|
||||||
|
print(f"Error loading emedding {fn}:", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
|
||||||
|
|
||||||
|
def find_embedding_at_position(self, tokens, offset):
|
||||||
|
token = tokens[offset]
|
||||||
|
possible_matches = self.ids_lookup.get(token, None)
|
||||||
|
|
||||||
|
if possible_matches is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for ids, embedding in possible_matches:
|
||||||
|
if tokens[offset:offset + len(ids)] == ids:
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def create_embedding(name, num_vectors_per_token):
|
||||||
|
init_text = '*'
|
||||||
|
|
||||||
|
cond_model = shared.sd_model.cond_stage_model
|
||||||
|
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
|
||||||
|
|
||||||
|
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||||
|
embedded = embedding_layer(ids.to(devices.device)).squeeze(0)
|
||||||
|
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
|
||||||
|
|
||||||
|
for i in range(num_vectors_per_token):
|
||||||
|
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
||||||
|
|
||||||
|
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
|
||||||
|
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||||
|
|
||||||
|
embedding = Embedding(vec, name)
|
||||||
|
embedding.step = 0
|
||||||
|
embedding.save(fn)
|
||||||
|
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
|
||||||
|
assert embedding_name, 'embedding not selected'
|
||||||
|
|
||||||
|
shared.state.textinfo = "Initializing textual inversion training..."
|
||||||
|
shared.state.job_count = steps
|
||||||
|
|
||||||
|
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||||
|
|
||||||
|
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%d-%m"), embedding_name)
|
||||||
|
|
||||||
|
if save_embedding_every > 0:
|
||||||
|
embedding_dir = os.path.join(log_directory, "embeddings")
|
||||||
|
os.makedirs(embedding_dir, exist_ok=True)
|
||||||
|
else:
|
||||||
|
embedding_dir = None
|
||||||
|
|
||||||
|
if create_image_every > 0:
|
||||||
|
images_dir = os.path.join(log_directory, "images")
|
||||||
|
os.makedirs(images_dir, exist_ok=True)
|
||||||
|
else:
|
||||||
|
images_dir = None
|
||||||
|
|
||||||
|
cond_model = shared.sd_model.cond_stage_model
|
||||||
|
|
||||||
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
|
with torch.autocast("cuda"):
|
||||||
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
|
||||||
|
|
||||||
|
hijack = sd_hijack.model_hijack
|
||||||
|
|
||||||
|
embedding = hijack.embedding_db.word_embeddings[embedding_name]
|
||||||
|
embedding.vec.requires_grad = True
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
|
||||||
|
|
||||||
|
losses = torch.zeros((32,))
|
||||||
|
|
||||||
|
last_saved_file = "<none>"
|
||||||
|
last_saved_image = "<none>"
|
||||||
|
|
||||||
|
ititial_step = embedding.step or 0
|
||||||
|
if ititial_step > steps:
|
||||||
|
return embedding, filename
|
||||||
|
|
||||||
|
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
||||||
|
for i, (x, text) in pbar:
|
||||||
|
embedding.step = i + ititial_step
|
||||||
|
|
||||||
|
if embedding.step > steps:
|
||||||
|
break
|
||||||
|
|
||||||
|
if shared.state.interrupted:
|
||||||
|
break
|
||||||
|
|
||||||
|
with torch.autocast("cuda"):
|
||||||
|
c = cond_model([text])
|
||||||
|
loss = shared.sd_model(x.unsqueeze(0), c)[0]
|
||||||
|
|
||||||
|
losses[embedding.step % losses.shape[0]] = loss.item()
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
pbar.set_description(f"loss: {losses.mean():.7f}")
|
||||||
|
|
||||||
|
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
|
||||||
|
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
|
||||||
|
embedding.save(last_saved_file)
|
||||||
|
|
||||||
|
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
|
||||||
|
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
|
||||||
|
|
||||||
|
p = processing.StableDiffusionProcessingTxt2Img(
|
||||||
|
sd_model=shared.sd_model,
|
||||||
|
prompt=text,
|
||||||
|
steps=20,
|
||||||
|
do_not_save_grid=True,
|
||||||
|
do_not_save_samples=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
processed = processing.process_images(p)
|
||||||
|
image = processed.images[0]
|
||||||
|
|
||||||
|
shared.state.current_image = image
|
||||||
|
image.save(last_saved_image)
|
||||||
|
|
||||||
|
last_saved_image += f", prompt: {text}"
|
||||||
|
|
||||||
|
shared.state.job_no = embedding.step
|
||||||
|
|
||||||
|
shared.state.textinfo = f"""
|
||||||
|
<p>
|
||||||
|
Loss: {losses.mean():.7f}<br/>
|
||||||
|
Step: {embedding.step}<br/>
|
||||||
|
Last prompt: {html.escape(text)}<br/>
|
||||||
|
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
||||||
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
|
</p>
|
||||||
|
"""
|
||||||
|
|
||||||
|
embedding.cached_checksum = None
|
||||||
|
embedding.save(filename)
|
||||||
|
|
||||||
|
return embedding, filename
|
||||||
|
|
32
modules/textual_inversion/ui.py
Normal file
32
modules/textual_inversion/ui.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
import html
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
import modules.textual_inversion.textual_inversion as ti
|
||||||
|
from modules import sd_hijack, shared
|
||||||
|
|
||||||
|
|
||||||
|
def create_embedding(name, nvpt):
|
||||||
|
filename = ti.create_embedding(name, nvpt)
|
||||||
|
|
||||||
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||||
|
|
||||||
|
return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
|
||||||
|
|
||||||
|
|
||||||
|
def train_embedding(*args):
|
||||||
|
|
||||||
|
try:
|
||||||
|
sd_hijack.undo_optimizations()
|
||||||
|
|
||||||
|
embedding, filename = ti.train_embedding(*args)
|
||||||
|
|
||||||
|
res = f"""
|
||||||
|
Training {'interrupted' if shared.state.interrupted else 'finished'} after {embedding.step} steps.
|
||||||
|
Embedding saved to {html.escape(filename)}
|
||||||
|
"""
|
||||||
|
return res, ""
|
||||||
|
except Exception:
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
sd_hijack.apply_optimizations()
|
135
modules/ui.py
135
modules/ui.py
|
@ -21,6 +21,7 @@ import gradio as gr
|
||||||
import gradio.utils
|
import gradio.utils
|
||||||
import gradio.routes
|
import gradio.routes
|
||||||
|
|
||||||
|
from modules import sd_hijack
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
from modules.shared import opts, cmd_opts
|
from modules.shared import opts, cmd_opts
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
@ -32,6 +33,7 @@ import modules.gfpgan_model
|
||||||
import modules.codeformer_model
|
import modules.codeformer_model
|
||||||
import modules.styles
|
import modules.styles
|
||||||
import modules.generation_parameters_copypaste
|
import modules.generation_parameters_copypaste
|
||||||
|
import modules.textual_inversion.ui
|
||||||
|
|
||||||
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
|
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
|
||||||
mimetypes.init()
|
mimetypes.init()
|
||||||
|
@ -142,8 +144,8 @@ def save_files(js_data, images, index):
|
||||||
return '', '', plaintext_to_html(f"Saved: {filenames[0]}")
|
return '', '', plaintext_to_html(f"Saved: {filenames[0]}")
|
||||||
|
|
||||||
|
|
||||||
def wrap_gradio_call(func):
|
def wrap_gradio_call(func, extra_outputs=None):
|
||||||
def f(*args, **kwargs):
|
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
||||||
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
|
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
|
||||||
if run_memmon:
|
if run_memmon:
|
||||||
shared.mem_mon.monitor()
|
shared.mem_mon.monitor()
|
||||||
|
@ -159,7 +161,10 @@ def wrap_gradio_call(func):
|
||||||
shared.state.job = ""
|
shared.state.job = ""
|
||||||
shared.state.job_count = 0
|
shared.state.job_count = 0
|
||||||
|
|
||||||
res = [None, '', f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
|
if extra_outputs_array is None:
|
||||||
|
extra_outputs_array = [None, '']
|
||||||
|
|
||||||
|
res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
|
||||||
|
|
||||||
elapsed = time.perf_counter() - t
|
elapsed = time.perf_counter() - t
|
||||||
|
|
||||||
|
@ -179,6 +184,7 @@ def wrap_gradio_call(func):
|
||||||
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>{vram_html}</div>"
|
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>{vram_html}</div>"
|
||||||
|
|
||||||
shared.state.interrupted = False
|
shared.state.interrupted = False
|
||||||
|
shared.state.job_count = 0
|
||||||
|
|
||||||
return tuple(res)
|
return tuple(res)
|
||||||
|
|
||||||
|
@ -187,7 +193,7 @@ def wrap_gradio_call(func):
|
||||||
|
|
||||||
def check_progress_call(id_part):
|
def check_progress_call(id_part):
|
||||||
if shared.state.job_count == 0:
|
if shared.state.job_count == 0:
|
||||||
return "", gr_show(False), gr_show(False)
|
return "", gr_show(False), gr_show(False), gr_show(False)
|
||||||
|
|
||||||
progress = 0
|
progress = 0
|
||||||
|
|
||||||
|
@ -219,13 +225,19 @@ def check_progress_call(id_part):
|
||||||
else:
|
else:
|
||||||
preview_visibility = gr_show(True)
|
preview_visibility = gr_show(True)
|
||||||
|
|
||||||
return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image
|
if shared.state.textinfo is not None:
|
||||||
|
textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
|
||||||
|
else:
|
||||||
|
textinfo_result = gr_show(False)
|
||||||
|
|
||||||
|
return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image, textinfo_result
|
||||||
|
|
||||||
|
|
||||||
def check_progress_call_initial(id_part):
|
def check_progress_call_initial(id_part):
|
||||||
shared.state.job_count = -1
|
shared.state.job_count = -1
|
||||||
shared.state.current_latent = None
|
shared.state.current_latent = None
|
||||||
shared.state.current_image = None
|
shared.state.current_image = None
|
||||||
|
shared.state.textinfo = None
|
||||||
|
|
||||||
return check_progress_call(id_part)
|
return check_progress_call(id_part)
|
||||||
|
|
||||||
|
@ -399,13 +411,16 @@ def create_toprow(is_img2img):
|
||||||
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste
|
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste
|
||||||
|
|
||||||
|
|
||||||
def setup_progressbar(progressbar, preview, id_part):
|
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
|
||||||
|
if textinfo is None:
|
||||||
|
textinfo = gr.HTML(visible=False)
|
||||||
|
|
||||||
check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
|
check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
|
||||||
check_progress.click(
|
check_progress.click(
|
||||||
fn=lambda: check_progress_call(id_part),
|
fn=lambda: check_progress_call(id_part),
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
inputs=[],
|
inputs=[],
|
||||||
outputs=[progressbar, preview, preview],
|
outputs=[progressbar, preview, preview, textinfo],
|
||||||
)
|
)
|
||||||
|
|
||||||
check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
|
check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
|
||||||
|
@ -413,11 +428,14 @@ def setup_progressbar(progressbar, preview, id_part):
|
||||||
fn=lambda: check_progress_call_initial(id_part),
|
fn=lambda: check_progress_call_initial(id_part),
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
inputs=[],
|
inputs=[],
|
||||||
outputs=[progressbar, preview, preview],
|
outputs=[progressbar, preview, preview, textinfo],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
def create_ui(wrap_gradio_gpu_call):
|
||||||
|
import modules.img2img
|
||||||
|
import modules.txt2img
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||||
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False)
|
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False)
|
||||||
dummy_component = gr.Label(visible=False)
|
dummy_component = gr.Label(visible=False)
|
||||||
|
@ -483,7 +501,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
||||||
|
|
||||||
txt2img_args = dict(
|
txt2img_args = dict(
|
||||||
fn=txt2img,
|
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img),
|
||||||
_js="submit",
|
_js="submit",
|
||||||
inputs=[
|
inputs=[
|
||||||
txt2img_prompt,
|
txt2img_prompt,
|
||||||
|
@ -675,7 +693,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
)
|
)
|
||||||
|
|
||||||
img2img_args = dict(
|
img2img_args = dict(
|
||||||
fn=img2img,
|
fn=wrap_gradio_gpu_call(modules.img2img.img2img),
|
||||||
_js="submit_img2img",
|
_js="submit_img2img",
|
||||||
inputs=[
|
inputs=[
|
||||||
dummy_component,
|
dummy_component,
|
||||||
|
@ -828,7 +846,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
|
open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
|
||||||
|
|
||||||
submit.click(
|
submit.click(
|
||||||
fn=run_extras,
|
fn=wrap_gradio_gpu_call(modules.extras.run_extras),
|
||||||
_js="get_extras_tab_index",
|
_js="get_extras_tab_index",
|
||||||
inputs=[
|
inputs=[
|
||||||
dummy_component,
|
dummy_component,
|
||||||
|
@ -878,7 +896,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
pnginfo_send_to_img2img = gr.Button('Send to img2img')
|
pnginfo_send_to_img2img = gr.Button('Send to img2img')
|
||||||
|
|
||||||
image.change(
|
image.change(
|
||||||
fn=wrap_gradio_call(run_pnginfo),
|
fn=wrap_gradio_call(modules.extras.run_pnginfo),
|
||||||
inputs=[image],
|
inputs=[image],
|
||||||
outputs=[html, generation_info, html2],
|
outputs=[html, generation_info, html2],
|
||||||
)
|
)
|
||||||
|
@ -900,6 +918,92 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
with gr.Column(variant='panel'):
|
with gr.Column(variant='panel'):
|
||||||
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
|
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
|
||||||
|
|
||||||
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||||
|
|
||||||
|
with gr.Blocks() as textual_inversion_interface:
|
||||||
|
with gr.Row().style(equal_height=False):
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Group():
|
||||||
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new embedding</p>")
|
||||||
|
|
||||||
|
new_embedding_name = gr.Textbox(label="Name")
|
||||||
|
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=3):
|
||||||
|
gr.HTML(value="")
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
create_embedding = gr.Button(value="Create", variant='primary')
|
||||||
|
|
||||||
|
with gr.Group():
|
||||||
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 512x512 images</p>")
|
||||||
|
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||||
|
learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
|
||||||
|
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||||
|
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
||||||
|
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
|
||||||
|
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
||||||
|
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=1000, precision=0)
|
||||||
|
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=1000, precision=0)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=2):
|
||||||
|
gr.HTML(value="")
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Row():
|
||||||
|
interrupt_training = gr.Button(value="Interrupt")
|
||||||
|
train_embedding = gr.Button(value="Train", variant='primary')
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
progressbar = gr.HTML(elem_id="ti_progressbar")
|
||||||
|
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
|
||||||
|
|
||||||
|
ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
|
||||||
|
ti_preview = gr.Image(elem_id='ti_preview', visible=False)
|
||||||
|
ti_progress = gr.HTML(elem_id="ti_progress", value="")
|
||||||
|
ti_outcome = gr.HTML(elem_id="ti_error", value="")
|
||||||
|
setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress)
|
||||||
|
|
||||||
|
create_embedding.click(
|
||||||
|
fn=modules.textual_inversion.ui.create_embedding,
|
||||||
|
inputs=[
|
||||||
|
new_embedding_name,
|
||||||
|
nvpt,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
train_embedding_name,
|
||||||
|
ti_output,
|
||||||
|
ti_outcome,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
train_embedding.click(
|
||||||
|
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
|
||||||
|
_js="start_training_textual_inversion",
|
||||||
|
inputs=[
|
||||||
|
train_embedding_name,
|
||||||
|
learn_rate,
|
||||||
|
dataset_directory,
|
||||||
|
log_directory,
|
||||||
|
steps,
|
||||||
|
create_image_every,
|
||||||
|
save_embedding_every,
|
||||||
|
template_file,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
ti_output,
|
||||||
|
ti_outcome,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
interrupt_training.click(
|
||||||
|
fn=lambda: shared.state.interrupt(),
|
||||||
|
inputs=[],
|
||||||
|
outputs=[],
|
||||||
|
)
|
||||||
|
|
||||||
def create_setting_component(key):
|
def create_setting_component(key):
|
||||||
def fun():
|
def fun():
|
||||||
return opts.data[key] if key in opts.data else opts.data_labels[key].default
|
return opts.data[key] if key in opts.data else opts.data_labels[key].default
|
||||||
|
@ -1011,6 +1115,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
(extras_interface, "Extras", "extras"),
|
(extras_interface, "Extras", "extras"),
|
||||||
(pnginfo_interface, "PNG Info", "pnginfo"),
|
(pnginfo_interface, "PNG Info", "pnginfo"),
|
||||||
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
|
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
|
||||||
|
(textual_inversion_interface, "Textual inversion", "ti"),
|
||||||
(settings_interface, "Settings", "settings"),
|
(settings_interface, "Settings", "settings"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -1044,11 +1149,11 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
||||||
|
|
||||||
def modelmerger(*args):
|
def modelmerger(*args):
|
||||||
try:
|
try:
|
||||||
results = run_modelmerger(*args)
|
results = modules.extras.run_modelmerger(*args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error loading/saving model file:", file=sys.stderr)
|
print("Error loading/saving model file:", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
modules.sd_models.list_models() #To remove the potentially missing models from the list
|
modules.sd_models.list_models() # to remove the potentially missing models from the list
|
||||||
return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
|
return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
10
style.css
10
style.css
|
@ -157,7 +157,7 @@ button{
|
||||||
max-width: 10em;
|
max-width: 10em;
|
||||||
}
|
}
|
||||||
|
|
||||||
#txt2img_preview, #img2img_preview{
|
#txt2img_preview, #img2img_preview, #ti_preview{
|
||||||
position: absolute;
|
position: absolute;
|
||||||
width: 320px;
|
width: 320px;
|
||||||
left: 0;
|
left: 0;
|
||||||
|
@ -172,18 +172,18 @@ button{
|
||||||
}
|
}
|
||||||
|
|
||||||
@media screen and (min-width: 768px) {
|
@media screen and (min-width: 768px) {
|
||||||
#txt2img_preview, #img2img_preview {
|
#txt2img_preview, #img2img_preview, #ti_preview {
|
||||||
position: absolute;
|
position: absolute;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@media screen and (max-width: 767px) {
|
@media screen and (max-width: 767px) {
|
||||||
#txt2img_preview, #img2img_preview {
|
#txt2img_preview, #img2img_preview, #ti_preview {
|
||||||
position: relative;
|
position: relative;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0{
|
#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0, #ti_preview div.left-0.top-0{
|
||||||
display: none;
|
display: none;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -247,7 +247,7 @@ input[type="range"]{
|
||||||
#txt2img_negative_prompt, #img2img_negative_prompt{
|
#txt2img_negative_prompt, #img2img_negative_prompt{
|
||||||
}
|
}
|
||||||
|
|
||||||
#txt2img_progressbar, #img2img_progressbar{
|
#txt2img_progressbar, #img2img_progressbar, #ti_progressbar{
|
||||||
position: absolute;
|
position: absolute;
|
||||||
z-index: 1000;
|
z-index: 1000;
|
||||||
right: 0;
|
right: 0;
|
||||||
|
|
19
textual_inversion_templates/style.txt
Normal file
19
textual_inversion_templates/style.txt
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
a painting, art by [name]
|
||||||
|
a rendering, art by [name]
|
||||||
|
a cropped painting, art by [name]
|
||||||
|
the painting, art by [name]
|
||||||
|
a clean painting, art by [name]
|
||||||
|
a dirty painting, art by [name]
|
||||||
|
a dark painting, art by [name]
|
||||||
|
a picture, art by [name]
|
||||||
|
a cool painting, art by [name]
|
||||||
|
a close-up painting, art by [name]
|
||||||
|
a bright painting, art by [name]
|
||||||
|
a cropped painting, art by [name]
|
||||||
|
a good painting, art by [name]
|
||||||
|
a close-up painting, art by [name]
|
||||||
|
a rendition, art by [name]
|
||||||
|
a nice painting, art by [name]
|
||||||
|
a small painting, art by [name]
|
||||||
|
a weird painting, art by [name]
|
||||||
|
a large painting, art by [name]
|
19
textual_inversion_templates/style_filewords.txt
Normal file
19
textual_inversion_templates/style_filewords.txt
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
a painting of [filewords], art by [name]
|
||||||
|
a rendering of [filewords], art by [name]
|
||||||
|
a cropped painting of [filewords], art by [name]
|
||||||
|
the painting of [filewords], art by [name]
|
||||||
|
a clean painting of [filewords], art by [name]
|
||||||
|
a dirty painting of [filewords], art by [name]
|
||||||
|
a dark painting of [filewords], art by [name]
|
||||||
|
a picture of [filewords], art by [name]
|
||||||
|
a cool painting of [filewords], art by [name]
|
||||||
|
a close-up painting of [filewords], art by [name]
|
||||||
|
a bright painting of [filewords], art by [name]
|
||||||
|
a cropped painting of [filewords], art by [name]
|
||||||
|
a good painting of [filewords], art by [name]
|
||||||
|
a close-up painting of [filewords], art by [name]
|
||||||
|
a rendition of [filewords], art by [name]
|
||||||
|
a nice painting of [filewords], art by [name]
|
||||||
|
a small painting of [filewords], art by [name]
|
||||||
|
a weird painting of [filewords], art by [name]
|
||||||
|
a large painting of [filewords], art by [name]
|
27
textual_inversion_templates/subject.txt
Normal file
27
textual_inversion_templates/subject.txt
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
a photo of a [name]
|
||||||
|
a rendering of a [name]
|
||||||
|
a cropped photo of the [name]
|
||||||
|
the photo of a [name]
|
||||||
|
a photo of a clean [name]
|
||||||
|
a photo of a dirty [name]
|
||||||
|
a dark photo of the [name]
|
||||||
|
a photo of my [name]
|
||||||
|
a photo of the cool [name]
|
||||||
|
a close-up photo of a [name]
|
||||||
|
a bright photo of the [name]
|
||||||
|
a cropped photo of a [name]
|
||||||
|
a photo of the [name]
|
||||||
|
a good photo of the [name]
|
||||||
|
a photo of one [name]
|
||||||
|
a close-up photo of the [name]
|
||||||
|
a rendition of the [name]
|
||||||
|
a photo of the clean [name]
|
||||||
|
a rendition of a [name]
|
||||||
|
a photo of a nice [name]
|
||||||
|
a good photo of a [name]
|
||||||
|
a photo of the nice [name]
|
||||||
|
a photo of the small [name]
|
||||||
|
a photo of the weird [name]
|
||||||
|
a photo of the large [name]
|
||||||
|
a photo of a cool [name]
|
||||||
|
a photo of a small [name]
|
27
textual_inversion_templates/subject_filewords.txt
Normal file
27
textual_inversion_templates/subject_filewords.txt
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
a photo of a [name], [filewords]
|
||||||
|
a rendering of a [name], [filewords]
|
||||||
|
a cropped photo of the [name], [filewords]
|
||||||
|
the photo of a [name], [filewords]
|
||||||
|
a photo of a clean [name], [filewords]
|
||||||
|
a photo of a dirty [name], [filewords]
|
||||||
|
a dark photo of the [name], [filewords]
|
||||||
|
a photo of my [name], [filewords]
|
||||||
|
a photo of the cool [name], [filewords]
|
||||||
|
a close-up photo of a [name], [filewords]
|
||||||
|
a bright photo of the [name], [filewords]
|
||||||
|
a cropped photo of a [name], [filewords]
|
||||||
|
a photo of the [name], [filewords]
|
||||||
|
a good photo of the [name], [filewords]
|
||||||
|
a photo of one [name], [filewords]
|
||||||
|
a close-up photo of the [name], [filewords]
|
||||||
|
a rendition of the [name], [filewords]
|
||||||
|
a photo of the clean [name], [filewords]
|
||||||
|
a rendition of a [name], [filewords]
|
||||||
|
a photo of a nice [name], [filewords]
|
||||||
|
a good photo of a [name], [filewords]
|
||||||
|
a photo of the nice [name], [filewords]
|
||||||
|
a photo of the small [name], [filewords]
|
||||||
|
a photo of the weird [name], [filewords]
|
||||||
|
a photo of the large [name], [filewords]
|
||||||
|
a photo of a cool [name], [filewords]
|
||||||
|
a photo of a small [name], [filewords]
|
15
webui.py
15
webui.py
|
@ -12,7 +12,6 @@ import modules.bsrgan_model as bsrgan
|
||||||
import modules.extras
|
import modules.extras
|
||||||
import modules.face_restoration
|
import modules.face_restoration
|
||||||
import modules.gfpgan_model as gfpgan
|
import modules.gfpgan_model as gfpgan
|
||||||
import modules.img2img
|
|
||||||
import modules.ldsr_model as ldsr
|
import modules.ldsr_model as ldsr
|
||||||
import modules.lowvram
|
import modules.lowvram
|
||||||
import modules.realesrgan_model as realesrgan
|
import modules.realesrgan_model as realesrgan
|
||||||
|
@ -21,7 +20,6 @@ import modules.sd_hijack
|
||||||
import modules.sd_models
|
import modules.sd_models
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import modules.swinir_model as swinir
|
import modules.swinir_model as swinir
|
||||||
import modules.txt2img
|
|
||||||
import modules.ui
|
import modules.ui
|
||||||
from modules import modelloader
|
from modules import modelloader
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
|
@ -46,7 +44,7 @@ def wrap_queued_call(func):
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
||||||
def wrap_gradio_gpu_call(func):
|
def wrap_gradio_gpu_call(func, extra_outputs=None):
|
||||||
def f(*args, **kwargs):
|
def f(*args, **kwargs):
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
|
@ -58,6 +56,7 @@ def wrap_gradio_gpu_call(func):
|
||||||
shared.state.current_image = None
|
shared.state.current_image = None
|
||||||
shared.state.current_image_sampling_step = 0
|
shared.state.current_image_sampling_step = 0
|
||||||
shared.state.interrupted = False
|
shared.state.interrupted = False
|
||||||
|
shared.state.textinfo = None
|
||||||
|
|
||||||
with queue_lock:
|
with queue_lock:
|
||||||
res = func(*args, **kwargs)
|
res = func(*args, **kwargs)
|
||||||
|
@ -69,7 +68,7 @@ def wrap_gradio_gpu_call(func):
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
return modules.ui.wrap_gradio_call(f)
|
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
|
||||||
|
|
||||||
|
|
||||||
modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
|
modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
|
||||||
|
@ -86,13 +85,7 @@ def webui():
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, sigint_handler)
|
signal.signal(signal.SIGINT, sigint_handler)
|
||||||
|
|
||||||
demo = modules.ui.create_ui(
|
demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
|
||||||
txt2img=wrap_gradio_gpu_call(modules.txt2img.txt2img),
|
|
||||||
img2img=wrap_gradio_gpu_call(modules.img2img.img2img),
|
|
||||||
run_extras=wrap_gradio_gpu_call(modules.extras.run_extras),
|
|
||||||
run_pnginfo=modules.extras.run_pnginfo,
|
|
||||||
run_modelmerger=modules.extras.run_modelmerger
|
|
||||||
)
|
|
||||||
|
|
||||||
demo.launch(
|
demo.launch(
|
||||||
share=cmd_opts.share,
|
share=cmd_opts.share,
|
||||||
|
|
Loading…
Reference in a new issue