add BREAK keyword to end current text chunk and start the next

This commit is contained in:
AUTOMATIC 2023-01-15 22:29:53 +03:00
parent 205991df78
commit 8e2aeee4a1
2 changed files with 19 additions and 5 deletions

View file

@ -274,6 +274,7 @@ re_attention = re.compile(r"""
: :
""", re.X) """, re.X)
re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
def parse_prompt_attention(text): def parse_prompt_attention(text):
""" """
@ -339,7 +340,11 @@ def parse_prompt_attention(text):
elif text == ']' and len(square_brackets) > 0: elif text == ']' and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier) multiply_range(square_brackets.pop(), square_bracket_multiplier)
else: else:
res.append([text, 1.0]) parts = re.split(re_break, text)
for i, part in enumerate(parts):
if i > 0:
res.append(["BREAK", -1])
res.append([part, 1.0])
for pos in round_brackets: for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier) multiply_range(pos, round_bracket_multiplier)

View file

@ -96,13 +96,18 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
token_count = 0 token_count = 0
last_comma = -1 last_comma = -1
def next_chunk(): def next_chunk(is_last=False):
"""puts current chunk into the list of results and produces the next one - empty""" """puts current chunk into the list of results and produces the next one - empty;
if is_last is true, tokens <end-of-text> tokens at the end won't add to token_count"""
nonlocal token_count nonlocal token_count
nonlocal last_comma nonlocal last_comma
nonlocal chunk nonlocal chunk
token_count += len(chunk.tokens) if is_last:
token_count += len(chunk.tokens)
else:
token_count += self.chunk_length
to_add = self.chunk_length - len(chunk.tokens) to_add = self.chunk_length - len(chunk.tokens)
if to_add > 0: if to_add > 0:
chunk.tokens += [self.id_end] * to_add chunk.tokens += [self.id_end] * to_add
@ -116,6 +121,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
chunk = PromptChunk() chunk = PromptChunk()
for tokens, (text, weight) in zip(tokenized, parsed): for tokens, (text, weight) in zip(tokenized, parsed):
if text == 'BREAK' and weight == -1:
next_chunk()
continue
position = 0 position = 0
while position < len(tokens): while position < len(tokens):
token = tokens[position] token = tokens[position]
@ -159,7 +168,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
position += embedding_length_in_tokens position += embedding_length_in_tokens
if len(chunk.tokens) > 0 or len(chunks) == 0: if len(chunk.tokens) > 0 or len(chunks) == 0:
next_chunk() next_chunk(is_last=True)
return chunks, token_count return chunks, token_count