This commit is contained in:
discus0434 2022-10-20 00:10:45 +00:00
parent 634acdd954
commit 6f98e89486
3 changed files with 45 additions and 32 deletions

View file

@ -22,16 +22,20 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
class HypernetworkModule(torch.nn.Module): class HypernetworkModule(torch.nn.Module):
multiplier = 1.0 multiplier = 1.0
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False): def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None):
super().__init__() super().__init__()
assert layer_structure is not None, "layer_structure mut not be None" assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
linears = [] linears = []
for i in range(len(layer_structure) - 1): for i in range(len(layer_structure) - 1):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
if activation_func == "relu":
linears.append(torch.nn.ReLU())
if activation_func == "leakyrelu":
linears.append(torch.nn.LeakyReLU())
if add_layer_norm: if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
@ -42,8 +46,9 @@ class HypernetworkModule(torch.nn.Module):
self.load_state_dict(state_dict) self.load_state_dict(state_dict)
else: else:
for layer in self.linear: for layer in self.linear:
layer.weight.data.normal_(mean=0.0, std=0.01) if not "ReLU" in layer.__str__():
layer.bias.data.zero_() layer.weight.data.normal_(mean=0.0, std=0.01)
layer.bias.data.zero_()
self.to(devices.device) self.to(devices.device)
@ -69,7 +74,8 @@ class HypernetworkModule(torch.nn.Module):
def trainables(self): def trainables(self):
layer_structure = [] layer_structure = []
for layer in self.linear: for layer in self.linear:
layer_structure += [layer.weight, layer.bias] if not "ReLU" in layer.__str__():
layer_structure += [layer.weight, layer.bias]
return layer_structure return layer_structure
@ -81,7 +87,7 @@ class Hypernetwork:
filename = None filename = None
name = None name = None
def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False): def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False, activation_func=None):
self.filename = None self.filename = None
self.name = name self.name = name
self.layers = {} self.layers = {}
@ -90,11 +96,12 @@ class Hypernetwork:
self.sd_checkpoint_name = None self.sd_checkpoint_name = None
self.layer_structure = layer_structure self.layer_structure = layer_structure
self.add_layer_norm = add_layer_norm self.add_layer_norm = add_layer_norm
self.activation_func = activation_func
for size in enable_sizes or []: for size in enable_sizes or []:
self.layers[size] = ( self.layers[size] = (
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm), HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm), HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
) )
def weights(self): def weights(self):
@ -117,6 +124,7 @@ class Hypernetwork:
state_dict['name'] = self.name state_dict['name'] = self.name
state_dict['layer_structure'] = self.layer_structure state_dict['layer_structure'] = self.layer_structure
state_dict['is_layer_norm'] = self.add_layer_norm state_dict['is_layer_norm'] = self.add_layer_norm
state_dict['activation_func'] = self.activation_func
state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
@ -131,12 +139,13 @@ class Hypernetwork:
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1]) self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
self.add_layer_norm = state_dict.get('is_layer_norm', False) self.add_layer_norm = state_dict.get('is_layer_norm', False)
self.activation_func = state_dict.get('activation_func', None)
for size, sd in state_dict.items(): for size, sd in state_dict.items():
if type(size) == int: if type(size) == int:
self.layers[size] = ( self.layers[size] = (
HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm), HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm, self.activation_func),
HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm), HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm, self.activation_func),
) )
self.name = state_dict.get('name', self.name) self.name = state_dict.get('name', self.name)

View file

@ -10,7 +10,7 @@ from modules import sd_hijack, shared, devices
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False): def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False, activation_func=None):
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists" assert not os.path.exists(fn), f"file {fn} already exists"
@ -22,6 +22,7 @@ def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm
enable_sizes=[int(x) for x in enable_sizes], enable_sizes=[int(x) for x in enable_sizes],
layer_structure=layer_structure, layer_structure=layer_structure,
add_layer_norm=add_layer_norm, add_layer_norm=add_layer_norm,
activation_func=activation_func,
) )
hypernet.save(fn) hypernet.save(fn)

View file

@ -5,43 +5,44 @@ import json
import math import math
import mimetypes import mimetypes
import os import os
import platform
import random import random
import subprocess as sp
import sys import sys
import tempfile import tempfile
import time import time
import traceback import traceback
import platform
import subprocess as sp
from functools import partial, reduce from functools import partial, reduce
import gradio as gr
import gradio.routes
import gradio.utils
import numpy as np import numpy as np
import piexif
import torch import torch
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
import piexif
import gradio as gr from modules import localization, sd_hijack, sd_models
import gradio.utils
import gradio.routes
from modules import sd_hijack, sd_models, localization
from modules.paths import script_path from modules.paths import script_path
from modules.shared import opts, cmd_opts, restricted_opts from modules.shared import cmd_opts, opts, restricted_opts
if cmd_opts.deepdanbooru: if cmd_opts.deepdanbooru:
from modules.deepbooru import get_deepbooru_tags from modules.deepbooru import get_deepbooru_tags
import modules.shared as shared
from modules.sd_samplers import samplers, samplers_for_img2img
from modules.sd_hijack import model_hijack
import modules.ldsr_model
import modules.scripts
import modules.gfpgan_model
import modules.codeformer_model import modules.codeformer_model
import modules.styles
import modules.generation_parameters_copypaste import modules.generation_parameters_copypaste
from modules import prompt_parser import modules.gfpgan_model
from modules.images import save_image
import modules.textual_inversion.ui
import modules.hypernetworks.ui import modules.hypernetworks.ui
import modules.images_history as img_his import modules.images_history as img_his
import modules.ldsr_model
import modules.scripts
import modules.shared as shared
import modules.styles
import modules.textual_inversion.ui
from modules import prompt_parser
from modules.images import save_image
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init() mimetypes.init()
@ -1219,6 +1220,7 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["relu", "leakyrelu"])
with gr.Row(): with gr.Row():
with gr.Column(scale=3): with gr.Column(scale=3):
@ -1303,6 +1305,7 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_sizes, new_hypernetwork_sizes,
new_hypernetwork_layer_structure, new_hypernetwork_layer_structure,
new_hypernetwork_add_layer_norm, new_hypernetwork_add_layer_norm,
new_hypernetwork_activation_func,
], ],
outputs=[ outputs=[
train_hypernetwork_name, train_hypernetwork_name,