Merge pull request #3917 from MartinCairnsSQL/adjust-ddim-uniform-steps
Certain step counts for DDIM cause out of bounds error
This commit is contained in:
commit
40b3a7e8a5
1 changed files with 15 additions and 13 deletions
|
@ -1,5 +1,6 @@
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from math import floor
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -205,17 +206,22 @@ class VanillaStableDiffusionSampler:
|
||||||
self.mask = p.mask if hasattr(p, 'mask') else None
|
self.mask = p.mask if hasattr(p, 'mask') else None
|
||||||
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_steps_if_invalid(self, p, num_steps):
|
||||||
|
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
|
||||||
|
valid_step = 999 / (1000 // num_steps)
|
||||||
|
if valid_step == floor(valid_step):
|
||||||
|
return int(valid_step) + 1
|
||||||
|
|
||||||
|
return num_steps
|
||||||
|
|
||||||
|
|
||||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
steps, t_enc = setup_img2img_steps(p, steps)
|
steps, t_enc = setup_img2img_steps(p, steps)
|
||||||
|
steps = self.adjust_steps_if_invalid(p, steps)
|
||||||
self.initialize(p)
|
self.initialize(p)
|
||||||
|
|
||||||
# existing code fails with certain step counts, like 9
|
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
||||||
try:
|
|
||||||
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
|
||||||
except Exception:
|
|
||||||
self.sampler.make_schedule(ddim_num_steps=steps+1, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
|
||||||
|
|
||||||
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
||||||
|
|
||||||
self.init_latent = x
|
self.init_latent = x
|
||||||
|
@ -239,18 +245,14 @@ class VanillaStableDiffusionSampler:
|
||||||
self.last_latent = x
|
self.last_latent = x
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
||||||
steps = steps or p.steps
|
steps = self.adjust_steps_if_invalid(p, steps or p.steps)
|
||||||
|
|
||||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||||
if image_conditioning is not None:
|
if image_conditioning is not None:
|
||||||
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
||||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||||
|
|
||||||
# existing code fails with certain step counts, like 9
|
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||||
try:
|
|
||||||
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
|
||||||
except Exception:
|
|
||||||
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
|
||||||
|
|
||||||
return samples_ddim
|
return samples_ddim
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue