From 399b229783a7b5fddab0a258740b4d59d668ee12 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 24 Dec 2022 09:03:45 +0300 Subject: [PATCH] eliminate duplicated code add an option to samplers for skipping next to last sigma --- modules/sd_samplers.py | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 1a1b8919..d26e48dc 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -23,16 +23,16 @@ samplers_k_diffusion = [ ('Euler', 'sample_euler', ['k_euler'], {}), ('LMS', 'sample_lms', ['k_lms'], {}), ('Heun', 'sample_heun', ['k_heun'], {}), - ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}), - ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}), + ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}), + ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}), ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}), ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}), ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}), ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}), ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), - ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}), - ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}), + ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}), + ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}), ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}), ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}), @@ -444,9 +444,7 @@ class KDiffusionSampler: return extra_params_kwargs - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - steps, t_enc = setup_img2img_steps(p, steps) - + def get_sigmas(self, p, steps): if p.sampler_noise_scheduler_override: sigmas = p.sampler_noise_scheduler_override(steps) elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': @@ -454,9 +452,16 @@ class KDiffusionSampler: else: sigmas = self.model_wrap.get_sigmas(steps) - if self.funcname in ['sample_dpm_2_ancestral', 'sample_dpm_2']: + if self.config is not None and self.config.options.get('discard_next_to_last_sigma', False): sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) + return sigmas + + def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): + steps, t_enc = setup_img2img_steps(p, steps) + + sigmas = self.get_sigmas(p, steps) + sigma_sched = sigmas[steps - t_enc - 1:] xi = x + noise * sigma_sched[0] @@ -488,18 +493,10 @@ class KDiffusionSampler: def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None): steps = steps or p.steps - if p.sampler_noise_scheduler_override: - sigmas = p.sampler_noise_scheduler_override(steps) - elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': - sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device) - else: - sigmas = self.model_wrap.get_sigmas(steps) + sigmas = self.get_sigmas(p, steps) x = x * sigmas[0] - if self.funcname in ['sample_dpm_2_ancestral', 'sample_dpm_2']: - sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) - extra_params_kwargs = self.initialize(p) if 'sigma_min' in inspect.signature(self.func).parameters: extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()