Fix error when batch count > 1
This commit is contained in:
parent
9e27af76d1
commit
f55a7e04d8
1 changed files with 5 additions and 4 deletions
|
@ -269,14 +269,15 @@ class KDiffusionSampler:
|
||||||
|
|
||||||
return sigmas
|
return sigmas
|
||||||
|
|
||||||
def create_noise_sampler(self, x, sigmas, seeds):
|
def create_noise_sampler(self, x, sigmas, p):
|
||||||
"""For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
|
"""For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
|
||||||
if shared.opts.no_dpmpp_sde_batch_determinism:
|
if shared.opts.no_dpmpp_sde_batch_determinism:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
from k_diffusion.sampling import BrownianTreeNoiseSampler
|
from k_diffusion.sampling import BrownianTreeNoiseSampler
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seeds)
|
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
|
||||||
|
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
|
||||||
|
|
||||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||||
|
@ -302,7 +303,7 @@ class KDiffusionSampler:
|
||||||
extra_params_kwargs['sigmas'] = sigma_sched
|
extra_params_kwargs['sigmas'] = sigma_sched
|
||||||
|
|
||||||
if self.funcname == 'sample_dpmpp_sde':
|
if self.funcname == 'sample_dpmpp_sde':
|
||||||
noise_sampler = self.create_noise_sampler(x, sigmas, p.all_seeds)
|
noise_sampler = self.create_noise_sampler(x, sigmas, p)
|
||||||
extra_params_kwargs['noise_sampler'] = noise_sampler
|
extra_params_kwargs['noise_sampler'] = noise_sampler
|
||||||
|
|
||||||
self.model_wrap_cfg.init_latent = x
|
self.model_wrap_cfg.init_latent = x
|
||||||
|
@ -337,7 +338,7 @@ class KDiffusionSampler:
|
||||||
extra_params_kwargs['sigmas'] = sigmas
|
extra_params_kwargs['sigmas'] = sigmas
|
||||||
|
|
||||||
if self.funcname == 'sample_dpmpp_sde':
|
if self.funcname == 'sample_dpmpp_sde':
|
||||||
noise_sampler = self.create_noise_sampler(x, sigmas, p.all_seeds)
|
noise_sampler = self.create_noise_sampler(x, sigmas, p)
|
||||||
extra_params_kwargs['noise_sampler'] = noise_sampler
|
extra_params_kwargs['noise_sampler'] = noise_sampler
|
||||||
|
|
||||||
self.last_latent = x
|
self.last_latent = x
|
||||||
|
|
Loading…
Reference in a new issue