make it possible for extensions/scripts to add their own embedding directories
This commit is contained in:
parent
a0c87f1fdf
commit
085427de0e
2 changed files with 106 additions and 67 deletions
|
@ -83,10 +83,12 @@ class StableDiffusionModelHijack:
|
||||||
clip = None
|
clip = None
|
||||||
optimization_method = None
|
optimization_method = None
|
||||||
|
|
||||||
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
|
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
|
||||||
|
|
||||||
def hijack(self, m):
|
def hijack(self, m):
|
||||||
|
|
||||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||||
model_embeddings = m.cond_stage_model.roberta.embeddings
|
model_embeddings = m.cond_stage_model.roberta.embeddings
|
||||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
|
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
|
||||||
|
@ -117,7 +119,6 @@ class StableDiffusionModelHijack:
|
||||||
self.layers = flatten(m)
|
self.layers = flatten(m)
|
||||||
|
|
||||||
def undo_hijack(self, m):
|
def undo_hijack(self, m):
|
||||||
|
|
||||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||||
|
|
||||||
|
|
|
@ -66,17 +66,41 @@ class Embedding:
|
||||||
return self.cached_checksum
|
return self.cached_checksum
|
||||||
|
|
||||||
|
|
||||||
|
class DirWithTextualInversionEmbeddings:
|
||||||
|
def __init__(self, path):
|
||||||
|
self.path = path
|
||||||
|
self.mtime = None
|
||||||
|
|
||||||
|
def has_changed(self):
|
||||||
|
if not os.path.isdir(self.path):
|
||||||
|
return False
|
||||||
|
|
||||||
|
mt = os.path.getmtime(self.path)
|
||||||
|
if self.mtime is None or mt > self.mtime:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
if not os.path.isdir(self.path):
|
||||||
|
return
|
||||||
|
|
||||||
|
self.mtime = os.path.getmtime(self.path)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingDatabase:
|
class EmbeddingDatabase:
|
||||||
def __init__(self, embeddings_dir):
|
def __init__(self):
|
||||||
self.ids_lookup = {}
|
self.ids_lookup = {}
|
||||||
self.word_embeddings = {}
|
self.word_embeddings = {}
|
||||||
self.skipped_embeddings = {}
|
self.skipped_embeddings = {}
|
||||||
self.dir_mtime = None
|
|
||||||
self.embeddings_dir = embeddings_dir
|
|
||||||
self.expected_shape = -1
|
self.expected_shape = -1
|
||||||
|
self.embedding_dirs = {}
|
||||||
|
|
||||||
|
def add_embedding_dir(self, path):
|
||||||
|
self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
|
||||||
|
|
||||||
|
def clear_embedding_dirs(self):
|
||||||
|
self.embedding_dirs.clear()
|
||||||
|
|
||||||
def register_embedding(self, embedding, model):
|
def register_embedding(self, embedding, model):
|
||||||
|
|
||||||
self.word_embeddings[embedding.name] = embedding
|
self.word_embeddings[embedding.name] = embedding
|
||||||
|
|
||||||
ids = model.cond_stage_model.tokenize([embedding.name])[0]
|
ids = model.cond_stage_model.tokenize([embedding.name])[0]
|
||||||
|
@ -93,18 +117,7 @@ class EmbeddingDatabase:
|
||||||
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
|
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
|
||||||
return vec.shape[1]
|
return vec.shape[1]
|
||||||
|
|
||||||
def load_textual_inversion_embeddings(self, force_reload = False):
|
def load_from_file(self, path, filename):
|
||||||
mt = os.path.getmtime(self.embeddings_dir)
|
|
||||||
if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.dir_mtime = mt
|
|
||||||
self.ids_lookup.clear()
|
|
||||||
self.word_embeddings.clear()
|
|
||||||
self.skipped_embeddings.clear()
|
|
||||||
self.expected_shape = self.get_expected_shape()
|
|
||||||
|
|
||||||
def process_file(path, filename):
|
|
||||||
name, ext = os.path.splitext(filename)
|
name, ext = os.path.splitext(filename)
|
||||||
ext = ext.upper()
|
ext = ext.upper()
|
||||||
|
|
||||||
|
@ -155,7 +168,11 @@ class EmbeddingDatabase:
|
||||||
else:
|
else:
|
||||||
self.skipped_embeddings[name] = embedding
|
self.skipped_embeddings[name] = embedding
|
||||||
|
|
||||||
for root, dirs, fns in os.walk(self.embeddings_dir):
|
def load_from_dir(self, embdir):
|
||||||
|
if not os.path.isdir(embdir.path):
|
||||||
|
return
|
||||||
|
|
||||||
|
for root, dirs, fns in os.walk(embdir.path):
|
||||||
for fn in fns:
|
for fn in fns:
|
||||||
try:
|
try:
|
||||||
fullfn = os.path.join(root, fn)
|
fullfn = os.path.join(root, fn)
|
||||||
|
@ -163,12 +180,32 @@ class EmbeddingDatabase:
|
||||||
if os.stat(fullfn).st_size == 0:
|
if os.stat(fullfn).st_size == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
process_file(fullfn, fn)
|
self.load_from_file(fullfn, fn)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error loading embedding {fn}:", file=sys.stderr)
|
print(f"Error loading embedding {fn}:", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
def load_textual_inversion_embeddings(self, force_reload=False):
|
||||||
|
if not force_reload:
|
||||||
|
need_reload = False
|
||||||
|
for path, embdir in self.embedding_dirs.items():
|
||||||
|
if embdir.has_changed():
|
||||||
|
need_reload = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not need_reload:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.ids_lookup.clear()
|
||||||
|
self.word_embeddings.clear()
|
||||||
|
self.skipped_embeddings.clear()
|
||||||
|
self.expected_shape = self.get_expected_shape()
|
||||||
|
|
||||||
|
for path, embdir in self.embedding_dirs.items():
|
||||||
|
self.load_from_dir(embdir)
|
||||||
|
embdir.update()
|
||||||
|
|
||||||
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
|
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
|
||||||
if len(self.skipped_embeddings) > 0:
|
if len(self.skipped_embeddings) > 0:
|
||||||
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
|
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
|
||||||
|
@ -251,14 +288,15 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
|
||||||
assert os.path.isfile(template_file), "Prompt template file doesn't exist"
|
assert os.path.isfile(template_file), "Prompt template file doesn't exist"
|
||||||
assert steps, "Max steps is empty or 0"
|
assert steps, "Max steps is empty or 0"
|
||||||
assert isinstance(steps, int), "Max steps must be integer"
|
assert isinstance(steps, int), "Max steps must be integer"
|
||||||
assert steps > 0 , "Max steps must be positive"
|
assert steps > 0, "Max steps must be positive"
|
||||||
assert isinstance(save_model_every, int), "Save {name} must be integer"
|
assert isinstance(save_model_every, int), "Save {name} must be integer"
|
||||||
assert save_model_every >= 0 , "Save {name} must be positive or 0"
|
assert save_model_every >= 0, "Save {name} must be positive or 0"
|
||||||
assert isinstance(create_image_every, int), "Create image must be integer"
|
assert isinstance(create_image_every, int), "Create image must be integer"
|
||||||
assert create_image_every >= 0 , "Create image must be positive or 0"
|
assert create_image_every >= 0, "Create image must be positive or 0"
|
||||||
if save_model_every or create_image_every:
|
if save_model_every or create_image_every:
|
||||||
assert log_directory, "Log directory is empty"
|
assert log_directory, "Log directory is empty"
|
||||||
|
|
||||||
|
|
||||||
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
save_embedding_every = save_embedding_every or 0
|
save_embedding_every = save_embedding_every or 0
|
||||||
create_image_every = create_image_every or 0
|
create_image_every = create_image_every or 0
|
||||||
|
|
Loading…
Reference in a new issue