resolve conflicts

This commit is contained in:
aria1th 2022-10-30 20:39:04 +09:00
parent 20194fd975
commit 9d96d7d0a0

View file

@ -21,6 +21,7 @@ from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_norm
from collections import defaultdict, deque from collections import defaultdict, deque
from statistics import stdev, mean from statistics import stdev, mean
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
class HypernetworkModule(torch.nn.Module): class HypernetworkModule(torch.nn.Module):
multiplier = 1.0 multiplier = 1.0
@ -139,6 +140,8 @@ class Hypernetwork:
self.weight_init = weight_init self.weight_init = weight_init
self.add_layer_norm = add_layer_norm self.add_layer_norm = add_layer_norm
self.use_dropout = use_dropout self.use_dropout = use_dropout
self.optimizer_name = None
self.optimizer_state_dict = None
for size in enable_sizes or []: for size in enable_sizes or []:
self.layers[size] = ( self.layers[size] = (
@ -171,6 +174,10 @@ class Hypernetwork:
state_dict['use_dropout'] = self.use_dropout state_dict['use_dropout'] = self.use_dropout
state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
if self.optimizer_name is not None:
state_dict['optimizer_name'] = self.optimizer_name
if self.optimizer_state_dict:
state_dict['optimizer_state_dict'] = self.optimizer_state_dict
torch.save(state_dict, filename) torch.save(state_dict, filename)
@ -190,7 +197,14 @@ class Hypernetwork:
self.add_layer_norm = state_dict.get('is_layer_norm', False) self.add_layer_norm = state_dict.get('is_layer_norm', False)
print(f"Layer norm is set to {self.add_layer_norm}") print(f"Layer norm is set to {self.add_layer_norm}")
self.use_dropout = state_dict.get('use_dropout', False) self.use_dropout = state_dict.get('use_dropout', False)
print(f"Dropout usage is set to {self.use_dropout}" ) print(f"Dropout usage is set to {self.use_dropout}")
self.optimizer_name = state_dict.get('optimizer_name', 'AdamW')
print(f"Optimizer name is {self.optimizer_name}")
self.optimizer_state_dict = state_dict.get('optimizer_state_dict', None)
if self.optimizer_state_dict:
print("Loaded existing optimizer from checkpoint")
else:
print("No saved optimizer exists in checkpoint")
for size, sd in state_dict.items(): for size, sd in state_dict.items():
if type(size) == int: if type(size) == int:
@ -392,8 +406,19 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
weights = hypernetwork.weights() weights = hypernetwork.weights()
for weight in weights: for weight in weights:
weight.requires_grad = True weight.requires_grad = True
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc... # Here we use optimizer from saved HN, or we can specify as UI option.
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) if (optimizer_name := hypernetwork.optimizer_name) in optimizer_dict:
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
else:
print(f"Optimizer type {optimizer_name} is not defined!")
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
optimizer_name = 'AdamW'
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
try:
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
except RuntimeError as e:
print("Cannot resume from saved optimizer!")
print(e)
steps_without_grad = 0 steps_without_grad = 0
@ -455,8 +480,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
# Before saving, change name to match current checkpoint. # Before saving, change name to match current checkpoint.
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt') last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
hypernetwork.optimizer_name = optimizer_name
if shared.opts.save_optimizer_state:
hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file) save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
"loss": f"{previous_mean_loss:.7f}", "loss": f"{previous_mean_loss:.7f}",
"learn_rate": scheduler.learn_rate "learn_rate": scheduler.learn_rate
@ -514,14 +542,18 @@ Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/> Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>
""" """
report_statistics(loss_dict) report_statistics(loss_dict)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.optimizer_name = optimizer_name
if shared.opts.save_optimizer_state:
hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename) save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
del optimizer
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
return hypernetwork, filename return hypernetwork, filename
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename): def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
old_hypernetwork_name = hypernetwork.name old_hypernetwork_name = hypernetwork.name
old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None