added samples to img2img

fixed a bug with sampler selection (oops)
This commit is contained in:
AUTOMATIC 2022-08-26 14:10:40 +03:00
parent 155dd2fc0c
commit 21765c17e6

View file

@ -64,7 +64,7 @@ css_hide_progressbar = """
SamplerData = namedtuple('SamplerData', ['name', 'constructor'])
samplers = [
*[SamplerData(x[0], lambda model: KDiffusionSampler(model, x[1])) for x in [
*[SamplerData(x[0], lambda m, funcname=x[1]: KDiffusionSampler(m, funcname)) for x in [
('LMS', 'sample_lms'),
('Heun', 'sample_heun'),
('Euler', 'sample_euler'),
@ -72,9 +72,10 @@ samplers = [
('DPM 2', 'sample_dpm_2'),
('DPM 2 Ancestral', 'sample_dpm_2_ancestral'),
] if hasattr(k_diffusion.sampling, x[1])],
SamplerData('DDIM', lambda model: DDIMSampler(model)),
SamplerData('PLMS', lambda model: PLMSSampler(model)),
SamplerData('DDIM', lambda m: DDIMSampler(model)),
SamplerData('PLMS', lambda m: PLMSSampler(model)),
]
samplers_for_img2img = [x for x in samplers if x.name != 'DDIM' and x.name != 'PLMS']
RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"])
@ -197,14 +198,14 @@ class KDiffusionSampler:
self.model = m
self.model_wrap = k_diffusion.external.CompVisDenoiser(m)
self.funcname = funcname
self.func = getattr(k_diffusion.sampling, self.funcname)
def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T):
sigmas = self.model_wrap.get_sigmas(S)
x = x_T * sigmas[0]
model_wrap_cfg = CFGDenoiser(self.model_wrap)
fun = getattr(k_diffusion.sampling, self.funcname)
samples_ddim = fun(model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False)
samples_ddim = self.func(model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False)
return samples_ddim, None
@ -810,10 +811,10 @@ txt2img_interface = gr.Interface(
)
def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
outpath = opts.outdir or "outputs/img2img-samples"
sampler = KDiffusionSampler(model, 'sample_lms')
sampler = samplers_for_img2img[sampler_index].constructor(model)
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
@ -842,7 +843,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_mat
xi = x0 + noise
sigma_sched = sigmas[ddim_steps - t_enc - 1:]
model_wrap_cfg = CFGDenoiser(sampler.model_wrap)
samples_ddim = k_diffusion.sampling.sample_lms(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False)
samples_ddim = sampler.func(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False)
return samples_ddim
if loopback:
@ -919,6 +920,7 @@ img2img_interface = gr.Interface(
gr.Textbox(placeholder="A fantasy landscape, trending on artstation.", lines=1),
gr.Image(value=sample_img2img, source="upload", interactive=True, type="pil"),
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index"),
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
gr.Checkbox(label='Loopback (use images from previous batch when creating next batch)', value=False),