I blame code autocomplete

This commit is contained in:
AngelBottomless 2022-11-04 15:50:54 +09:00 committed by GitHub
parent 0abb39f461
commit 0d07cbfa15
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -33,12 +33,9 @@ class HypernetworkModule(torch.nn.Module):
"tanh": torch.nn.Tanh, "tanh": torch.nn.Tanh,
"sigmoid": torch.nn.Sigmoid, "sigmoid": torch.nn.Sigmoid,
} }
activation_dict.update( activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
{cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if
inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False):
add_layer_norm=False, use_dropout=False):
super().__init__() super().__init__()
assert layer_structure is not None, "layer_structure must not be None" assert layer_structure is not None, "layer_structure must not be None"
@ -49,7 +46,7 @@ class HypernetworkModule(torch.nn.Module):
for i in range(len(layer_structure) - 1): for i in range(len(layer_structure) - 1):
# Add a fully-connected layer # Add a fully-connected layer
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i + 1]))) linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
# Add an activation func # Add an activation func
if activation_func == "linear" or activation_func is None: if activation_func == "linear" or activation_func is None:
@ -61,7 +58,7 @@ class HypernetworkModule(torch.nn.Module):
# Add layer normalization # Add layer normalization
if add_layer_norm: if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i + 1]))) linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
# Add dropout expect last layer # Add dropout expect last layer
if use_dropout and i < len(layer_structure) - 3: if use_dropout and i < len(layer_structure) - 3:
@ -130,8 +127,7 @@ class Hypernetwork:
filename = None filename = None
name = None name = None
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
add_layer_norm=False, use_dropout=False):
self.filename = None self.filename = None
self.name = name self.name = name
self.layers = {} self.layers = {}
@ -146,10 +142,8 @@ class Hypernetwork:
for size in enable_sizes or []: for size in enable_sizes or []:
self.layers[size] = ( self.layers[size] = (
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
self.add_layer_norm, self.use_dropout), HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
self.add_layer_norm, self.use_dropout),
) )
def weights(self): def weights(self):
@ -196,15 +190,13 @@ class Hypernetwork:
self.add_layer_norm = state_dict.get('is_layer_norm', False) self.add_layer_norm = state_dict.get('is_layer_norm', False)
print(f"Layer norm is set to {self.add_layer_norm}") print(f"Layer norm is set to {self.add_layer_norm}")
self.use_dropout = state_dict.get('use_dropout', False) self.use_dropout = state_dict.get('use_dropout', False)
print(f"Dropout usage is set to {self.use_dropout}") print(f"Dropout usage is set to {self.use_dropout}" )
for size, sd in state_dict.items(): for size, sd in state_dict.items():
if type(size) == int: if type(size) == int:
self.layers[size] = ( self.layers[size] = (
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
self.add_layer_norm, self.use_dropout), HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
self.add_layer_norm, self.use_dropout),
) )
self.name = state_dict.get('name', self.name) self.name = state_dict.get('name', self.name)
@ -316,7 +308,7 @@ def statistics(data):
std = 0 std = 0
else: else:
std = stdev(data) std = stdev(data)
total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std / (len(data) ** 0.5):.3f})" total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
recent_data = data[-32:] recent_data = data[-32:]
if len(recent_data) < 2: if len(recent_data) < 2:
std = 0 std = 0
@ -326,7 +318,7 @@ def statistics(data):
return total_information, recent_information return total_information, recent_information
def report_statistics(loss_info: dict): def report_statistics(loss_info:dict):
keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x])) keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
for key in keys: for key in keys:
try: try:
@ -338,18 +330,14 @@ def report_statistics(loss_info: dict):
print(e) print(e)
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width,
training_height, steps, create_image_every, save_hypernetwork_every, template_file, def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps,
preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem. # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images from modules import images
save_hypernetwork_every = save_hypernetwork_every or 0 save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0 create_image_every = create_image_every or 0
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
save_hypernetwork_every, create_image_every, log_directory,
name="hypernetwork")
path = shared.hypernetworks.get(hypernetwork_name, None) path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork() shared.loaded_hypernetwork = Hypernetwork()
@ -388,20 +376,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
# dataset loading may take a while, so input validations and early returns should be done before this # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"): with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
height=training_height,
repeats=shared.opts.training_image_repeats_per_epoch,
placeholder_token=hypernetwork_name,
model=shared.sd_model, device=devices.device,
template_file=template_file, include_cond=True,
batch_size=batch_size)
if unload: if unload:
shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu)
size = len(ds.indexes) size = len(ds.indexes)
loss_dict = defaultdict(lambda: deque(maxlen=1024)) loss_dict = defaultdict(lambda : deque(maxlen = 1024))
losses = torch.zeros((size,)) losses = torch.zeros((size,))
previous_mean_losses = [0] previous_mean_losses = [0]
previous_mean_loss = 0 previous_mean_loss = 0
@ -510,7 +492,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
preview_text = p.prompt preview_text = p.prompt
processed = processing.process_images(p) processed = processing.process_images(p)
image = processed.images[0] if len(processed.images) > 0 else None image = processed.images[0] if len(processed.images)>0 else None
if unload: if unload:
shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.cond_stage_model.to(devices.cpu)
@ -518,10 +500,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if image is not None: if image is not None:
shared.state.current_image = image shared.state.current_image = image
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
shared.opts.samples_format, processed.infotexts[0],
p=p, forced_filename=forced_filename,
save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}" last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step shared.state.job_no = hypernetwork.step
@ -543,7 +522,6 @@ Last saved image: {html.escape(last_saved_image)}<br/>
return hypernetwork, filename return hypernetwork, filename
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename): def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
old_hypernetwork_name = hypernetwork.name old_hypernetwork_name = hypernetwork.name
old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None