Merge pull request #6253 from Shondoit/ti-optim
Save Optimizer next to TI embedding
This commit is contained in:
commit
7bbd984dda
2 changed files with 33 additions and 9 deletions
|
@ -356,7 +356,7 @@ options_templates.update(options_section(('system', "System"), {
|
||||||
options_templates.update(options_section(('training', "Training"), {
|
options_templates.update(options_section(('training', "Training"), {
|
||||||
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
||||||
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
|
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
|
||||||
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
|
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
|
||||||
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
||||||
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
||||||
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
||||||
|
|
|
@ -28,6 +28,7 @@ class Embedding:
|
||||||
self.cached_checksum = None
|
self.cached_checksum = None
|
||||||
self.sd_checkpoint = None
|
self.sd_checkpoint = None
|
||||||
self.sd_checkpoint_name = None
|
self.sd_checkpoint_name = None
|
||||||
|
self.optimizer_state_dict = None
|
||||||
|
|
||||||
def save(self, filename):
|
def save(self, filename):
|
||||||
embedding_data = {
|
embedding_data = {
|
||||||
|
@ -41,6 +42,13 @@ class Embedding:
|
||||||
|
|
||||||
torch.save(embedding_data, filename)
|
torch.save(embedding_data, filename)
|
||||||
|
|
||||||
|
if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
|
||||||
|
optimizer_saved_dict = {
|
||||||
|
'hash': self.checksum(),
|
||||||
|
'optimizer_state_dict': self.optimizer_state_dict,
|
||||||
|
}
|
||||||
|
torch.save(optimizer_saved_dict, filename + '.optim')
|
||||||
|
|
||||||
def checksum(self):
|
def checksum(self):
|
||||||
if self.cached_checksum is not None:
|
if self.cached_checksum is not None:
|
||||||
return self.cached_checksum
|
return self.cached_checksum
|
||||||
|
@ -95,9 +103,10 @@ class EmbeddingDatabase:
|
||||||
self.expected_shape = self.get_expected_shape()
|
self.expected_shape = self.get_expected_shape()
|
||||||
|
|
||||||
def process_file(path, filename):
|
def process_file(path, filename):
|
||||||
name = os.path.splitext(filename)[0]
|
name, ext = os.path.splitext(filename)
|
||||||
|
ext = ext.upper()
|
||||||
|
|
||||||
if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
||||||
embed_image = Image.open(path)
|
embed_image = Image.open(path)
|
||||||
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
|
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
|
||||||
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
|
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
|
||||||
|
@ -105,8 +114,10 @@ class EmbeddingDatabase:
|
||||||
else:
|
else:
|
||||||
data = extract_image_data_embed(embed_image)
|
data = extract_image_data_embed(embed_image)
|
||||||
name = data.get('name', name)
|
name = data.get('name', name)
|
||||||
else:
|
elif ext in ['.BIN', '.PT']:
|
||||||
data = torch.load(path, map_location="cpu")
|
data = torch.load(path, map_location="cpu")
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
# textual inversion embeddings
|
# textual inversion embeddings
|
||||||
if 'string_to_param' in data:
|
if 'string_to_param' in data:
|
||||||
|
@ -301,6 +312,20 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||||
|
|
||||||
embedding.vec.requires_grad = True
|
embedding.vec.requires_grad = True
|
||||||
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
|
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
|
||||||
|
if shared.opts.save_optimizer_state:
|
||||||
|
optimizer_state_dict = None
|
||||||
|
if os.path.exists(filename + '.optim'):
|
||||||
|
optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
|
||||||
|
if embedding.checksum() == optimizer_saved_dict.get('hash', None):
|
||||||
|
optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
||||||
|
|
||||||
|
if optimizer_state_dict is not None:
|
||||||
|
optimizer.load_state_dict(optimizer_state_dict)
|
||||||
|
print("Loaded existing optimizer from checkpoint")
|
||||||
|
else:
|
||||||
|
print("No saved optimizer exists in checkpoint")
|
||||||
|
|
||||||
|
|
||||||
scaler = torch.cuda.amp.GradScaler()
|
scaler = torch.cuda.amp.GradScaler()
|
||||||
|
|
||||||
batch_size = ds.batch_size
|
batch_size = ds.batch_size
|
||||||
|
@ -367,9 +392,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||||
# Before saving, change name to match current checkpoint.
|
# Before saving, change name to match current checkpoint.
|
||||||
embedding_name_every = f'{embedding_name}-{steps_done}'
|
embedding_name_every = f'{embedding_name}-{steps_done}'
|
||||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
|
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
|
||||||
#if shared.opts.save_optimizer_state:
|
save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
|
||||||
#embedding.optimizer_state_dict = optimizer.state_dict()
|
|
||||||
save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
|
|
||||||
embedding_yet_to_be_embedded = True
|
embedding_yet_to_be_embedded = True
|
||||||
|
|
||||||
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
|
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
|
||||||
|
@ -459,7 +482,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
</p>
|
</p>
|
||||||
"""
|
"""
|
||||||
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||||
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
pass
|
pass
|
||||||
|
@ -471,7 +494,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
|
|
||||||
return embedding, filename
|
return embedding, filename
|
||||||
|
|
||||||
def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True):
|
def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
|
||||||
old_embedding_name = embedding.name
|
old_embedding_name = embedding.name
|
||||||
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
|
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
|
||||||
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
|
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
|
||||||
|
@ -482,6 +505,7 @@ def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cache
|
||||||
if remove_cached_checksum:
|
if remove_cached_checksum:
|
||||||
embedding.cached_checksum = None
|
embedding.cached_checksum = None
|
||||||
embedding.name = embedding_name
|
embedding.name = embedding_name
|
||||||
|
embedding.optimizer_state_dict = optimizer.state_dict()
|
||||||
embedding.save(filename)
|
embedding.save(filename)
|
||||||
except:
|
except:
|
||||||
embedding.sd_checkpoint = old_sd_checkpoint
|
embedding.sd_checkpoint = old_sd_checkpoint
|
||||||
|
|
Loading…
Reference in a new issue