fixed all lines PyCharm was nagging me about

fixed input verification not working properly with long textual inversion tokens in some cases (plus it will prevent incorrect outputs for forks that use the :::: prompt weighing method)
changed process_images to object class with same fields as args it was previously accepting
changed options system to make it possible to explicitly specify gradio objects with args
This commit is contained in:
AUTOMATIC 2022-08-27 21:32:28 +03:00
parent 4e0fdca2f4
commit c30aee2f4b

563
webui.py
View file

@ -1,14 +1,13 @@
import argparse, os, sys, glob import argparse
import os
import sys
from collections import namedtuple from collections import namedtuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
import gradio as gr import gradio as gr
from omegaconf import OmegaConf from omegaconf import OmegaConf
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
from itertools import islice
from einops import rearrange, repeat
from torch import autocast from torch import autocast
import mimetypes import mimetypes
import random import random
@ -22,14 +21,13 @@ import k_diffusion.sampling
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.plms import PLMSSampler
import ldm.modules.encoders.modules
try: try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
from transformers import logging from transformers import logging
logging.set_verbosity_error() logging.set_verbosity_error()
except: except Exception:
pass pass
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
@ -41,13 +39,13 @@ opt_C = 4
opt_f = 8 opt_f = 8
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
invalid_filename_chars = '<>:"/\|?*\n' invalid_filename_chars = '<>:"/\\|?*\n'
config_filename = "config.json" config_filename = "config.json"
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="configs/stable-diffusion/v1-inference.yaml", help="path to config which constructs model",) parser.add_argument("--config", type=str, default="configs/stable-diffusion/v1-inference.yaml", help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",) parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",)
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) # i disagree with where you're putting it but since all guidefags are doing it this way, there you go parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)") parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
@ -64,7 +62,7 @@ css_hide_progressbar = """
SamplerData = namedtuple('SamplerData', ['name', 'constructor']) SamplerData = namedtuple('SamplerData', ['name', 'constructor'])
samplers = [ samplers = [
*[SamplerData(x[0], lambda m, funcname=x[1]: KDiffusionSampler(m, funcname)) for x in [ *[SamplerData(x[0], lambda funcname=x[1]: KDiffusionSampler(funcname)) for x in [
('LMS', 'sample_lms'), ('LMS', 'sample_lms'),
('Heun', 'sample_heun'), ('Heun', 'sample_heun'),
('Euler', 'sample_euler'), ('Euler', 'sample_euler'),
@ -72,8 +70,8 @@ samplers = [
('DPM 2', 'sample_dpm_2'), ('DPM 2', 'sample_dpm_2'),
('DPM 2 Ancestral', 'sample_dpm_2_ancestral'), ('DPM 2 Ancestral', 'sample_dpm_2_ancestral'),
] if hasattr(k_diffusion.sampling, x[1])], ] if hasattr(k_diffusion.sampling, x[1])],
SamplerData('DDIM', lambda m: DDIMSampler(model)), SamplerData('DDIM', lambda: VanillaStableDiffusionSampler(DDIMSampler)),
SamplerData('PLMS', lambda m: PLMSSampler(model)), SamplerData('PLMS', lambda: VanillaStableDiffusionSampler(PLMSSampler)),
] ]
samplers_for_img2img = [x for x in samplers if x.name != 'DDIM' and x.name != 'PLMS'] samplers_for_img2img = [x for x in samplers if x.name != 'DDIM' and x.name != 'PLMS']
@ -102,7 +100,7 @@ try:
), ),
] ]
have_realesrgan = True have_realesrgan = True
except: except Exception:
print("Error loading Real-ESRGAN:", file=sys.stderr) print("Error loading Real-ESRGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
@ -111,24 +109,30 @@ except:
class Options: class Options:
class OptionInfo:
def __init__(self, default=None, label="", component=None, component_args=None):
self.default = default
self.label = label
self.component = component
self.component_args = component_args
data = None data = None
data_labels = { data_labels = {
"outdir": ("", "Output dictectory; if empty, defaults to 'outputs/*'"), "outdir": OptionInfo("", "Output dictectory; if empty, defaults to 'outputs/*'"),
"samples_save": (True, "Save indiviual samples"), "samples_save": OptionInfo(True, "Save indiviual samples"),
"samples_format": ('png', 'File format for indiviual samples'), "samples_format": OptionInfo('png', 'File format for indiviual samples'),
"grid_save": (True, "Save image grids"), "grid_save": OptionInfo(True, "Save image grids"),
"grid_format": ('png', 'File format for grids'), "grid_format": OptionInfo('png', 'File format for grids'),
"grid_extended_filename": (False, "Add extended info (seed, prompt) to filename when saving grid"), "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
"n_rows": (-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", -1, 16), "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
"jpeg_quality": (80, "Quality for saved jpeg images", 1, 100), "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
"verify_input": (True, "Check input, and produce warning if it's too long"), "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
"enable_pnginfo": (True, "Save text information about generation parameters as chunks to png files"), "prompt_matrix_add_to_start": OptionInfo(True, "In prompt matrix, add the variable combination of text to the start of the prompt, rather than the end"),
"prompt_matrix_add_to_start": (True, "In prompt matrix, add the variable combination of text to the start of the prompt, rather than the end"), "sd_upscale_overlap": OptionInfo(64, "Overlap for tiles for SD upscale. The smaller it is, the less smooth transition from one tile to another", gr.Slider, {"minimum": 0, "maximum": 256, "step": 16}),
"sd_upscale_overlap": (64, "Overlap for tiles for SD upscale. The smaller it is, the less smooth transition from one tile to another", 0, 256, 16),
} }
def __init__(self): def __init__(self):
self.data = {k: v[0] for k, v in self.data_labels.items()} self.data = {k: v.default for k, v in self.data_labels.items()}
def __setattr__(self, key, value): def __setattr__(self, key, value):
if self.data is not None: if self.data is not None:
@ -143,7 +147,7 @@ class Options:
return self.data[item] return self.data[item]
if item in self.data_labels: if item in self.data_labels:
return self.data_labels[item][0] return self.data_labels[item].default
return super(Options, self).__getattribute__(item) return super(Options, self).__getattribute__(item)
@ -156,11 +160,6 @@ class Options:
self.data = json.load(file) self.data = json.load(file)
def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def load_model_from_config(config, ckpt, verbose=False): def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}") print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu") pl_sd = torch.load(ckpt, map_location="cpu")
@ -181,36 +180,6 @@ def load_model_from_config(config, ckpt, verbose=False):
return model return model
class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
class KDiffusionSampler:
def __init__(self, m, funcname):
self.model = m
self.model_wrap = k_diffusion.external.CompVisDenoiser(m)
self.funcname = funcname
self.func = getattr(k_diffusion.sampling, self.funcname)
def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T):
sigmas = self.model_wrap.get_sigmas(S)
x = x_T * sigmas[0]
model_wrap_cfg = CFGDenoiser(self.model_wrap)
samples_ddim = self.func(model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False)
return samples_ddim, None
def create_random_tensors(shape, seeds): def create_random_tensors(shape, seeds):
xs = [] xs = []
for seed in seeds: for seed in seeds:
@ -256,7 +225,7 @@ def plaintext_to_html(text):
return text return text
def load_GFPGAN(): def load_gfpgan():
model_name = 'GFPGANv1.3' model_name = 'GFPGANv1.3'
model_path = os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models', model_name + '.pth') model_path = os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models', model_name + '.pth')
if not os.path.isfile(model_path): if not os.path.isfile(model_path):
@ -358,7 +327,7 @@ def combine_grid(grid):
def draw_prompt_matrix(im, width, height, all_prompts): def draw_prompt_matrix(im, width, height, all_prompts):
def wrap(text, d, font, line_length): def wrap(text, font, line_length):
lines = [''] lines = ['']
for word in text.split(): for word in text.split():
line = f'{lines[-1]} {word}'.strip() line = f'{lines[-1]} {word}'.strip()
@ -368,16 +337,16 @@ def draw_prompt_matrix(im, width, height, all_prompts):
lines.append(word) lines.append(word)
return '\n'.join(lines) return '\n'.join(lines)
def draw_texts(pos, x, y, texts, sizes): def draw_texts(pos, draw_x, draw_y, texts, sizes):
for i, (text, size) in enumerate(zip(texts, sizes)): for i, (text, size) in enumerate(zip(texts, sizes)):
active = pos & (1 << i) != 0 active = pos & (1 << i) != 0
if not active: if not active:
text = '\u0336'.join(text) + '\u0336' text = '\u0336'.join(text) + '\u0336'
d.multiline_text((x, y + size[1] / 2), text, font=fnt, fill=color_active if active else color_inactive, anchor="mm", align="center") d.multiline_text((draw_x, draw_y + size[1] / 2), text, font=fnt, fill=color_active if active else color_inactive, anchor="mm", align="center")
y += size[1] + line_spacing draw_y += size[1] + line_spacing
fontsize = (width + height) // 25 fontsize = (width + height) // 25
line_spacing = fontsize // 2 line_spacing = fontsize // 2
@ -399,8 +368,8 @@ def draw_prompt_matrix(im, width, height, all_prompts):
d = ImageDraw.Draw(result) d = ImageDraw.Draw(result)
boundary = math.ceil(len(prompts) / 2) boundary = math.ceil(len(prompts) / 2)
prompts_horiz = [wrap(x, d, fnt, width) for x in prompts[:boundary]] prompts_horiz = [wrap(x, fnt, width) for x in prompts[:boundary]]
prompts_vert = [wrap(x, d, fnt, pad_left) for x in prompts[boundary:]] prompts_vert = [wrap(x, fnt, pad_left) for x in prompts[boundary:]]
sizes_hor = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_horiz]] sizes_hor = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_horiz]]
sizes_ver = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_vert]] sizes_ver = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_vert]]
@ -458,25 +427,6 @@ def resize_image(resize_mode, im, width, height):
return res return res
def check_prompt_length(prompt, comments):
"""this function tests if prompt is too long, and if so, adds a message to comments"""
tokenizer = model.cond_stage_model.tokenizer
max_length = model.cond_stage_model.max_length
info = model.cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length, return_overflowing_tokens=True, padding="max_length", return_tensors="pt")
ovf = info['overflowing_tokens'][0]
overflowing_count = ovf.shape[0]
if overflowing_count == 0:
return
vocab = {v: k for k, v in tokenizer.get_vocab().items()}
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = tokenizer.convert_tokens_to_string(''.join(overflowing_words))
comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
def wrap_gradio_call(func): def wrap_gradio_call(func):
def f(*p1, **p2): def f(*p1, **p2):
t = time.perf_counter() t = time.perf_counter()
@ -494,7 +444,7 @@ def wrap_gradio_call(func):
GFPGAN = None GFPGAN = None
if os.path.exists(cmd_opts.gfpgan_dir): if os.path.exists(cmd_opts.gfpgan_dir):
try: try:
GFPGAN = load_GFPGAN() GFPGAN = load_gfpgan()
print("Loaded GFPGAN") print("Loaded GFPGAN")
except Exception: except Exception:
print("Error loading GFPGAN:", file=sys.stderr) print("Error loading GFPGAN:", file=sys.stderr)
@ -506,11 +456,11 @@ class StableDiffuionModelHijack:
word_embeddings = {} word_embeddings = {}
word_embeddings_checksums = {} word_embeddings_checksums = {}
fixes = None fixes = None
used_custom_terms = [] comments = None
dir_mtime = None dir_mtime = None
def load_textual_inversion_embeddings(self, dir, model): def load_textual_inversion_embeddings(self, dirname, model):
mt = os.path.getmtime(dir) mt = os.path.getmtime(dirname)
if self.dir_mtime is not None and mt <= self.dir_mtime: if self.dir_mtime is not None and mt <= self.dir_mtime:
return return
@ -543,10 +493,10 @@ class StableDiffuionModelHijack:
self.ids_lookup[first_id] = [] self.ids_lookup[first_id] = []
self.ids_lookup[first_id].append((ids, name)) self.ids_lookup[first_id].append((ids, name))
for fn in os.listdir(dir): for fn in os.listdir(dirname):
try: try:
process_file(os.path.join(dir, fn), fn) process_file(os.path.join(dirname, fn), fn)
except: except Exception:
print(f"Error loading emedding {fn}:", file=sys.stderr) print(f"Error loading emedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
continue continue
@ -561,10 +511,10 @@ class StableDiffuionModelHijack:
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, embeddings): def __init__(self, wrapped, hijack):
super().__init__() super().__init__()
self.wrapped = wrapped self.wrapped = wrapped
self.embeddings = embeddings self.hijack = hijack
self.tokenizer = wrapped.tokenizer self.tokenizer = wrapped.tokenizer
self.max_length = wrapped.max_length self.max_length = wrapped.max_length
self.token_mults = {} self.token_mults = {}
@ -586,12 +536,13 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.token_mults[ident] = mult self.token_mults[ident] = mult
def forward(self, text): def forward(self, text):
self.embeddings.fixes = [] self.hijack.fixes = []
self.embeddings.used_custom_terms = [] self.hijack.comments = []
remade_batch_tokens = [] remade_batch_tokens = []
id_start = self.wrapped.tokenizer.bos_token_id id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id id_end = self.wrapped.tokenizer.eos_token_id
maxlen = self.wrapped.max_length - 2 maxlen = self.wrapped.max_length - 2
used_custom_terms = []
cache = {} cache = {}
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"] batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
@ -611,7 +562,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens): while i < len(tokens):
token = tokens[i] token = tokens[i]
possible_matches = self.embeddings.ids_lookup.get(token, None) possible_matches = self.hijack.ids_lookup.get(token, None)
mult_change = self.token_mults.get(token) mult_change = self.token_mults.get(token)
if mult_change is not None: if mult_change is not None:
@ -628,7 +579,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
multipliers.append(mult) multipliers.append(mult)
i += len(ids) - 1 i += len(ids) - 1
found = True found = True
self.embeddings.used_custom_terms.append((word, self.embeddings.word_embeddings_checksums[word])) used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
break break
if not found: if not found:
@ -637,6 +588,14 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
i += 1 i += 1
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
ovf = remade_tokens[maxlen - 2:]
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
self.hijack.comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers) cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
@ -645,9 +604,12 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
remade_batch_tokens.append(remade_tokens) remade_batch_tokens.append(remade_tokens)
self.embeddings.fixes.append(fixes) self.hijack.fixes.append(fixes)
batch_multipliers.append(multipliers) batch_multipliers.append(multipliers)
if len(used_custom_terms) > 0:
self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
tokens = torch.asarray(remade_batch_tokens).to(self.wrapped.device) tokens = torch.asarray(remade_batch_tokens).to(self.wrapped.device)
outputs = self.wrapped.transformer(input_ids=tokens) outputs = self.wrapped.transformer(input_ids=tokens)
z = outputs.last_hidden_state z = outputs.last_hidden_state
@ -679,71 +641,123 @@ class EmbeddingsWithFixes(nn.Module):
for offset, word in fixes: for offset, word in fixes:
tensor[offset] = self.embeddings.word_embeddings[word] tensor[offset] = self.embeddings.word_embeddings[word]
return inputs_embeds return inputs_embeds
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False, extra_generation_params=None): class StableDiffusionProcessing:
def __init__(self, outpath=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, prompt_matrix=False, use_GFPGAN=False, do_not_save_grid=False, extra_generation_params=None):
self.outpath: str = outpath
self.prompt: str = prompt
self.seed: int = seed
self.sampler_index: int = sampler_index
self.batch_size: int = batch_size
self.n_iter: int = n_iter
self.steps: int = steps
self.cfg_scale: float = cfg_scale
self.width: int = width
self.height: int = height
self.prompt_matrix: bool = prompt_matrix
self.use_GFPGAN: bool = use_GFPGAN
self.do_not_save_grid: bool = do_not_save_grid
self.extra_generation_params: dict = extra_generation_params
def init(self):
pass
def sample(self, x, conditioning, unconditional_conditioning):
raise NotImplementedError()
class VanillaStableDiffusionSampler:
def __init__(self, constructor):
self.sampler = constructor(sd_model)
def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning):
samples_ddim, _ = self.sampler.sample(S=p.steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x)
return samples_ddim
class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
class KDiffusionSampler:
def __init__(self, funcname):
self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model)
self.funcname = funcname
self.func = getattr(k_diffusion.sampling, self.funcname)
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning):
sigmas = self.model_wrap.get_sigmas(p.steps)
x = x * sigmas[0]
samples_ddim = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False)
return samples_ddim
def process_images(p: StableDiffusionProcessing):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
assert prompt is not None prompt = p.prompt
model = sd_model
assert p.prompt is not None
torch_gc() torch_gc()
if seed == -1: seed = int(random.randrange(4294967294) if p.seed == -1 else p.seed)
seed = random.randrange(4294967294)
seed = int(seed)
os.makedirs(outpath, exist_ok=True) os.makedirs(p.outpath, exist_ok=True)
sample_path = os.path.join(outpath, "samples") sample_path = os.path.join(p.outpath, "samples")
os.makedirs(sample_path, exist_ok=True) os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path)) base_count = len(os.listdir(sample_path))
grid_count = len(os.listdir(outpath)) - 1 grid_count = len(os.listdir(p.outpath)) - 1
comments = [] comments = []
prompt_matrix_parts = [] prompt_matrix_parts = []
if prompt_matrix: if p.prompt_matrix:
all_prompts = [] all_prompts = []
prompt_matrix_parts = prompt.split("|") prompt_matrix_parts = prompt.split("|")
combination_count = 2 ** (len(prompt_matrix_parts) - 1) combination_count = 2 ** (len(prompt_matrix_parts) - 1)
for combination_num in range(combination_count): for combination_num in range(combination_count):
selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1<<n)] selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1 << n)]
if opts.prompt_matrix_add_to_start: if opts.prompt_matrix_add_to_start:
selected_prompts = selected_prompts + [prompt_matrix_parts[0]] selected_prompts = selected_prompts + [prompt_matrix_parts[0]]
else: else:
selected_prompts = [prompt_matrix_parts[0]] + selected_prompts selected_prompts = [prompt_matrix_parts[0]] + selected_prompts
all_prompts.append( ", ".join(selected_prompts)) all_prompts.append(", ".join(selected_prompts))
n_iter = math.ceil(len(all_prompts) / batch_size) p.n_iter = math.ceil(len(all_prompts) / p.batch_size)
all_seeds = len(all_prompts) * [seed] all_seeds = len(all_prompts) * [seed]
print(f"Prompt matrix will create {len(all_prompts)} images using a total of {n_iter} batches.") print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.")
else: else:
all_prompts = p.batch_size * p.n_iter * [prompt]
if opts.verify_input:
try:
check_prompt_length(prompt, comments)
except:
import traceback
print("Error verifying input:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
all_prompts = batch_size * n_iter * [prompt]
all_seeds = [seed + x for x in range(len(all_prompts))] all_seeds = [seed + x for x in range(len(all_prompts))]
generation_params = { generation_params = {
"Steps": steps, "Steps": p.steps,
"Sampler": samplers[sampler_index].name, "Sampler": samplers[p.sampler_index].name,
"CFG scale": cfg_scale, "CFG scale": p.cfg_scale,
"Seed": seed, "Seed": seed,
"GFPGAN": ("GFPGAN" if use_GFPGAN and GFPGAN is not None else None) "GFPGAN": ("GFPGAN" if p.use_GFPGAN and GFPGAN is not None else None)
} }
if extra_generation_params is not None: if p.extra_generation_params is not None:
generation_params.update(extra_generation_params) generation_params.update(p.extra_generation_params)
generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None]) generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
@ -755,32 +769,32 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index,
output_images = [] output_images = []
with torch.no_grad(), autocast("cuda"), model.ema_scope(): with torch.no_grad(), autocast("cuda"), model.ema_scope():
init_data = func_init() p.init()
for n in range(n_iter): for n in range(p.n_iter):
prompts = all_prompts[n * batch_size:(n + 1) * batch_size] prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
seeds = all_seeds[n * batch_size:(n + 1) * batch_size] seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
uc = model.get_learned_conditioning(len(prompts) * [""]) uc = model.get_learned_conditioning(len(prompts) * [""])
c = model.get_learned_conditioning(prompts) c = model.get_learned_conditioning(prompts)
if len(model_hijack.used_custom_terms) > 0: if len(model_hijack.comments) > 0:
comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in model_hijack.used_custom_terms])) comments += model_hijack.comments
# we manually generate all input noises because each one should have a specific seed # we manually generate all input noises because each one should have a specific seed
x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds) x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds)
samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc) samples_ddim = p.sample(x=x, conditioning=c, unconditional_conditioning=uc)
x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
if prompt_matrix or opts.samples_save or opts.grid_save: if p.prompt_matrix or opts.samples_save or opts.grid_save:
for i, x_sample in enumerate(x_samples_ddim): for i, x_sample in enumerate(x_samples_ddim):
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
if use_GFPGAN and GFPGAN is not None: if p.use_GFPGAN and GFPGAN is not None:
torch_gc() torch_gc()
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample, has_aligned=False, only_center_face=False, paste_back=True) cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample, has_aligned=False, only_center_face=False, paste_back=True)
x_sample = restored_img x_sample = restored_img
@ -791,44 +805,44 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index,
output_images.append(image) output_images.append(image)
base_count += 1 base_count += 1
if (prompt_matrix or opts.grid_save) and not do_not_save_grid: if (p.prompt_matrix or opts.grid_save) and not p.do_not_save_grid:
if prompt_matrix: if p.prompt_matrix:
grid = image_grid(output_images, batch_size, force_n_rows=1 << ((len(prompt_matrix_parts)-1)//2)) grid = image_grid(output_images, p.batch_size, force_n_rows=1 << ((len(prompt_matrix_parts)-1)//2))
try: try:
grid = draw_prompt_matrix(grid, width, height, prompt_matrix_parts) grid = draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts)
except: except Exception:
import traceback import traceback
print("Error creating prompt_matrix text:", file=sys.stderr) print("Error creating prompt_matrix text:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
output_images.insert(0, grid) output_images.insert(0, grid)
else: else:
grid = image_grid(output_images, batch_size) grid = image_grid(output_images, p.batch_size)
save_image(grid, outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename) save_image(grid, p.outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
grid_count += 1 grid_count += 1
torch_gc() torch_gc()
return output_images, seed, infotext() return output_images, seed, infotext()
def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int): class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
outpath = opts.outdir or "outputs/txt2img-samples" sampler = None
sampler = samplers[sampler_index].constructor(model) def init(self):
self.sampler = samplers[self.sampler_index].constructor()
def init(): def sample(self, x, conditioning, unconditional_conditioning):
pass samples_ddim = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
def sample(init_data, x, conditioning, unconditional_conditioning):
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x)
return samples_ddim return samples_ddim
output_images, seed, info = process_images(
def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int):
outpath = opts.outdir or "outputs/txt2img-samples"
p = StableDiffusionProcessingTxt2Img(
outpath=outpath, outpath=outpath,
func_init=init,
func_sample=sample,
prompt=prompt, prompt=prompt,
seed=seed, seed=seed,
sampler_index=sampler_index, sampler_index=sampler_index,
@ -842,7 +856,7 @@ def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool,
use_GFPGAN=use_GFPGAN use_GFPGAN=use_GFPGAN
) )
del sampler output_images, seed, info = process_images(p)
return output_images, seed, plaintext_to_html(info) return output_images, seed, plaintext_to_html(info)
@ -858,7 +872,7 @@ class Flagging(gr.FlaggingCallback):
os.makedirs("log/images", exist_ok=True) os.makedirs("log/images", exist_ok=True)
# those must match the "txt2img" function # those must match the "txt2img" function
prompt, ddim_steps, sampler_name, use_GFPGAN, prompt_matrix, ddim_eta, n_iter, n_samples, cfg_scale, request_seed, height, width, images, seed, comment = flag_data prompt, ddim_steps, sampler_name, use_gfpgan, prompt_matrix, ddim_eta, n_iter, n_samples, cfg_scale, request_seed, height, width, images, seed, comment = flag_data
filenames = [] filenames = []
@ -896,7 +910,6 @@ txt2img_interface = gr.Interface(
gr.Radio(label='Sampling method', choices=[x.name for x in samplers], value=samplers[0].name, type="index"), gr.Radio(label='Sampling method', choices=[x.name for x in samplers], value=samplers[0].name, type="index"),
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None), gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False), gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1), gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1),
gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1), gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0), gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
@ -914,40 +927,77 @@ txt2img_interface = gr.Interface(
) )
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None
def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, **kwargs):
super().__init__(**kwargs)
self.init_images = init_images
self.resize_mode: int = resize_mode
self.denoising_strength: float = denoising_strength
self.init_latent = None
def init(self):
self.sampler = samplers_for_img2img[self.sampler_index].constructor()
imgs = []
for img in self.init_images:
image = img.convert("RGB")
image = resize_image(self.resize_mode, image, self.width, self.height)
image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
imgs.append(image)
if len(imgs) == 1:
batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
elif len(imgs) <= self.batch_size:
self.batch_size = len(imgs)
batch_images = np.array(imgs)
else:
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
image = torch.from_numpy(batch_images)
image = 2. * image - 1.
image = image.to(device)
self.init_latent = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image))
def sample(self, x, conditioning, unconditional_conditioning):
t_enc = int(self.denoising_strength * self.steps)
sigmas = self.sampler.model_wrap.get_sigmas(self.steps)
noise = x * sigmas[self.steps - t_enc - 1]
xi = self.init_latent + noise
sigma_sched = sigmas[self.steps - t_enc - 1:]
samples_ddim = self.sampler.func(self.sampler.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': self.cfg_scale}, disable=False)
return samples_ddim
def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, sd_upscale: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int): def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, sd_upscale: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
outpath = opts.outdir or "outputs/img2img-samples" outpath = opts.outdir or "outputs/img2img-samples"
sampler = samplers_for_img2img[sampler_index].constructor(model)
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
def init(): p = StableDiffusionProcessingImg2Img(
image = init_img.convert("RGB") outpath=outpath,
image = resize_image(resize_mode, image, width, height) prompt=prompt,
image = np.array(image).astype(np.float32) / 255.0 seed=seed,
image = image[None].transpose(0, 3, 1, 2) sampler_index=sampler_index,
image = torch.from_numpy(image) batch_size=batch_size,
n_iter=n_iter,
init_image = 2. * image - 1. steps=ddim_steps,
init_image = init_image.to(device) cfg_scale=cfg_scale,
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) width=width,
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space height=height,
prompt_matrix=prompt_matrix,
return init_latent, use_GFPGAN=use_GFPGAN,
init_images=[init_img],
def sample(init_data, x, conditioning, unconditional_conditioning): resize_mode=resize_mode,
t_enc = int(denoising_strength * ddim_steps) denoising_strength=denoising_strength,
extra_generation_params={"Denoising Strength": denoising_strength}
x0, = init_data )
sigmas = sampler.model_wrap.get_sigmas(ddim_steps)
noise = x * sigmas[ddim_steps - t_enc - 1]
xi = x0 + noise
sigma_sched = sigmas[ddim_steps - t_enc - 1:]
model_wrap_cfg = CFGDenoiser(sampler.model_wrap)
samples_ddim = sampler.func(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False)
return samples_ddim
if loopback: if loopback:
output_images, info = None, None output_images, info = None, None
@ -955,32 +1005,19 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
initial_seed = None initial_seed = None
for i in range(n_iter): for i in range(n_iter):
output_images, seed, info = process_images( p.n_iter = 1
outpath=outpath, p.batch_size = 1
func_init=init, p.do_not_save_grid = True
func_sample=sample,
prompt=prompt, output_images, seed, info = process_images(p)
seed=seed,
sampler_index=sampler_index,
batch_size=1,
n_iter=1,
steps=ddim_steps,
cfg_scale=cfg_scale,
width=width,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN,
do_not_save_grid=True,
extra_generation_params={"Denoising Strength": denoising_strength},
)
if initial_seed is None: if initial_seed is None:
initial_seed = seed initial_seed = seed
init_img = output_images[0] p.init_img = output_images[0]
seed = seed + 1 p.seed = seed + 1
denoising_strength = max(denoising_strength * 0.95, 0.1) p.denoising_strength = max(p.denoising_strength * 0.95, 0.1)
history.append(init_img) history.append(output_images[0])
grid_count = len(os.listdir(outpath)) - 1 grid_count = len(os.listdir(outpath)) - 1
grid = image_grid(history, batch_size, force_n_rows=1) grid = image_grid(history, batch_size, force_n_rows=1)
@ -1000,39 +1037,36 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
grid = split_grid(img, tile_w=width, tile_h=height, overlap=opts.sd_upscale_overlap) grid = split_grid(img, tile_w=width, tile_h=height, overlap=opts.sd_upscale_overlap)
p.n_iter = 1
p.do_not_save_grid = True
print(f"SD upscaling will process a total of {len(grid.tiles[0][2])}x{len(grid.tiles)} images.") work = []
work_results = []
for y, h, row in grid.tiles: for y, h, row in grid.tiles:
for tiledata in row: for tiledata in row:
init_img = tiledata[2] work.append(tiledata[2])
output_images, seed, info = process_images( batch_count = math.ceil(len(work) / p.batch_size)
outpath=outpath, print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} in a total of {batch_count} batches.")
func_init=init,
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_index=sampler_index,
batch_size=1, # since process_images can't work with multiple different images we have to do this for now
n_iter=1,
steps=ddim_steps,
cfg_scale=cfg_scale,
width=width,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN,
do_not_save_grid=True,
extra_generation_params={"Denoising Strength": denoising_strength},
)
if initial_seed is None: for i in range(batch_count):
initial_seed = seed p.init_images = work[i*p.batch_size:(i+1)*p.batch_size]
initial_info = info
seed += 1 output_images, seed, info = process_images(p)
tiledata[2] = output_images[0] if initial_seed is None:
initial_seed = seed
initial_info = info
p.seed = seed + 1
work_results += output_images
image_index = 0
for y, h, row in grid.tiles:
for tiledata in row:
tiledata[2] = work_results[image_index]
image_index += 1
combined_image = combine_grid(grid) combined_image = combine_grid(grid)
@ -1044,25 +1078,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
info = initial_info info = initial_info
else: else:
output_images, seed, info = process_images( output_images, seed, info = process_images(p)
outpath=outpath,
func_init=init,
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_index=sampler_index,
batch_size=batch_size,
n_iter=n_iter,
steps=ddim_steps,
cfg_scale=cfg_scale,
width=width,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN,
extra_generation_params={"Denoising Strength": denoising_strength},
)
del sampler
return output_images, seed, plaintext_to_html(info) return output_images, seed, plaintext_to_html(info)
@ -1178,22 +1194,19 @@ def run_settings(*args):
def create_setting_component(key): def create_setting_component(key):
def fun(): def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key][0] return opts.data[key] if key in opts.data else opts.data_labels[key].default
labelinfo = opts.data_labels[key] info = opts.data_labels[key]
t = type(labelinfo[0]) t = type(info.default)
label = labelinfo[1]
if t == str: if info.component is not None:
item = gr.Textbox(label=label, value=fun, lines=1) item = info.component(label=info.label, value=fun, **(info.component_args or {}))
elif t == str:
item = gr.Textbox(label=info.label, value=fun, lines=1)
elif t == int: elif t == int:
if len(labelinfo) == 5: item = gr.Number(label=info.label, value=fun)
item = gr.Slider(minimum=labelinfo[2], maximum=labelinfo[3], step=labelinfo[4], label=label, value=fun)
elif len(labelinfo) == 4:
item = gr.Slider(minimum=labelinfo[2], maximum=labelinfo[3], step=1, label=label, value=fun)
else:
item = gr.Number(label=label, value=fun)
elif t == bool: elif t == bool:
item = gr.Checkbox(label=label, value=fun) item = gr.Checkbox(label=info.label, value=fun)
else: else:
raise Exception(f'bad options item type: {str(t)} for key {key}') raise Exception(f'bad options item type: {str(t)} for key {key}')
@ -1219,14 +1232,14 @@ interfaces = [
(settings_interface, "Settings"), (settings_interface, "Settings"),
] ]
config = OmegaConf.load(cmd_opts.config) sd_config = OmegaConf.load(cmd_opts.config)
model = load_model_from_config(config, cmd_opts.ckpt) sd_model = load_model_from_config(sd_config, cmd_opts.ckpt)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = (model if cmd_opts.no_half else model.half()).to(device) sd_model = (sd_model if cmd_opts.no_half else sd_model.half()).to(device)
model_hijack = StableDiffuionModelHijack() model_hijack = StableDiffuionModelHijack()
model_hijack.hijack(model) model_hijack.hijack(sd_model)
demo = gr.TabbedInterface( demo = gr.TabbedInterface(
interface_list=[x[0] for x in interfaces], interface_list=[x[0] for x in interfaces],