Fixed copying mistake
This commit is contained in:
parent
8e7097d06a
commit
0719c10bf1
1 changed files with 25 additions and 54 deletions
|
@ -19,63 +19,35 @@ from ldm.models.diffusion.ddim import DDIMSampler, noise_like
|
||||||
# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py
|
# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py
|
||||||
# =================================================================================================
|
# =================================================================================================
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample(
|
def sample(self,
|
||||||
self,
|
S,
|
||||||
S,
|
batch_size,
|
||||||
batch_size,
|
shape,
|
||||||
shape,
|
conditioning=None,
|
||||||
conditioning=None,
|
callback=None,
|
||||||
callback=None,
|
normals_sequence=None,
|
||||||
normals_sequence=None,
|
img_callback=None,
|
||||||
img_callback=None,
|
quantize_x0=False,
|
||||||
quantize_x0=False,
|
eta=0.,
|
||||||
eta=0.,
|
mask=None,
|
||||||
mask=None,
|
x0=None,
|
||||||
x0=None,
|
temperature=1.,
|
||||||
temperature=1.,
|
noise_dropout=0.,
|
||||||
noise_dropout=0.,
|
score_corrector=None,
|
||||||
score_corrector=None,
|
corrector_kwargs=None,
|
||||||
corrector_kwargs=None,
|
verbose=True,
|
||||||
verbose=True,
|
x_T=None,
|
||||||
x_T=None,
|
log_every_t=100,
|
||||||
log_every_t=100,
|
unconditional_guidance_scale=1.,
|
||||||
unconditional_guidance_scale=1.,
|
unconditional_conditioning=None,
|
||||||
unconditional_conditioning=None,
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
**kwargs
|
||||||
**kwargs
|
):
|
||||||
):
|
|
||||||
if conditioning is not None:
|
if conditioning is not None:
|
||||||
if isinstance(conditioning, dict):
|
if isinstance(conditioning, dict):
|
||||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||||
while isinstance(ctmp, list):
|
while isinstance(ctmp, list):
|
||||||
ctmp = elf.inpainting_fill == 2:
|
ctmp = ctmp[0]
|
||||||
self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
|
|
||||||
elif self.inpainting_fill == 3:
|
|
||||||
self.init_latent = self.init_latent * self.mask
|
|
||||||
|
|
||||||
if self.image_mask is not None:
|
|
||||||
conditioning_mask = np.array(self.image_mask.convert("L"))
|
|
||||||
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
|
|
||||||
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
|
|
||||||
|
|
||||||
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
|
|
||||||
conditioning_mask = torch.round(conditioning_mask)
|
|
||||||
else:
|
|
||||||
conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
|
|
||||||
|
|
||||||
# Create another latent image, this time with a masked version of the original input.
|
|
||||||
conditioning_mask = conditioning_mask.to(image.device)
|
|
||||||
conditioning_image = image * (1.0 - conditioning_mask)
|
|
||||||
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
|
|
||||||
|
|
||||||
# Create the concatenated conditioning tensor to be fed to `c_concat`
|
|
||||||
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
|
|
||||||
conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
|
|
||||||
self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
|
|
||||||
self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
|
|
||||||
|
|
||||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
|
||||||
x = create_random_tensors([opctmp[0]
|
|
||||||
cbs = ctmp.shape[0]
|
cbs = ctmp.shape[0]
|
||||||
if cbs != batch_size:
|
if cbs != batch_size:
|
||||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
@ -106,7 +78,6 @@ def sample(
|
||||||
)
|
)
|
||||||
return samples, intermediates
|
return samples, intermediates
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
|
|
Loading…
Reference in a new issue