fixed all lines PyCharm was nagging me about
fixed input verification not working properly with long textual inversion tokens in some cases (plus it will prevent incorrect outputs for forks that use the :::: prompt weighing method) changed process_images to object class with same fields as args it was previously accepting changed options system to make it possible to explicitly specify gradio objects with args
This commit is contained in:
parent
4e0fdca2f4
commit
c30aee2f4b
1 changed files with 288 additions and 275 deletions
563
webui.py
563
webui.py
|
@ -1,14 +1,13 @@
|
||||||
import argparse, os, sys, glob
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
||||||
from itertools import islice
|
|
||||||
from einops import rearrange, repeat
|
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import random
|
import random
|
||||||
|
@ -22,14 +21,13 @@ import k_diffusion.sampling
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from ldm.models.diffusion.plms import PLMSSampler
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
import ldm.modules.encoders.modules
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||||
|
|
||||||
from transformers import logging
|
from transformers import logging
|
||||||
logging.set_verbosity_error()
|
logging.set_verbosity_error()
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
|
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
|
||||||
|
@ -41,13 +39,13 @@ opt_C = 4
|
||||||
opt_f = 8
|
opt_f = 8
|
||||||
|
|
||||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||||
invalid_filename_chars = '<>:"/\|?*\n'
|
invalid_filename_chars = '<>:"/\\|?*\n'
|
||||||
config_filename = "config.json"
|
config_filename = "config.json"
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--config", type=str, default="configs/stable-diffusion/v1-inference.yaml", help="path to config which constructs model",)
|
parser.add_argument("--config", type=str, default="configs/stable-diffusion/v1-inference.yaml", help="path to config which constructs model",)
|
||||||
parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",)
|
parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",)
|
||||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) # i disagree with where you're putting it but since all guidefags are doing it this way, there you go
|
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||||
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
||||||
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)")
|
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)")
|
||||||
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
||||||
|
@ -64,7 +62,7 @@ css_hide_progressbar = """
|
||||||
|
|
||||||
SamplerData = namedtuple('SamplerData', ['name', 'constructor'])
|
SamplerData = namedtuple('SamplerData', ['name', 'constructor'])
|
||||||
samplers = [
|
samplers = [
|
||||||
*[SamplerData(x[0], lambda m, funcname=x[1]: KDiffusionSampler(m, funcname)) for x in [
|
*[SamplerData(x[0], lambda funcname=x[1]: KDiffusionSampler(funcname)) for x in [
|
||||||
('LMS', 'sample_lms'),
|
('LMS', 'sample_lms'),
|
||||||
('Heun', 'sample_heun'),
|
('Heun', 'sample_heun'),
|
||||||
('Euler', 'sample_euler'),
|
('Euler', 'sample_euler'),
|
||||||
|
@ -72,8 +70,8 @@ samplers = [
|
||||||
('DPM 2', 'sample_dpm_2'),
|
('DPM 2', 'sample_dpm_2'),
|
||||||
('DPM 2 Ancestral', 'sample_dpm_2_ancestral'),
|
('DPM 2 Ancestral', 'sample_dpm_2_ancestral'),
|
||||||
] if hasattr(k_diffusion.sampling, x[1])],
|
] if hasattr(k_diffusion.sampling, x[1])],
|
||||||
SamplerData('DDIM', lambda m: DDIMSampler(model)),
|
SamplerData('DDIM', lambda: VanillaStableDiffusionSampler(DDIMSampler)),
|
||||||
SamplerData('PLMS', lambda m: PLMSSampler(model)),
|
SamplerData('PLMS', lambda: VanillaStableDiffusionSampler(PLMSSampler)),
|
||||||
]
|
]
|
||||||
samplers_for_img2img = [x for x in samplers if x.name != 'DDIM' and x.name != 'PLMS']
|
samplers_for_img2img = [x for x in samplers if x.name != 'DDIM' and x.name != 'PLMS']
|
||||||
|
|
||||||
|
@ -102,7 +100,7 @@ try:
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
have_realesrgan = True
|
have_realesrgan = True
|
||||||
except:
|
except Exception:
|
||||||
print("Error loading Real-ESRGAN:", file=sys.stderr)
|
print("Error loading Real-ESRGAN:", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
@ -111,24 +109,30 @@ except:
|
||||||
|
|
||||||
|
|
||||||
class Options:
|
class Options:
|
||||||
|
class OptionInfo:
|
||||||
|
def __init__(self, default=None, label="", component=None, component_args=None):
|
||||||
|
self.default = default
|
||||||
|
self.label = label
|
||||||
|
self.component = component
|
||||||
|
self.component_args = component_args
|
||||||
|
|
||||||
data = None
|
data = None
|
||||||
data_labels = {
|
data_labels = {
|
||||||
"outdir": ("", "Output dictectory; if empty, defaults to 'outputs/*'"),
|
"outdir": OptionInfo("", "Output dictectory; if empty, defaults to 'outputs/*'"),
|
||||||
"samples_save": (True, "Save indiviual samples"),
|
"samples_save": OptionInfo(True, "Save indiviual samples"),
|
||||||
"samples_format": ('png', 'File format for indiviual samples'),
|
"samples_format": OptionInfo('png', 'File format for indiviual samples'),
|
||||||
"grid_save": (True, "Save image grids"),
|
"grid_save": OptionInfo(True, "Save image grids"),
|
||||||
"grid_format": ('png', 'File format for grids'),
|
"grid_format": OptionInfo('png', 'File format for grids'),
|
||||||
"grid_extended_filename": (False, "Add extended info (seed, prompt) to filename when saving grid"),
|
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
|
||||||
"n_rows": (-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", -1, 16),
|
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
|
||||||
"jpeg_quality": (80, "Quality for saved jpeg images", 1, 100),
|
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
|
||||||
"verify_input": (True, "Check input, and produce warning if it's too long"),
|
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
|
||||||
"enable_pnginfo": (True, "Save text information about generation parameters as chunks to png files"),
|
"prompt_matrix_add_to_start": OptionInfo(True, "In prompt matrix, add the variable combination of text to the start of the prompt, rather than the end"),
|
||||||
"prompt_matrix_add_to_start": (True, "In prompt matrix, add the variable combination of text to the start of the prompt, rather than the end"),
|
"sd_upscale_overlap": OptionInfo(64, "Overlap for tiles for SD upscale. The smaller it is, the less smooth transition from one tile to another", gr.Slider, {"minimum": 0, "maximum": 256, "step": 16}),
|
||||||
"sd_upscale_overlap": (64, "Overlap for tiles for SD upscale. The smaller it is, the less smooth transition from one tile to another", 0, 256, 16),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.data = {k: v[0] for k, v in self.data_labels.items()}
|
self.data = {k: v.default for k, v in self.data_labels.items()}
|
||||||
|
|
||||||
def __setattr__(self, key, value):
|
def __setattr__(self, key, value):
|
||||||
if self.data is not None:
|
if self.data is not None:
|
||||||
|
@ -143,7 +147,7 @@ class Options:
|
||||||
return self.data[item]
|
return self.data[item]
|
||||||
|
|
||||||
if item in self.data_labels:
|
if item in self.data_labels:
|
||||||
return self.data_labels[item][0]
|
return self.data_labels[item].default
|
||||||
|
|
||||||
return super(Options, self).__getattribute__(item)
|
return super(Options, self).__getattribute__(item)
|
||||||
|
|
||||||
|
@ -156,11 +160,6 @@ class Options:
|
||||||
self.data = json.load(file)
|
self.data = json.load(file)
|
||||||
|
|
||||||
|
|
||||||
def chunk(it, size):
|
|
||||||
it = iter(it)
|
|
||||||
return iter(lambda: tuple(islice(it, size)), ())
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_config(config, ckpt, verbose=False):
|
def load_model_from_config(config, ckpt, verbose=False):
|
||||||
print(f"Loading model from {ckpt}")
|
print(f"Loading model from {ckpt}")
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||||
|
@ -181,36 +180,6 @@ def load_model_from_config(config, ckpt, verbose=False):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
class CFGDenoiser(nn.Module):
|
|
||||||
def __init__(self, model):
|
|
||||||
super().__init__()
|
|
||||||
self.inner_model = model
|
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
|
||||||
x_in = torch.cat([x] * 2)
|
|
||||||
sigma_in = torch.cat([sigma] * 2)
|
|
||||||
cond_in = torch.cat([uncond, cond])
|
|
||||||
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
|
||||||
return uncond + (cond - uncond) * cond_scale
|
|
||||||
|
|
||||||
|
|
||||||
class KDiffusionSampler:
|
|
||||||
def __init__(self, m, funcname):
|
|
||||||
self.model = m
|
|
||||||
self.model_wrap = k_diffusion.external.CompVisDenoiser(m)
|
|
||||||
self.funcname = funcname
|
|
||||||
self.func = getattr(k_diffusion.sampling, self.funcname)
|
|
||||||
|
|
||||||
def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T):
|
|
||||||
sigmas = self.model_wrap.get_sigmas(S)
|
|
||||||
x = x_T * sigmas[0]
|
|
||||||
model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
|
||||||
|
|
||||||
samples_ddim = self.func(model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False)
|
|
||||||
|
|
||||||
return samples_ddim, None
|
|
||||||
|
|
||||||
|
|
||||||
def create_random_tensors(shape, seeds):
|
def create_random_tensors(shape, seeds):
|
||||||
xs = []
|
xs = []
|
||||||
for seed in seeds:
|
for seed in seeds:
|
||||||
|
@ -256,7 +225,7 @@ def plaintext_to_html(text):
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def load_GFPGAN():
|
def load_gfpgan():
|
||||||
model_name = 'GFPGANv1.3'
|
model_name = 'GFPGANv1.3'
|
||||||
model_path = os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models', model_name + '.pth')
|
model_path = os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models', model_name + '.pth')
|
||||||
if not os.path.isfile(model_path):
|
if not os.path.isfile(model_path):
|
||||||
|
@ -358,7 +327,7 @@ def combine_grid(grid):
|
||||||
|
|
||||||
|
|
||||||
def draw_prompt_matrix(im, width, height, all_prompts):
|
def draw_prompt_matrix(im, width, height, all_prompts):
|
||||||
def wrap(text, d, font, line_length):
|
def wrap(text, font, line_length):
|
||||||
lines = ['']
|
lines = ['']
|
||||||
for word in text.split():
|
for word in text.split():
|
||||||
line = f'{lines[-1]} {word}'.strip()
|
line = f'{lines[-1]} {word}'.strip()
|
||||||
|
@ -368,16 +337,16 @@ def draw_prompt_matrix(im, width, height, all_prompts):
|
||||||
lines.append(word)
|
lines.append(word)
|
||||||
return '\n'.join(lines)
|
return '\n'.join(lines)
|
||||||
|
|
||||||
def draw_texts(pos, x, y, texts, sizes):
|
def draw_texts(pos, draw_x, draw_y, texts, sizes):
|
||||||
for i, (text, size) in enumerate(zip(texts, sizes)):
|
for i, (text, size) in enumerate(zip(texts, sizes)):
|
||||||
active = pos & (1 << i) != 0
|
active = pos & (1 << i) != 0
|
||||||
|
|
||||||
if not active:
|
if not active:
|
||||||
text = '\u0336'.join(text) + '\u0336'
|
text = '\u0336'.join(text) + '\u0336'
|
||||||
|
|
||||||
d.multiline_text((x, y + size[1] / 2), text, font=fnt, fill=color_active if active else color_inactive, anchor="mm", align="center")
|
d.multiline_text((draw_x, draw_y + size[1] / 2), text, font=fnt, fill=color_active if active else color_inactive, anchor="mm", align="center")
|
||||||
|
|
||||||
y += size[1] + line_spacing
|
draw_y += size[1] + line_spacing
|
||||||
|
|
||||||
fontsize = (width + height) // 25
|
fontsize = (width + height) // 25
|
||||||
line_spacing = fontsize // 2
|
line_spacing = fontsize // 2
|
||||||
|
@ -399,8 +368,8 @@ def draw_prompt_matrix(im, width, height, all_prompts):
|
||||||
d = ImageDraw.Draw(result)
|
d = ImageDraw.Draw(result)
|
||||||
|
|
||||||
boundary = math.ceil(len(prompts) / 2)
|
boundary = math.ceil(len(prompts) / 2)
|
||||||
prompts_horiz = [wrap(x, d, fnt, width) for x in prompts[:boundary]]
|
prompts_horiz = [wrap(x, fnt, width) for x in prompts[:boundary]]
|
||||||
prompts_vert = [wrap(x, d, fnt, pad_left) for x in prompts[boundary:]]
|
prompts_vert = [wrap(x, fnt, pad_left) for x in prompts[boundary:]]
|
||||||
|
|
||||||
sizes_hor = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_horiz]]
|
sizes_hor = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_horiz]]
|
||||||
sizes_ver = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_vert]]
|
sizes_ver = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_vert]]
|
||||||
|
@ -458,25 +427,6 @@ def resize_image(resize_mode, im, width, height):
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def check_prompt_length(prompt, comments):
|
|
||||||
"""this function tests if prompt is too long, and if so, adds a message to comments"""
|
|
||||||
|
|
||||||
tokenizer = model.cond_stage_model.tokenizer
|
|
||||||
max_length = model.cond_stage_model.max_length
|
|
||||||
|
|
||||||
info = model.cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length, return_overflowing_tokens=True, padding="max_length", return_tensors="pt")
|
|
||||||
ovf = info['overflowing_tokens'][0]
|
|
||||||
overflowing_count = ovf.shape[0]
|
|
||||||
if overflowing_count == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
vocab = {v: k for k, v in tokenizer.get_vocab().items()}
|
|
||||||
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
|
||||||
overflowing_text = tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
|
||||||
|
|
||||||
comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
|
||||||
|
|
||||||
|
|
||||||
def wrap_gradio_call(func):
|
def wrap_gradio_call(func):
|
||||||
def f(*p1, **p2):
|
def f(*p1, **p2):
|
||||||
t = time.perf_counter()
|
t = time.perf_counter()
|
||||||
|
@ -494,7 +444,7 @@ def wrap_gradio_call(func):
|
||||||
GFPGAN = None
|
GFPGAN = None
|
||||||
if os.path.exists(cmd_opts.gfpgan_dir):
|
if os.path.exists(cmd_opts.gfpgan_dir):
|
||||||
try:
|
try:
|
||||||
GFPGAN = load_GFPGAN()
|
GFPGAN = load_gfpgan()
|
||||||
print("Loaded GFPGAN")
|
print("Loaded GFPGAN")
|
||||||
except Exception:
|
except Exception:
|
||||||
print("Error loading GFPGAN:", file=sys.stderr)
|
print("Error loading GFPGAN:", file=sys.stderr)
|
||||||
|
@ -506,11 +456,11 @@ class StableDiffuionModelHijack:
|
||||||
word_embeddings = {}
|
word_embeddings = {}
|
||||||
word_embeddings_checksums = {}
|
word_embeddings_checksums = {}
|
||||||
fixes = None
|
fixes = None
|
||||||
used_custom_terms = []
|
comments = None
|
||||||
dir_mtime = None
|
dir_mtime = None
|
||||||
|
|
||||||
def load_textual_inversion_embeddings(self, dir, model):
|
def load_textual_inversion_embeddings(self, dirname, model):
|
||||||
mt = os.path.getmtime(dir)
|
mt = os.path.getmtime(dirname)
|
||||||
if self.dir_mtime is not None and mt <= self.dir_mtime:
|
if self.dir_mtime is not None and mt <= self.dir_mtime:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -543,10 +493,10 @@ class StableDiffuionModelHijack:
|
||||||
self.ids_lookup[first_id] = []
|
self.ids_lookup[first_id] = []
|
||||||
self.ids_lookup[first_id].append((ids, name))
|
self.ids_lookup[first_id].append((ids, name))
|
||||||
|
|
||||||
for fn in os.listdir(dir):
|
for fn in os.listdir(dirname):
|
||||||
try:
|
try:
|
||||||
process_file(os.path.join(dir, fn), fn)
|
process_file(os.path.join(dirname, fn), fn)
|
||||||
except:
|
except Exception:
|
||||||
print(f"Error loading emedding {fn}:", file=sys.stderr)
|
print(f"Error loading emedding {fn}:", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
continue
|
continue
|
||||||
|
@ -561,10 +511,10 @@ class StableDiffuionModelHijack:
|
||||||
|
|
||||||
|
|
||||||
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
def __init__(self, wrapped, embeddings):
|
def __init__(self, wrapped, hijack):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.wrapped = wrapped
|
self.wrapped = wrapped
|
||||||
self.embeddings = embeddings
|
self.hijack = hijack
|
||||||
self.tokenizer = wrapped.tokenizer
|
self.tokenizer = wrapped.tokenizer
|
||||||
self.max_length = wrapped.max_length
|
self.max_length = wrapped.max_length
|
||||||
self.token_mults = {}
|
self.token_mults = {}
|
||||||
|
@ -586,12 +536,13 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
self.token_mults[ident] = mult
|
self.token_mults[ident] = mult
|
||||||
|
|
||||||
def forward(self, text):
|
def forward(self, text):
|
||||||
self.embeddings.fixes = []
|
self.hijack.fixes = []
|
||||||
self.embeddings.used_custom_terms = []
|
self.hijack.comments = []
|
||||||
remade_batch_tokens = []
|
remade_batch_tokens = []
|
||||||
id_start = self.wrapped.tokenizer.bos_token_id
|
id_start = self.wrapped.tokenizer.bos_token_id
|
||||||
id_end = self.wrapped.tokenizer.eos_token_id
|
id_end = self.wrapped.tokenizer.eos_token_id
|
||||||
maxlen = self.wrapped.max_length - 2
|
maxlen = self.wrapped.max_length - 2
|
||||||
|
used_custom_terms = []
|
||||||
|
|
||||||
cache = {}
|
cache = {}
|
||||||
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
|
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
|
||||||
|
@ -611,7 +562,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
while i < len(tokens):
|
while i < len(tokens):
|
||||||
token = tokens[i]
|
token = tokens[i]
|
||||||
|
|
||||||
possible_matches = self.embeddings.ids_lookup.get(token, None)
|
possible_matches = self.hijack.ids_lookup.get(token, None)
|
||||||
|
|
||||||
mult_change = self.token_mults.get(token)
|
mult_change = self.token_mults.get(token)
|
||||||
if mult_change is not None:
|
if mult_change is not None:
|
||||||
|
@ -628,7 +579,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
multipliers.append(mult)
|
multipliers.append(mult)
|
||||||
i += len(ids) - 1
|
i += len(ids) - 1
|
||||||
found = True
|
found = True
|
||||||
self.embeddings.used_custom_terms.append((word, self.embeddings.word_embeddings_checksums[word]))
|
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
|
||||||
break
|
break
|
||||||
|
|
||||||
if not found:
|
if not found:
|
||||||
|
@ -637,6 +588,14 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
|
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
|
if len(remade_tokens) > maxlen - 2:
|
||||||
|
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
||||||
|
ovf = remade_tokens[maxlen - 2:]
|
||||||
|
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
||||||
|
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
||||||
|
|
||||||
|
self.hijack.comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||||
|
|
||||||
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||||
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
|
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
|
||||||
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
||||||
|
@ -645,9 +604,12 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
||||||
|
|
||||||
remade_batch_tokens.append(remade_tokens)
|
remade_batch_tokens.append(remade_tokens)
|
||||||
self.embeddings.fixes.append(fixes)
|
self.hijack.fixes.append(fixes)
|
||||||
batch_multipliers.append(multipliers)
|
batch_multipliers.append(multipliers)
|
||||||
|
|
||||||
|
if len(used_custom_terms) > 0:
|
||||||
|
self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
||||||
|
|
||||||
tokens = torch.asarray(remade_batch_tokens).to(self.wrapped.device)
|
tokens = torch.asarray(remade_batch_tokens).to(self.wrapped.device)
|
||||||
outputs = self.wrapped.transformer(input_ids=tokens)
|
outputs = self.wrapped.transformer(input_ids=tokens)
|
||||||
z = outputs.last_hidden_state
|
z = outputs.last_hidden_state
|
||||||
|
@ -679,71 +641,123 @@ class EmbeddingsWithFixes(nn.Module):
|
||||||
for offset, word in fixes:
|
for offset, word in fixes:
|
||||||
tensor[offset] = self.embeddings.word_embeddings[word]
|
tensor[offset] = self.embeddings.word_embeddings[word]
|
||||||
|
|
||||||
|
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
|
|
||||||
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False, extra_generation_params=None):
|
class StableDiffusionProcessing:
|
||||||
|
def __init__(self, outpath=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, prompt_matrix=False, use_GFPGAN=False, do_not_save_grid=False, extra_generation_params=None):
|
||||||
|
self.outpath: str = outpath
|
||||||
|
self.prompt: str = prompt
|
||||||
|
self.seed: int = seed
|
||||||
|
self.sampler_index: int = sampler_index
|
||||||
|
self.batch_size: int = batch_size
|
||||||
|
self.n_iter: int = n_iter
|
||||||
|
self.steps: int = steps
|
||||||
|
self.cfg_scale: float = cfg_scale
|
||||||
|
self.width: int = width
|
||||||
|
self.height: int = height
|
||||||
|
self.prompt_matrix: bool = prompt_matrix
|
||||||
|
self.use_GFPGAN: bool = use_GFPGAN
|
||||||
|
self.do_not_save_grid: bool = do_not_save_grid
|
||||||
|
self.extra_generation_params: dict = extra_generation_params
|
||||||
|
|
||||||
|
def init(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def sample(self, x, conditioning, unconditional_conditioning):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class VanillaStableDiffusionSampler:
|
||||||
|
def __init__(self, constructor):
|
||||||
|
self.sampler = constructor(sd_model)
|
||||||
|
|
||||||
|
def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning):
|
||||||
|
samples_ddim, _ = self.sampler.sample(S=p.steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x)
|
||||||
|
return samples_ddim
|
||||||
|
|
||||||
|
|
||||||
|
class CFGDenoiser(nn.Module):
|
||||||
|
def __init__(self, model):
|
||||||
|
super().__init__()
|
||||||
|
self.inner_model = model
|
||||||
|
|
||||||
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||||
|
x_in = torch.cat([x] * 2)
|
||||||
|
sigma_in = torch.cat([sigma] * 2)
|
||||||
|
cond_in = torch.cat([uncond, cond])
|
||||||
|
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
||||||
|
return uncond + (cond - uncond) * cond_scale
|
||||||
|
|
||||||
|
|
||||||
|
class KDiffusionSampler:
|
||||||
|
def __init__(self, funcname):
|
||||||
|
self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model)
|
||||||
|
self.funcname = funcname
|
||||||
|
self.func = getattr(k_diffusion.sampling, self.funcname)
|
||||||
|
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
||||||
|
|
||||||
|
def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning):
|
||||||
|
sigmas = self.model_wrap.get_sigmas(p.steps)
|
||||||
|
x = x * sigmas[0]
|
||||||
|
|
||||||
|
samples_ddim = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False)
|
||||||
|
return samples_ddim
|
||||||
|
|
||||||
|
|
||||||
|
def process_images(p: StableDiffusionProcessing):
|
||||||
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
||||||
|
|
||||||
assert prompt is not None
|
prompt = p.prompt
|
||||||
|
model = sd_model
|
||||||
|
|
||||||
|
assert p.prompt is not None
|
||||||
torch_gc()
|
torch_gc()
|
||||||
|
|
||||||
if seed == -1:
|
seed = int(random.randrange(4294967294) if p.seed == -1 else p.seed)
|
||||||
seed = random.randrange(4294967294)
|
|
||||||
seed = int(seed)
|
|
||||||
|
|
||||||
os.makedirs(outpath, exist_ok=True)
|
os.makedirs(p.outpath, exist_ok=True)
|
||||||
|
|
||||||
sample_path = os.path.join(outpath, "samples")
|
sample_path = os.path.join(p.outpath, "samples")
|
||||||
os.makedirs(sample_path, exist_ok=True)
|
os.makedirs(sample_path, exist_ok=True)
|
||||||
base_count = len(os.listdir(sample_path))
|
base_count = len(os.listdir(sample_path))
|
||||||
grid_count = len(os.listdir(outpath)) - 1
|
grid_count = len(os.listdir(p.outpath)) - 1
|
||||||
|
|
||||||
comments = []
|
comments = []
|
||||||
|
|
||||||
prompt_matrix_parts = []
|
prompt_matrix_parts = []
|
||||||
if prompt_matrix:
|
if p.prompt_matrix:
|
||||||
all_prompts = []
|
all_prompts = []
|
||||||
prompt_matrix_parts = prompt.split("|")
|
prompt_matrix_parts = prompt.split("|")
|
||||||
combination_count = 2 ** (len(prompt_matrix_parts) - 1)
|
combination_count = 2 ** (len(prompt_matrix_parts) - 1)
|
||||||
for combination_num in range(combination_count):
|
for combination_num in range(combination_count):
|
||||||
selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1<<n)]
|
selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1 << n)]
|
||||||
|
|
||||||
if opts.prompt_matrix_add_to_start:
|
if opts.prompt_matrix_add_to_start:
|
||||||
selected_prompts = selected_prompts + [prompt_matrix_parts[0]]
|
selected_prompts = selected_prompts + [prompt_matrix_parts[0]]
|
||||||
else:
|
else:
|
||||||
selected_prompts = [prompt_matrix_parts[0]] + selected_prompts
|
selected_prompts = [prompt_matrix_parts[0]] + selected_prompts
|
||||||
|
|
||||||
all_prompts.append( ", ".join(selected_prompts))
|
all_prompts.append(", ".join(selected_prompts))
|
||||||
|
|
||||||
n_iter = math.ceil(len(all_prompts) / batch_size)
|
p.n_iter = math.ceil(len(all_prompts) / p.batch_size)
|
||||||
all_seeds = len(all_prompts) * [seed]
|
all_seeds = len(all_prompts) * [seed]
|
||||||
|
|
||||||
print(f"Prompt matrix will create {len(all_prompts)} images using a total of {n_iter} batches.")
|
print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.")
|
||||||
else:
|
else:
|
||||||
|
all_prompts = p.batch_size * p.n_iter * [prompt]
|
||||||
if opts.verify_input:
|
|
||||||
try:
|
|
||||||
check_prompt_length(prompt, comments)
|
|
||||||
except:
|
|
||||||
import traceback
|
|
||||||
print("Error verifying input:", file=sys.stderr)
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
all_prompts = batch_size * n_iter * [prompt]
|
|
||||||
all_seeds = [seed + x for x in range(len(all_prompts))]
|
all_seeds = [seed + x for x in range(len(all_prompts))]
|
||||||
|
|
||||||
generation_params = {
|
generation_params = {
|
||||||
"Steps": steps,
|
"Steps": p.steps,
|
||||||
"Sampler": samplers[sampler_index].name,
|
"Sampler": samplers[p.sampler_index].name,
|
||||||
"CFG scale": cfg_scale,
|
"CFG scale": p.cfg_scale,
|
||||||
"Seed": seed,
|
"Seed": seed,
|
||||||
"GFPGAN": ("GFPGAN" if use_GFPGAN and GFPGAN is not None else None)
|
"GFPGAN": ("GFPGAN" if p.use_GFPGAN and GFPGAN is not None else None)
|
||||||
}
|
}
|
||||||
|
|
||||||
if extra_generation_params is not None:
|
if p.extra_generation_params is not None:
|
||||||
generation_params.update(extra_generation_params)
|
generation_params.update(p.extra_generation_params)
|
||||||
|
|
||||||
generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
|
generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
|
||||||
|
|
||||||
|
@ -755,32 +769,32 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index,
|
||||||
|
|
||||||
output_images = []
|
output_images = []
|
||||||
with torch.no_grad(), autocast("cuda"), model.ema_scope():
|
with torch.no_grad(), autocast("cuda"), model.ema_scope():
|
||||||
init_data = func_init()
|
p.init()
|
||||||
|
|
||||||
for n in range(n_iter):
|
for n in range(p.n_iter):
|
||||||
prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
|
prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
seeds = all_seeds[n * batch_size:(n + 1) * batch_size]
|
seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
|
|
||||||
uc = model.get_learned_conditioning(len(prompts) * [""])
|
uc = model.get_learned_conditioning(len(prompts) * [""])
|
||||||
c = model.get_learned_conditioning(prompts)
|
c = model.get_learned_conditioning(prompts)
|
||||||
|
|
||||||
if len(model_hijack.used_custom_terms) > 0:
|
if len(model_hijack.comments) > 0:
|
||||||
comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in model_hijack.used_custom_terms]))
|
comments += model_hijack.comments
|
||||||
|
|
||||||
# we manually generate all input noises because each one should have a specific seed
|
# we manually generate all input noises because each one should have a specific seed
|
||||||
x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds)
|
x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds)
|
||||||
|
|
||||||
samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc)
|
samples_ddim = p.sample(x=x, conditioning=c, unconditional_conditioning=uc)
|
||||||
|
|
||||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
if prompt_matrix or opts.samples_save or opts.grid_save:
|
if p.prompt_matrix or opts.samples_save or opts.grid_save:
|
||||||
for i, x_sample in enumerate(x_samples_ddim):
|
for i, x_sample in enumerate(x_samples_ddim):
|
||||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||||
x_sample = x_sample.astype(np.uint8)
|
x_sample = x_sample.astype(np.uint8)
|
||||||
|
|
||||||
if use_GFPGAN and GFPGAN is not None:
|
if p.use_GFPGAN and GFPGAN is not None:
|
||||||
torch_gc()
|
torch_gc()
|
||||||
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample, has_aligned=False, only_center_face=False, paste_back=True)
|
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample, has_aligned=False, only_center_face=False, paste_back=True)
|
||||||
x_sample = restored_img
|
x_sample = restored_img
|
||||||
|
@ -791,44 +805,44 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index,
|
||||||
output_images.append(image)
|
output_images.append(image)
|
||||||
base_count += 1
|
base_count += 1
|
||||||
|
|
||||||
if (prompt_matrix or opts.grid_save) and not do_not_save_grid:
|
if (p.prompt_matrix or opts.grid_save) and not p.do_not_save_grid:
|
||||||
if prompt_matrix:
|
if p.prompt_matrix:
|
||||||
grid = image_grid(output_images, batch_size, force_n_rows=1 << ((len(prompt_matrix_parts)-1)//2))
|
grid = image_grid(output_images, p.batch_size, force_n_rows=1 << ((len(prompt_matrix_parts)-1)//2))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
grid = draw_prompt_matrix(grid, width, height, prompt_matrix_parts)
|
grid = draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts)
|
||||||
except:
|
except Exception:
|
||||||
import traceback
|
import traceback
|
||||||
print("Error creating prompt_matrix text:", file=sys.stderr)
|
print("Error creating prompt_matrix text:", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
output_images.insert(0, grid)
|
output_images.insert(0, grid)
|
||||||
else:
|
else:
|
||||||
grid = image_grid(output_images, batch_size)
|
grid = image_grid(output_images, p.batch_size)
|
||||||
|
|
||||||
save_image(grid, outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
|
save_image(grid, p.outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
|
||||||
grid_count += 1
|
grid_count += 1
|
||||||
|
|
||||||
torch_gc()
|
torch_gc()
|
||||||
return output_images, seed, infotext()
|
return output_images, seed, infotext()
|
||||||
|
|
||||||
|
|
||||||
def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int):
|
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
outpath = opts.outdir or "outputs/txt2img-samples"
|
sampler = None
|
||||||
|
|
||||||
sampler = samplers[sampler_index].constructor(model)
|
def init(self):
|
||||||
|
self.sampler = samplers[self.sampler_index].constructor()
|
||||||
|
|
||||||
def init():
|
def sample(self, x, conditioning, unconditional_conditioning):
|
||||||
pass
|
samples_ddim = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
|
||||||
|
|
||||||
def sample(init_data, x, conditioning, unconditional_conditioning):
|
|
||||||
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x)
|
|
||||||
return samples_ddim
|
return samples_ddim
|
||||||
|
|
||||||
output_images, seed, info = process_images(
|
|
||||||
|
def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int):
|
||||||
|
outpath = opts.outdir or "outputs/txt2img-samples"
|
||||||
|
|
||||||
|
p = StableDiffusionProcessingTxt2Img(
|
||||||
outpath=outpath,
|
outpath=outpath,
|
||||||
func_init=init,
|
|
||||||
func_sample=sample,
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
sampler_index=sampler_index,
|
sampler_index=sampler_index,
|
||||||
|
@ -842,7 +856,7 @@ def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool,
|
||||||
use_GFPGAN=use_GFPGAN
|
use_GFPGAN=use_GFPGAN
|
||||||
)
|
)
|
||||||
|
|
||||||
del sampler
|
output_images, seed, info = process_images(p)
|
||||||
|
|
||||||
return output_images, seed, plaintext_to_html(info)
|
return output_images, seed, plaintext_to_html(info)
|
||||||
|
|
||||||
|
@ -858,7 +872,7 @@ class Flagging(gr.FlaggingCallback):
|
||||||
os.makedirs("log/images", exist_ok=True)
|
os.makedirs("log/images", exist_ok=True)
|
||||||
|
|
||||||
# those must match the "txt2img" function
|
# those must match the "txt2img" function
|
||||||
prompt, ddim_steps, sampler_name, use_GFPGAN, prompt_matrix, ddim_eta, n_iter, n_samples, cfg_scale, request_seed, height, width, images, seed, comment = flag_data
|
prompt, ddim_steps, sampler_name, use_gfpgan, prompt_matrix, ddim_eta, n_iter, n_samples, cfg_scale, request_seed, height, width, images, seed, comment = flag_data
|
||||||
|
|
||||||
filenames = []
|
filenames = []
|
||||||
|
|
||||||
|
@ -896,7 +910,6 @@ txt2img_interface = gr.Interface(
|
||||||
gr.Radio(label='Sampling method', choices=[x.name for x in samplers], value=samplers[0].name, type="index"),
|
gr.Radio(label='Sampling method', choices=[x.name for x in samplers], value=samplers[0].name, type="index"),
|
||||||
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
|
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
|
||||||
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
|
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
|
||||||
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
|
|
||||||
gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1),
|
gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1),
|
||||||
gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
|
gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
|
||||||
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
|
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
|
||||||
|
@ -914,40 +927,77 @@ txt2img_interface = gr.Interface(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
|
sampler = None
|
||||||
|
|
||||||
|
def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.init_images = init_images
|
||||||
|
self.resize_mode: int = resize_mode
|
||||||
|
self.denoising_strength: float = denoising_strength
|
||||||
|
self.init_latent = None
|
||||||
|
|
||||||
|
def init(self):
|
||||||
|
self.sampler = samplers_for_img2img[self.sampler_index].constructor()
|
||||||
|
|
||||||
|
imgs = []
|
||||||
|
for img in self.init_images:
|
||||||
|
image = img.convert("RGB")
|
||||||
|
image = resize_image(self.resize_mode, image, self.width, self.height)
|
||||||
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
|
image = np.moveaxis(image, 2, 0)
|
||||||
|
imgs.append(image)
|
||||||
|
|
||||||
|
if len(imgs) == 1:
|
||||||
|
batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
|
||||||
|
elif len(imgs) <= self.batch_size:
|
||||||
|
self.batch_size = len(imgs)
|
||||||
|
batch_images = np.array(imgs)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
|
||||||
|
|
||||||
|
image = torch.from_numpy(batch_images)
|
||||||
|
image = 2. * image - 1.
|
||||||
|
image = image.to(device)
|
||||||
|
|
||||||
|
self.init_latent = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image))
|
||||||
|
|
||||||
|
def sample(self, x, conditioning, unconditional_conditioning):
|
||||||
|
t_enc = int(self.denoising_strength * self.steps)
|
||||||
|
|
||||||
|
sigmas = self.sampler.model_wrap.get_sigmas(self.steps)
|
||||||
|
noise = x * sigmas[self.steps - t_enc - 1]
|
||||||
|
|
||||||
|
xi = self.init_latent + noise
|
||||||
|
sigma_sched = sigmas[self.steps - t_enc - 1:]
|
||||||
|
samples_ddim = self.sampler.func(self.sampler.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': self.cfg_scale}, disable=False)
|
||||||
|
return samples_ddim
|
||||||
|
|
||||||
|
|
||||||
def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, sd_upscale: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
|
def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, sd_upscale: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
|
||||||
outpath = opts.outdir or "outputs/img2img-samples"
|
outpath = opts.outdir or "outputs/img2img-samples"
|
||||||
|
|
||||||
sampler = samplers_for_img2img[sampler_index].constructor(model)
|
|
||||||
|
|
||||||
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||||
|
|
||||||
def init():
|
p = StableDiffusionProcessingImg2Img(
|
||||||
image = init_img.convert("RGB")
|
outpath=outpath,
|
||||||
image = resize_image(resize_mode, image, width, height)
|
prompt=prompt,
|
||||||
image = np.array(image).astype(np.float32) / 255.0
|
seed=seed,
|
||||||
image = image[None].transpose(0, 3, 1, 2)
|
sampler_index=sampler_index,
|
||||||
image = torch.from_numpy(image)
|
batch_size=batch_size,
|
||||||
|
n_iter=n_iter,
|
||||||
init_image = 2. * image - 1.
|
steps=ddim_steps,
|
||||||
init_image = init_image.to(device)
|
cfg_scale=cfg_scale,
|
||||||
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
width=width,
|
||||||
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
|
height=height,
|
||||||
|
prompt_matrix=prompt_matrix,
|
||||||
return init_latent,
|
use_GFPGAN=use_GFPGAN,
|
||||||
|
init_images=[init_img],
|
||||||
def sample(init_data, x, conditioning, unconditional_conditioning):
|
resize_mode=resize_mode,
|
||||||
t_enc = int(denoising_strength * ddim_steps)
|
denoising_strength=denoising_strength,
|
||||||
|
extra_generation_params={"Denoising Strength": denoising_strength}
|
||||||
x0, = init_data
|
)
|
||||||
|
|
||||||
sigmas = sampler.model_wrap.get_sigmas(ddim_steps)
|
|
||||||
noise = x * sigmas[ddim_steps - t_enc - 1]
|
|
||||||
|
|
||||||
xi = x0 + noise
|
|
||||||
sigma_sched = sigmas[ddim_steps - t_enc - 1:]
|
|
||||||
model_wrap_cfg = CFGDenoiser(sampler.model_wrap)
|
|
||||||
samples_ddim = sampler.func(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False)
|
|
||||||
return samples_ddim
|
|
||||||
|
|
||||||
if loopback:
|
if loopback:
|
||||||
output_images, info = None, None
|
output_images, info = None, None
|
||||||
|
@ -955,32 +1005,19 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
|
||||||
initial_seed = None
|
initial_seed = None
|
||||||
|
|
||||||
for i in range(n_iter):
|
for i in range(n_iter):
|
||||||
output_images, seed, info = process_images(
|
p.n_iter = 1
|
||||||
outpath=outpath,
|
p.batch_size = 1
|
||||||
func_init=init,
|
p.do_not_save_grid = True
|
||||||
func_sample=sample,
|
|
||||||
prompt=prompt,
|
output_images, seed, info = process_images(p)
|
||||||
seed=seed,
|
|
||||||
sampler_index=sampler_index,
|
|
||||||
batch_size=1,
|
|
||||||
n_iter=1,
|
|
||||||
steps=ddim_steps,
|
|
||||||
cfg_scale=cfg_scale,
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
prompt_matrix=prompt_matrix,
|
|
||||||
use_GFPGAN=use_GFPGAN,
|
|
||||||
do_not_save_grid=True,
|
|
||||||
extra_generation_params={"Denoising Strength": denoising_strength},
|
|
||||||
)
|
|
||||||
|
|
||||||
if initial_seed is None:
|
if initial_seed is None:
|
||||||
initial_seed = seed
|
initial_seed = seed
|
||||||
|
|
||||||
init_img = output_images[0]
|
p.init_img = output_images[0]
|
||||||
seed = seed + 1
|
p.seed = seed + 1
|
||||||
denoising_strength = max(denoising_strength * 0.95, 0.1)
|
p.denoising_strength = max(p.denoising_strength * 0.95, 0.1)
|
||||||
history.append(init_img)
|
history.append(output_images[0])
|
||||||
|
|
||||||
grid_count = len(os.listdir(outpath)) - 1
|
grid_count = len(os.listdir(outpath)) - 1
|
||||||
grid = image_grid(history, batch_size, force_n_rows=1)
|
grid = image_grid(history, batch_size, force_n_rows=1)
|
||||||
|
@ -1000,39 +1037,36 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
|
||||||
|
|
||||||
grid = split_grid(img, tile_w=width, tile_h=height, overlap=opts.sd_upscale_overlap)
|
grid = split_grid(img, tile_w=width, tile_h=height, overlap=opts.sd_upscale_overlap)
|
||||||
|
|
||||||
|
p.n_iter = 1
|
||||||
|
p.do_not_save_grid = True
|
||||||
|
|
||||||
print(f"SD upscaling will process a total of {len(grid.tiles[0][2])}x{len(grid.tiles)} images.")
|
work = []
|
||||||
|
work_results = []
|
||||||
|
|
||||||
for y, h, row in grid.tiles:
|
for y, h, row in grid.tiles:
|
||||||
for tiledata in row:
|
for tiledata in row:
|
||||||
init_img = tiledata[2]
|
work.append(tiledata[2])
|
||||||
|
|
||||||
output_images, seed, info = process_images(
|
batch_count = math.ceil(len(work) / p.batch_size)
|
||||||
outpath=outpath,
|
print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} in a total of {batch_count} batches.")
|
||||||
func_init=init,
|
|
||||||
func_sample=sample,
|
|
||||||
prompt=prompt,
|
|
||||||
seed=seed,
|
|
||||||
sampler_index=sampler_index,
|
|
||||||
batch_size=1, # since process_images can't work with multiple different images we have to do this for now
|
|
||||||
n_iter=1,
|
|
||||||
steps=ddim_steps,
|
|
||||||
cfg_scale=cfg_scale,
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
prompt_matrix=prompt_matrix,
|
|
||||||
use_GFPGAN=use_GFPGAN,
|
|
||||||
do_not_save_grid=True,
|
|
||||||
extra_generation_params={"Denoising Strength": denoising_strength},
|
|
||||||
)
|
|
||||||
|
|
||||||
if initial_seed is None:
|
for i in range(batch_count):
|
||||||
initial_seed = seed
|
p.init_images = work[i*p.batch_size:(i+1)*p.batch_size]
|
||||||
initial_info = info
|
|
||||||
|
|
||||||
seed += 1
|
output_images, seed, info = process_images(p)
|
||||||
|
|
||||||
tiledata[2] = output_images[0]
|
if initial_seed is None:
|
||||||
|
initial_seed = seed
|
||||||
|
initial_info = info
|
||||||
|
|
||||||
|
p.seed = seed + 1
|
||||||
|
work_results += output_images
|
||||||
|
|
||||||
|
image_index = 0
|
||||||
|
for y, h, row in grid.tiles:
|
||||||
|
for tiledata in row:
|
||||||
|
tiledata[2] = work_results[image_index]
|
||||||
|
image_index += 1
|
||||||
|
|
||||||
combined_image = combine_grid(grid)
|
combined_image = combine_grid(grid)
|
||||||
|
|
||||||
|
@ -1044,25 +1078,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
|
||||||
info = initial_info
|
info = initial_info
|
||||||
|
|
||||||
else:
|
else:
|
||||||
output_images, seed, info = process_images(
|
output_images, seed, info = process_images(p)
|
||||||
outpath=outpath,
|
|
||||||
func_init=init,
|
|
||||||
func_sample=sample,
|
|
||||||
prompt=prompt,
|
|
||||||
seed=seed,
|
|
||||||
sampler_index=sampler_index,
|
|
||||||
batch_size=batch_size,
|
|
||||||
n_iter=n_iter,
|
|
||||||
steps=ddim_steps,
|
|
||||||
cfg_scale=cfg_scale,
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
prompt_matrix=prompt_matrix,
|
|
||||||
use_GFPGAN=use_GFPGAN,
|
|
||||||
extra_generation_params={"Denoising Strength": denoising_strength},
|
|
||||||
)
|
|
||||||
|
|
||||||
del sampler
|
|
||||||
|
|
||||||
return output_images, seed, plaintext_to_html(info)
|
return output_images, seed, plaintext_to_html(info)
|
||||||
|
|
||||||
|
@ -1178,22 +1194,19 @@ def run_settings(*args):
|
||||||
|
|
||||||
def create_setting_component(key):
|
def create_setting_component(key):
|
||||||
def fun():
|
def fun():
|
||||||
return opts.data[key] if key in opts.data else opts.data_labels[key][0]
|
return opts.data[key] if key in opts.data else opts.data_labels[key].default
|
||||||
|
|
||||||
labelinfo = opts.data_labels[key]
|
info = opts.data_labels[key]
|
||||||
t = type(labelinfo[0])
|
t = type(info.default)
|
||||||
label = labelinfo[1]
|
|
||||||
if t == str:
|
if info.component is not None:
|
||||||
item = gr.Textbox(label=label, value=fun, lines=1)
|
item = info.component(label=info.label, value=fun, **(info.component_args or {}))
|
||||||
|
elif t == str:
|
||||||
|
item = gr.Textbox(label=info.label, value=fun, lines=1)
|
||||||
elif t == int:
|
elif t == int:
|
||||||
if len(labelinfo) == 5:
|
item = gr.Number(label=info.label, value=fun)
|
||||||
item = gr.Slider(minimum=labelinfo[2], maximum=labelinfo[3], step=labelinfo[4], label=label, value=fun)
|
|
||||||
elif len(labelinfo) == 4:
|
|
||||||
item = gr.Slider(minimum=labelinfo[2], maximum=labelinfo[3], step=1, label=label, value=fun)
|
|
||||||
else:
|
|
||||||
item = gr.Number(label=label, value=fun)
|
|
||||||
elif t == bool:
|
elif t == bool:
|
||||||
item = gr.Checkbox(label=label, value=fun)
|
item = gr.Checkbox(label=info.label, value=fun)
|
||||||
else:
|
else:
|
||||||
raise Exception(f'bad options item type: {str(t)} for key {key}')
|
raise Exception(f'bad options item type: {str(t)} for key {key}')
|
||||||
|
|
||||||
|
@ -1219,14 +1232,14 @@ interfaces = [
|
||||||
(settings_interface, "Settings"),
|
(settings_interface, "Settings"),
|
||||||
]
|
]
|
||||||
|
|
||||||
config = OmegaConf.load(cmd_opts.config)
|
sd_config = OmegaConf.load(cmd_opts.config)
|
||||||
model = load_model_from_config(config, cmd_opts.ckpt)
|
sd_model = load_model_from_config(sd_config, cmd_opts.ckpt)
|
||||||
|
|
||||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
model = (model if cmd_opts.no_half else model.half()).to(device)
|
sd_model = (sd_model if cmd_opts.no_half else sd_model.half()).to(device)
|
||||||
|
|
||||||
model_hijack = StableDiffuionModelHijack()
|
model_hijack = StableDiffuionModelHijack()
|
||||||
model_hijack.hijack(model)
|
model_hijack.hijack(sd_model)
|
||||||
|
|
||||||
demo = gr.TabbedInterface(
|
demo = gr.TabbedInterface(
|
||||||
interface_list=[x[0] for x in interfaces],
|
interface_list=[x[0] for x in interfaces],
|
||||||
|
|
Loading…
Reference in a new issue