implementation for attention using [] and ()
This commit is contained in:
parent
a51bedfb5a
commit
9597b265ec
3 changed files with 62 additions and 23 deletions
|
@ -188,3 +188,9 @@ and put it into `embeddings` dir and use Usada Pekora in prompt.
|
|||
A tab with settings, allowing you to use UI to edit more than half of parameters that previously
|
||||
were commandline. Settings are saved to config.js file. Settings that remain as commandline
|
||||
options are ones that are required at startup.
|
||||
|
||||
### Attention
|
||||
Using `()` in prompt decreases model's attention to enclosed words, and `[]` increases it. You can combine
|
||||
multiple modifiers:
|
||||
|
||||
![](images/attention-3.jpg)
|
||||
|
|
BIN
images/attention-3.jpg
Normal file
BIN
images/attention-3.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 944 KiB |
79
webui.py
79
webui.py
|
@ -433,15 +433,15 @@ if os.path.exists(cmd_opts.gfpgan_dir):
|
|||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
|
||||
class TextInversionEmbeddings:
|
||||
class StableDiffuionModelHijack:
|
||||
ids_lookup = {}
|
||||
word_embeddings = {}
|
||||
word_embeddings_checksums = {}
|
||||
fixes = []
|
||||
fixes = None
|
||||
used_custom_terms = []
|
||||
dir_mtime = None
|
||||
|
||||
def load(self, dir, model):
|
||||
def load_textual_inversion_embeddings(self, dir, model):
|
||||
mt = os.path.getmtime(dir)
|
||||
if self.dir_mtime is not None and mt <= self.dir_mtime:
|
||||
return
|
||||
|
@ -469,6 +469,7 @@ class TextInversionEmbeddings:
|
|||
self.word_embeddings_checksums[name] = f'{const_hash(emb)&0xffff:04x}'
|
||||
|
||||
ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
|
||||
|
||||
first_id = ids[0]
|
||||
if first_id not in self.ids_lookup:
|
||||
self.ids_lookup[first_id] = []
|
||||
|
@ -497,6 +498,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||
self.embeddings = embeddings
|
||||
self.tokenizer = wrapped.tokenizer
|
||||
self.max_length = wrapped.max_length
|
||||
self.token_mults = {}
|
||||
|
||||
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
||||
for text, ident in tokens_with_parens:
|
||||
mult = 1.0
|
||||
for c in text:
|
||||
if c == '[':
|
||||
mult /= 1.1
|
||||
if c == ']':
|
||||
mult *= 1.1
|
||||
if c == '(':
|
||||
mult *= 1.1
|
||||
if c == ')':
|
||||
mult /= 1.1
|
||||
|
||||
if mult != 1.0:
|
||||
self.token_mults[ident] = mult
|
||||
|
||||
def forward(self, text):
|
||||
self.embeddings.fixes = []
|
||||
|
@ -508,14 +526,17 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||
|
||||
cache = {}
|
||||
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
|
||||
batch_multipliers = []
|
||||
for tokens in batch_tokens:
|
||||
tuple_tokens = tuple(tokens)
|
||||
|
||||
if tuple_tokens in cache:
|
||||
remade_tokens, fixes = cache[tuple_tokens]
|
||||
remade_tokens, fixes, multipliers = cache[tuple_tokens]
|
||||
else:
|
||||
fixes = []
|
||||
remade_tokens = []
|
||||
multipliers = []
|
||||
mult = 1.0
|
||||
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
|
@ -523,14 +544,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||
|
||||
possible_matches = self.embeddings.ids_lookup.get(token, None)
|
||||
|
||||
if possible_matches is None:
|
||||
mult_change = self.token_mults.get(token)
|
||||
if mult_change is not None:
|
||||
mult *= mult_change
|
||||
elif possible_matches is None:
|
||||
remade_tokens.append(token)
|
||||
multipliers.append(mult)
|
||||
else:
|
||||
found = False
|
||||
for ids, word in possible_matches:
|
||||
if tokens[i:i+len(ids)] == ids:
|
||||
fixes.append((len(remade_tokens), word))
|
||||
remade_tokens.append(777)
|
||||
multipliers.append(mult)
|
||||
i += len(ids) - 1
|
||||
found = True
|
||||
self.embeddings.used_custom_terms.append((word, self.embeddings.word_embeddings_checksums[word]))
|
||||
|
@ -538,19 +564,32 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||
|
||||
if not found:
|
||||
remade_tokens.append(token)
|
||||
multipliers.append(mult)
|
||||
|
||||
i += 1
|
||||
|
||||
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
|
||||
cache[tuple_tokens] = (remade_tokens, fixes)
|
||||
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
||||
|
||||
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
||||
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
||||
|
||||
remade_batch_tokens.append(remade_tokens)
|
||||
self.embeddings.fixes.append(fixes)
|
||||
batch_multipliers.append(multipliers)
|
||||
|
||||
tokens = torch.asarray(remade_batch_tokens).to(self.wrapped.device)
|
||||
outputs = self.wrapped.transformer(input_ids=tokens)
|
||||
z = outputs.last_hidden_state
|
||||
|
||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||
batch_multipliers = torch.asarray(np.array(batch_multipliers)).to(device)
|
||||
original_mean = z.mean()
|
||||
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
||||
new_mean = z.mean()
|
||||
z *= original_mean / new_mean
|
||||
|
||||
return z
|
||||
|
||||
|
||||
|
@ -562,24 +601,19 @@ class EmbeddingsWithFixes(nn.Module):
|
|||
|
||||
def forward(self, input_ids):
|
||||
batch_fixes = self.embeddings.fixes
|
||||
self.embeddings.fixes = []
|
||||
self.embeddings.fixes = None
|
||||
|
||||
inputs_embeds = self.wrapped(input_ids)
|
||||
|
||||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||
for offset, word in fixes:
|
||||
tensor[offset] = self.embeddings.word_embeddings[word]
|
||||
if batch_fixes is not None:
|
||||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||
for offset, word in fixes:
|
||||
tensor[offset] = self.embeddings.word_embeddings[word]
|
||||
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
def get_learned_conditioning_with_embeddings(model, prompts):
|
||||
if os.path.exists(cmd_opts.embeddings_dir):
|
||||
text_inversion_embeddings.load(cmd_opts.embeddings_dir, model)
|
||||
|
||||
return model.get_learned_conditioning(prompts)
|
||||
|
||||
|
||||
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False, extra_generation_params=None):
|
||||
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
||||
|
||||
|
@ -648,7 +682,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index,
|
|||
return f"{prompt}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments])
|
||||
|
||||
if os.path.exists(cmd_opts.embeddings_dir):
|
||||
text_inversion_embeddings.load(cmd_opts.embeddings_dir, model)
|
||||
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, model)
|
||||
|
||||
output_images = []
|
||||
with torch.no_grad(), autocast("cuda"), model.ema_scope():
|
||||
|
@ -661,8 +695,8 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index,
|
|||
uc = model.get_learned_conditioning(len(prompts) * [""])
|
||||
c = model.get_learned_conditioning(prompts)
|
||||
|
||||
if len(text_inversion_embeddings.used_custom_terms) > 0:
|
||||
comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in text_inversion_embeddings.used_custom_terms]))
|
||||
if len(model_hijack.used_custom_terms) > 0:
|
||||
comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in model_hijack.used_custom_terms]))
|
||||
|
||||
# we manually generate all input noises because each one should have a specific seed
|
||||
x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds)
|
||||
|
@ -1060,10 +1094,9 @@ model = load_model_from_config(config, cmd_opts.ckpt)
|
|||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
model = (model if cmd_opts.no_half else model.half()).to(device)
|
||||
text_inversion_embeddings = TextInversionEmbeddings()
|
||||
|
||||
if os.path.exists(cmd_opts.embeddings_dir):
|
||||
text_inversion_embeddings.hijack(model)
|
||||
model_hijack = StableDiffuionModelHijack()
|
||||
model_hijack.hijack(model)
|
||||
|
||||
demo = gr.TabbedInterface(
|
||||
interface_list=[x[0] for x in interfaces],
|
||||
|
|
Loading…
Reference in a new issue