Implement PR #3309 but for embeddings.
This commit is contained in:
parent
c2dc9bfa89
commit
4875a6c217
1 changed files with 8 additions and 1 deletions
|
@ -167,6 +167,8 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
|
|||
for i in range(num_vectors_per_token):
|
||||
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
||||
|
||||
# Remove illegal characters from name.
|
||||
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
||||
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
|
||||
if not overwrite_old:
|
||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||
|
@ -287,7 +289,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
|||
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
|
||||
|
||||
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
|
||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
|
||||
# Before saving, change name to match current checkpoint.
|
||||
embedding.name = f'{embedding_name}-{embedding.step}'
|
||||
last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
|
||||
embedding.save(last_saved_file)
|
||||
embedding_yet_to_be_embedded = True
|
||||
|
||||
|
@ -374,6 +378,9 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||
embedding.sd_checkpoint = checkpoint.hash
|
||||
embedding.sd_checkpoint_name = checkpoint.model_name
|
||||
embedding.cached_checksum = None
|
||||
# Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention).
|
||||
embedding.name = embedding_name
|
||||
filename = os.path.join(shared.cmd_opts.embedding_dir, f'{embedding.name}.pt')
|
||||
embedding.save(filename)
|
||||
|
||||
return embedding, filename
|
||||
|
|
Loading…
Reference in a new issue