validate textual inversion embeddings
This commit is contained in:
parent
f34c734172
commit
f55ac33d44
3 changed files with 41 additions and 7 deletions
|
@ -325,6 +325,9 @@ def load_model(checkpoint_info=None):
|
||||||
script_callbacks.model_loaded_callback(sd_model)
|
script_callbacks.model_loaded_callback(sd_model)
|
||||||
|
|
||||||
print("Model loaded.")
|
print("Model loaded.")
|
||||||
|
|
||||||
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload = True) # Reload embeddings after model load as they may or may not fit the model
|
||||||
|
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,8 @@ class Embedding:
|
||||||
self.vec = vec
|
self.vec = vec
|
||||||
self.name = name
|
self.name = name
|
||||||
self.step = step
|
self.step = step
|
||||||
|
self.shape = None
|
||||||
|
self.vectors = 0
|
||||||
self.cached_checksum = None
|
self.cached_checksum = None
|
||||||
self.sd_checkpoint = None
|
self.sd_checkpoint = None
|
||||||
self.sd_checkpoint_name = None
|
self.sd_checkpoint_name = None
|
||||||
|
@ -57,8 +59,10 @@ class EmbeddingDatabase:
|
||||||
def __init__(self, embeddings_dir):
|
def __init__(self, embeddings_dir):
|
||||||
self.ids_lookup = {}
|
self.ids_lookup = {}
|
||||||
self.word_embeddings = {}
|
self.word_embeddings = {}
|
||||||
|
self.skipped_embeddings = []
|
||||||
self.dir_mtime = None
|
self.dir_mtime = None
|
||||||
self.embeddings_dir = embeddings_dir
|
self.embeddings_dir = embeddings_dir
|
||||||
|
self.expected_shape = -1
|
||||||
|
|
||||||
def register_embedding(self, embedding, model):
|
def register_embedding(self, embedding, model):
|
||||||
|
|
||||||
|
@ -75,14 +79,35 @@ class EmbeddingDatabase:
|
||||||
|
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
def load_textual_inversion_embeddings(self):
|
def get_expected_shape(self):
|
||||||
|
expected_shape = -1 # initialize with unknown
|
||||||
|
idx = torch.tensor(0).to(shared.device)
|
||||||
|
if expected_shape == -1:
|
||||||
|
try: # matches sd15 signature
|
||||||
|
first_embedding = shared.sd_model.cond_stage_model.wrapped.transformer.text_model.embeddings.token_embedding.wrapped(idx)
|
||||||
|
expected_shape = first_embedding.shape[0]
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
if expected_shape == -1:
|
||||||
|
try: # matches sd20 signature
|
||||||
|
first_embedding = shared.sd_model.cond_stage_model.wrapped.model.token_embedding.wrapped(idx)
|
||||||
|
expected_shape = first_embedding.shape[0]
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
if expected_shape == -1:
|
||||||
|
print('Could not determine expected embeddings shape from model')
|
||||||
|
return expected_shape
|
||||||
|
|
||||||
|
def load_textual_inversion_embeddings(self, force_reload = False):
|
||||||
mt = os.path.getmtime(self.embeddings_dir)
|
mt = os.path.getmtime(self.embeddings_dir)
|
||||||
if self.dir_mtime is not None and mt <= self.dir_mtime:
|
if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.dir_mtime = mt
|
self.dir_mtime = mt
|
||||||
self.ids_lookup.clear()
|
self.ids_lookup.clear()
|
||||||
self.word_embeddings.clear()
|
self.word_embeddings.clear()
|
||||||
|
self.skipped_embeddings = []
|
||||||
|
self.expected_shape = self.get_expected_shape()
|
||||||
|
|
||||||
def process_file(path, filename):
|
def process_file(path, filename):
|
||||||
name = os.path.splitext(filename)[0]
|
name = os.path.splitext(filename)[0]
|
||||||
|
@ -122,7 +147,14 @@ class EmbeddingDatabase:
|
||||||
embedding.step = data.get('step', None)
|
embedding.step = data.get('step', None)
|
||||||
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
||||||
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
||||||
self.register_embedding(embedding, shared.sd_model)
|
embedding.vectors = vec.shape[0]
|
||||||
|
embedding.shape = vec.shape[-1]
|
||||||
|
|
||||||
|
if (self.expected_shape == -1) or (self.expected_shape == embedding.shape):
|
||||||
|
self.register_embedding(embedding, shared.sd_model)
|
||||||
|
else:
|
||||||
|
self.skipped_embeddings.append(name)
|
||||||
|
# print('Skipping embedding {name}: shape was {shape} expected {expected}'.format(name = name, shape = embedding.shape, expected = self.expected_shape))
|
||||||
|
|
||||||
for fn in os.listdir(self.embeddings_dir):
|
for fn in os.listdir(self.embeddings_dir):
|
||||||
try:
|
try:
|
||||||
|
@ -137,8 +169,9 @@ class EmbeddingDatabase:
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
|
print("Textual inversion embeddings {num} loaded: {val}".format(num = len(self.word_embeddings), val = ', '.join(self.word_embeddings.keys())))
|
||||||
print("Embeddings:", ', '.join(self.word_embeddings.keys()))
|
if (len(self.skipped_embeddings) > 0):
|
||||||
|
print("Textual inversion embeddings {num} skipped: {val}".format(num = len(self.skipped_embeddings), val = ', '.join(self.skipped_embeddings)))
|
||||||
|
|
||||||
def find_embedding_at_position(self, tokens, offset):
|
def find_embedding_at_position(self, tokens, offset):
|
||||||
token = tokens[offset]
|
token = tokens[offset]
|
||||||
|
|
|
@ -1157,8 +1157,6 @@ def create_ui():
|
||||||
with gr.Column(variant='panel'):
|
with gr.Column(variant='panel'):
|
||||||
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
|
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
|
||||||
|
|
||||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as train_interface:
|
with gr.Blocks(analytics_enabled=False) as train_interface:
|
||||||
with gr.Row().style(equal_height=False):
|
with gr.Row().style(equal_height=False):
|
||||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
|
||||||
|
|
Loading…
Reference in a new issue