import os import sys import traceback import torch import numpy as np from torch import einsum from modules.shared import opts, device, cmd_opts from ldm.util import default from einops import rearrange import ldm.modules.attention # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion def split_cross_attention_forward(self, x, context=None, mask=None): h = self.heads q = self.to_q(x) context = default(context, x) k = self.to_k(context) v = self.to_v(context) del context, x q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) for i in range(0, q.shape[0], 2): end = i + 2 s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) s1 *= self.scale s2 = s1.softmax(dim=-1) del s1 r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) del s2 r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) del r1 return self.to_out(r2) class StableDiffusionModelHijack: ids_lookup = {} word_embeddings = {} word_embeddings_checksums = {} fixes = None comments = [] dir_mtime = None def load_textual_inversion_embeddings(self, dirname, model): mt = os.path.getmtime(dirname) if self.dir_mtime is not None and mt <= self.dir_mtime: return self.dir_mtime = mt self.ids_lookup.clear() self.word_embeddings.clear() tokenizer = model.cond_stage_model.tokenizer def const_hash(a): r = 0 for v in a: r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF return r def process_file(path, filename): name = os.path.splitext(filename)[0] data = torch.load(path) param_dict = data['string_to_param'] if hasattr(param_dict, '_parameters'): param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 assert len(param_dict) == 1, 'embedding file has multiple terms in it' emb = next(iter(param_dict.items()))[1] self.word_embeddings[name] = emb.detach() self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1))&0xffff:04x}' ids = tokenizer([name], add_special_tokens=False)['input_ids'][0] first_id = ids[0] if first_id not in self.ids_lookup: self.ids_lookup[first_id] = [] self.ids_lookup[first_id].append((ids, name)) for fn in os.listdir(dirname): try: process_file(os.path.join(dirname, fn), fn) except Exception: print(f"Error loading emedding {fn}:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) continue print(f"Loaded a total of {len(self.word_embeddings)} text inversion embeddings.") def hijack(self, m): model_embeddings = m.cond_stage_model.transformer.text_model.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) if cmd_opts.opt_split_attention: ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): super().__init__() self.wrapped = wrapped self.hijack = hijack self.tokenizer = wrapped.tokenizer self.max_length = wrapped.max_length self.token_mults = {} tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] for text, ident in tokens_with_parens: mult = 1.0 for c in text: if c == '[': mult /= 1.1 if c == ']': mult *= 1.1 if c == '(': mult *= 1.1 if c == ')': mult /= 1.1 if mult != 1.0: self.token_mults[ident] = mult def forward(self, text): self.hijack.fixes = [] self.hijack.comments = [] remade_batch_tokens = [] id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id maxlen = self.wrapped.max_length - 2 used_custom_terms = [] cache = {} batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"] batch_multipliers = [] for tokens in batch_tokens: tuple_tokens = tuple(tokens) if tuple_tokens in cache: remade_tokens, fixes, multipliers = cache[tuple_tokens] else: fixes = [] remade_tokens = [] multipliers = [] mult = 1.0 i = 0 while i < len(tokens): token = tokens[i] possible_matches = self.hijack.ids_lookup.get(token, None) mult_change = self.token_mults.get(token) if opts.enable_emphasis else None if mult_change is not None: mult *= mult_change elif possible_matches is None: remade_tokens.append(token) multipliers.append(mult) else: found = False for ids, word in possible_matches: if tokens[i:i+len(ids)] == ids: emb_len = int(self.hijack.word_embeddings[word].shape[0]) fixes.append((len(remade_tokens), word)) remade_tokens += [0] * emb_len multipliers += [mult] * emb_len i += len(ids) - 1 found = True used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word])) break if not found: remade_tokens.append(token) multipliers.append(mult) i += 1 if len(remade_tokens) > maxlen - 2: vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} ovf = remade_tokens[maxlen - 2:] overflowing_words = [vocab.get(int(x), "") for x in ovf] overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) self.hijack.comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] cache[tuple_tokens] = (remade_tokens, fixes, multipliers) multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] remade_batch_tokens.append(remade_tokens) self.hijack.fixes.append(fixes) batch_multipliers.append(multipliers) if len(used_custom_terms) > 0: self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) tokens = torch.asarray(remade_batch_tokens).to(device) outputs = self.wrapped.transformer(input_ids=tokens) z = outputs.last_hidden_state # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise batch_multipliers = torch.asarray(np.array(batch_multipliers)).to(device) original_mean = z.mean() z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) new_mean = z.mean() z *= original_mean / new_mean return z class EmbeddingsWithFixes(torch.nn.Module): def __init__(self, wrapped, embeddings): super().__init__() self.wrapped = wrapped self.embeddings = embeddings def forward(self, input_ids): batch_fixes = self.embeddings.fixes self.embeddings.fixes = None inputs_embeds = self.wrapped(input_ids) if batch_fixes is not None: for fixes, tensor in zip(batch_fixes, inputs_embeds): for offset, word in fixes: emb = self.embeddings.word_embeddings[word] emb_len = min(tensor.shape[0]-offset, emb.shape[0]) tensor[offset:offset+emb_len] = self.embeddings.word_embeddings[word][0:emb_len] return inputs_embeds model_hijack = StableDiffusionModelHijack()