Merge pull request #6253 from Shondoit/ti-optim

Save Optimizer next to TI embedding
This commit is contained in:
AUTOMATIC1111 2023-01-04 14:09:13 +03:00 committed by GitHub
commit 7bbd984dda
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 9 deletions

View file

@ -356,7 +356,7 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."), "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),

View file

@ -28,6 +28,7 @@ class Embedding:
self.cached_checksum = None self.cached_checksum = None
self.sd_checkpoint = None self.sd_checkpoint = None
self.sd_checkpoint_name = None self.sd_checkpoint_name = None
self.optimizer_state_dict = None
def save(self, filename): def save(self, filename):
embedding_data = { embedding_data = {
@ -41,6 +42,13 @@ class Embedding:
torch.save(embedding_data, filename) torch.save(embedding_data, filename)
if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
optimizer_saved_dict = {
'hash': self.checksum(),
'optimizer_state_dict': self.optimizer_state_dict,
}
torch.save(optimizer_saved_dict, filename + '.optim')
def checksum(self): def checksum(self):
if self.cached_checksum is not None: if self.cached_checksum is not None:
return self.cached_checksum return self.cached_checksum
@ -95,9 +103,10 @@ class EmbeddingDatabase:
self.expected_shape = self.get_expected_shape() self.expected_shape = self.get_expected_shape()
def process_file(path, filename): def process_file(path, filename):
name = os.path.splitext(filename)[0] name, ext = os.path.splitext(filename)
ext = ext.upper()
if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']: if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
embed_image = Image.open(path) embed_image = Image.open(path)
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
data = embedding_from_b64(embed_image.text['sd-ti-embedding']) data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
@ -105,8 +114,10 @@ class EmbeddingDatabase:
else: else:
data = extract_image_data_embed(embed_image) data = extract_image_data_embed(embed_image)
name = data.get('name', name) name = data.get('name', name)
else: elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu") data = torch.load(path, map_location="cpu")
else:
return
# textual inversion embeddings # textual inversion embeddings
if 'string_to_param' in data: if 'string_to_param' in data:
@ -301,6 +312,20 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
embedding.vec.requires_grad = True embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
if shared.opts.save_optimizer_state:
optimizer_state_dict = None
if os.path.exists(filename + '.optim'):
optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
if embedding.checksum() == optimizer_saved_dict.get('hash', None):
optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
if optimizer_state_dict is not None:
optimizer.load_state_dict(optimizer_state_dict)
print("Loaded existing optimizer from checkpoint")
else:
print("No saved optimizer exists in checkpoint")
scaler = torch.cuda.amp.GradScaler() scaler = torch.cuda.amp.GradScaler()
batch_size = ds.batch_size batch_size = ds.batch_size
@ -367,9 +392,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
# Before saving, change name to match current checkpoint. # Before saving, change name to match current checkpoint.
embedding_name_every = f'{embedding_name}-{steps_done}' embedding_name_every = f'{embedding_name}-{steps_done}'
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt') last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
#if shared.opts.save_optimizer_state: save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
#embedding.optimizer_state_dict = optimizer.state_dict()
save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
embedding_yet_to_be_embedded = True embedding_yet_to_be_embedded = True
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, { write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
@ -459,7 +482,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>
""" """
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True) save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
except Exception: except Exception:
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
pass pass
@ -471,7 +494,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
return embedding, filename return embedding, filename
def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True): def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
old_embedding_name = embedding.name old_embedding_name = embedding.name
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
@ -482,6 +505,7 @@ def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cache
if remove_cached_checksum: if remove_cached_checksum:
embedding.cached_checksum = None embedding.cached_checksum = None
embedding.name = embedding_name embedding.name = embedding_name
embedding.optimizer_state_dict = optimizer.state_dict()
embedding.save(filename) embedding.save(filename)
except: except:
embedding.sd_checkpoint = old_sd_checkpoint embedding.sd_checkpoint = old_sd_checkpoint