add BREAK keyword to end current text chunk and start the next
This commit is contained in:
parent
205991df78
commit
8e2aeee4a1
2 changed files with 19 additions and 5 deletions
|
@ -274,6 +274,7 @@ re_attention = re.compile(r"""
|
||||||
:
|
:
|
||||||
""", re.X)
|
""", re.X)
|
||||||
|
|
||||||
|
re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
|
||||||
|
|
||||||
def parse_prompt_attention(text):
|
def parse_prompt_attention(text):
|
||||||
"""
|
"""
|
||||||
|
@ -339,7 +340,11 @@ def parse_prompt_attention(text):
|
||||||
elif text == ']' and len(square_brackets) > 0:
|
elif text == ']' and len(square_brackets) > 0:
|
||||||
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
||||||
else:
|
else:
|
||||||
res.append([text, 1.0])
|
parts = re.split(re_break, text)
|
||||||
|
for i, part in enumerate(parts):
|
||||||
|
if i > 0:
|
||||||
|
res.append(["BREAK", -1])
|
||||||
|
res.append([part, 1.0])
|
||||||
|
|
||||||
for pos in round_brackets:
|
for pos in round_brackets:
|
||||||
multiply_range(pos, round_bracket_multiplier)
|
multiply_range(pos, round_bracket_multiplier)
|
||||||
|
|
|
@ -96,13 +96,18 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||||
token_count = 0
|
token_count = 0
|
||||||
last_comma = -1
|
last_comma = -1
|
||||||
|
|
||||||
def next_chunk():
|
def next_chunk(is_last=False):
|
||||||
"""puts current chunk into the list of results and produces the next one - empty"""
|
"""puts current chunk into the list of results and produces the next one - empty;
|
||||||
|
if is_last is true, tokens <end-of-text> tokens at the end won't add to token_count"""
|
||||||
nonlocal token_count
|
nonlocal token_count
|
||||||
nonlocal last_comma
|
nonlocal last_comma
|
||||||
nonlocal chunk
|
nonlocal chunk
|
||||||
|
|
||||||
|
if is_last:
|
||||||
token_count += len(chunk.tokens)
|
token_count += len(chunk.tokens)
|
||||||
|
else:
|
||||||
|
token_count += self.chunk_length
|
||||||
|
|
||||||
to_add = self.chunk_length - len(chunk.tokens)
|
to_add = self.chunk_length - len(chunk.tokens)
|
||||||
if to_add > 0:
|
if to_add > 0:
|
||||||
chunk.tokens += [self.id_end] * to_add
|
chunk.tokens += [self.id_end] * to_add
|
||||||
|
@ -116,6 +121,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||||
chunk = PromptChunk()
|
chunk = PromptChunk()
|
||||||
|
|
||||||
for tokens, (text, weight) in zip(tokenized, parsed):
|
for tokens, (text, weight) in zip(tokenized, parsed):
|
||||||
|
if text == 'BREAK' and weight == -1:
|
||||||
|
next_chunk()
|
||||||
|
continue
|
||||||
|
|
||||||
position = 0
|
position = 0
|
||||||
while position < len(tokens):
|
while position < len(tokens):
|
||||||
token = tokens[position]
|
token = tokens[position]
|
||||||
|
@ -159,7 +168,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||||
position += embedding_length_in_tokens
|
position += embedding_length_in_tokens
|
||||||
|
|
||||||
if len(chunk.tokens) > 0 or len(chunks) == 0:
|
if len(chunk.tokens) > 0 or len(chunks) == 0:
|
||||||
next_chunk()
|
next_chunk(is_last=True)
|
||||||
|
|
||||||
return chunks, token_count
|
return chunks, token_count
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue