Merge pull request #3199 from discus0434/master
Add features to insert activation functions to hypernetworks
This commit is contained in:
commit
a26fc2834c
3 changed files with 23 additions and 11 deletions
|
@ -22,16 +22,20 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
class HypernetworkModule(torch.nn.Module):
|
class HypernetworkModule(torch.nn.Module):
|
||||||
multiplier = 1.0
|
multiplier = 1.0
|
||||||
|
|
||||||
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False):
|
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
assert layer_structure is not None, "layer_structure mut not be None"
|
assert layer_structure is not None, "layer_structure must not be None"
|
||||||
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
||||||
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
||||||
|
|
||||||
linears = []
|
linears = []
|
||||||
for i in range(len(layer_structure) - 1):
|
for i in range(len(layer_structure) - 1):
|
||||||
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
||||||
|
if activation_func == "relu":
|
||||||
|
linears.append(torch.nn.ReLU())
|
||||||
|
if activation_func == "leakyrelu":
|
||||||
|
linears.append(torch.nn.LeakyReLU())
|
||||||
if add_layer_norm:
|
if add_layer_norm:
|
||||||
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
||||||
|
|
||||||
|
@ -42,8 +46,9 @@ class HypernetworkModule(torch.nn.Module):
|
||||||
self.load_state_dict(state_dict)
|
self.load_state_dict(state_dict)
|
||||||
else:
|
else:
|
||||||
for layer in self.linear:
|
for layer in self.linear:
|
||||||
layer.weight.data.normal_(mean=0.0, std=0.01)
|
if not "ReLU" in layer.__str__():
|
||||||
layer.bias.data.zero_()
|
layer.weight.data.normal_(mean=0.0, std=0.01)
|
||||||
|
layer.bias.data.zero_()
|
||||||
|
|
||||||
self.to(devices.device)
|
self.to(devices.device)
|
||||||
|
|
||||||
|
@ -69,7 +74,8 @@ class HypernetworkModule(torch.nn.Module):
|
||||||
def trainables(self):
|
def trainables(self):
|
||||||
layer_structure = []
|
layer_structure = []
|
||||||
for layer in self.linear:
|
for layer in self.linear:
|
||||||
layer_structure += [layer.weight, layer.bias]
|
if not "ReLU" in layer.__str__():
|
||||||
|
layer_structure += [layer.weight, layer.bias]
|
||||||
return layer_structure
|
return layer_structure
|
||||||
|
|
||||||
|
|
||||||
|
@ -81,7 +87,7 @@ class Hypernetwork:
|
||||||
filename = None
|
filename = None
|
||||||
name = None
|
name = None
|
||||||
|
|
||||||
def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False):
|
def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False, activation_func=None):
|
||||||
self.filename = None
|
self.filename = None
|
||||||
self.name = name
|
self.name = name
|
||||||
self.layers = {}
|
self.layers = {}
|
||||||
|
@ -90,11 +96,12 @@ class Hypernetwork:
|
||||||
self.sd_checkpoint_name = None
|
self.sd_checkpoint_name = None
|
||||||
self.layer_structure = layer_structure
|
self.layer_structure = layer_structure
|
||||||
self.add_layer_norm = add_layer_norm
|
self.add_layer_norm = add_layer_norm
|
||||||
|
self.activation_func = activation_func
|
||||||
|
|
||||||
for size in enable_sizes or []:
|
for size in enable_sizes or []:
|
||||||
self.layers[size] = (
|
self.layers[size] = (
|
||||||
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
|
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
|
||||||
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
|
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
|
||||||
)
|
)
|
||||||
|
|
||||||
def weights(self):
|
def weights(self):
|
||||||
|
@ -117,6 +124,7 @@ class Hypernetwork:
|
||||||
state_dict['name'] = self.name
|
state_dict['name'] = self.name
|
||||||
state_dict['layer_structure'] = self.layer_structure
|
state_dict['layer_structure'] = self.layer_structure
|
||||||
state_dict['is_layer_norm'] = self.add_layer_norm
|
state_dict['is_layer_norm'] = self.add_layer_norm
|
||||||
|
state_dict['activation_func'] = self.activation_func
|
||||||
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
||||||
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
||||||
|
|
||||||
|
@ -131,12 +139,13 @@ class Hypernetwork:
|
||||||
|
|
||||||
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
||||||
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
||||||
|
self.activation_func = state_dict.get('activation_func', None)
|
||||||
|
|
||||||
for size, sd in state_dict.items():
|
for size, sd in state_dict.items():
|
||||||
if type(size) == int:
|
if type(size) == int:
|
||||||
self.layers[size] = (
|
self.layers[size] = (
|
||||||
HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm),
|
HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm, self.activation_func),
|
||||||
HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm),
|
HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm, self.activation_func),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.name = state_dict.get('name', self.name)
|
self.name = state_dict.get('name', self.name)
|
||||||
|
|
|
@ -10,7 +10,7 @@ from modules import sd_hijack, shared, devices
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
|
|
||||||
|
|
||||||
def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False):
|
def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False, activation_func=None):
|
||||||
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
||||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm
|
||||||
enable_sizes=[int(x) for x in enable_sizes],
|
enable_sizes=[int(x) for x in enable_sizes],
|
||||||
layer_structure=layer_structure,
|
layer_structure=layer_structure,
|
||||||
add_layer_norm=add_layer_norm,
|
add_layer_norm=add_layer_norm,
|
||||||
|
activation_func=activation_func,
|
||||||
)
|
)
|
||||||
hypernet.save(fn)
|
hypernet.save(fn)
|
||||||
|
|
||||||
|
|
|
@ -1224,6 +1224,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
|
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
|
||||||
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
|
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
|
||||||
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
|
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
|
||||||
|
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"])
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=3):
|
with gr.Column(scale=3):
|
||||||
|
@ -1308,6 +1309,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
new_hypernetwork_sizes,
|
new_hypernetwork_sizes,
|
||||||
new_hypernetwork_layer_structure,
|
new_hypernetwork_layer_structure,
|
||||||
new_hypernetwork_add_layer_norm,
|
new_hypernetwork_add_layer_norm,
|
||||||
|
new_hypernetwork_activation_func,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
train_hypernetwork_name,
|
train_hypernetwork_name,
|
||||||
|
|
Loading…
Reference in a new issue