merge CFGDenoiserEdit and CFGDenoiser into single object
This commit is contained in:
parent
127bfb6c41
commit
72dd5785d9
1 changed files with 37 additions and 96 deletions
|
@ -41,90 +41,6 @@ sampler_extra_params = {
|
|||
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
}
|
||||
|
||||
class CFGDenoiserEdit(torch.nn.Module):
|
||||
"""
|
||||
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
||||
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
|
||||
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
|
||||
negative prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.init_latent = None
|
||||
self.step = 0
|
||||
|
||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale, image_cfg_scale):
|
||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||
denoised = torch.clone(denoised_uncond)
|
||||
|
||||
for i, conds in enumerate(conds_list):
|
||||
for cond_index, weight in conds:
|
||||
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
|
||||
denoised[i] = out_uncond[cond_index] + cond_scale * (out_cond[cond_index] - out_img_cond[cond_index]) + image_cfg_scale * (out_img_cond[cond_index] - out_uncond[cond_index])
|
||||
|
||||
return denoised
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond, image_cfg_scale):
|
||||
if state.interrupted or state.skipped:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||
|
||||
batch_size = len(conds_list)
|
||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)])
|
||||
|
||||
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
|
||||
cfg_denoiser_callback(denoiser_params)
|
||||
x_in = denoiser_params.x
|
||||
image_cond_in = denoiser_params.image_cond
|
||||
sigma_in = denoiser_params.sigma
|
||||
|
||||
if tensor.shape[1] == uncond.shape[1]:
|
||||
cond_in = torch.cat([tensor, uncond, uncond])
|
||||
|
||||
if shared.batch_cond_uncond:
|
||||
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = a + batch_size
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
||||
for batch_offset in range(0, tensor.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = min(a + batch_size, tensor.shape[0])
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": torch.cat([tensor[a:b]], uncond) , "c_concat": [image_cond_in[a:b]]})
|
||||
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
|
||||
|
||||
devices.test_for_nans(x_out, "unet")
|
||||
|
||||
if opts.live_preview_content == "Prompt":
|
||||
sd_samplers_common.store_latent(x_out[0:uncond.shape[0]])
|
||||
elif opts.live_preview_content == "Negative prompt":
|
||||
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
|
||||
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale, image_cfg_scale)
|
||||
|
||||
if self.mask is not None:
|
||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||
|
||||
self.step += 1
|
||||
|
||||
return denoised
|
||||
|
||||
|
||||
class CFGDenoiser(torch.nn.Module):
|
||||
"""
|
||||
|
@ -141,6 +57,7 @@ class CFGDenoiser(torch.nn.Module):
|
|||
self.nmask = None
|
||||
self.init_latent = None
|
||||
self.step = 0
|
||||
self.image_cfg_scale = None
|
||||
|
||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||
|
@ -152,19 +69,36 @@ class CFGDenoiser(torch.nn.Module):
|
|||
|
||||
return denoised
|
||||
|
||||
def combine_denoised_for_edit_model(self, x_out, cond_scale):
|
||||
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
|
||||
denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
|
||||
|
||||
return denoised
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
|
||||
if state.interrupted or state.skipped:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
|
||||
# so is_edit_model is set to False to support AND composition.
|
||||
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
|
||||
|
||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||
|
||||
assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||
|
||||
batch_size = len(conds_list)
|
||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
|
||||
if not is_edit_model:
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
|
||||
else:
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)])
|
||||
|
||||
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
|
||||
cfg_denoiser_callback(denoiser_params)
|
||||
|
@ -173,7 +107,10 @@ class CFGDenoiser(torch.nn.Module):
|
|||
sigma_in = denoiser_params.sigma
|
||||
|
||||
if tensor.shape[1] == uncond.shape[1]:
|
||||
cond_in = torch.cat([tensor, uncond])
|
||||
if not is_edit_model:
|
||||
cond_in = torch.cat([tensor, uncond])
|
||||
else:
|
||||
cond_in = torch.cat([tensor, uncond, uncond])
|
||||
|
||||
if shared.batch_cond_uncond:
|
||||
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
|
||||
|
@ -189,7 +126,13 @@ class CFGDenoiser(torch.nn.Module):
|
|||
for batch_offset in range(0, tensor.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = min(a + batch_size, tensor.shape[0])
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
|
||||
|
||||
if not is_edit_model:
|
||||
c_crossattn = [tensor[a:b]]
|
||||
else:
|
||||
c_crossattn = torch.cat([tensor[a:b]], uncond)
|
||||
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": c_crossattn, "c_concat": [image_cond_in[a:b]]})
|
||||
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
|
||||
|
||||
|
@ -200,7 +143,10 @@ class CFGDenoiser(torch.nn.Module):
|
|||
elif opts.live_preview_content == "Negative prompt":
|
||||
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
|
||||
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||
if not is_edit_model:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||
else:
|
||||
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
|
||||
|
||||
if self.mask is not None:
|
||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||
|
@ -280,12 +226,10 @@ class KDiffusionSampler:
|
|||
return p.steps
|
||||
|
||||
def initialize(self, p):
|
||||
if shared.sd_model.cond_stage_key == "edit" and getattr(p, 'image_cfg_scale', None) != 1:
|
||||
self.model_wrap_cfg = CFGDenoiserEdit(self.model_wrap)
|
||||
|
||||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
self.model_wrap_cfg.step = 0
|
||||
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
||||
self.eta = p.eta if p.eta is not None else opts.eta_ancestral
|
||||
|
||||
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
||||
|
@ -355,9 +299,6 @@ class KDiffusionSampler:
|
|||
'cond_scale': p.cfg_scale,
|
||||
}
|
||||
|
||||
if hasattr(p, 'image_cfg_scale') and p.image_cfg_scale != 1 and p.image_cfg_scale != None:
|
||||
extra_args['image_cfg_scale'] = p.image_cfg_scale
|
||||
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
return samples
|
||||
|
|
Loading…
Reference in a new issue