new implementation for attention/emphasis
This commit is contained in:
parent
29ce8a687d
commit
c1c27dad3b
3 changed files with 186 additions and 6 deletions
|
@ -126,5 +126,90 @@ def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step):
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
re_attention = re.compile(r"""
|
||||||
|
\\\(|
|
||||||
|
\\\)|
|
||||||
|
\\\[|
|
||||||
|
\\]|
|
||||||
|
\\\\|
|
||||||
|
\\|
|
||||||
|
\(|
|
||||||
|
\[|
|
||||||
|
:([+-]?[.\d]+)\)|
|
||||||
|
\)|
|
||||||
|
]|
|
||||||
|
[^\\()\[\]:]+|
|
||||||
|
:
|
||||||
|
""", re.X)
|
||||||
|
|
||||||
#get_learned_conditioning_prompt_schedules(["fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"], 100)
|
|
||||||
|
def parse_prompt_attention(text):
|
||||||
|
"""
|
||||||
|
Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight.
|
||||||
|
Accepted tokens are:
|
||||||
|
(abc) - increases attention to abc by a multiplier of 1.1
|
||||||
|
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
||||||
|
[abc] - decreases attention to abc by a multiplier of 1.1
|
||||||
|
\( - literal character '('
|
||||||
|
\[ - literal character '['
|
||||||
|
\) - literal character ')'
|
||||||
|
\] - literal character ']'
|
||||||
|
\\ - literal character '\'
|
||||||
|
anything else - just text
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
'a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).'
|
||||||
|
|
||||||
|
produces:
|
||||||
|
|
||||||
|
[
|
||||||
|
['a ', 1.0],
|
||||||
|
['house', 1.5730000000000004],
|
||||||
|
[' ', 1.1],
|
||||||
|
['on', 1.0],
|
||||||
|
[' a ', 1.1],
|
||||||
|
['hill', 0.55],
|
||||||
|
[', sun, ', 1.1],
|
||||||
|
['sky', 1.4641000000000006],
|
||||||
|
['.', 1.1]
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
|
||||||
|
res = []
|
||||||
|
round_brackets = []
|
||||||
|
square_brackets = []
|
||||||
|
|
||||||
|
round_bracket_multiplier = 1.1
|
||||||
|
square_bracket_multiplier = 1 / 1.1
|
||||||
|
|
||||||
|
def multiply_range(start_position, multiplier):
|
||||||
|
for p in range(start_position, len(res)):
|
||||||
|
res[p][1] *= multiplier
|
||||||
|
|
||||||
|
for m in re_attention.finditer(text):
|
||||||
|
text = m.group(0)
|
||||||
|
weight = m.group(1)
|
||||||
|
|
||||||
|
if text.startswith('\\'):
|
||||||
|
res.append([text[1:], 1.0])
|
||||||
|
elif text == '(':
|
||||||
|
round_brackets.append(len(res))
|
||||||
|
elif text == '[':
|
||||||
|
square_brackets.append(len(res))
|
||||||
|
elif weight is not None and len(round_brackets) > 0:
|
||||||
|
multiply_range(round_brackets.pop(), float(weight))
|
||||||
|
elif text == ')' and len(round_brackets) > 0:
|
||||||
|
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
||||||
|
elif text == ']' and len(square_brackets) > 0:
|
||||||
|
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
||||||
|
else:
|
||||||
|
res.append([text, 1.0])
|
||||||
|
|
||||||
|
for pos in round_brackets:
|
||||||
|
multiply_range(pos, round_bracket_multiplier)
|
||||||
|
|
||||||
|
for pos in square_brackets:
|
||||||
|
multiply_range(pos, square_bracket_multiplier)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
|
@ -6,6 +6,7 @@ import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
|
|
||||||
|
from modules import prompt_parser
|
||||||
from modules.shared import opts, device, cmd_opts
|
from modules.shared import opts, device, cmd_opts
|
||||||
|
|
||||||
from ldm.util import default
|
from ldm.util import default
|
||||||
|
@ -211,6 +212,7 @@ class StableDiffusionModelHijack:
|
||||||
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||||
emb = next(iter(param_dict.items()))[1]
|
emb = next(iter(param_dict.items()))[1]
|
||||||
|
# diffuser concepts
|
||||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
||||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||||
|
|
||||||
|
@ -236,7 +238,7 @@ class StableDiffusionModelHijack:
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print(f"Loaded a total of {len(self.word_embeddings)} text inversion embeddings.")
|
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
|
||||||
|
|
||||||
def hijack(self, m):
|
def hijack(self, m):
|
||||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||||
|
@ -275,6 +277,7 @@ class StableDiffusionModelHijack:
|
||||||
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
|
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
|
||||||
return remade_batch_tokens[0], token_count, max_length
|
return remade_batch_tokens[0], token_count, max_length
|
||||||
|
|
||||||
|
|
||||||
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
def __init__(self, wrapped, hijack):
|
def __init__(self, wrapped, hijack):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -300,7 +303,92 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
if mult != 1.0:
|
if mult != 1.0:
|
||||||
self.token_mults[ident] = mult
|
self.token_mults[ident] = mult
|
||||||
|
|
||||||
def process_text(self, text):
|
|
||||||
|
def tokenize_line(self, line, used_custom_terms, hijack_comments):
|
||||||
|
id_start = self.wrapped.tokenizer.bos_token_id
|
||||||
|
id_end = self.wrapped.tokenizer.eos_token_id
|
||||||
|
maxlen = self.wrapped.max_length
|
||||||
|
|
||||||
|
if opts.enable_emphasis:
|
||||||
|
parsed = prompt_parser.parse_prompt_attention(line)
|
||||||
|
else:
|
||||||
|
parsed = [[line, 1.0]]
|
||||||
|
|
||||||
|
tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"]
|
||||||
|
|
||||||
|
fixes = []
|
||||||
|
remade_tokens = []
|
||||||
|
multipliers = []
|
||||||
|
|
||||||
|
for tokens, (text, weight) in zip(tokenized, parsed):
|
||||||
|
i = 0
|
||||||
|
while i < len(tokens):
|
||||||
|
token = tokens[i]
|
||||||
|
|
||||||
|
possible_matches = self.hijack.ids_lookup.get(token, None)
|
||||||
|
|
||||||
|
if possible_matches is None:
|
||||||
|
remade_tokens.append(token)
|
||||||
|
multipliers.append(weight)
|
||||||
|
else:
|
||||||
|
found = False
|
||||||
|
for ids, word in possible_matches:
|
||||||
|
if tokens[i:i + len(ids)] == ids:
|
||||||
|
emb_len = int(self.hijack.word_embeddings[word].shape[0])
|
||||||
|
fixes.append((len(remade_tokens), word))
|
||||||
|
remade_tokens += [0] * emb_len
|
||||||
|
multipliers += [weight] * emb_len
|
||||||
|
i += len(ids) - 1
|
||||||
|
found = True
|
||||||
|
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
|
||||||
|
break
|
||||||
|
|
||||||
|
if not found:
|
||||||
|
remade_tokens.append(token)
|
||||||
|
multipliers.append(weight)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
if len(remade_tokens) > maxlen - 2:
|
||||||
|
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
||||||
|
ovf = remade_tokens[maxlen - 2:]
|
||||||
|
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
||||||
|
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
||||||
|
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||||
|
|
||||||
|
token_count = len(remade_tokens)
|
||||||
|
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||||
|
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
|
||||||
|
|
||||||
|
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
||||||
|
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
||||||
|
|
||||||
|
return remade_tokens, fixes, multipliers, token_count
|
||||||
|
|
||||||
|
def process_text(self, texts):
|
||||||
|
used_custom_terms = []
|
||||||
|
remade_batch_tokens = []
|
||||||
|
hijack_comments = []
|
||||||
|
hijack_fixes = []
|
||||||
|
token_count = 0
|
||||||
|
|
||||||
|
cache = {}
|
||||||
|
batch_multipliers = []
|
||||||
|
for line in texts:
|
||||||
|
if line in cache:
|
||||||
|
remade_tokens, fixes, multipliers = cache[line]
|
||||||
|
else:
|
||||||
|
remade_tokens, fixes, multipliers, token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
|
||||||
|
|
||||||
|
cache[line] = (remade_tokens, fixes, multipliers)
|
||||||
|
|
||||||
|
remade_batch_tokens.append(remade_tokens)
|
||||||
|
hijack_fixes.append(fixes)
|
||||||
|
batch_multipliers.append(multipliers)
|
||||||
|
|
||||||
|
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||||
|
|
||||||
|
|
||||||
|
def process_text_old(self, text):
|
||||||
id_start = self.wrapped.tokenizer.bos_token_id
|
id_start = self.wrapped.tokenizer.bos_token_id
|
||||||
id_end = self.wrapped.tokenizer.eos_token_id
|
id_end = self.wrapped.tokenizer.eos_token_id
|
||||||
maxlen = self.wrapped.max_length
|
maxlen = self.wrapped.max_length
|
||||||
|
@ -376,12 +464,18 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||||
|
|
||||||
def forward(self, text):
|
def forward(self, text):
|
||||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
|
|
||||||
|
if opts.use_old_emphasis_implementation:
|
||||||
|
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
|
||||||
|
else:
|
||||||
|
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
|
||||||
|
|
||||||
|
|
||||||
self.hijack.fixes = hijack_fixes
|
self.hijack.fixes = hijack_fixes
|
||||||
self.hijack.comments = hijack_comments
|
self.hijack.comments = hijack_comments
|
||||||
|
|
||||||
if len(used_custom_terms) > 0:
|
if len(used_custom_terms) > 0:
|
||||||
self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
||||||
|
|
||||||
tokens = torch.asarray(remade_batch_tokens).to(device)
|
tokens = torch.asarray(remade_batch_tokens).to(device)
|
||||||
outputs = self.wrapped.transformer(input_ids=tokens)
|
outputs = self.wrapped.transformer(input_ids=tokens)
|
||||||
|
|
|
@ -195,7 +195,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||||
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
||||||
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
|
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
|
||||||
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
|
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
|
||||||
"enable_emphasis": OptionInfo(True, "Use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
"enable_emphasis": OptionInfo(True, "Eemphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||||
|
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
||||||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||||
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
|
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
|
||||||
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
|
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
|
||||||
|
|
Loading…
Reference in a new issue