diff --git a/webui.py b/webui.py index 13e5112a..8de1bcf2 100644 --- a/webui.py +++ b/webui.py @@ -1,14 +1,13 @@ -import argparse, os, sys, glob +import argparse +import os +import sys from collections import namedtuple - import torch import torch.nn as nn import numpy as np import gradio as gr from omegaconf import OmegaConf from PIL import Image, ImageFont, ImageDraw, PngImagePlugin -from itertools import islice -from einops import rearrange, repeat from torch import autocast import mimetypes import random @@ -22,14 +21,13 @@ import k_diffusion.sampling from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler -import ldm.modules.encoders.modules try: # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. from transformers import logging logging.set_verbosity_error() -except: +except Exception: pass # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI @@ -41,13 +39,13 @@ opt_C = 4 opt_f = 8 LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) -invalid_filename_chars = '<>:"/\|?*\n' +invalid_filename_chars = '<>:"/\\|?*\n' config_filename = "config.json" parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default="configs/stable-diffusion/v1-inference.yaml", help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",) -parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) # i disagree with where you're putting it but since all guidefags are doing it this way, there you go +parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") @@ -64,7 +62,7 @@ css_hide_progressbar = """ SamplerData = namedtuple('SamplerData', ['name', 'constructor']) samplers = [ - *[SamplerData(x[0], lambda m, funcname=x[1]: KDiffusionSampler(m, funcname)) for x in [ + *[SamplerData(x[0], lambda funcname=x[1]: KDiffusionSampler(funcname)) for x in [ ('LMS', 'sample_lms'), ('Heun', 'sample_heun'), ('Euler', 'sample_euler'), @@ -72,8 +70,8 @@ samplers = [ ('DPM 2', 'sample_dpm_2'), ('DPM 2 Ancestral', 'sample_dpm_2_ancestral'), ] if hasattr(k_diffusion.sampling, x[1])], - SamplerData('DDIM', lambda m: DDIMSampler(model)), - SamplerData('PLMS', lambda m: PLMSSampler(model)), + SamplerData('DDIM', lambda: VanillaStableDiffusionSampler(DDIMSampler)), + SamplerData('PLMS', lambda: VanillaStableDiffusionSampler(PLMSSampler)), ] samplers_for_img2img = [x for x in samplers if x.name != 'DDIM' and x.name != 'PLMS'] @@ -102,7 +100,7 @@ try: ), ] have_realesrgan = True -except: +except Exception: print("Error loading Real-ESRGAN:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) @@ -111,24 +109,30 @@ except: class Options: + class OptionInfo: + def __init__(self, default=None, label="", component=None, component_args=None): + self.default = default + self.label = label + self.component = component + self.component_args = component_args + data = None data_labels = { - "outdir": ("", "Output dictectory; if empty, defaults to 'outputs/*'"), - "samples_save": (True, "Save indiviual samples"), - "samples_format": ('png', 'File format for indiviual samples'), - "grid_save": (True, "Save image grids"), - "grid_format": ('png', 'File format for grids'), - "grid_extended_filename": (False, "Add extended info (seed, prompt) to filename when saving grid"), - "n_rows": (-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", -1, 16), - "jpeg_quality": (80, "Quality for saved jpeg images", 1, 100), - "verify_input": (True, "Check input, and produce warning if it's too long"), - "enable_pnginfo": (True, "Save text information about generation parameters as chunks to png files"), - "prompt_matrix_add_to_start": (True, "In prompt matrix, add the variable combination of text to the start of the prompt, rather than the end"), - "sd_upscale_overlap": (64, "Overlap for tiles for SD upscale. The smaller it is, the less smooth transition from one tile to another", 0, 256, 16), + "outdir": OptionInfo("", "Output dictectory; if empty, defaults to 'outputs/*'"), + "samples_save": OptionInfo(True, "Save indiviual samples"), + "samples_format": OptionInfo('png', 'File format for indiviual samples'), + "grid_save": OptionInfo(True, "Save image grids"), + "grid_format": OptionInfo('png', 'File format for grids'), + "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"), + "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}), + "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}), + "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"), + "prompt_matrix_add_to_start": OptionInfo(True, "In prompt matrix, add the variable combination of text to the start of the prompt, rather than the end"), + "sd_upscale_overlap": OptionInfo(64, "Overlap for tiles for SD upscale. The smaller it is, the less smooth transition from one tile to another", gr.Slider, {"minimum": 0, "maximum": 256, "step": 16}), } def __init__(self): - self.data = {k: v[0] for k, v in self.data_labels.items()} + self.data = {k: v.default for k, v in self.data_labels.items()} def __setattr__(self, key, value): if self.data is not None: @@ -143,7 +147,7 @@ class Options: return self.data[item] if item in self.data_labels: - return self.data_labels[item][0] + return self.data_labels[item].default return super(Options, self).__getattribute__(item) @@ -156,11 +160,6 @@ class Options: self.data = json.load(file) -def chunk(it, size): - it = iter(it) - return iter(lambda: tuple(islice(it, size)), ()) - - def load_model_from_config(config, ckpt, verbose=False): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") @@ -181,36 +180,6 @@ def load_model_from_config(config, ckpt, verbose=False): return model -class CFGDenoiser(nn.Module): - def __init__(self, model): - super().__init__() - self.inner_model = model - - def forward(self, x, sigma, uncond, cond, cond_scale): - x_in = torch.cat([x] * 2) - sigma_in = torch.cat([sigma] * 2) - cond_in = torch.cat([uncond, cond]) - uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) - return uncond + (cond - uncond) * cond_scale - - -class KDiffusionSampler: - def __init__(self, m, funcname): - self.model = m - self.model_wrap = k_diffusion.external.CompVisDenoiser(m) - self.funcname = funcname - self.func = getattr(k_diffusion.sampling, self.funcname) - - def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T): - sigmas = self.model_wrap.get_sigmas(S) - x = x_T * sigmas[0] - model_wrap_cfg = CFGDenoiser(self.model_wrap) - - samples_ddim = self.func(model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False) - - return samples_ddim, None - - def create_random_tensors(shape, seeds): xs = [] for seed in seeds: @@ -256,7 +225,7 @@ def plaintext_to_html(text): return text -def load_GFPGAN(): +def load_gfpgan(): model_name = 'GFPGANv1.3' model_path = os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models', model_name + '.pth') if not os.path.isfile(model_path): @@ -358,7 +327,7 @@ def combine_grid(grid): def draw_prompt_matrix(im, width, height, all_prompts): - def wrap(text, d, font, line_length): + def wrap(text, font, line_length): lines = [''] for word in text.split(): line = f'{lines[-1]} {word}'.strip() @@ -368,16 +337,16 @@ def draw_prompt_matrix(im, width, height, all_prompts): lines.append(word) return '\n'.join(lines) - def draw_texts(pos, x, y, texts, sizes): + def draw_texts(pos, draw_x, draw_y, texts, sizes): for i, (text, size) in enumerate(zip(texts, sizes)): active = pos & (1 << i) != 0 if not active: text = '\u0336'.join(text) + '\u0336' - d.multiline_text((x, y + size[1] / 2), text, font=fnt, fill=color_active if active else color_inactive, anchor="mm", align="center") + d.multiline_text((draw_x, draw_y + size[1] / 2), text, font=fnt, fill=color_active if active else color_inactive, anchor="mm", align="center") - y += size[1] + line_spacing + draw_y += size[1] + line_spacing fontsize = (width + height) // 25 line_spacing = fontsize // 2 @@ -399,8 +368,8 @@ def draw_prompt_matrix(im, width, height, all_prompts): d = ImageDraw.Draw(result) boundary = math.ceil(len(prompts) / 2) - prompts_horiz = [wrap(x, d, fnt, width) for x in prompts[:boundary]] - prompts_vert = [wrap(x, d, fnt, pad_left) for x in prompts[boundary:]] + prompts_horiz = [wrap(x, fnt, width) for x in prompts[:boundary]] + prompts_vert = [wrap(x, fnt, pad_left) for x in prompts[boundary:]] sizes_hor = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_horiz]] sizes_ver = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_vert]] @@ -458,25 +427,6 @@ def resize_image(resize_mode, im, width, height): return res -def check_prompt_length(prompt, comments): - """this function tests if prompt is too long, and if so, adds a message to comments""" - - tokenizer = model.cond_stage_model.tokenizer - max_length = model.cond_stage_model.max_length - - info = model.cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length, return_overflowing_tokens=True, padding="max_length", return_tensors="pt") - ovf = info['overflowing_tokens'][0] - overflowing_count = ovf.shape[0] - if overflowing_count == 0: - return - - vocab = {v: k for k, v in tokenizer.get_vocab().items()} - overflowing_words = [vocab.get(int(x), "") for x in ovf] - overflowing_text = tokenizer.convert_tokens_to_string(''.join(overflowing_words)) - - comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") - - def wrap_gradio_call(func): def f(*p1, **p2): t = time.perf_counter() @@ -494,7 +444,7 @@ def wrap_gradio_call(func): GFPGAN = None if os.path.exists(cmd_opts.gfpgan_dir): try: - GFPGAN = load_GFPGAN() + GFPGAN = load_gfpgan() print("Loaded GFPGAN") except Exception: print("Error loading GFPGAN:", file=sys.stderr) @@ -506,11 +456,11 @@ class StableDiffuionModelHijack: word_embeddings = {} word_embeddings_checksums = {} fixes = None - used_custom_terms = [] + comments = None dir_mtime = None - def load_textual_inversion_embeddings(self, dir, model): - mt = os.path.getmtime(dir) + def load_textual_inversion_embeddings(self, dirname, model): + mt = os.path.getmtime(dirname) if self.dir_mtime is not None and mt <= self.dir_mtime: return @@ -543,10 +493,10 @@ class StableDiffuionModelHijack: self.ids_lookup[first_id] = [] self.ids_lookup[first_id].append((ids, name)) - for fn in os.listdir(dir): + for fn in os.listdir(dirname): try: - process_file(os.path.join(dir, fn), fn) - except: + process_file(os.path.join(dirname, fn), fn) + except Exception: print(f"Error loading emedding {fn}:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) continue @@ -561,10 +511,10 @@ class StableDiffuionModelHijack: class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): - def __init__(self, wrapped, embeddings): + def __init__(self, wrapped, hijack): super().__init__() self.wrapped = wrapped - self.embeddings = embeddings + self.hijack = hijack self.tokenizer = wrapped.tokenizer self.max_length = wrapped.max_length self.token_mults = {} @@ -586,12 +536,13 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.token_mults[ident] = mult def forward(self, text): - self.embeddings.fixes = [] - self.embeddings.used_custom_terms = [] + self.hijack.fixes = [] + self.hijack.comments = [] remade_batch_tokens = [] id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id maxlen = self.wrapped.max_length - 2 + used_custom_terms = [] cache = {} batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"] @@ -611,7 +562,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while i < len(tokens): token = tokens[i] - possible_matches = self.embeddings.ids_lookup.get(token, None) + possible_matches = self.hijack.ids_lookup.get(token, None) mult_change = self.token_mults.get(token) if mult_change is not None: @@ -628,7 +579,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): multipliers.append(mult) i += len(ids) - 1 found = True - self.embeddings.used_custom_terms.append((word, self.embeddings.word_embeddings_checksums[word])) + used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word])) break if not found: @@ -637,6 +588,14 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): i += 1 + if len(remade_tokens) > maxlen - 2: + vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} + ovf = remade_tokens[maxlen - 2:] + overflowing_words = [vocab.get(int(x), "") for x in ovf] + overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) + + self.hijack.comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] cache[tuple_tokens] = (remade_tokens, fixes, multipliers) @@ -645,9 +604,12 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] remade_batch_tokens.append(remade_tokens) - self.embeddings.fixes.append(fixes) + self.hijack.fixes.append(fixes) batch_multipliers.append(multipliers) + if len(used_custom_terms) > 0: + self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) + tokens = torch.asarray(remade_batch_tokens).to(self.wrapped.device) outputs = self.wrapped.transformer(input_ids=tokens) z = outputs.last_hidden_state @@ -679,71 +641,123 @@ class EmbeddingsWithFixes(nn.Module): for offset, word in fixes: tensor[offset] = self.embeddings.word_embeddings[word] - return inputs_embeds -def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False, extra_generation_params=None): +class StableDiffusionProcessing: + def __init__(self, outpath=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, prompt_matrix=False, use_GFPGAN=False, do_not_save_grid=False, extra_generation_params=None): + self.outpath: str = outpath + self.prompt: str = prompt + self.seed: int = seed + self.sampler_index: int = sampler_index + self.batch_size: int = batch_size + self.n_iter: int = n_iter + self.steps: int = steps + self.cfg_scale: float = cfg_scale + self.width: int = width + self.height: int = height + self.prompt_matrix: bool = prompt_matrix + self.use_GFPGAN: bool = use_GFPGAN + self.do_not_save_grid: bool = do_not_save_grid + self.extra_generation_params: dict = extra_generation_params + + def init(self): + pass + + def sample(self, x, conditioning, unconditional_conditioning): + raise NotImplementedError() + + +class VanillaStableDiffusionSampler: + def __init__(self, constructor): + self.sampler = constructor(sd_model) + + def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning): + samples_ddim, _ = self.sampler.sample(S=p.steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x) + return samples_ddim + + +class CFGDenoiser(nn.Module): + def __init__(self, model): + super().__init__() + self.inner_model = model + + def forward(self, x, sigma, uncond, cond, cond_scale): + x_in = torch.cat([x] * 2) + sigma_in = torch.cat([sigma] * 2) + cond_in = torch.cat([uncond, cond]) + uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + return uncond + (cond - uncond) * cond_scale + + +class KDiffusionSampler: + def __init__(self, funcname): + self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model) + self.funcname = funcname + self.func = getattr(k_diffusion.sampling, self.funcname) + self.model_wrap_cfg = CFGDenoiser(self.model_wrap) + + def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning): + sigmas = self.model_wrap.get_sigmas(p.steps) + x = x * sigmas[0] + + samples_ddim = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False) + return samples_ddim + + +def process_images(p: StableDiffusionProcessing): """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" - assert prompt is not None + prompt = p.prompt + model = sd_model + + assert p.prompt is not None torch_gc() - if seed == -1: - seed = random.randrange(4294967294) - seed = int(seed) + seed = int(random.randrange(4294967294) if p.seed == -1 else p.seed) - os.makedirs(outpath, exist_ok=True) + os.makedirs(p.outpath, exist_ok=True) - sample_path = os.path.join(outpath, "samples") + sample_path = os.path.join(p.outpath, "samples") os.makedirs(sample_path, exist_ok=True) base_count = len(os.listdir(sample_path)) - grid_count = len(os.listdir(outpath)) - 1 + grid_count = len(os.listdir(p.outpath)) - 1 comments = [] prompt_matrix_parts = [] - if prompt_matrix: + if p.prompt_matrix: all_prompts = [] prompt_matrix_parts = prompt.split("|") combination_count = 2 ** (len(prompt_matrix_parts) - 1) for combination_num in range(combination_count): - selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1< 0: - comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in model_hijack.used_custom_terms])) + if len(model_hijack.comments) > 0: + comments += model_hijack.comments # we manually generate all input noises because each one should have a specific seed - x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds) + x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds) - samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc) + samples_ddim = p.sample(x=x, conditioning=c, unconditional_conditioning=uc) x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - if prompt_matrix or opts.samples_save or opts.grid_save: + if p.prompt_matrix or opts.samples_save or opts.grid_save: for i, x_sample in enumerate(x_samples_ddim): - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) - if use_GFPGAN and GFPGAN is not None: + if p.use_GFPGAN and GFPGAN is not None: torch_gc() cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample, has_aligned=False, only_center_face=False, paste_back=True) x_sample = restored_img @@ -791,44 +805,44 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, output_images.append(image) base_count += 1 - if (prompt_matrix or opts.grid_save) and not do_not_save_grid: - if prompt_matrix: - grid = image_grid(output_images, batch_size, force_n_rows=1 << ((len(prompt_matrix_parts)-1)//2)) + if (p.prompt_matrix or opts.grid_save) and not p.do_not_save_grid: + if p.prompt_matrix: + grid = image_grid(output_images, p.batch_size, force_n_rows=1 << ((len(prompt_matrix_parts)-1)//2)) try: - grid = draw_prompt_matrix(grid, width, height, prompt_matrix_parts) - except: + grid = draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts) + except Exception: import traceback print("Error creating prompt_matrix text:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) output_images.insert(0, grid) else: - grid = image_grid(output_images, batch_size) + grid = image_grid(output_images, p.batch_size) - save_image(grid, outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename) + save_image(grid, p.outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename) grid_count += 1 torch_gc() return output_images, seed, infotext() -def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int): - outpath = opts.outdir or "outputs/txt2img-samples" +class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): + sampler = None - sampler = samplers[sampler_index].constructor(model) + def init(self): + self.sampler = samplers[self.sampler_index].constructor() - def init(): - pass - - def sample(init_data, x, conditioning, unconditional_conditioning): - samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x) + def sample(self, x, conditioning, unconditional_conditioning): + samples_ddim = self.sampler.sample(self, x, conditioning, unconditional_conditioning) return samples_ddim - output_images, seed, info = process_images( + +def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int): + outpath = opts.outdir or "outputs/txt2img-samples" + + p = StableDiffusionProcessingTxt2Img( outpath=outpath, - func_init=init, - func_sample=sample, prompt=prompt, seed=seed, sampler_index=sampler_index, @@ -842,7 +856,7 @@ def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, use_GFPGAN=use_GFPGAN ) - del sampler + output_images, seed, info = process_images(p) return output_images, seed, plaintext_to_html(info) @@ -858,7 +872,7 @@ class Flagging(gr.FlaggingCallback): os.makedirs("log/images", exist_ok=True) # those must match the "txt2img" function - prompt, ddim_steps, sampler_name, use_GFPGAN, prompt_matrix, ddim_eta, n_iter, n_samples, cfg_scale, request_seed, height, width, images, seed, comment = flag_data + prompt, ddim_steps, sampler_name, use_gfpgan, prompt_matrix, ddim_eta, n_iter, n_samples, cfg_scale, request_seed, height, width, images, seed, comment = flag_data filenames = [] @@ -896,7 +910,6 @@ txt2img_interface = gr.Interface( gr.Radio(label='Sampling method', choices=[x.name for x in samplers], value=samplers[0].name, type="index"), gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None), gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False), - gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False), gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1), gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1), gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0), @@ -914,40 +927,77 @@ txt2img_interface = gr.Interface( ) +class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): + sampler = None + + def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, **kwargs): + super().__init__(**kwargs) + + self.init_images = init_images + self.resize_mode: int = resize_mode + self.denoising_strength: float = denoising_strength + self.init_latent = None + + def init(self): + self.sampler = samplers_for_img2img[self.sampler_index].constructor() + + imgs = [] + for img in self.init_images: + image = img.convert("RGB") + image = resize_image(self.resize_mode, image, self.width, self.height) + image = np.array(image).astype(np.float32) / 255.0 + image = np.moveaxis(image, 2, 0) + imgs.append(image) + + if len(imgs) == 1: + batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0) + elif len(imgs) <= self.batch_size: + self.batch_size = len(imgs) + batch_images = np.array(imgs) + else: + raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less") + + image = torch.from_numpy(batch_images) + image = 2. * image - 1. + image = image.to(device) + + self.init_latent = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image)) + + def sample(self, x, conditioning, unconditional_conditioning): + t_enc = int(self.denoising_strength * self.steps) + + sigmas = self.sampler.model_wrap.get_sigmas(self.steps) + noise = x * sigmas[self.steps - t_enc - 1] + + xi = self.init_latent + noise + sigma_sched = sigmas[self.steps - t_enc - 1:] + samples_ddim = self.sampler.func(self.sampler.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': self.cfg_scale}, disable=False) + return samples_ddim + + def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, sd_upscale: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int): outpath = opts.outdir or "outputs/img2img-samples" - sampler = samplers_for_img2img[sampler_index].constructor(model) - assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' - def init(): - image = init_img.convert("RGB") - image = resize_image(resize_mode, image, width, height) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - - init_image = 2. * image - 1. - init_image = init_image.to(device) - init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) - init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space - - return init_latent, - - def sample(init_data, x, conditioning, unconditional_conditioning): - t_enc = int(denoising_strength * ddim_steps) - - x0, = init_data - - sigmas = sampler.model_wrap.get_sigmas(ddim_steps) - noise = x * sigmas[ddim_steps - t_enc - 1] - - xi = x0 + noise - sigma_sched = sigmas[ddim_steps - t_enc - 1:] - model_wrap_cfg = CFGDenoiser(sampler.model_wrap) - samples_ddim = sampler.func(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False) - return samples_ddim + p = StableDiffusionProcessingImg2Img( + outpath=outpath, + prompt=prompt, + seed=seed, + sampler_index=sampler_index, + batch_size=batch_size, + n_iter=n_iter, + steps=ddim_steps, + cfg_scale=cfg_scale, + width=width, + height=height, + prompt_matrix=prompt_matrix, + use_GFPGAN=use_GFPGAN, + init_images=[init_img], + resize_mode=resize_mode, + denoising_strength=denoising_strength, + extra_generation_params={"Denoising Strength": denoising_strength} + ) if loopback: output_images, info = None, None @@ -955,32 +1005,19 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG initial_seed = None for i in range(n_iter): - output_images, seed, info = process_images( - outpath=outpath, - func_init=init, - func_sample=sample, - prompt=prompt, - seed=seed, - sampler_index=sampler_index, - batch_size=1, - n_iter=1, - steps=ddim_steps, - cfg_scale=cfg_scale, - width=width, - height=height, - prompt_matrix=prompt_matrix, - use_GFPGAN=use_GFPGAN, - do_not_save_grid=True, - extra_generation_params={"Denoising Strength": denoising_strength}, - ) + p.n_iter = 1 + p.batch_size = 1 + p.do_not_save_grid = True + + output_images, seed, info = process_images(p) if initial_seed is None: initial_seed = seed - init_img = output_images[0] - seed = seed + 1 - denoising_strength = max(denoising_strength * 0.95, 0.1) - history.append(init_img) + p.init_img = output_images[0] + p.seed = seed + 1 + p.denoising_strength = max(p.denoising_strength * 0.95, 0.1) + history.append(output_images[0]) grid_count = len(os.listdir(outpath)) - 1 grid = image_grid(history, batch_size, force_n_rows=1) @@ -1000,39 +1037,36 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG grid = split_grid(img, tile_w=width, tile_h=height, overlap=opts.sd_upscale_overlap) + p.n_iter = 1 + p.do_not_save_grid = True - print(f"SD upscaling will process a total of {len(grid.tiles[0][2])}x{len(grid.tiles)} images.") + work = [] + work_results = [] for y, h, row in grid.tiles: for tiledata in row: - init_img = tiledata[2] + work.append(tiledata[2]) - output_images, seed, info = process_images( - outpath=outpath, - func_init=init, - func_sample=sample, - prompt=prompt, - seed=seed, - sampler_index=sampler_index, - batch_size=1, # since process_images can't work with multiple different images we have to do this for now - n_iter=1, - steps=ddim_steps, - cfg_scale=cfg_scale, - width=width, - height=height, - prompt_matrix=prompt_matrix, - use_GFPGAN=use_GFPGAN, - do_not_save_grid=True, - extra_generation_params={"Denoising Strength": denoising_strength}, - ) + batch_count = math.ceil(len(work) / p.batch_size) + print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} in a total of {batch_count} batches.") - if initial_seed is None: - initial_seed = seed - initial_info = info + for i in range(batch_count): + p.init_images = work[i*p.batch_size:(i+1)*p.batch_size] - seed += 1 + output_images, seed, info = process_images(p) - tiledata[2] = output_images[0] + if initial_seed is None: + initial_seed = seed + initial_info = info + + p.seed = seed + 1 + work_results += output_images + + image_index = 0 + for y, h, row in grid.tiles: + for tiledata in row: + tiledata[2] = work_results[image_index] + image_index += 1 combined_image = combine_grid(grid) @@ -1044,25 +1078,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG info = initial_info else: - output_images, seed, info = process_images( - outpath=outpath, - func_init=init, - func_sample=sample, - prompt=prompt, - seed=seed, - sampler_index=sampler_index, - batch_size=batch_size, - n_iter=n_iter, - steps=ddim_steps, - cfg_scale=cfg_scale, - width=width, - height=height, - prompt_matrix=prompt_matrix, - use_GFPGAN=use_GFPGAN, - extra_generation_params={"Denoising Strength": denoising_strength}, - ) - - del sampler + output_images, seed, info = process_images(p) return output_images, seed, plaintext_to_html(info) @@ -1178,22 +1194,19 @@ def run_settings(*args): def create_setting_component(key): def fun(): - return opts.data[key] if key in opts.data else opts.data_labels[key][0] + return opts.data[key] if key in opts.data else opts.data_labels[key].default - labelinfo = opts.data_labels[key] - t = type(labelinfo[0]) - label = labelinfo[1] - if t == str: - item = gr.Textbox(label=label, value=fun, lines=1) + info = opts.data_labels[key] + t = type(info.default) + + if info.component is not None: + item = info.component(label=info.label, value=fun, **(info.component_args or {})) + elif t == str: + item = gr.Textbox(label=info.label, value=fun, lines=1) elif t == int: - if len(labelinfo) == 5: - item = gr.Slider(minimum=labelinfo[2], maximum=labelinfo[3], step=labelinfo[4], label=label, value=fun) - elif len(labelinfo) == 4: - item = gr.Slider(minimum=labelinfo[2], maximum=labelinfo[3], step=1, label=label, value=fun) - else: - item = gr.Number(label=label, value=fun) + item = gr.Number(label=info.label, value=fun) elif t == bool: - item = gr.Checkbox(label=label, value=fun) + item = gr.Checkbox(label=info.label, value=fun) else: raise Exception(f'bad options item type: {str(t)} for key {key}') @@ -1219,14 +1232,14 @@ interfaces = [ (settings_interface, "Settings"), ] -config = OmegaConf.load(cmd_opts.config) -model = load_model_from_config(config, cmd_opts.ckpt) +sd_config = OmegaConf.load(cmd_opts.config) +sd_model = load_model_from_config(sd_config, cmd_opts.ckpt) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") -model = (model if cmd_opts.no_half else model.half()).to(device) +sd_model = (sd_model if cmd_opts.no_half else sd_model.half()).to(device) model_hijack = StableDiffuionModelHijack() -model_hijack.hijack(model) +model_hijack.hijack(sd_model) demo = gr.TabbedInterface( interface_list=[x[0] for x in interfaces],