Merge pull request #3414 from discus0434/master
[Hypernetworks] Add a feature to use dropout / more activation functions
This commit is contained in:
commit
ffea9b1509
3 changed files with 63 additions and 40 deletions
|
@ -1,28 +1,32 @@
|
|||
import csv
|
||||
import datetime
|
||||
import glob
|
||||
import html
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import tqdm
|
||||
import csv
|
||||
|
||||
import torch
|
||||
|
||||
from ldm.util import default
|
||||
from modules import devices, shared, processing, sd_models
|
||||
import torch
|
||||
from torch import einsum
|
||||
from einops import rearrange, repeat
|
||||
import modules.textual_inversion.dataset
|
||||
import torch
|
||||
import tqdm
|
||||
from einops import rearrange, repeat
|
||||
from ldm.util import default
|
||||
from modules import devices, processing, sd_models, shared
|
||||
from modules.textual_inversion import textual_inversion
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
from torch import einsum
|
||||
|
||||
|
||||
class HypernetworkModule(torch.nn.Module):
|
||||
multiplier = 1.0
|
||||
activation_dict = {
|
||||
"relu": torch.nn.ReLU,
|
||||
"leakyrelu": torch.nn.LeakyReLU,
|
||||
"elu": torch.nn.ELU,
|
||||
"swish": torch.nn.Hardswish,
|
||||
}
|
||||
|
||||
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None):
|
||||
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
|
||||
super().__init__()
|
||||
|
||||
assert layer_structure is not None, "layer_structure must not be None"
|
||||
|
@ -31,20 +35,26 @@ class HypernetworkModule(torch.nn.Module):
|
|||
|
||||
linears = []
|
||||
for i in range(len(layer_structure) - 1):
|
||||
|
||||
# Add a fully-connected layer
|
||||
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
||||
|
||||
if activation_func == "relu":
|
||||
linears.append(torch.nn.ReLU())
|
||||
elif activation_func == "leakyrelu":
|
||||
linears.append(torch.nn.LeakyReLU())
|
||||
elif activation_func == 'linear' or activation_func is None:
|
||||
# Add an activation func
|
||||
if activation_func == "linear" or activation_func is None:
|
||||
pass
|
||||
elif activation_func in self.activation_dict:
|
||||
linears.append(self.activation_dict[activation_func]())
|
||||
else:
|
||||
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
|
||||
|
||||
# Add layer normalization
|
||||
if add_layer_norm:
|
||||
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
||||
|
||||
# Add dropout expect last layer
|
||||
if use_dropout and i < len(layer_structure) - 3:
|
||||
linears.append(torch.nn.Dropout(p=0.3))
|
||||
|
||||
self.linear = torch.nn.Sequential(*linears)
|
||||
|
||||
if state_dict is not None:
|
||||
|
@ -93,7 +103,7 @@ class Hypernetwork:
|
|||
filename = None
|
||||
name = None
|
||||
|
||||
def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False, activation_func=None):
|
||||
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
|
||||
self.filename = None
|
||||
self.name = name
|
||||
self.layers = {}
|
||||
|
@ -101,13 +111,14 @@ class Hypernetwork:
|
|||
self.sd_checkpoint = None
|
||||
self.sd_checkpoint_name = None
|
||||
self.layer_structure = layer_structure
|
||||
self.add_layer_norm = add_layer_norm
|
||||
self.activation_func = activation_func
|
||||
self.add_layer_norm = add_layer_norm
|
||||
self.use_dropout = use_dropout
|
||||
|
||||
for size in enable_sizes or []:
|
||||
self.layers[size] = (
|
||||
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
|
||||
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
|
||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
|
||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
|
||||
)
|
||||
|
||||
def weights(self):
|
||||
|
@ -129,8 +140,9 @@ class Hypernetwork:
|
|||
state_dict['step'] = self.step
|
||||
state_dict['name'] = self.name
|
||||
state_dict['layer_structure'] = self.layer_structure
|
||||
state_dict['is_layer_norm'] = self.add_layer_norm
|
||||
state_dict['activation_func'] = self.activation_func
|
||||
state_dict['is_layer_norm'] = self.add_layer_norm
|
||||
state_dict['use_dropout'] = self.use_dropout
|
||||
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
||||
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
||||
|
||||
|
@ -144,14 +156,15 @@ class Hypernetwork:
|
|||
state_dict = torch.load(filename, map_location='cpu')
|
||||
|
||||
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
||||
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
||||
self.activation_func = state_dict.get('activation_func', None)
|
||||
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
||||
self.use_dropout = state_dict.get('use_dropout', False)
|
||||
|
||||
for size, sd in state_dict.items():
|
||||
if type(size) == int:
|
||||
self.layers[size] = (
|
||||
HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm, self.activation_func),
|
||||
HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm, self.activation_func),
|
||||
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
|
||||
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
|
||||
)
|
||||
|
||||
self.name = state_dict.get('name', self.name)
|
||||
|
@ -308,6 +321,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||
return hypernetwork, filename
|
||||
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
||||
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
|
||||
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
|
||||
|
||||
steps_without_grad = 0
|
||||
|
|
|
@ -3,14 +3,13 @@ import os
|
|||
import re
|
||||
|
||||
import gradio as gr
|
||||
|
||||
import modules.textual_inversion.textual_inversion
|
||||
import modules.textual_inversion.preprocess
|
||||
from modules import sd_hijack, shared, devices
|
||||
import modules.textual_inversion.textual_inversion
|
||||
from modules import devices, sd_hijack, shared
|
||||
from modules.hypernetworks import hypernetwork
|
||||
|
||||
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, add_layer_norm=False, activation_func=None):
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
|
||||
# Remove illegal characters from name.
|
||||
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
||||
|
||||
|
@ -25,8 +24,9 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
|
|||
name=name,
|
||||
enable_sizes=[int(x) for x in enable_sizes],
|
||||
layer_structure=layer_structure,
|
||||
add_layer_norm=add_layer_norm,
|
||||
activation_func=activation_func,
|
||||
add_layer_norm=add_layer_norm,
|
||||
use_dropout=use_dropout,
|
||||
)
|
||||
hypernet.save(fn)
|
||||
|
||||
|
|
|
@ -5,19 +5,22 @@ import json
|
|||
import math
|
||||
import mimetypes
|
||||
import os
|
||||
import platform
|
||||
import random
|
||||
import subprocess as sp
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import traceback
|
||||
import platform
|
||||
import subprocess as sp
|
||||
from functools import partial, reduce
|
||||
|
||||
import gradio as gr
|
||||
import gradio.routes
|
||||
import gradio.utils
|
||||
import numpy as np
|
||||
import piexif
|
||||
import torch
|
||||
from PIL import Image, PngImagePlugin
|
||||
import piexif
|
||||
|
||||
import gradio as gr
|
||||
import gradio.utils
|
||||
|
@ -30,17 +33,21 @@ from modules.shared import opts, cmd_opts, restricted_opts
|
|||
|
||||
if cmd_opts.deepdanbooru:
|
||||
from modules.deepbooru import get_deepbooru_tags
|
||||
import modules.shared as shared
|
||||
from modules.sd_samplers import samplers, samplers_for_img2img
|
||||
from modules.sd_hijack import model_hijack
|
||||
|
||||
import modules.codeformer_model
|
||||
import modules.generation_parameters_copypaste
|
||||
import modules.gfpgan_model
|
||||
import modules.hypernetworks.ui
|
||||
import modules.images_history as img_his
|
||||
import modules.ldsr_model
|
||||
import modules.scripts
|
||||
import modules.gfpgan_model
|
||||
import modules.codeformer_model
|
||||
import modules.shared as shared
|
||||
import modules.styles
|
||||
import modules.generation_parameters_copypaste
|
||||
import modules.textual_inversion.ui
|
||||
from modules import prompt_parser
|
||||
from modules.images import save_image
|
||||
from modules.sd_hijack import model_hijack
|
||||
from modules.sd_samplers import samplers, samplers_for_img2img
|
||||
import modules.textual_inversion.ui
|
||||
import modules.hypernetworks.ui
|
||||
|
||||
|
@ -1233,9 +1240,10 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
new_hypernetwork_name = gr.Textbox(label="Name")
|
||||
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
|
||||
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
|
||||
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu", "elu", "swish"])
|
||||
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
|
||||
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
|
||||
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
|
||||
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"])
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
|
@ -1285,7 +1293,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
with gr.Row():
|
||||
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005")
|
||||
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
|
||||
|
||||
|
||||
batch_size = gr.Number(label='Batch size', value=1, precision=0)
|
||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
||||
|
@ -1335,8 +1343,9 @@ def create_ui(wrap_gradio_gpu_call):
|
|||
new_hypernetwork_sizes,
|
||||
overwrite_old_hypernetwork,
|
||||
new_hypernetwork_layer_structure,
|
||||
new_hypernetwork_add_layer_norm,
|
||||
new_hypernetwork_activation_func,
|
||||
new_hypernetwork_add_layer_norm,
|
||||
new_hypernetwork_use_dropout
|
||||
],
|
||||
outputs=[
|
||||
train_hypernetwork_name,
|
||||
|
|
Loading…
Reference in a new issue