progress bar description for k-diffsuion for 88393097
This commit is contained in:
parent
49fcdbefa3
commit
2d5689a051
1 changed files with 9 additions and 1 deletions
10
webui.py
10
webui.py
|
@ -35,6 +35,7 @@ import traceback
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
import signal
|
import signal
|
||||||
|
import tqdm
|
||||||
|
|
||||||
import k_diffusion.sampling
|
import k_diffusion.sampling
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
@ -842,6 +843,7 @@ class StableDiffusionProcessing:
|
||||||
self.extra_generation_params: dict = extra_generation_params
|
self.extra_generation_params: dict = extra_generation_params
|
||||||
self.overlay_images = overlay_images
|
self.overlay_images = overlay_images
|
||||||
self.paste_to = None
|
self.paste_to = None
|
||||||
|
self.progress_info = ""
|
||||||
|
|
||||||
def init(self):
|
def init(self):
|
||||||
pass
|
pass
|
||||||
|
@ -917,7 +919,6 @@ class CFGDenoiser(nn.Module):
|
||||||
|
|
||||||
return denoised
|
return denoised
|
||||||
|
|
||||||
|
|
||||||
class KDiffusionSampler:
|
class KDiffusionSampler:
|
||||||
def __init__(self, funcname):
|
def __init__(self, funcname):
|
||||||
self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model)
|
self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model)
|
||||||
|
@ -938,12 +939,18 @@ class KDiffusionSampler:
|
||||||
self.model_wrap_cfg.nmask = p.nmask
|
self.model_wrap_cfg.nmask = p.nmask
|
||||||
self.model_wrap_cfg.init_latent = p.init_latent
|
self.model_wrap_cfg.init_latent = p.init_latent
|
||||||
|
|
||||||
|
if hasattr(k_diffusion.sampling, 'trange'):
|
||||||
|
k_diffusion.sampling.trange = lambda *args, **kwargs: tqdm.tqdm(range(*args), desc=p.progress_info, **kwargs)
|
||||||
|
|
||||||
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False)
|
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False)
|
||||||
|
|
||||||
def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning):
|
def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning):
|
||||||
sigmas = self.model_wrap.get_sigmas(p.steps)
|
sigmas = self.model_wrap.get_sigmas(p.steps)
|
||||||
x = x * sigmas[0]
|
x = x * sigmas[0]
|
||||||
|
|
||||||
|
if hasattr(k_diffusion.sampling, 'trange'):
|
||||||
|
k_diffusion.sampling.trange = lambda *args, **kwargs: tqdm.tqdm(range(*args), desc=p.progress_info, **kwargs)
|
||||||
|
|
||||||
samples_ddim = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False)
|
samples_ddim = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False)
|
||||||
return samples_ddim
|
return samples_ddim
|
||||||
|
|
||||||
|
@ -1030,6 +1037,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
# we manually generate all input noises because each one should have a specific seed
|
# we manually generate all input noises because each one should have a specific seed
|
||||||
x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds)
|
x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds)
|
||||||
|
|
||||||
|
p.progress_info = f"Batch {n+1} out of {p.n_iter}"
|
||||||
samples_ddim = p.sample(x=x, conditioning=c, unconditional_conditioning=uc)
|
samples_ddim = p.sample(x=x, conditioning=c, unconditional_conditioning=uc)
|
||||||
|
|
||||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||||
|
|
Loading…
Reference in a new issue