remove dependence on TQDM for sampler progress/interrupt functionality
This commit is contained in:
parent
ec1924ee57
commit
cbf15edbf9
2 changed files with 58 additions and 55 deletions
|
@ -402,12 +402,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
|
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
|
||||||
|
|
||||||
if state.interrupted or state.skipped:
|
|
||||||
|
|
||||||
# if we are interrupted, sample returns just noise
|
|
||||||
# use the image collected previously in sampler loop
|
|
||||||
samples_ddim = shared.state.current_latent
|
|
||||||
|
|
||||||
samples_ddim = samples_ddim.to(devices.dtype_vae)
|
samples_ddim = samples_ddim.to(devices.dtype_vae)
|
||||||
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
|
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
|
@ -98,25 +98,8 @@ def store_latent(decoded):
|
||||||
shared.state.current_image = sample_to_image(decoded)
|
shared.state.current_image = sample_to_image(decoded)
|
||||||
|
|
||||||
|
|
||||||
|
class InterruptedException(BaseException):
|
||||||
def extended_tdqm(sequence, *args, desc=None, **kwargs):
|
pass
|
||||||
state.sampling_steps = len(sequence)
|
|
||||||
state.sampling_step = 0
|
|
||||||
|
|
||||||
seq = sequence if cmd_opts.disable_console_progressbars else tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
|
|
||||||
|
|
||||||
for x in seq:
|
|
||||||
if state.interrupted or state.skipped:
|
|
||||||
break
|
|
||||||
|
|
||||||
yield x
|
|
||||||
|
|
||||||
state.sampling_step += 1
|
|
||||||
shared.total_tqdm.update()
|
|
||||||
|
|
||||||
|
|
||||||
ldm.models.diffusion.ddim.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)
|
|
||||||
ldm.models.diffusion.plms.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class VanillaStableDiffusionSampler:
|
class VanillaStableDiffusionSampler:
|
||||||
|
@ -128,14 +111,32 @@ class VanillaStableDiffusionSampler:
|
||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
self.sampler_noises = None
|
self.sampler_noises = None
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
self.stop_at = None
|
||||||
self.eta = None
|
self.eta = None
|
||||||
self.default_eta = 0.0
|
self.default_eta = 0.0
|
||||||
self.config = None
|
self.config = None
|
||||||
|
self.last_latent = None
|
||||||
|
|
||||||
def number_of_needed_noises(self, p):
|
def number_of_needed_noises(self, p):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
def launch_sampling(self, steps, func):
|
||||||
|
state.sampling_steps = steps
|
||||||
|
state.sampling_step = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
return func()
|
||||||
|
except InterruptedException:
|
||||||
|
return self.last_latent
|
||||||
|
|
||||||
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
||||||
|
if state.interrupted or state.skipped:
|
||||||
|
raise InterruptedException
|
||||||
|
|
||||||
|
if self.stop_at is not None and self.step > self.stop_at:
|
||||||
|
raise InterruptedException
|
||||||
|
|
||||||
|
|
||||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||||
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
||||||
|
|
||||||
|
@ -159,11 +160,16 @@ class VanillaStableDiffusionSampler:
|
||||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
||||||
|
|
||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
store_latent(self.init_latent * self.mask + self.nmask * res[1])
|
self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
|
||||||
else:
|
else:
|
||||||
store_latent(res[1])
|
self.last_latent = res[1]
|
||||||
|
|
||||||
|
store_latent(self.last_latent)
|
||||||
|
|
||||||
self.step += 1
|
self.step += 1
|
||||||
|
state.sampling_step = self.step
|
||||||
|
shared.total_tqdm.update()
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def initialize(self, p):
|
def initialize(self, p):
|
||||||
|
@ -192,7 +198,7 @@ class VanillaStableDiffusionSampler:
|
||||||
self.init_latent = x
|
self.init_latent = x
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
||||||
samples = self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)
|
samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
@ -206,9 +212,9 @@ class VanillaStableDiffusionSampler:
|
||||||
|
|
||||||
# existing code fails with certain step counts, like 9
|
# existing code fails with certain step counts, like 9
|
||||||
try:
|
try:
|
||||||
samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)
|
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||||
except Exception:
|
except Exception:
|
||||||
samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)
|
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||||
|
|
||||||
return samples_ddim
|
return samples_ddim
|
||||||
|
|
||||||
|
@ -223,6 +229,9 @@ class CFGDenoiser(torch.nn.Module):
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||||
|
if state.interrupted or state.skipped:
|
||||||
|
raise InterruptedException
|
||||||
|
|
||||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||||
|
|
||||||
|
@ -268,25 +277,6 @@ class CFGDenoiser(torch.nn.Module):
|
||||||
return denoised
|
return denoised
|
||||||
|
|
||||||
|
|
||||||
def extended_trange(sampler, count, *args, **kwargs):
|
|
||||||
state.sampling_steps = count
|
|
||||||
state.sampling_step = 0
|
|
||||||
|
|
||||||
seq = range(count) if cmd_opts.disable_console_progressbars else tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
|
|
||||||
|
|
||||||
for x in seq:
|
|
||||||
if state.interrupted or state.skipped:
|
|
||||||
break
|
|
||||||
|
|
||||||
if sampler.stop_at is not None and x > sampler.stop_at:
|
|
||||||
break
|
|
||||||
|
|
||||||
yield x
|
|
||||||
|
|
||||||
state.sampling_step += 1
|
|
||||||
shared.total_tqdm.update()
|
|
||||||
|
|
||||||
|
|
||||||
class TorchHijack:
|
class TorchHijack:
|
||||||
def __init__(self, kdiff_sampler):
|
def __init__(self, kdiff_sampler):
|
||||||
self.kdiff_sampler = kdiff_sampler
|
self.kdiff_sampler = kdiff_sampler
|
||||||
|
@ -314,9 +304,28 @@ class KDiffusionSampler:
|
||||||
self.eta = None
|
self.eta = None
|
||||||
self.default_eta = 1.0
|
self.default_eta = 1.0
|
||||||
self.config = None
|
self.config = None
|
||||||
|
self.last_latent = None
|
||||||
|
|
||||||
def callback_state(self, d):
|
def callback_state(self, d):
|
||||||
store_latent(d["denoised"])
|
step = d['i']
|
||||||
|
latent = d["denoised"]
|
||||||
|
store_latent(latent)
|
||||||
|
self.last_latent = latent
|
||||||
|
|
||||||
|
if self.stop_at is not None and step > self.stop_at:
|
||||||
|
raise InterruptedException
|
||||||
|
|
||||||
|
state.sampling_step = step
|
||||||
|
shared.total_tqdm.update()
|
||||||
|
|
||||||
|
def launch_sampling(self, steps, func):
|
||||||
|
state.sampling_steps = steps
|
||||||
|
state.sampling_step = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
return func()
|
||||||
|
except InterruptedException:
|
||||||
|
return self.last_latent
|
||||||
|
|
||||||
def number_of_needed_noises(self, p):
|
def number_of_needed_noises(self, p):
|
||||||
return p.steps
|
return p.steps
|
||||||
|
@ -339,9 +348,6 @@ class KDiffusionSampler:
|
||||||
self.sampler_noise_index = 0
|
self.sampler_noise_index = 0
|
||||||
self.eta = p.eta or opts.eta_ancestral
|
self.eta = p.eta or opts.eta_ancestral
|
||||||
|
|
||||||
if hasattr(k_diffusion.sampling, 'trange'):
|
|
||||||
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(self, *args, **kwargs)
|
|
||||||
|
|
||||||
if self.sampler_noises is not None:
|
if self.sampler_noises is not None:
|
||||||
k_diffusion.sampling.torch = TorchHijack(self)
|
k_diffusion.sampling.torch = TorchHijack(self)
|
||||||
|
|
||||||
|
@ -383,8 +389,9 @@ class KDiffusionSampler:
|
||||||
|
|
||||||
self.model_wrap_cfg.init_latent = x
|
self.model_wrap_cfg.init_latent = x
|
||||||
|
|
||||||
return self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
|
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
|
return samples
|
||||||
|
|
||||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
|
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
|
||||||
steps = steps or p.steps
|
steps = steps or p.steps
|
||||||
|
@ -406,6 +413,8 @@ class KDiffusionSampler:
|
||||||
extra_params_kwargs['n'] = steps
|
extra_params_kwargs['n'] = steps
|
||||||
else:
|
else:
|
||||||
extra_params_kwargs['sigmas'] = sigmas
|
extra_params_kwargs['sigmas'] = sigmas
|
||||||
samples = self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
|
|
||||||
|
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue