Weight initialization and More activation func
add weight init add weight init option in create_hypernetwork fstringify hypernet info save weight initialization info for further debugging fill bias with zero for He/Xavier initialize LayerNorm with Normal fix loading weight_init
This commit is contained in:
parent
3e15f8e0f5
commit
de096d0ce7
3 changed files with 44 additions and 11 deletions
|
@ -5,6 +5,7 @@ import html
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
import inspect
|
||||||
|
|
||||||
import modules.textual_inversion.dataset
|
import modules.textual_inversion.dataset
|
||||||
import torch
|
import torch
|
||||||
|
@ -15,10 +16,12 @@ from modules import devices, processing, sd_models, shared
|
||||||
from modules.textual_inversion import textual_inversion
|
from modules.textual_inversion import textual_inversion
|
||||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
|
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
|
||||||
|
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
from statistics import stdev, mean
|
from statistics import stdev, mean
|
||||||
|
|
||||||
|
|
||||||
class HypernetworkModule(torch.nn.Module):
|
class HypernetworkModule(torch.nn.Module):
|
||||||
multiplier = 1.0
|
multiplier = 1.0
|
||||||
activation_dict = {
|
activation_dict = {
|
||||||
|
@ -26,9 +29,12 @@ class HypernetworkModule(torch.nn.Module):
|
||||||
"leakyrelu": torch.nn.LeakyReLU,
|
"leakyrelu": torch.nn.LeakyReLU,
|
||||||
"elu": torch.nn.ELU,
|
"elu": torch.nn.ELU,
|
||||||
"swish": torch.nn.Hardswish,
|
"swish": torch.nn.Hardswish,
|
||||||
|
"tanh": torch.nn.Tanh,
|
||||||
|
"sigmoid": torch.nn.Sigmoid,
|
||||||
}
|
}
|
||||||
|
activation_dict.update({cls_name: cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
|
||||||
|
|
||||||
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
|
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
assert layer_structure is not None, "layer_structure must not be None"
|
assert layer_structure is not None, "layer_structure must not be None"
|
||||||
|
@ -65,9 +71,24 @@ class HypernetworkModule(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
for layer in self.linear:
|
for layer in self.linear:
|
||||||
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
||||||
layer.weight.data.normal_(mean=0.0, std=0.01)
|
w, b = layer.weight.data, layer.bias.data
|
||||||
layer.bias.data.zero_()
|
if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
|
||||||
|
normal_(w, mean=0.0, std=0.01)
|
||||||
|
normal_(b, mean=0.0, std=0.005)
|
||||||
|
elif weight_init == 'XavierUniform':
|
||||||
|
xavier_uniform_(w)
|
||||||
|
zeros_(b)
|
||||||
|
elif weight_init == 'XavierNormal':
|
||||||
|
xavier_normal_(w)
|
||||||
|
zeros_(b)
|
||||||
|
elif weight_init == 'KaimingUniform':
|
||||||
|
kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
|
||||||
|
zeros_(b)
|
||||||
|
elif weight_init == 'KaimingNormal':
|
||||||
|
kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
|
||||||
|
zeros_(b)
|
||||||
|
else:
|
||||||
|
raise KeyError(f"Key {weight_init} is not defined as initialization!")
|
||||||
self.to(devices.device)
|
self.to(devices.device)
|
||||||
|
|
||||||
def fix_old_state_dict(self, state_dict):
|
def fix_old_state_dict(self, state_dict):
|
||||||
|
@ -105,7 +126,7 @@ class Hypernetwork:
|
||||||
filename = None
|
filename = None
|
||||||
name = None
|
name = None
|
||||||
|
|
||||||
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
|
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
|
||||||
self.filename = None
|
self.filename = None
|
||||||
self.name = name
|
self.name = name
|
||||||
self.layers = {}
|
self.layers = {}
|
||||||
|
@ -114,13 +135,14 @@ class Hypernetwork:
|
||||||
self.sd_checkpoint_name = None
|
self.sd_checkpoint_name = None
|
||||||
self.layer_structure = layer_structure
|
self.layer_structure = layer_structure
|
||||||
self.activation_func = activation_func
|
self.activation_func = activation_func
|
||||||
|
self.weight_init = weight_init
|
||||||
self.add_layer_norm = add_layer_norm
|
self.add_layer_norm = add_layer_norm
|
||||||
self.use_dropout = use_dropout
|
self.use_dropout = use_dropout
|
||||||
|
|
||||||
for size in enable_sizes or []:
|
for size in enable_sizes or []:
|
||||||
self.layers[size] = (
|
self.layers[size] = (
|
||||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
|
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
|
||||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
|
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
|
||||||
)
|
)
|
||||||
|
|
||||||
def weights(self):
|
def weights(self):
|
||||||
|
@ -144,6 +166,7 @@ class Hypernetwork:
|
||||||
state_dict['layer_structure'] = self.layer_structure
|
state_dict['layer_structure'] = self.layer_structure
|
||||||
state_dict['activation_func'] = self.activation_func
|
state_dict['activation_func'] = self.activation_func
|
||||||
state_dict['is_layer_norm'] = self.add_layer_norm
|
state_dict['is_layer_norm'] = self.add_layer_norm
|
||||||
|
state_dict['weight_initialization'] = self.weight_init
|
||||||
state_dict['use_dropout'] = self.use_dropout
|
state_dict['use_dropout'] = self.use_dropout
|
||||||
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
||||||
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
||||||
|
@ -158,15 +181,21 @@ class Hypernetwork:
|
||||||
state_dict = torch.load(filename, map_location='cpu')
|
state_dict = torch.load(filename, map_location='cpu')
|
||||||
|
|
||||||
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
||||||
|
print(self.layer_structure)
|
||||||
self.activation_func = state_dict.get('activation_func', None)
|
self.activation_func = state_dict.get('activation_func', None)
|
||||||
|
print(f"Activation function is {self.activation_func}")
|
||||||
|
self.weight_init = state_dict.get('weight_initialization', 'Normal')
|
||||||
|
print(f"Weight initialization is {self.weight_init}")
|
||||||
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
||||||
|
print(f"Layer norm is set to {self.add_layer_norm}")
|
||||||
self.use_dropout = state_dict.get('use_dropout', False)
|
self.use_dropout = state_dict.get('use_dropout', False)
|
||||||
|
print(f"Dropout usage is set to {self.use_dropout}" )
|
||||||
|
|
||||||
for size, sd in state_dict.items():
|
for size, sd in state_dict.items():
|
||||||
if type(size) == int:
|
if type(size) == int:
|
||||||
self.layers[size] = (
|
self.layers[size] = (
|
||||||
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
|
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
|
||||||
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
|
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.name = state_dict.get('name', self.name)
|
self.name = state_dict.get('name', self.name)
|
||||||
|
|
|
@ -8,8 +8,9 @@ import modules.textual_inversion.textual_inversion
|
||||||
from modules import devices, sd_hijack, shared
|
from modules import devices, sd_hijack, shared
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
|
|
||||||
|
keys = list(hypernetwork.HypernetworkModule.activation_dict.keys())
|
||||||
|
|
||||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
|
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
|
||||||
# Remove illegal characters from name.
|
# Remove illegal characters from name.
|
||||||
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
||||||
|
|
||||||
|
@ -25,6 +26,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
|
||||||
enable_sizes=[int(x) for x in enable_sizes],
|
enable_sizes=[int(x) for x in enable_sizes],
|
||||||
layer_structure=layer_structure,
|
layer_structure=layer_structure,
|
||||||
activation_func=activation_func,
|
activation_func=activation_func,
|
||||||
|
weight_init=weight_init,
|
||||||
add_layer_norm=add_layer_norm,
|
add_layer_norm=add_layer_norm,
|
||||||
use_dropout=use_dropout,
|
use_dropout=use_dropout,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1238,7 +1238,8 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
new_hypernetwork_name = gr.Textbox(label="Name")
|
new_hypernetwork_name = gr.Textbox(label="Name")
|
||||||
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
|
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
|
||||||
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
|
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
|
||||||
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu", "elu", "swish"])
|
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=modules.hypernetworks.ui.keys)
|
||||||
|
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. relu-like - Kaiming, sigmoid-like - Xavier is recommended", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
|
||||||
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
|
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
|
||||||
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
|
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
|
||||||
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
|
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
|
||||||
|
@ -1342,6 +1343,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
overwrite_old_hypernetwork,
|
overwrite_old_hypernetwork,
|
||||||
new_hypernetwork_layer_structure,
|
new_hypernetwork_layer_structure,
|
||||||
new_hypernetwork_activation_func,
|
new_hypernetwork_activation_func,
|
||||||
|
new_hypernetwork_initialization_option,
|
||||||
new_hypernetwork_add_layer_norm,
|
new_hypernetwork_add_layer_norm,
|
||||||
new_hypernetwork_use_dropout
|
new_hypernetwork_use_dropout
|
||||||
],
|
],
|
||||||
|
|
Loading…
Reference in a new issue