From c9579b51a610f34a57c296ec5cd8796db8d7690b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 25 Aug 2022 23:31:44 +0300 Subject: [PATCH] extra samplers from K-diffusion --- webui.py | 55 ++++++++++++++++++++++++++++++++----------------------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/webui.py b/webui.py index 4e586c02..40764178 100644 --- a/webui.py +++ b/webui.py @@ -1,4 +1,6 @@ import argparse, os, sys, glob +from collections import namedtuple + import torch import torch.nn as nn import numpy as np @@ -16,7 +18,7 @@ import time import json import traceback -import k_diffusion as K +import k_diffusion.sampling from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler @@ -60,6 +62,19 @@ css_hide_progressbar = """ .meta-text { display:none!important; } """ +SamplerData = namedtuple('SamplerData', ['name', 'constructor']) +samplers = [ + *[SamplerData(x[0], lambda model: KDiffusionSampler(model, x[1])) for x in [ + ('LMS', 'sample_lms'), + ('Heun', 'sample_heun'), + ('Euler', 'sample_euler'), + ('Euler ancestral', 'sample_euler_ancestral'), + ('DPM 2', 'sample_dpm_2'), + ('DPM 2 Ancestral', 'sample_dpm_2_ancestral'), + ] if hasattr(k_diffusion.sampling, x[1])], + SamplerData('DDIM', lambda model: DDIMSampler(model)), + SamplerData('PLMS', lambda model: PLMSSampler(model)), +] class Options: @@ -142,16 +157,18 @@ class CFGDenoiser(nn.Module): class KDiffusionSampler: - def __init__(self, m): + def __init__(self, m, funcname): self.model = m - self.model_wrap = K.external.CompVisDenoiser(m) + self.model_wrap = k_diffusion.external.CompVisDenoiser(m) + self.funcname = funcname def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T): sigmas = self.model_wrap.get_sigmas(S) x = x_T * sigmas[0] model_wrap_cfg = CFGDenoiser(self.model_wrap) - samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False) + fun = getattr(k_diffusion.sampling, self.funcname) + samples_ddim = fun(model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False) return samples_ddim, None @@ -526,7 +543,7 @@ def get_learned_conditioning_with_embeddings(model, prompts): return model.get_learned_conditioning(prompts) -def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False): +def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False): """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" assert prompt is not None @@ -579,7 +596,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, def infotext(): return f""" {prompt} -Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''} +Steps: {steps}, Sampler: {samplers[sampler_index].name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''} """.strip() + "".join(["\n\n" + x for x in comments]) if os.path.exists(cmd_opts.embeddings_dir): @@ -645,17 +662,10 @@ Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', return output_images, seed, infotext() -def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int): +def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int): outpath = opts.outdir or "outputs/txt2img-samples" - if sampler_name == 'PLMS': - sampler = PLMSSampler(model) - elif sampler_name == 'DDIM': - sampler = DDIMSampler(model) - elif sampler_name == 'k-diffusion': - sampler = KDiffusionSampler(model) - else: - raise Exception("Unknown sampler: " + sampler_name) + sampler = samplers[sampler_index].constructor(model) def init(): pass @@ -670,7 +680,7 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, p func_sample=sample, prompt=prompt, seed=seed, - sampler_name=sampler_name, + sampler_index=sampler_index, batch_size=batch_size, n_iter=n_iter, steps=ddim_steps, @@ -732,7 +742,7 @@ txt2img_interface = gr.Interface( inputs=[ gr.Textbox(label="Prompt", placeholder="A corgi wearing a top hat as an oil painting.", lines=1), gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50), - gr.Radio(label='Sampling method', choices=["DDIM", "PLMS", "k-diffusion"], value="k-diffusion"), + gr.Radio(label='Sampling method', choices=[x.name for x in samplers], value=samplers[0].name, type="index"), gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None), gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False), gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False), @@ -756,7 +766,7 @@ txt2img_interface = gr.Interface( def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int): outpath = opts.outdir or "outputs/img2img-samples" - sampler = KDiffusionSampler(model) + sampler = KDiffusionSampler(model, 'sample_lms') assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' @@ -785,7 +795,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_mat xi = x0 + noise sigma_sched = sigmas[ddim_steps - t_enc - 1:] model_wrap_cfg = CFGDenoiser(sampler.model_wrap) - samples_ddim = K.sampling.sample_lms(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False) + samples_ddim = k_diffusion.sampling.sample_lms(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False) return samples_ddim if loopback: @@ -800,7 +810,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_mat func_sample=sample, prompt=prompt, seed=seed, - sampler_name='k-diffusion', + sampler_index=0, batch_size=1, n_iter=1, steps=ddim_steps, @@ -835,7 +845,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_mat func_sample=sample, prompt=prompt, seed=seed, - sampler_name='k-diffusion', + sampler_index=0, batch_size=batch_size, n_iter=n_iter, steps=ddim_steps, @@ -877,10 +887,10 @@ img2img_interface = gr.Interface( gr.Number(label='Seed'), gr.HTML(), ], - title="Stable Diffusion Image-to-Image", allow_flagging="never", ) + def run_GFPGAN(image, strength): image = image.convert("RGB") @@ -904,7 +914,6 @@ gfpgan_interface = gr.Interface( gr.Number(label='Seed', visible=False), gr.HTML(), ], - title="GFPGAN", description="Fix faces on images", allow_flagging="never", )