extra samplers from K-diffusion

This commit is contained in:
AUTOMATIC 2022-08-25 23:31:44 +03:00
parent 91dc8710ec
commit c9579b51a6

View file

@ -1,4 +1,6 @@
import argparse, os, sys, glob import argparse, os, sys, glob
from collections import namedtuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
@ -16,7 +18,7 @@ import time
import json import json
import traceback import traceback
import k_diffusion as K import k_diffusion.sampling
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.plms import PLMSSampler
@ -60,6 +62,19 @@ css_hide_progressbar = """
.meta-text { display:none!important; } .meta-text { display:none!important; }
""" """
SamplerData = namedtuple('SamplerData', ['name', 'constructor'])
samplers = [
*[SamplerData(x[0], lambda model: KDiffusionSampler(model, x[1])) for x in [
('LMS', 'sample_lms'),
('Heun', 'sample_heun'),
('Euler', 'sample_euler'),
('Euler ancestral', 'sample_euler_ancestral'),
('DPM 2', 'sample_dpm_2'),
('DPM 2 Ancestral', 'sample_dpm_2_ancestral'),
] if hasattr(k_diffusion.sampling, x[1])],
SamplerData('DDIM', lambda model: DDIMSampler(model)),
SamplerData('PLMS', lambda model: PLMSSampler(model)),
]
class Options: class Options:
@ -142,16 +157,18 @@ class CFGDenoiser(nn.Module):
class KDiffusionSampler: class KDiffusionSampler:
def __init__(self, m): def __init__(self, m, funcname):
self.model = m self.model = m
self.model_wrap = K.external.CompVisDenoiser(m) self.model_wrap = k_diffusion.external.CompVisDenoiser(m)
self.funcname = funcname
def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T): def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T):
sigmas = self.model_wrap.get_sigmas(S) sigmas = self.model_wrap.get_sigmas(S)
x = x_T * sigmas[0] x = x_T * sigmas[0]
model_wrap_cfg = CFGDenoiser(self.model_wrap) model_wrap_cfg = CFGDenoiser(self.model_wrap)
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False) fun = getattr(k_diffusion.sampling, self.funcname)
samples_ddim = fun(model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False)
return samples_ddim, None return samples_ddim, None
@ -526,7 +543,7 @@ def get_learned_conditioning_with_embeddings(model, prompts):
return model.get_learned_conditioning(prompts) return model.get_learned_conditioning(prompts)
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False): def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
assert prompt is not None assert prompt is not None
@ -579,7 +596,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
def infotext(): def infotext():
return f""" return f"""
{prompt} {prompt}
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''} Steps: {steps}, Sampler: {samplers[sampler_index].name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
""".strip() + "".join(["\n\n" + x for x in comments]) """.strip() + "".join(["\n\n" + x for x in comments])
if os.path.exists(cmd_opts.embeddings_dir): if os.path.exists(cmd_opts.embeddings_dir):
@ -645,17 +662,10 @@ Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{',
return output_images, seed, infotext() return output_images, seed, infotext()
def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int): def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int):
outpath = opts.outdir or "outputs/txt2img-samples" outpath = opts.outdir or "outputs/txt2img-samples"
if sampler_name == 'PLMS': sampler = samplers[sampler_index].constructor(model)
sampler = PLMSSampler(model)
elif sampler_name == 'DDIM':
sampler = DDIMSampler(model)
elif sampler_name == 'k-diffusion':
sampler = KDiffusionSampler(model)
else:
raise Exception("Unknown sampler: " + sampler_name)
def init(): def init():
pass pass
@ -670,7 +680,7 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, p
func_sample=sample, func_sample=sample,
prompt=prompt, prompt=prompt,
seed=seed, seed=seed,
sampler_name=sampler_name, sampler_index=sampler_index,
batch_size=batch_size, batch_size=batch_size,
n_iter=n_iter, n_iter=n_iter,
steps=ddim_steps, steps=ddim_steps,
@ -732,7 +742,7 @@ txt2img_interface = gr.Interface(
inputs=[ inputs=[
gr.Textbox(label="Prompt", placeholder="A corgi wearing a top hat as an oil painting.", lines=1), gr.Textbox(label="Prompt", placeholder="A corgi wearing a top hat as an oil painting.", lines=1),
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50), gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
gr.Radio(label='Sampling method', choices=["DDIM", "PLMS", "k-diffusion"], value="k-diffusion"), gr.Radio(label='Sampling method', choices=[x.name for x in samplers], value=samplers[0].name, type="index"),
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None), gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False), gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False), gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
@ -756,7 +766,7 @@ txt2img_interface = gr.Interface(
def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int): def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
outpath = opts.outdir or "outputs/img2img-samples" outpath = opts.outdir or "outputs/img2img-samples"
sampler = KDiffusionSampler(model) sampler = KDiffusionSampler(model, 'sample_lms')
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
@ -785,7 +795,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_mat
xi = x0 + noise xi = x0 + noise
sigma_sched = sigmas[ddim_steps - t_enc - 1:] sigma_sched = sigmas[ddim_steps - t_enc - 1:]
model_wrap_cfg = CFGDenoiser(sampler.model_wrap) model_wrap_cfg = CFGDenoiser(sampler.model_wrap)
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False) samples_ddim = k_diffusion.sampling.sample_lms(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False)
return samples_ddim return samples_ddim
if loopback: if loopback:
@ -800,7 +810,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_mat
func_sample=sample, func_sample=sample,
prompt=prompt, prompt=prompt,
seed=seed, seed=seed,
sampler_name='k-diffusion', sampler_index=0,
batch_size=1, batch_size=1,
n_iter=1, n_iter=1,
steps=ddim_steps, steps=ddim_steps,
@ -835,7 +845,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_mat
func_sample=sample, func_sample=sample,
prompt=prompt, prompt=prompt,
seed=seed, seed=seed,
sampler_name='k-diffusion', sampler_index=0,
batch_size=batch_size, batch_size=batch_size,
n_iter=n_iter, n_iter=n_iter,
steps=ddim_steps, steps=ddim_steps,
@ -877,10 +887,10 @@ img2img_interface = gr.Interface(
gr.Number(label='Seed'), gr.Number(label='Seed'),
gr.HTML(), gr.HTML(),
], ],
title="Stable Diffusion Image-to-Image",
allow_flagging="never", allow_flagging="never",
) )
def run_GFPGAN(image, strength): def run_GFPGAN(image, strength):
image = image.convert("RGB") image = image.convert("RGB")
@ -904,7 +914,6 @@ gfpgan_interface = gr.Interface(
gr.Number(label='Seed', visible=False), gr.Number(label='Seed', visible=False),
gr.HTML(), gr.HTML(),
], ],
title="GFPGAN",
description="Fix faces on images", description="Fix faces on images",
allow_flagging="never", allow_flagging="never",
) )