Merge branch 'clip_hijack_rework'
This commit is contained in:
commit
c4a221c405
5 changed files with 259 additions and 174 deletions
|
@ -150,10 +150,10 @@ class StableDiffusionModelHijack:
|
||||||
def clear_comments(self):
|
def clear_comments(self):
|
||||||
self.comments = []
|
self.comments = []
|
||||||
|
|
||||||
def tokenize(self, text):
|
def get_prompt_lengths(self, text):
|
||||||
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
|
_, token_count = self.clip.process_texts([text])
|
||||||
|
|
||||||
return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
|
return token_count, self.clip.get_target_prompt_token_count(token_count)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsWithFixes(torch.nn.Module):
|
class EmbeddingsWithFixes(torch.nn.Module):
|
||||||
|
|
|
@ -1,30 +1,89 @@
|
||||||
import math
|
import math
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from modules import prompt_parser, devices
|
from modules import prompt_parser, devices, sd_hijack
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
def get_target_prompt_token_count(token_count):
|
|
||||||
return math.ceil(max(token_count, 1) / 75) * 75
|
class PromptChunk:
|
||||||
|
"""
|
||||||
|
This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
|
||||||
|
If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
|
||||||
|
Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
|
||||||
|
so just 75 tokens from prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.tokens = []
|
||||||
|
self.multipliers = []
|
||||||
|
self.fixes = []
|
||||||
|
|
||||||
|
|
||||||
|
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
|
||||||
|
"""An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt
|
||||||
|
chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
|
||||||
|
are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
|
||||||
|
|
||||||
|
|
||||||
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||||
|
"""A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
|
||||||
|
have unlimited prompt length and assign weights to tokens in prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, wrapped, hijack):
|
def __init__(self, wrapped, hijack):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.wrapped = wrapped
|
self.wrapped = wrapped
|
||||||
self.hijack = hijack
|
"""Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
|
||||||
|
depending on model."""
|
||||||
|
|
||||||
|
self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
|
||||||
|
self.chunk_length = 75
|
||||||
|
|
||||||
|
def empty_chunk(self):
|
||||||
|
"""creates an empty PromptChunk and returns it"""
|
||||||
|
|
||||||
|
chunk = PromptChunk()
|
||||||
|
chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
|
||||||
|
chunk.multipliers = [1.0] * (self.chunk_length + 2)
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
def get_target_prompt_token_count(self, token_count):
|
||||||
|
"""returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
|
||||||
|
|
||||||
|
return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
|
||||||
|
|
||||||
def tokenize(self, texts):
|
def tokenize(self, texts):
|
||||||
|
"""Converts a batch of texts into a batch of token ids"""
|
||||||
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def encode_with_transformers(self, tokens):
|
def encode_with_transformers(self, tokens):
|
||||||
|
"""
|
||||||
|
converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
|
||||||
|
All python lists with tokens are assumed to have same length, usually 77.
|
||||||
|
if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
|
||||||
|
model - can be 768 and 1024.
|
||||||
|
Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None).
|
||||||
|
"""
|
||||||
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def encode_embedding_init_text(self, init_text, nvpt):
|
def encode_embedding_init_text(self, init_text, nvpt):
|
||||||
|
"""Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through
|
||||||
|
transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned."""
|
||||||
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def tokenize_line(self, line, used_custom_terms, hijack_comments):
|
def tokenize_line(self, line):
|
||||||
|
"""
|
||||||
|
this transforms a single prompt into a list of PromptChunk objects - as many as needed to
|
||||||
|
represent the prompt.
|
||||||
|
Returns the list and the total number of tokens in the prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
if opts.enable_emphasis:
|
if opts.enable_emphasis:
|
||||||
parsed = prompt_parser.parse_prompt_attention(line)
|
parsed = prompt_parser.parse_prompt_attention(line)
|
||||||
else:
|
else:
|
||||||
|
@ -32,205 +91,152 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||||
|
|
||||||
tokenized = self.tokenize([text for text, _ in parsed])
|
tokenized = self.tokenize([text for text, _ in parsed])
|
||||||
|
|
||||||
fixes = []
|
chunks = []
|
||||||
remade_tokens = []
|
chunk = PromptChunk()
|
||||||
multipliers = []
|
token_count = 0
|
||||||
last_comma = -1
|
last_comma = -1
|
||||||
|
|
||||||
for tokens, (text, weight) in zip(tokenized, parsed):
|
def next_chunk():
|
||||||
i = 0
|
"""puts current chunk into the list of results and produces the next one - empty"""
|
||||||
while i < len(tokens):
|
nonlocal token_count
|
||||||
token = tokens[i]
|
nonlocal last_comma
|
||||||
|
nonlocal chunk
|
||||||
|
|
||||||
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
token_count += len(chunk.tokens)
|
||||||
|
to_add = self.chunk_length - len(chunk.tokens)
|
||||||
|
if to_add > 0:
|
||||||
|
chunk.tokens += [self.id_end] * to_add
|
||||||
|
chunk.multipliers += [1.0] * to_add
|
||||||
|
|
||||||
|
chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
|
||||||
|
chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
|
||||||
|
|
||||||
|
last_comma = -1
|
||||||
|
chunks.append(chunk)
|
||||||
|
chunk = PromptChunk()
|
||||||
|
|
||||||
|
for tokens, (text, weight) in zip(tokenized, parsed):
|
||||||
|
position = 0
|
||||||
|
while position < len(tokens):
|
||||||
|
token = tokens[position]
|
||||||
|
|
||||||
if token == self.comma_token:
|
if token == self.comma_token:
|
||||||
last_comma = len(remade_tokens)
|
last_comma = len(chunk.tokens)
|
||||||
elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
|
|
||||||
last_comma += 1
|
|
||||||
reloc_tokens = remade_tokens[last_comma:]
|
|
||||||
reloc_mults = multipliers[last_comma:]
|
|
||||||
|
|
||||||
remade_tokens = remade_tokens[:last_comma]
|
# this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
|
||||||
length = len(remade_tokens)
|
# is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
|
||||||
|
elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
|
||||||
|
break_location = last_comma + 1
|
||||||
|
|
||||||
rem = int(math.ceil(length / 75)) * 75 - length
|
reloc_tokens = chunk.tokens[break_location:]
|
||||||
remade_tokens += [self.id_end] * rem + reloc_tokens
|
reloc_mults = chunk.multipliers[break_location:]
|
||||||
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
|
|
||||||
|
|
||||||
|
chunk.tokens = chunk.tokens[:break_location]
|
||||||
|
chunk.multipliers = chunk.multipliers[:break_location]
|
||||||
|
|
||||||
|
next_chunk()
|
||||||
|
chunk.tokens = reloc_tokens
|
||||||
|
chunk.multipliers = reloc_mults
|
||||||
|
|
||||||
|
if len(chunk.tokens) == self.chunk_length:
|
||||||
|
next_chunk()
|
||||||
|
|
||||||
|
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position)
|
||||||
if embedding is None:
|
if embedding is None:
|
||||||
remade_tokens.append(token)
|
chunk.tokens.append(token)
|
||||||
multipliers.append(weight)
|
chunk.multipliers.append(weight)
|
||||||
i += 1
|
position += 1
|
||||||
else:
|
continue
|
||||||
emb_len = int(embedding.vec.shape[0])
|
|
||||||
iteration = len(remade_tokens) // 75
|
|
||||||
if (len(remade_tokens) + emb_len) // 75 != iteration:
|
|
||||||
rem = (75 * (iteration + 1) - len(remade_tokens))
|
|
||||||
remade_tokens += [self.id_end] * rem
|
|
||||||
multipliers += [1.0] * rem
|
|
||||||
iteration += 1
|
|
||||||
fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
|
|
||||||
remade_tokens += [0] * emb_len
|
|
||||||
multipliers += [weight] * emb_len
|
|
||||||
used_custom_terms.append((embedding.name, embedding.checksum()))
|
|
||||||
i += embedding_length_in_tokens
|
|
||||||
|
|
||||||
token_count = len(remade_tokens)
|
emb_len = int(embedding.vec.shape[0])
|
||||||
prompt_target_length = get_target_prompt_token_count(token_count)
|
if len(chunk.tokens) + emb_len > self.chunk_length:
|
||||||
tokens_to_add = prompt_target_length - len(remade_tokens)
|
next_chunk()
|
||||||
|
|
||||||
remade_tokens = remade_tokens + [self.id_end] * tokens_to_add
|
chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))
|
||||||
multipliers = multipliers + [1.0] * tokens_to_add
|
|
||||||
|
|
||||||
return remade_tokens, fixes, multipliers, token_count
|
chunk.tokens += [0] * emb_len
|
||||||
|
chunk.multipliers += [weight] * emb_len
|
||||||
|
position += embedding_length_in_tokens
|
||||||
|
|
||||||
|
if len(chunk.tokens) > 0 or len(chunks) == 0:
|
||||||
|
next_chunk()
|
||||||
|
|
||||||
|
return chunks, token_count
|
||||||
|
|
||||||
|
def process_texts(self, texts):
|
||||||
|
"""
|
||||||
|
Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
|
||||||
|
length, in tokens, of all texts.
|
||||||
|
"""
|
||||||
|
|
||||||
def process_text(self, texts):
|
|
||||||
used_custom_terms = []
|
|
||||||
remade_batch_tokens = []
|
|
||||||
hijack_comments = []
|
|
||||||
hijack_fixes = []
|
|
||||||
token_count = 0
|
token_count = 0
|
||||||
|
|
||||||
cache = {}
|
cache = {}
|
||||||
batch_multipliers = []
|
batch_chunks = []
|
||||||
for line in texts:
|
for line in texts:
|
||||||
if line in cache:
|
if line in cache:
|
||||||
remade_tokens, fixes, multipliers = cache[line]
|
chunks = cache[line]
|
||||||
else:
|
else:
|
||||||
remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
|
chunks, current_token_count = self.tokenize_line(line)
|
||||||
token_count = max(current_token_count, token_count)
|
token_count = max(current_token_count, token_count)
|
||||||
|
|
||||||
cache[line] = (remade_tokens, fixes, multipliers)
|
cache[line] = chunks
|
||||||
|
|
||||||
remade_batch_tokens.append(remade_tokens)
|
batch_chunks.append(chunks)
|
||||||
hijack_fixes.append(fixes)
|
|
||||||
batch_multipliers.append(multipliers)
|
|
||||||
|
|
||||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
return batch_chunks, token_count
|
||||||
|
|
||||||
def process_text_old(self, texts):
|
def forward(self, texts):
|
||||||
id_start = self.id_start
|
"""
|
||||||
id_end = self.id_end
|
Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
|
||||||
maxlen = self.wrapped.max_length # you get to stay at 77
|
Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
|
||||||
used_custom_terms = []
|
be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
|
||||||
remade_batch_tokens = []
|
An example shape returned by this function can be: (2, 77, 768).
|
||||||
hijack_comments = []
|
Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
|
||||||
hijack_fixes = []
|
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
|
||||||
token_count = 0
|
"""
|
||||||
|
|
||||||
cache = {}
|
if opts.use_old_emphasis_implementation:
|
||||||
batch_tokens = self.tokenize(texts)
|
import modules.sd_hijack_clip_old
|
||||||
batch_multipliers = []
|
return modules.sd_hijack_clip_old.forward_old(self, texts)
|
||||||
for tokens in batch_tokens:
|
|
||||||
tuple_tokens = tuple(tokens)
|
|
||||||
|
|
||||||
if tuple_tokens in cache:
|
batch_chunks, token_count = self.process_texts(texts)
|
||||||
remade_tokens, fixes, multipliers = cache[tuple_tokens]
|
|
||||||
else:
|
|
||||||
fixes = []
|
|
||||||
remade_tokens = []
|
|
||||||
multipliers = []
|
|
||||||
mult = 1.0
|
|
||||||
|
|
||||||
i = 0
|
used_embeddings = {}
|
||||||
while i < len(tokens):
|
chunk_count = max([len(x) for x in batch_chunks])
|
||||||
token = tokens[i]
|
|
||||||
|
|
||||||
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
zs = []
|
||||||
|
for i in range(chunk_count):
|
||||||
|
batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]
|
||||||
|
|
||||||
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
|
tokens = [x.tokens for x in batch_chunk]
|
||||||
if mult_change is not None:
|
multipliers = [x.multipliers for x in batch_chunk]
|
||||||
mult *= mult_change
|
self.hijack.fixes = [x.fixes for x in batch_chunk]
|
||||||
i += 1
|
|
||||||
elif embedding is None:
|
|
||||||
remade_tokens.append(token)
|
|
||||||
multipliers.append(mult)
|
|
||||||
i += 1
|
|
||||||
else:
|
|
||||||
emb_len = int(embedding.vec.shape[0])
|
|
||||||
fixes.append((len(remade_tokens), embedding))
|
|
||||||
remade_tokens += [0] * emb_len
|
|
||||||
multipliers += [mult] * emb_len
|
|
||||||
used_custom_terms.append((embedding.name, embedding.checksum()))
|
|
||||||
i += embedding_length_in_tokens
|
|
||||||
|
|
||||||
if len(remade_tokens) > maxlen - 2:
|
for fixes in self.hijack.fixes:
|
||||||
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
for position, embedding in fixes:
|
||||||
ovf = remade_tokens[maxlen - 2:]
|
used_embeddings[embedding.name] = embedding
|
||||||
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
|
||||||
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
|
||||||
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
|
||||||
|
|
||||||
token_count = len(remade_tokens)
|
z = self.process_tokens(tokens, multipliers)
|
||||||
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
zs.append(z)
|
||||||
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
|
|
||||||
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
|
||||||
|
|
||||||
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
if len(used_embeddings) > 0:
|
||||||
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()])
|
||||||
|
self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
|
||||||
|
|
||||||
remade_batch_tokens.append(remade_tokens)
|
return torch.hstack(zs)
|
||||||
hijack_fixes.append(fixes)
|
|
||||||
batch_multipliers.append(multipliers)
|
|
||||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
|
||||||
|
|
||||||
def forward(self, text):
|
|
||||||
use_old = opts.use_old_emphasis_implementation
|
|
||||||
if use_old:
|
|
||||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
|
|
||||||
else:
|
|
||||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
|
|
||||||
|
|
||||||
self.hijack.comments += hijack_comments
|
|
||||||
|
|
||||||
if len(used_custom_terms) > 0:
|
|
||||||
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
|
||||||
|
|
||||||
if use_old:
|
|
||||||
self.hijack.fixes = hijack_fixes
|
|
||||||
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
|
||||||
|
|
||||||
z = None
|
|
||||||
i = 0
|
|
||||||
while max(map(len, remade_batch_tokens)) != 0:
|
|
||||||
rem_tokens = [x[75:] for x in remade_batch_tokens]
|
|
||||||
rem_multipliers = [x[75:] for x in batch_multipliers]
|
|
||||||
|
|
||||||
self.hijack.fixes = []
|
|
||||||
for unfiltered in hijack_fixes:
|
|
||||||
fixes = []
|
|
||||||
for fix in unfiltered:
|
|
||||||
if fix[0] == i:
|
|
||||||
fixes.append(fix[1])
|
|
||||||
self.hijack.fixes.append(fixes)
|
|
||||||
|
|
||||||
tokens = []
|
|
||||||
multipliers = []
|
|
||||||
for j in range(len(remade_batch_tokens)):
|
|
||||||
if len(remade_batch_tokens[j]) > 0:
|
|
||||||
tokens.append(remade_batch_tokens[j][:75])
|
|
||||||
multipliers.append(batch_multipliers[j][:75])
|
|
||||||
else:
|
|
||||||
tokens.append([self.id_end] * 75)
|
|
||||||
multipliers.append([1.0] * 75)
|
|
||||||
|
|
||||||
z1 = self.process_tokens(tokens, multipliers)
|
|
||||||
z = z1 if z is None else torch.cat((z, z1), axis=-2)
|
|
||||||
|
|
||||||
remade_batch_tokens = rem_tokens
|
|
||||||
batch_multipliers = rem_multipliers
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
return z
|
|
||||||
|
|
||||||
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
||||||
if not opts.use_old_emphasis_implementation:
|
"""
|
||||||
remade_batch_tokens = [[self.id_start] + x[:75] + [self.id_end] for x in remade_batch_tokens]
|
sends one single prompt chunk to be encoded by transformers neural network.
|
||||||
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
|
remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
|
||||||
|
there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
|
||||||
|
Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
|
||||||
|
corresponds to one token.
|
||||||
|
"""
|
||||||
tokens = torch.asarray(remade_batch_tokens).to(devices.device)
|
tokens = torch.asarray(remade_batch_tokens).to(devices.device)
|
||||||
|
|
||||||
|
# this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
|
||||||
if self.id_end != self.id_pad:
|
if self.id_end != self.id_pad:
|
||||||
for batch_pos in range(len(remade_batch_tokens)):
|
for batch_pos in range(len(remade_batch_tokens)):
|
||||||
index = remade_batch_tokens[batch_pos].index(self.id_end)
|
index = remade_batch_tokens[batch_pos].index(self.id_end)
|
||||||
|
@ -239,8 +245,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||||
z = self.encode_with_transformers(tokens)
|
z = self.encode_with_transformers(tokens)
|
||||||
|
|
||||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||||
batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
|
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
|
||||||
batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(devices.device)
|
|
||||||
original_mean = z.mean()
|
original_mean = z.mean()
|
||||||
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
||||||
new_mean = z.mean()
|
new_mean = z.mean()
|
||||||
|
|
81
modules/sd_hijack_clip_old.py
Normal file
81
modules/sd_hijack_clip_old.py
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
from modules import sd_hijack_clip
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
|
|
||||||
|
def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
|
||||||
|
id_start = self.id_start
|
||||||
|
id_end = self.id_end
|
||||||
|
maxlen = self.wrapped.max_length # you get to stay at 77
|
||||||
|
used_custom_terms = []
|
||||||
|
remade_batch_tokens = []
|
||||||
|
hijack_comments = []
|
||||||
|
hijack_fixes = []
|
||||||
|
token_count = 0
|
||||||
|
|
||||||
|
cache = {}
|
||||||
|
batch_tokens = self.tokenize(texts)
|
||||||
|
batch_multipliers = []
|
||||||
|
for tokens in batch_tokens:
|
||||||
|
tuple_tokens = tuple(tokens)
|
||||||
|
|
||||||
|
if tuple_tokens in cache:
|
||||||
|
remade_tokens, fixes, multipliers = cache[tuple_tokens]
|
||||||
|
else:
|
||||||
|
fixes = []
|
||||||
|
remade_tokens = []
|
||||||
|
multipliers = []
|
||||||
|
mult = 1.0
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
while i < len(tokens):
|
||||||
|
token = tokens[i]
|
||||||
|
|
||||||
|
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||||
|
|
||||||
|
mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None
|
||||||
|
if mult_change is not None:
|
||||||
|
mult *= mult_change
|
||||||
|
i += 1
|
||||||
|
elif embedding is None:
|
||||||
|
remade_tokens.append(token)
|
||||||
|
multipliers.append(mult)
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
emb_len = int(embedding.vec.shape[0])
|
||||||
|
fixes.append((len(remade_tokens), embedding))
|
||||||
|
remade_tokens += [0] * emb_len
|
||||||
|
multipliers += [mult] * emb_len
|
||||||
|
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||||
|
i += embedding_length_in_tokens
|
||||||
|
|
||||||
|
if len(remade_tokens) > maxlen - 2:
|
||||||
|
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
||||||
|
ovf = remade_tokens[maxlen - 2:]
|
||||||
|
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
||||||
|
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
||||||
|
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||||
|
|
||||||
|
token_count = len(remade_tokens)
|
||||||
|
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||||
|
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
|
||||||
|
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
||||||
|
|
||||||
|
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
||||||
|
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
||||||
|
|
||||||
|
remade_batch_tokens.append(remade_tokens)
|
||||||
|
hijack_fixes.append(fixes)
|
||||||
|
batch_multipliers.append(multipliers)
|
||||||
|
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||||
|
|
||||||
|
|
||||||
|
def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
|
||||||
|
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts)
|
||||||
|
|
||||||
|
self.hijack.comments += hijack_comments
|
||||||
|
|
||||||
|
if len(used_custom_terms) > 0:
|
||||||
|
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
||||||
|
|
||||||
|
self.hijack.fixes = hijack_fixes
|
||||||
|
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
|
@ -79,7 +79,6 @@ class EmbeddingDatabase:
|
||||||
|
|
||||||
self.word_embeddings[embedding.name] = embedding
|
self.word_embeddings[embedding.name] = embedding
|
||||||
|
|
||||||
# TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working
|
|
||||||
ids = model.cond_stage_model.tokenize([embedding.name])[0]
|
ids = model.cond_stage_model.tokenize([embedding.name])[0]
|
||||||
|
|
||||||
first_id = ids[0]
|
first_id = ids[0]
|
||||||
|
|
|
@ -368,7 +368,7 @@ def update_token_counter(text, steps):
|
||||||
|
|
||||||
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
|
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
|
||||||
prompts = [prompt_text for step, prompt_text in flat_prompts]
|
prompts = [prompt_text for step, prompt_text in flat_prompts]
|
||||||
tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1])
|
token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
|
||||||
style_class = ' class="red"' if (token_count > max_length) else ""
|
style_class = ' class="red"' if (token_count > max_length) else ""
|
||||||
return f"<span {style_class}>{token_count}/{max_length}</span>"
|
return f"<span {style_class}>{token_count}/{max_length}</span>"
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue