more comments
This commit is contained in:
parent
08066676a4
commit
1740c33547
1 changed files with 16 additions and 5 deletions
|
@ -3,7 +3,7 @@ from collections import namedtuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from modules import prompt_parser, devices
|
from modules import prompt_parser, devices, sd_hijack
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,14 +22,24 @@ class PromptChunk:
|
||||||
|
|
||||||
|
|
||||||
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
|
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
|
||||||
"""This is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt chunk"""
|
"""An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt
|
||||||
|
chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
|
||||||
|
are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
|
||||||
|
|
||||||
|
|
||||||
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||||
|
"""A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
|
||||||
|
have unlimited prompt length and assign weights to tokens in prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, wrapped, hijack):
|
def __init__(self, wrapped, hijack):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.wrapped = wrapped
|
self.wrapped = wrapped
|
||||||
self.hijack = hijack
|
"""Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
|
||||||
|
depending on model."""
|
||||||
|
|
||||||
|
self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
|
||||||
self.chunk_length = 75
|
self.chunk_length = 75
|
||||||
|
|
||||||
def empty_chunk(self):
|
def empty_chunk(self):
|
||||||
|
@ -55,7 +65,8 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||||
converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
|
converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
|
||||||
All python lists with tokens are assumed to have same length, usually 77.
|
All python lists with tokens are assumed to have same length, usually 77.
|
||||||
if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
|
if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
|
||||||
model - can be 768 and 1024
|
model - can be 768 and 1024.
|
||||||
|
Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -113,7 +124,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||||
last_comma = len(chunk.tokens)
|
last_comma = len(chunk.tokens)
|
||||||
|
|
||||||
# this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
|
# this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
|
||||||
# is a setting that specifies that is there is a comma nearby, the text after comma should be moved out of this chunk and into the next.
|
# is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
|
||||||
elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
|
elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
|
||||||
break_location = last_comma + 1
|
break_location = last_comma + 1
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue