resolve conflict - first revert
This commit is contained in:
parent
1764ac3c8b
commit
0abb39f461
1 changed files with 52 additions and 71 deletions
|
@ -21,7 +21,6 @@ from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_norm
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
from statistics import stdev, mean
|
from statistics import stdev, mean
|
||||||
|
|
||||||
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
|
|
||||||
|
|
||||||
class HypernetworkModule(torch.nn.Module):
|
class HypernetworkModule(torch.nn.Module):
|
||||||
multiplier = 1.0
|
multiplier = 1.0
|
||||||
|
@ -34,9 +33,12 @@ class HypernetworkModule(torch.nn.Module):
|
||||||
"tanh": torch.nn.Tanh,
|
"tanh": torch.nn.Tanh,
|
||||||
"sigmoid": torch.nn.Sigmoid,
|
"sigmoid": torch.nn.Sigmoid,
|
||||||
}
|
}
|
||||||
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
|
activation_dict.update(
|
||||||
|
{cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if
|
||||||
|
inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
|
||||||
|
|
||||||
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False):
|
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
|
||||||
|
add_layer_norm=False, use_dropout=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
assert layer_structure is not None, "layer_structure must not be None"
|
assert layer_structure is not None, "layer_structure must not be None"
|
||||||
|
@ -128,7 +130,8 @@ class Hypernetwork:
|
||||||
filename = None
|
filename = None
|
||||||
name = None
|
name = None
|
||||||
|
|
||||||
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
|
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None,
|
||||||
|
add_layer_norm=False, use_dropout=False):
|
||||||
self.filename = None
|
self.filename = None
|
||||||
self.name = name
|
self.name = name
|
||||||
self.layers = {}
|
self.layers = {}
|
||||||
|
@ -140,13 +143,13 @@ class Hypernetwork:
|
||||||
self.weight_init = weight_init
|
self.weight_init = weight_init
|
||||||
self.add_layer_norm = add_layer_norm
|
self.add_layer_norm = add_layer_norm
|
||||||
self.use_dropout = use_dropout
|
self.use_dropout = use_dropout
|
||||||
self.optimizer_name = None
|
|
||||||
self.optimizer_state_dict = None
|
|
||||||
|
|
||||||
for size in enable_sizes or []:
|
for size in enable_sizes or []:
|
||||||
self.layers[size] = (
|
self.layers[size] = (
|
||||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
|
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
||||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
|
self.add_layer_norm, self.use_dropout),
|
||||||
|
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
||||||
|
self.add_layer_norm, self.use_dropout),
|
||||||
)
|
)
|
||||||
|
|
||||||
def weights(self):
|
def weights(self):
|
||||||
|
@ -161,7 +164,6 @@ class Hypernetwork:
|
||||||
|
|
||||||
def save(self, filename):
|
def save(self, filename):
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
optimizer_saved_dict = {}
|
|
||||||
|
|
||||||
for k, v in self.layers.items():
|
for k, v in self.layers.items():
|
||||||
state_dict[k] = (v[0].state_dict(), v[1].state_dict())
|
state_dict[k] = (v[0].state_dict(), v[1].state_dict())
|
||||||
|
@ -175,14 +177,8 @@ class Hypernetwork:
|
||||||
state_dict['use_dropout'] = self.use_dropout
|
state_dict['use_dropout'] = self.use_dropout
|
||||||
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
||||||
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
||||||
if self.optimizer_name is not None:
|
|
||||||
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
|
|
||||||
|
|
||||||
torch.save(state_dict, filename)
|
torch.save(state_dict, filename)
|
||||||
if self.optimizer_state_dict:
|
|
||||||
optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
|
|
||||||
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
|
|
||||||
torch.save(optimizer_saved_dict, filename + '.optim')
|
|
||||||
|
|
||||||
def load(self, filename):
|
def load(self, filename):
|
||||||
self.filename = filename
|
self.filename = filename
|
||||||
|
@ -202,23 +198,13 @@ class Hypernetwork:
|
||||||
self.use_dropout = state_dict.get('use_dropout', False)
|
self.use_dropout = state_dict.get('use_dropout', False)
|
||||||
print(f"Dropout usage is set to {self.use_dropout}")
|
print(f"Dropout usage is set to {self.use_dropout}")
|
||||||
|
|
||||||
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
|
|
||||||
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
|
|
||||||
print(f"Optimizer name is {self.optimizer_name}")
|
|
||||||
if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
|
|
||||||
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
|
||||||
else:
|
|
||||||
self.optimizer_state_dict = None
|
|
||||||
if self.optimizer_state_dict:
|
|
||||||
print("Loaded existing optimizer from checkpoint")
|
|
||||||
else:
|
|
||||||
print("No saved optimizer exists in checkpoint")
|
|
||||||
|
|
||||||
for size, sd in state_dict.items():
|
for size, sd in state_dict.items():
|
||||||
if type(size) == int:
|
if type(size) == int:
|
||||||
self.layers[size] = (
|
self.layers[size] = (
|
||||||
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
|
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
|
||||||
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
|
self.add_layer_norm, self.use_dropout),
|
||||||
|
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
|
||||||
|
self.add_layer_norm, self.use_dropout),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.name = state_dict.get('name', self.name)
|
self.name = state_dict.get('name', self.name)
|
||||||
|
@ -233,7 +219,7 @@ def list_hypernetworks(path):
|
||||||
name = os.path.splitext(os.path.basename(filename))[0]
|
name = os.path.splitext(os.path.basename(filename))[0]
|
||||||
# Prevent a hypothetical "None.pt" from being listed.
|
# Prevent a hypothetical "None.pt" from being listed.
|
||||||
if name != "None":
|
if name != "None":
|
||||||
res[name + f"({sd_models.model_hash(filename)})"] = filename
|
res[name] = filename
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@ -352,14 +338,18 @@ def report_statistics(loss_info:dict):
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
|
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width,
|
||||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
training_height, steps, create_image_every, save_hypernetwork_every, template_file,
|
||||||
|
preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps,
|
||||||
|
preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||||
from modules import images
|
from modules import images
|
||||||
|
|
||||||
save_hypernetwork_every = save_hypernetwork_every or 0
|
save_hypernetwork_every = save_hypernetwork_every or 0
|
||||||
create_image_every = create_image_every or 0
|
create_image_every = create_image_every or 0
|
||||||
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
|
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps,
|
||||||
|
save_hypernetwork_every, create_image_every, log_directory,
|
||||||
|
name="hypernetwork")
|
||||||
|
|
||||||
path = shared.hypernetworks.get(hypernetwork_name, None)
|
path = shared.hypernetworks.get(hypernetwork_name, None)
|
||||||
shared.loaded_hypernetwork = Hypernetwork()
|
shared.loaded_hypernetwork = Hypernetwork()
|
||||||
|
@ -379,7 +369,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
else:
|
else:
|
||||||
hypernetwork_dir = None
|
hypernetwork_dir = None
|
||||||
|
|
||||||
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
|
|
||||||
if create_image_every > 0:
|
if create_image_every > 0:
|
||||||
images_dir = os.path.join(log_directory, "images")
|
images_dir = os.path.join(log_directory, "images")
|
||||||
os.makedirs(images_dir, exist_ok=True)
|
os.makedirs(images_dir, exist_ok=True)
|
||||||
|
@ -399,7 +388,13 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
with torch.autocast("cuda"):
|
with torch.autocast("cuda"):
|
||||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width,
|
||||||
|
height=training_height,
|
||||||
|
repeats=shared.opts.training_image_repeats_per_epoch,
|
||||||
|
placeholder_token=hypernetwork_name,
|
||||||
|
model=shared.sd_model, device=devices.device,
|
||||||
|
template_file=template_file, include_cond=True,
|
||||||
|
batch_size=batch_size)
|
||||||
|
|
||||||
if unload:
|
if unload:
|
||||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||||
|
@ -415,19 +410,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
weights = hypernetwork.weights()
|
weights = hypernetwork.weights()
|
||||||
for weight in weights:
|
for weight in weights:
|
||||||
weight.requires_grad = True
|
weight.requires_grad = True
|
||||||
# Here we use optimizer from saved HN, or we can specify as UI option.
|
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
|
||||||
if (optimizer_name := hypernetwork.optimizer_name) in optimizer_dict:
|
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
|
||||||
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
|
|
||||||
else:
|
|
||||||
print(f"Optimizer type {optimizer_name} is not defined!")
|
|
||||||
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
|
|
||||||
optimizer_name = 'AdamW'
|
|
||||||
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
|
|
||||||
try:
|
|
||||||
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
|
|
||||||
except RuntimeError as e:
|
|
||||||
print("Cannot resume from saved optimizer!")
|
|
||||||
print(e)
|
|
||||||
|
|
||||||
steps_without_grad = 0
|
steps_without_grad = 0
|
||||||
|
|
||||||
|
@ -489,11 +473,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
# Before saving, change name to match current checkpoint.
|
# Before saving, change name to match current checkpoint.
|
||||||
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
|
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
|
||||||
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
|
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
|
||||||
hypernetwork.optimizer_name = optimizer_name
|
|
||||||
if shared.opts.save_optimizer_state:
|
|
||||||
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
|
||||||
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
|
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
|
||||||
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
|
||||||
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
|
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
|
||||||
"loss": f"{previous_mean_loss:.7f}",
|
"loss": f"{previous_mean_loss:.7f}",
|
||||||
"learn_rate": scheduler.learn_rate
|
"learn_rate": scheduler.learn_rate
|
||||||
|
@ -537,7 +518,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
|
|
||||||
if image is not None:
|
if image is not None:
|
||||||
shared.state.current_image = image
|
shared.state.current_image = image
|
||||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt,
|
||||||
|
shared.opts.samples_format, processed.infotexts[0],
|
||||||
|
p=p, forced_filename=forced_filename,
|
||||||
|
save_to_dirs=False)
|
||||||
last_saved_image += f", prompt: {preview_text}"
|
last_saved_image += f", prompt: {preview_text}"
|
||||||
|
|
||||||
shared.state.job_no = hypernetwork.step
|
shared.state.job_no = hypernetwork.step
|
||||||
|
@ -551,15 +535,12 @@ Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
||||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
</p>
|
</p>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
report_statistics(loss_dict)
|
report_statistics(loss_dict)
|
||||||
|
|
||||||
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
|
||||||
hypernetwork.optimizer_name = optimizer_name
|
|
||||||
if shared.opts.save_optimizer_state:
|
|
||||||
hypernetwork.optimizer_state_dict = optimizer.state_dict()
|
|
||||||
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
|
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
|
||||||
del optimizer
|
|
||||||
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
|
|
||||||
return hypernetwork, filename
|
return hypernetwork, filename
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue