implementation for attention using [] and ()

This commit is contained in:
AUTOMATIC 2022-08-27 11:17:55 +03:00
parent a51bedfb5a
commit 9597b265ec
3 changed files with 62 additions and 23 deletions

View file

@ -188,3 +188,9 @@ and put it into `embeddings` dir and use Usada Pekora in prompt.
A tab with settings, allowing you to use UI to edit more than half of parameters that previously A tab with settings, allowing you to use UI to edit more than half of parameters that previously
were commandline. Settings are saved to config.js file. Settings that remain as commandline were commandline. Settings are saved to config.js file. Settings that remain as commandline
options are ones that are required at startup. options are ones that are required at startup.
### Attention
Using `()` in prompt decreases model's attention to enclosed words, and `[]` increases it. You can combine
multiple modifiers:
![](images/attention-3.jpg)

BIN
images/attention-3.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 944 KiB

View file

@ -433,15 +433,15 @@ if os.path.exists(cmd_opts.gfpgan_dir):
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
class TextInversionEmbeddings: class StableDiffuionModelHijack:
ids_lookup = {} ids_lookup = {}
word_embeddings = {} word_embeddings = {}
word_embeddings_checksums = {} word_embeddings_checksums = {}
fixes = [] fixes = None
used_custom_terms = [] used_custom_terms = []
dir_mtime = None dir_mtime = None
def load(self, dir, model): def load_textual_inversion_embeddings(self, dir, model):
mt = os.path.getmtime(dir) mt = os.path.getmtime(dir)
if self.dir_mtime is not None and mt <= self.dir_mtime: if self.dir_mtime is not None and mt <= self.dir_mtime:
return return
@ -469,6 +469,7 @@ class TextInversionEmbeddings:
self.word_embeddings_checksums[name] = f'{const_hash(emb)&0xffff:04x}' self.word_embeddings_checksums[name] = f'{const_hash(emb)&0xffff:04x}'
ids = tokenizer([name], add_special_tokens=False)['input_ids'][0] ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
first_id = ids[0] first_id = ids[0]
if first_id not in self.ids_lookup: if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = [] self.ids_lookup[first_id] = []
@ -497,6 +498,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.embeddings = embeddings self.embeddings = embeddings
self.tokenizer = wrapped.tokenizer self.tokenizer = wrapped.tokenizer
self.max_length = wrapped.max_length self.max_length = wrapped.max_length
self.token_mults = {}
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
for c in text:
if c == '[':
mult /= 1.1
if c == ']':
mult *= 1.1
if c == '(':
mult *= 1.1
if c == ')':
mult /= 1.1
if mult != 1.0:
self.token_mults[ident] = mult
def forward(self, text): def forward(self, text):
self.embeddings.fixes = [] self.embeddings.fixes = []
@ -508,14 +526,17 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
cache = {} cache = {}
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"] batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
batch_multipliers = []
for tokens in batch_tokens: for tokens in batch_tokens:
tuple_tokens = tuple(tokens) tuple_tokens = tuple(tokens)
if tuple_tokens in cache: if tuple_tokens in cache:
remade_tokens, fixes = cache[tuple_tokens] remade_tokens, fixes, multipliers = cache[tuple_tokens]
else: else:
fixes = [] fixes = []
remade_tokens = [] remade_tokens = []
multipliers = []
mult = 1.0
i = 0 i = 0
while i < len(tokens): while i < len(tokens):
@ -523,14 +544,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
possible_matches = self.embeddings.ids_lookup.get(token, None) possible_matches = self.embeddings.ids_lookup.get(token, None)
if possible_matches is None: mult_change = self.token_mults.get(token)
if mult_change is not None:
mult *= mult_change
elif possible_matches is None:
remade_tokens.append(token) remade_tokens.append(token)
multipliers.append(mult)
else: else:
found = False found = False
for ids, word in possible_matches: for ids, word in possible_matches:
if tokens[i:i+len(ids)] == ids: if tokens[i:i+len(ids)] == ids:
fixes.append((len(remade_tokens), word)) fixes.append((len(remade_tokens), word))
remade_tokens.append(777) remade_tokens.append(777)
multipliers.append(mult)
i += len(ids) - 1 i += len(ids) - 1
found = True found = True
self.embeddings.used_custom_terms.append((word, self.embeddings.word_embeddings_checksums[word])) self.embeddings.used_custom_terms.append((word, self.embeddings.word_embeddings_checksums[word]))
@ -538,19 +564,32 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if not found: if not found:
remade_tokens.append(token) remade_tokens.append(token)
multipliers.append(mult)
i += 1 i += 1
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes) cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
remade_batch_tokens.append(remade_tokens) remade_batch_tokens.append(remade_tokens)
self.embeddings.fixes.append(fixes) self.embeddings.fixes.append(fixes)
batch_multipliers.append(multipliers)
tokens = torch.asarray(remade_batch_tokens).to(self.wrapped.device) tokens = torch.asarray(remade_batch_tokens).to(self.wrapped.device)
outputs = self.wrapped.transformer(input_ids=tokens) outputs = self.wrapped.transformer(input_ids=tokens)
z = outputs.last_hidden_state z = outputs.last_hidden_state
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers = torch.asarray(np.array(batch_multipliers)).to(device)
original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
z *= original_mean / new_mean
return z return z
@ -562,24 +601,19 @@ class EmbeddingsWithFixes(nn.Module):
def forward(self, input_ids): def forward(self, input_ids):
batch_fixes = self.embeddings.fixes batch_fixes = self.embeddings.fixes
self.embeddings.fixes = [] self.embeddings.fixes = None
inputs_embeds = self.wrapped(input_ids) inputs_embeds = self.wrapped(input_ids)
for fixes, tensor in zip(batch_fixes, inputs_embeds): if batch_fixes is not None:
for offset, word in fixes: for fixes, tensor in zip(batch_fixes, inputs_embeds):
tensor[offset] = self.embeddings.word_embeddings[word] for offset, word in fixes:
tensor[offset] = self.embeddings.word_embeddings[word]
return inputs_embeds return inputs_embeds
def get_learned_conditioning_with_embeddings(model, prompts):
if os.path.exists(cmd_opts.embeddings_dir):
text_inversion_embeddings.load(cmd_opts.embeddings_dir, model)
return model.get_learned_conditioning(prompts)
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False, extra_generation_params=None): def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False, extra_generation_params=None):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
@ -648,7 +682,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index,
return f"{prompt}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments]) return f"{prompt}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments])
if os.path.exists(cmd_opts.embeddings_dir): if os.path.exists(cmd_opts.embeddings_dir):
text_inversion_embeddings.load(cmd_opts.embeddings_dir, model) model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, model)
output_images = [] output_images = []
with torch.no_grad(), autocast("cuda"), model.ema_scope(): with torch.no_grad(), autocast("cuda"), model.ema_scope():
@ -661,8 +695,8 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index,
uc = model.get_learned_conditioning(len(prompts) * [""]) uc = model.get_learned_conditioning(len(prompts) * [""])
c = model.get_learned_conditioning(prompts) c = model.get_learned_conditioning(prompts)
if len(text_inversion_embeddings.used_custom_terms) > 0: if len(model_hijack.used_custom_terms) > 0:
comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in text_inversion_embeddings.used_custom_terms])) comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in model_hijack.used_custom_terms]))
# we manually generate all input noises because each one should have a specific seed # we manually generate all input noises because each one should have a specific seed
x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds) x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds)
@ -1060,10 +1094,9 @@ model = load_model_from_config(config, cmd_opts.ckpt)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = (model if cmd_opts.no_half else model.half()).to(device) model = (model if cmd_opts.no_half else model.half()).to(device)
text_inversion_embeddings = TextInversionEmbeddings()
if os.path.exists(cmd_opts.embeddings_dir): model_hijack = StableDiffuionModelHijack()
text_inversion_embeddings.hijack(model) model_hijack.hijack(model)
demo = gr.TabbedInterface( demo = gr.TabbedInterface(
interface_list=[x[0] for x in interfaces], interface_list=[x[0] for x in interfaces],