resolve conflict - first revert

This commit is contained in:
aria1th 2022-11-04 15:47:19 +09:00
parent 1764ac3c8b
commit 0abb39f461

View file

@ -21,7 +21,6 @@ from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_norm
from collections import defaultdict, deque from collections import defaultdict, deque
from statistics import stdev, mean from statistics import stdev, mean
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
class HypernetworkModule(torch.nn.Module): class HypernetworkModule(torch.nn.Module):
multiplier = 1.0 multiplier = 1.0
@ -34,9 +33,12 @@ class HypernetworkModule(torch.nn.Module):
"tanh": torch.nn.Tanh, "tanh": torch.nn.Tanh,
"sigmoid": torch.nn.Sigmoid, "sigmoid": torch.nn.Sigmoid,
} }
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'}) activation_dict.update(
{cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if
inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False): def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
add_layer_norm=False, use_dropout=False):
super().__init__() super().__init__()
assert layer_structure is not None, "layer_structure must not be None" assert layer_structure is not None, "layer_structure must not be None"
@ -47,7 +49,7 @@ class HypernetworkModule(torch.nn.Module):
for i in range(len(layer_structure) - 1): for i in range(len(layer_structure) - 1):
# Add a fully-connected layer # Add a fully-connected layer
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i + 1])))
# Add an activation func # Add an activation func
if activation_func == "linear" or activation_func is None: if activation_func == "linear" or activation_func is None:
@ -59,7 +61,7 @@ class HypernetworkModule(torch.nn.Module):
# Add layer normalization # Add layer normalization
if add_layer_norm: if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i + 1])))
# Add dropout expect last layer # Add dropout expect last layer
if use_dropout and i < len(layer_structure) - 3: if use_dropout and i < len(layer_structure) - 3:
@ -128,7 +130,8 @@ class Hypernetwork:
filename = None filename = None
name = None name = None
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False): def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None,
add_layer_norm=False, use_dropout=False):
self.filename = None self.filename = None
self.name = name self.name = name
self.layers = {} self.layers = {}
@ -140,13 +143,13 @@ class Hypernetwork:
self.weight_init = weight_init self.weight_init = weight_init
self.add_layer_norm = add_layer_norm self.add_layer_norm = add_layer_norm
self.use_dropout = use_dropout self.use_dropout = use_dropout
self.optimizer_name = None
self.optimizer_state_dict = None
for size in enable_sizes or []: for size in enable_sizes or []:
self.layers[size] = ( self.layers[size] = (
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout), HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout), self.add_layer_norm, self.use_dropout),
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
self.add_layer_norm, self.use_dropout),
) )
def weights(self): def weights(self):
@ -161,7 +164,6 @@ class Hypernetwork:
def save(self, filename): def save(self, filename):
state_dict = {} state_dict = {}
optimizer_saved_dict = {}
for k, v in self.layers.items(): for k, v in self.layers.items():
state_dict[k] = (v[0].state_dict(), v[1].state_dict()) state_dict[k] = (v[0].state_dict(), v[1].state_dict())
@ -175,14 +177,8 @@ class Hypernetwork:
state_dict['use_dropout'] = self.use_dropout state_dict['use_dropout'] = self.use_dropout
state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
if self.optimizer_name is not None:
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
torch.save(state_dict, filename) torch.save(state_dict, filename)
if self.optimizer_state_dict:
optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
torch.save(optimizer_saved_dict, filename + '.optim')
def load(self, filename): def load(self, filename):
self.filename = filename self.filename = filename
@ -202,23 +198,13 @@ class Hypernetwork:
self.use_dropout = state_dict.get('use_dropout', False) self.use_dropout = state_dict.get('use_dropout', False)
print(f"Dropout usage is set to {self.use_dropout}") print(f"Dropout usage is set to {self.use_dropout}")
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
print(f"Optimizer name is {self.optimizer_name}")
if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
else:
self.optimizer_state_dict = None
if self.optimizer_state_dict:
print("Loaded existing optimizer from checkpoint")
else:
print("No saved optimizer exists in checkpoint")
for size, sd in state_dict.items(): for size, sd in state_dict.items():
if type(size) == int: if type(size) == int:
self.layers[size] = ( self.layers[size] = (
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout), HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout), self.add_layer_norm, self.use_dropout),
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
self.add_layer_norm, self.use_dropout),
) )
self.name = state_dict.get('name', self.name) self.name = state_dict.get('name', self.name)
@ -233,7 +219,7 @@ def list_hypernetworks(path):
name = os.path.splitext(os.path.basename(filename))[0] name = os.path.splitext(os.path.basename(filename))[0]
# Prevent a hypothetical "None.pt" from being listed. # Prevent a hypothetical "None.pt" from being listed.
if name != "None": if name != "None":
res[name + f"({sd_models.model_hash(filename)})"] = filename res[name] = filename
return res return res
@ -330,7 +316,7 @@ def statistics(data):
std = 0 std = 0
else: else:
std = stdev(data) std = stdev(data)
total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})" total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std / (len(data) ** 0.5):.3f})"
recent_data = data[-32:] recent_data = data[-32:]
if len(recent_data) < 2: if len(recent_data) < 2:
std = 0 std = 0
@ -340,7 +326,7 @@ def statistics(data):
return total_information, recent_information return total_information, recent_information
def report_statistics(loss_info:dict): def report_statistics(loss_info: dict):
keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x])) keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
for key in keys: for key in keys:
try: try:
@ -352,14 +338,18 @@ def report_statistics(loss_info:dict):
print(e) print(e)
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width,
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): training_height, steps, create_image_every, save_hypernetwork_every, template_file,
preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps,
preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem. # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images from modules import images
save_hypernetwork_every = save_hypernetwork_every or 0 save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0 create_image_every = create_image_every or 0
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork") textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps,
save_hypernetwork_every, create_image_every, log_directory,
name="hypernetwork")
path = shared.hypernetworks.get(hypernetwork_name, None) path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork() shared.loaded_hypernetwork = Hypernetwork()
@ -379,7 +369,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
else: else:
hypernetwork_dir = None hypernetwork_dir = None
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
if create_image_every > 0: if create_image_every > 0:
images_dir = os.path.join(log_directory, "images") images_dir = os.path.join(log_directory, "images")
os.makedirs(images_dir, exist_ok=True) os.makedirs(images_dir, exist_ok=True)
@ -395,39 +384,34 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
return hypernetwork, filename return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
# dataset loading may take a while, so input validations and early returns should be done before this # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"): with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width,
height=training_height,
repeats=shared.opts.training_image_repeats_per_epoch,
placeholder_token=hypernetwork_name,
model=shared.sd_model, device=devices.device,
template_file=template_file, include_cond=True,
batch_size=batch_size)
if unload: if unload:
shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu)
size = len(ds.indexes) size = len(ds.indexes)
loss_dict = defaultdict(lambda : deque(maxlen = 1024)) loss_dict = defaultdict(lambda: deque(maxlen=1024))
losses = torch.zeros((size,)) losses = torch.zeros((size,))
previous_mean_losses = [0] previous_mean_losses = [0]
previous_mean_loss = 0 previous_mean_loss = 0
print("Mean loss of {} elements".format(size)) print("Mean loss of {} elements".format(size))
weights = hypernetwork.weights() weights = hypernetwork.weights()
for weight in weights: for weight in weights:
weight.requires_grad = True weight.requires_grad = True
# Here we use optimizer from saved HN, or we can specify as UI option. # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
if (optimizer_name := hypernetwork.optimizer_name) in optimizer_dict: optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
else:
print(f"Optimizer type {optimizer_name} is not defined!")
optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
optimizer_name = 'AdamW'
if hypernetwork.optimizer_state_dict: # This line must be changed if Optimizer type can be different from saved optimizer.
try:
optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
except RuntimeError as e:
print("Cannot resume from saved optimizer!")
print(e)
steps_without_grad = 0 steps_without_grad = 0
@ -441,7 +425,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if len(loss_dict) > 0: if len(loss_dict) > 0:
previous_mean_losses = [i[-1] for i in loss_dict.values()] previous_mean_losses = [i[-1] for i in loss_dict.values()]
previous_mean_loss = mean(previous_mean_losses) previous_mean_loss = mean(previous_mean_losses)
scheduler.apply(optimizer, hypernetwork.step) scheduler.apply(optimizer, hypernetwork.step)
if scheduler.finished: if scheduler.finished:
break break
@ -460,7 +444,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
losses[hypernetwork.step % losses.shape[0]] = loss.item() losses[hypernetwork.step % losses.shape[0]] = loss.item()
for entry in entries: for entry in entries:
loss_dict[entry.filename].append(loss.item()) loss_dict[entry.filename].append(loss.item())
optimizer.zero_grad() optimizer.zero_grad()
weights[0].grad = None weights[0].grad = None
loss.backward() loss.backward()
@ -475,9 +459,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
steps_done = hypernetwork.step + 1 steps_done = hypernetwork.step + 1
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
raise RuntimeError("Loss diverged.") raise RuntimeError("Loss diverged.")
if len(previous_mean_losses) > 1: if len(previous_mean_losses) > 1:
std = stdev(previous_mean_losses) std = stdev(previous_mean_losses)
else: else:
@ -489,11 +473,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
# Before saving, change name to match current checkpoint. # Before saving, change name to match current checkpoint.
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt') last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
hypernetwork.optimizer_name = optimizer_name
if shared.opts.save_optimizer_state:
hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file) save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
"loss": f"{previous_mean_loss:.7f}", "loss": f"{previous_mean_loss:.7f}",
"learn_rate": scheduler.learn_rate "learn_rate": scheduler.learn_rate
@ -529,7 +510,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
preview_text = p.prompt preview_text = p.prompt
processed = processing.process_images(p) processed = processing.process_images(p)
image = processed.images[0] if len(processed.images)>0 else None image = processed.images[0] if len(processed.images) > 0 else None
if unload: if unload:
shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.cond_stage_model.to(devices.cpu)
@ -537,7 +518,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if image is not None: if image is not None:
shared.state.current_image = image shared.state.current_image = image
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt,
shared.opts.samples_format, processed.infotexts[0],
p=p, forced_filename=forced_filename,
save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}" last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step shared.state.job_no = hypernetwork.step
@ -551,15 +535,12 @@ Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/> Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>
""" """
report_statistics(loss_dict) report_statistics(loss_dict)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.optimizer_name = optimizer_name
if shared.opts.save_optimizer_state:
hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename) save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
del optimizer
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
return hypernetwork, filename return hypernetwork, filename
@ -576,4 +557,4 @@ def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
hypernetwork.sd_checkpoint = old_sd_checkpoint hypernetwork.sd_checkpoint = old_sd_checkpoint
hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
hypernetwork.name = old_hypernetwork_name hypernetwork.name = old_hypernetwork_name
raise raise