add dropout
This commit is contained in:
parent
6a02841fff
commit
0e8ca8e7af
3 changed files with 72 additions and 53 deletions
|
@ -1,47 +1,60 @@
|
||||||
|
import csv
|
||||||
import datetime
|
import datetime
|
||||||
import glob
|
import glob
|
||||||
import html
|
import html
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import tqdm
|
|
||||||
import csv
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from ldm.util import default
|
|
||||||
from modules import devices, shared, processing, sd_models
|
|
||||||
import torch
|
|
||||||
from torch import einsum
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
import modules.textual_inversion.dataset
|
import modules.textual_inversion.dataset
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from ldm.util import default
|
||||||
|
from modules import devices, processing, sd_models, shared
|
||||||
from modules.textual_inversion import textual_inversion
|
from modules.textual_inversion import textual_inversion
|
||||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
|
from torch import einsum
|
||||||
|
|
||||||
|
|
||||||
class HypernetworkModule(torch.nn.Module):
|
class HypernetworkModule(torch.nn.Module):
|
||||||
multiplier = 1.0
|
multiplier = 1.0
|
||||||
activation_dict = {"relu": torch.nn.ReLU, "leakyrelu": torch.nn.LeakyReLU, "elu": torch.nn.ELU,
|
activation_dict = {
|
||||||
"swish": torch.nn.Hardswish}
|
"relu": torch.nn.ReLU,
|
||||||
|
"leakyrelu": torch.nn.LeakyReLU,
|
||||||
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None):
|
"elu": torch.nn.ELU,
|
||||||
|
"swish": torch.nn.Hardswish,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
assert layer_structure is not None, "layer_structure must not be None"
|
assert layer_structure is not None, "layer_structure must not be None"
|
||||||
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
||||||
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
||||||
|
assert activation_func not in self.activation_dict.keys() + "linear", f"Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'"
|
||||||
|
|
||||||
linears = []
|
linears = []
|
||||||
for i in range(len(layer_structure) - 1):
|
for i in range(len(layer_structure) - 1):
|
||||||
|
|
||||||
|
# Add a fully-connected layer
|
||||||
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
||||||
# if skip_first_layer because first parameters potentially contain negative values
|
|
||||||
# if i < 1: continue
|
# Add an activation func
|
||||||
if activation_func in HypernetworkModule.activation_dict:
|
if activation_func == "linear":
|
||||||
linears.append(HypernetworkModule.activation_dict[activation_func]())
|
pass
|
||||||
|
elif activation_func in self.activation_dict:
|
||||||
|
linears.append(self.activation_dict[activation_func]())
|
||||||
else:
|
else:
|
||||||
print("Invalid key {} encountered as activation function!".format(activation_func))
|
raise NotImplementedError(
|
||||||
# if use_dropout:
|
"Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'"
|
||||||
# linears.append(torch.nn.Dropout(p=0.3))
|
)
|
||||||
|
|
||||||
|
# Add dropout
|
||||||
|
if use_dropout:
|
||||||
|
linears.append(torch.nn.Dropout(p=0.3))
|
||||||
|
|
||||||
|
# Add layer normalization
|
||||||
if add_layer_norm:
|
if add_layer_norm:
|
||||||
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
||||||
|
|
||||||
|
@ -93,7 +106,7 @@ class Hypernetwork:
|
||||||
filename = None
|
filename = None
|
||||||
name = None
|
name = None
|
||||||
|
|
||||||
def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False, activation_func=None):
|
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
|
||||||
self.filename = None
|
self.filename = None
|
||||||
self.name = name
|
self.name = name
|
||||||
self.layers = {}
|
self.layers = {}
|
||||||
|
@ -101,13 +114,14 @@ class Hypernetwork:
|
||||||
self.sd_checkpoint = None
|
self.sd_checkpoint = None
|
||||||
self.sd_checkpoint_name = None
|
self.sd_checkpoint_name = None
|
||||||
self.layer_structure = layer_structure
|
self.layer_structure = layer_structure
|
||||||
self.add_layer_norm = add_layer_norm
|
|
||||||
self.activation_func = activation_func
|
self.activation_func = activation_func
|
||||||
|
self.add_layer_norm = add_layer_norm
|
||||||
|
self.use_dropout = use_dropout
|
||||||
|
|
||||||
for size in enable_sizes or []:
|
for size in enable_sizes or []:
|
||||||
self.layers[size] = (
|
self.layers[size] = (
|
||||||
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
|
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
|
||||||
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
|
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
|
||||||
)
|
)
|
||||||
|
|
||||||
def weights(self):
|
def weights(self):
|
||||||
|
@ -129,8 +143,9 @@ class Hypernetwork:
|
||||||
state_dict['step'] = self.step
|
state_dict['step'] = self.step
|
||||||
state_dict['name'] = self.name
|
state_dict['name'] = self.name
|
||||||
state_dict['layer_structure'] = self.layer_structure
|
state_dict['layer_structure'] = self.layer_structure
|
||||||
state_dict['is_layer_norm'] = self.add_layer_norm
|
|
||||||
state_dict['activation_func'] = self.activation_func
|
state_dict['activation_func'] = self.activation_func
|
||||||
|
state_dict['is_layer_norm'] = self.add_layer_norm
|
||||||
|
state_dict['use_dropout'] = self.use_dropout
|
||||||
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
||||||
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
||||||
|
|
||||||
|
@ -144,8 +159,9 @@ class Hypernetwork:
|
||||||
state_dict = torch.load(filename, map_location='cpu')
|
state_dict = torch.load(filename, map_location='cpu')
|
||||||
|
|
||||||
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
||||||
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
|
||||||
self.activation_func = state_dict.get('activation_func', None)
|
self.activation_func = state_dict.get('activation_func', None)
|
||||||
|
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
||||||
|
self.use_dropout = state_dict.get('use_dropout', False)
|
||||||
|
|
||||||
for size, sd in state_dict.items():
|
for size, sd in state_dict.items():
|
||||||
if type(size) == int:
|
if type(size) == int:
|
||||||
|
|
|
@ -3,14 +3,13 @@ import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
import modules.textual_inversion.textual_inversion
|
|
||||||
import modules.textual_inversion.preprocess
|
import modules.textual_inversion.preprocess
|
||||||
from modules import sd_hijack, shared, devices
|
import modules.textual_inversion.textual_inversion
|
||||||
|
from modules import devices, sd_hijack, shared
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
|
|
||||||
|
|
||||||
def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False, activation_func=None):
|
def create_hypernetwork(name, enable_sizes, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
|
||||||
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
||||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||||
|
|
||||||
|
@ -21,8 +20,9 @@ def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm
|
||||||
name=name,
|
name=name,
|
||||||
enable_sizes=[int(x) for x in enable_sizes],
|
enable_sizes=[int(x) for x in enable_sizes],
|
||||||
layer_structure=layer_structure,
|
layer_structure=layer_structure,
|
||||||
add_layer_norm=add_layer_norm,
|
|
||||||
activation_func=activation_func,
|
activation_func=activation_func,
|
||||||
|
add_layer_norm=add_layer_norm,
|
||||||
|
use_dropout=use_dropout,
|
||||||
)
|
)
|
||||||
hypernet.save(fn)
|
hypernet.save(fn)
|
||||||
|
|
||||||
|
|
|
@ -5,43 +5,44 @@ import json
|
||||||
import math
|
import math
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
|
import platform
|
||||||
import random
|
import random
|
||||||
|
import subprocess as sp
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import platform
|
|
||||||
import subprocess as sp
|
|
||||||
from functools import partial, reduce
|
from functools import partial, reduce
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
import gradio.routes
|
||||||
|
import gradio.utils
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import piexif
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
import piexif
|
|
||||||
|
|
||||||
import gradio as gr
|
from modules import localization, sd_hijack, sd_models
|
||||||
import gradio.utils
|
|
||||||
import gradio.routes
|
|
||||||
|
|
||||||
from modules import sd_hijack, sd_models, localization
|
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
from modules.shared import opts, cmd_opts, restricted_opts
|
from modules.shared import cmd_opts, opts, restricted_opts
|
||||||
|
|
||||||
if cmd_opts.deepdanbooru:
|
if cmd_opts.deepdanbooru:
|
||||||
from modules.deepbooru import get_deepbooru_tags
|
from modules.deepbooru import get_deepbooru_tags
|
||||||
import modules.shared as shared
|
|
||||||
from modules.sd_samplers import samplers, samplers_for_img2img
|
|
||||||
from modules.sd_hijack import model_hijack
|
|
||||||
import modules.ldsr_model
|
|
||||||
import modules.scripts
|
|
||||||
import modules.gfpgan_model
|
|
||||||
import modules.codeformer_model
|
import modules.codeformer_model
|
||||||
import modules.styles
|
|
||||||
import modules.generation_parameters_copypaste
|
import modules.generation_parameters_copypaste
|
||||||
from modules import prompt_parser
|
import modules.gfpgan_model
|
||||||
from modules.images import save_image
|
|
||||||
import modules.textual_inversion.ui
|
|
||||||
import modules.hypernetworks.ui
|
import modules.hypernetworks.ui
|
||||||
import modules.images_history as img_his
|
import modules.images_history as img_his
|
||||||
|
import modules.ldsr_model
|
||||||
|
import modules.scripts
|
||||||
|
import modules.shared as shared
|
||||||
|
import modules.styles
|
||||||
|
import modules.textual_inversion.ui
|
||||||
|
from modules import prompt_parser
|
||||||
|
from modules.images import save_image
|
||||||
|
from modules.sd_hijack import model_hijack
|
||||||
|
from modules.sd_samplers import samplers, samplers_for_img2img
|
||||||
|
|
||||||
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
|
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
|
||||||
mimetypes.init()
|
mimetypes.init()
|
||||||
|
@ -1223,8 +1224,9 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
new_hypernetwork_name = gr.Textbox(label="Name")
|
new_hypernetwork_name = gr.Textbox(label="Name")
|
||||||
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
|
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
|
||||||
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
|
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
|
||||||
|
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu", "elu", "swish"])
|
||||||
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
|
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
|
||||||
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"])
|
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=3):
|
with gr.Column(scale=3):
|
||||||
|
@ -1308,8 +1310,9 @@ def create_ui(wrap_gradio_gpu_call):
|
||||||
new_hypernetwork_name,
|
new_hypernetwork_name,
|
||||||
new_hypernetwork_sizes,
|
new_hypernetwork_sizes,
|
||||||
new_hypernetwork_layer_structure,
|
new_hypernetwork_layer_structure,
|
||||||
new_hypernetwork_add_layer_norm,
|
|
||||||
new_hypernetwork_activation_func,
|
new_hypernetwork_activation_func,
|
||||||
|
new_hypernetwork_add_layer_norm,
|
||||||
|
new_hypernetwork_use_dropout
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
train_hypernetwork_name,
|
train_hypernetwork_name,
|
||||||
|
|
Loading…
Reference in a new issue