emergency fix for #1199
This commit is contained in:
parent
15f333a266
commit
2ab64ec81a
1 changed files with 13 additions and 12 deletions
|
@ -3,6 +3,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
import inspect
|
||||||
|
|
||||||
import k_diffusion.sampling
|
import k_diffusion.sampling
|
||||||
import ldm.models.diffusion.ddim
|
import ldm.models.diffusion.ddim
|
||||||
|
@ -38,11 +39,11 @@ samplers = [
|
||||||
samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
|
samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
|
||||||
|
|
||||||
sampler_extra_params = {
|
sampler_extra_params = {
|
||||||
'sample_euler':['s_churn','s_tmin','s_tmax','s_noise'],
|
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||||
'sample_euler_ancestral':['eta'],
|
'sample_euler_ancestral': ['eta'],
|
||||||
'sample_heun' :['s_churn','s_tmin','s_tmax','s_noise'],
|
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||||
'sample_dpm_2':['s_churn','s_tmin','s_tmax','s_noise'],
|
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||||
'sample_dpm_2_ancestral':['eta'],
|
'sample_dpm_2_ancestral': ['eta'],
|
||||||
}
|
}
|
||||||
|
|
||||||
def setup_img2img_steps(p, steps=None):
|
def setup_img2img_steps(p, steps=None):
|
||||||
|
@ -231,7 +232,7 @@ class KDiffusionSampler:
|
||||||
self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization)
|
self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization)
|
||||||
self.funcname = funcname
|
self.funcname = funcname
|
||||||
self.func = getattr(k_diffusion.sampling, self.funcname)
|
self.func = getattr(k_diffusion.sampling, self.funcname)
|
||||||
self.extra_params = sampler_extra_params.get(funcname,[])
|
self.extra_params = sampler_extra_params.get(funcname, [])
|
||||||
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
||||||
self.sampler_noises = None
|
self.sampler_noises = None
|
||||||
self.sampler_noise_index = 0
|
self.sampler_noise_index = 0
|
||||||
|
@ -278,9 +279,9 @@ class KDiffusionSampler:
|
||||||
k_diffusion.sampling.torch = TorchHijack(self)
|
k_diffusion.sampling.torch = TorchHijack(self)
|
||||||
|
|
||||||
extra_params_kwargs = {}
|
extra_params_kwargs = {}
|
||||||
for val in self.extra_params:
|
for param_name in self.extra_params:
|
||||||
if hasattr(p,val):
|
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
||||||
extra_params_kwargs[val] = getattr(p,val)
|
extra_params_kwargs[param_name] = getattr(p, param_name)
|
||||||
|
|
||||||
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
|
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
|
||||||
|
|
||||||
|
@ -300,9 +301,9 @@ class KDiffusionSampler:
|
||||||
k_diffusion.sampling.torch = TorchHijack(self)
|
k_diffusion.sampling.torch = TorchHijack(self)
|
||||||
|
|
||||||
extra_params_kwargs = {}
|
extra_params_kwargs = {}
|
||||||
for val in self.extra_params:
|
for param_name in self.extra_params:
|
||||||
if hasattr(p,val):
|
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
||||||
extra_params_kwargs[val] = getattr(p,val)
|
extra_params_kwargs[param_name] = getattr(p, param_name)
|
||||||
|
|
||||||
samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
|
samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue