use hash to check valid optim
This commit is contained in:
parent
0b143c1163
commit
1764ac3c8b
1 changed files with 10 additions and 5 deletions
|
@ -177,11 +177,12 @@ class Hypernetwork:
|
||||||
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
||||||
if self.optimizer_name is not None:
|
if self.optimizer_name is not None:
|
||||||
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
|
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
|
||||||
if self.optimizer_state_dict:
|
|
||||||
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
|
|
||||||
torch.save(optimizer_saved_dict, filename + '.optim')
|
|
||||||
|
|
||||||
torch.save(state_dict, filename)
|
torch.save(state_dict, filename)
|
||||||
|
if self.optimizer_state_dict:
|
||||||
|
optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
|
||||||
|
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
|
||||||
|
torch.save(optimizer_saved_dict, filename + '.optim')
|
||||||
|
|
||||||
def load(self, filename):
|
def load(self, filename):
|
||||||
self.filename = filename
|
self.filename = filename
|
||||||
|
@ -204,7 +205,10 @@ class Hypernetwork:
|
||||||
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
|
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
|
||||||
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
|
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
|
||||||
print(f"Optimizer name is {self.optimizer_name}")
|
print(f"Optimizer name is {self.optimizer_name}")
|
||||||
|
if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
|
||||||
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
||||||
|
else:
|
||||||
|
self.optimizer_state_dict = None
|
||||||
if self.optimizer_state_dict:
|
if self.optimizer_state_dict:
|
||||||
print("Loaded existing optimizer from checkpoint")
|
print("Loaded existing optimizer from checkpoint")
|
||||||
else:
|
else:
|
||||||
|
@ -229,7 +233,7 @@ def list_hypernetworks(path):
|
||||||
name = os.path.splitext(os.path.basename(filename))[0]
|
name = os.path.splitext(os.path.basename(filename))[0]
|
||||||
# Prevent a hypothetical "None.pt" from being listed.
|
# Prevent a hypothetical "None.pt" from being listed.
|
||||||
if name != "None":
|
if name != "None":
|
||||||
res[name] = filename
|
res[name + f"({sd_models.model_hash(filename)})"] = filename
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@ -375,6 +379,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
else:
|
else:
|
||||||
hypernetwork_dir = None
|
hypernetwork_dir = None
|
||||||
|
|
||||||
|
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
|
||||||
if create_image_every > 0:
|
if create_image_every > 0:
|
||||||
images_dir = os.path.join(log_directory, "images")
|
images_dir = os.path.join(log_directory, "images")
|
||||||
os.makedirs(images_dir, exist_ok=True)
|
os.makedirs(images_dir, exist_ok=True)
|
||||||
|
|
Loading…
Reference in a new issue