add dropout

This commit is contained in:
discus0434 2022-10-22 11:07:00 +00:00
parent 6a02841fff
commit 0e8ca8e7af
3 changed files with 72 additions and 53 deletions

View file

@ -1,47 +1,60 @@
import csv
import datetime import datetime
import glob import glob
import html import html
import os import os
import sys import sys
import traceback import traceback
import tqdm
import csv
import torch
from ldm.util import default
from modules import devices, shared, processing, sd_models
import torch
from torch import einsum
from einops import rearrange, repeat
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
import torch
import tqdm
from einops import rearrange, repeat
from ldm.util import default
from modules import devices, processing, sd_models, shared
from modules.textual_inversion import textual_inversion from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
class HypernetworkModule(torch.nn.Module): class HypernetworkModule(torch.nn.Module):
multiplier = 1.0 multiplier = 1.0
activation_dict = {"relu": torch.nn.ReLU, "leakyrelu": torch.nn.LeakyReLU, "elu": torch.nn.ELU, activation_dict = {
"swish": torch.nn.Hardswish} "relu": torch.nn.ReLU,
"leakyrelu": torch.nn.LeakyReLU,
"elu": torch.nn.ELU,
"swish": torch.nn.Hardswish,
}
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None): def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
super().__init__() super().__init__()
assert layer_structure is not None, "layer_structure must not be None" assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
assert activation_func not in self.activation_dict.keys() + "linear", f"Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'"
linears = [] linears = []
for i in range(len(layer_structure) - 1): for i in range(len(layer_structure) - 1):
# Add a fully-connected layer
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
# if skip_first_layer because first parameters potentially contain negative values
# if i < 1: continue # Add an activation func
if activation_func in HypernetworkModule.activation_dict: if activation_func == "linear":
linears.append(HypernetworkModule.activation_dict[activation_func]()) pass
elif activation_func in self.activation_dict:
linears.append(self.activation_dict[activation_func]())
else: else:
print("Invalid key {} encountered as activation function!".format(activation_func)) raise NotImplementedError(
# if use_dropout: "Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'"
# linears.append(torch.nn.Dropout(p=0.3)) )
# Add dropout
if use_dropout:
linears.append(torch.nn.Dropout(p=0.3))
# Add layer normalization
if add_layer_norm: if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
@ -93,7 +106,7 @@ class Hypernetwork:
filename = None filename = None
name = None name = None
def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False, activation_func=None): def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
self.filename = None self.filename = None
self.name = name self.name = name
self.layers = {} self.layers = {}
@ -101,13 +114,14 @@ class Hypernetwork:
self.sd_checkpoint = None self.sd_checkpoint = None
self.sd_checkpoint_name = None self.sd_checkpoint_name = None
self.layer_structure = layer_structure self.layer_structure = layer_structure
self.add_layer_norm = add_layer_norm
self.activation_func = activation_func self.activation_func = activation_func
self.add_layer_norm = add_layer_norm
self.use_dropout = use_dropout
for size in enable_sizes or []: for size in enable_sizes or []:
self.layers[size] = ( self.layers[size] = (
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func), HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func), HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
) )
def weights(self): def weights(self):
@ -129,8 +143,9 @@ class Hypernetwork:
state_dict['step'] = self.step state_dict['step'] = self.step
state_dict['name'] = self.name state_dict['name'] = self.name
state_dict['layer_structure'] = self.layer_structure state_dict['layer_structure'] = self.layer_structure
state_dict['is_layer_norm'] = self.add_layer_norm
state_dict['activation_func'] = self.activation_func state_dict['activation_func'] = self.activation_func
state_dict['is_layer_norm'] = self.add_layer_norm
state_dict['use_dropout'] = self.use_dropout
state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
@ -144,8 +159,9 @@ class Hypernetwork:
state_dict = torch.load(filename, map_location='cpu') state_dict = torch.load(filename, map_location='cpu')
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1]) self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
self.add_layer_norm = state_dict.get('is_layer_norm', False)
self.activation_func = state_dict.get('activation_func', None) self.activation_func = state_dict.get('activation_func', None)
self.add_layer_norm = state_dict.get('is_layer_norm', False)
self.use_dropout = state_dict.get('use_dropout', False)
for size, sd in state_dict.items(): for size, sd in state_dict.items():
if type(size) == int: if type(size) == int:

View file

@ -3,14 +3,13 @@ import os
import re import re
import gradio as gr import gradio as gr
import modules.textual_inversion.textual_inversion
import modules.textual_inversion.preprocess import modules.textual_inversion.preprocess
from modules import sd_hijack, shared, devices import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack, shared
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False, activation_func=None): def create_hypernetwork(name, enable_sizes, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists" assert not os.path.exists(fn), f"file {fn} already exists"
@ -21,8 +20,9 @@ def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm
name=name, name=name,
enable_sizes=[int(x) for x in enable_sizes], enable_sizes=[int(x) for x in enable_sizes],
layer_structure=layer_structure, layer_structure=layer_structure,
add_layer_norm=add_layer_norm,
activation_func=activation_func, activation_func=activation_func,
add_layer_norm=add_layer_norm,
use_dropout=use_dropout,
) )
hypernet.save(fn) hypernet.save(fn)

View file

@ -5,43 +5,44 @@ import json
import math import math
import mimetypes import mimetypes
import os import os
import platform
import random import random
import subprocess as sp
import sys import sys
import tempfile import tempfile
import time import time
import traceback import traceback
import platform
import subprocess as sp
from functools import partial, reduce from functools import partial, reduce
import gradio as gr
import gradio.routes
import gradio.utils
import numpy as np import numpy as np
import piexif
import torch import torch
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
import piexif
import gradio as gr from modules import localization, sd_hijack, sd_models
import gradio.utils
import gradio.routes
from modules import sd_hijack, sd_models, localization
from modules.paths import script_path from modules.paths import script_path
from modules.shared import opts, cmd_opts, restricted_opts from modules.shared import cmd_opts, opts, restricted_opts
if cmd_opts.deepdanbooru: if cmd_opts.deepdanbooru:
from modules.deepbooru import get_deepbooru_tags from modules.deepbooru import get_deepbooru_tags
import modules.shared as shared
from modules.sd_samplers import samplers, samplers_for_img2img
from modules.sd_hijack import model_hijack
import modules.ldsr_model
import modules.scripts
import modules.gfpgan_model
import modules.codeformer_model import modules.codeformer_model
import modules.styles
import modules.generation_parameters_copypaste import modules.generation_parameters_copypaste
from modules import prompt_parser import modules.gfpgan_model
from modules.images import save_image
import modules.textual_inversion.ui
import modules.hypernetworks.ui import modules.hypernetworks.ui
import modules.images_history as img_his import modules.images_history as img_his
import modules.ldsr_model
import modules.scripts
import modules.shared as shared
import modules.styles
import modules.textual_inversion.ui
from modules import prompt_parser
from modules.images import save_image
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init() mimetypes.init()
@ -1223,8 +1224,9 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu", "elu", "swish"])
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"]) new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
with gr.Row(): with gr.Row():
with gr.Column(scale=3): with gr.Column(scale=3):
@ -1308,8 +1310,9 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_name, new_hypernetwork_name,
new_hypernetwork_sizes, new_hypernetwork_sizes,
new_hypernetwork_layer_structure, new_hypernetwork_layer_structure,
new_hypernetwork_add_layer_norm,
new_hypernetwork_activation_func, new_hypernetwork_activation_func,
new_hypernetwork_add_layer_norm,
new_hypernetwork_use_dropout
], ],
outputs=[ outputs=[
train_hypernetwork_name, train_hypernetwork_name,