Merge pull request #8958 from MrCheeze/variations-model
Add support for the unclip (Variations) models, unclip-h and unclip-l
This commit is contained in:
commit
f1db987e6a
8 changed files with 88 additions and 30 deletions
|
@ -235,7 +235,7 @@ def prepare_environment():
|
||||||
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
|
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
|
||||||
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
||||||
|
|
||||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "47b6b607fdd31875c9279cd2f4f16b92e4ea958e")
|
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
||||||
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
|
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
|
||||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "5b3af030dd83e0297272d861c19477735d0317ec")
|
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "5b3af030dd83e0297272d861c19477735d0317ec")
|
||||||
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
||||||
|
|
BIN
models/karlo/ViT-L-14_stats.th
Normal file
BIN
models/karlo/ViT-L-14_stats.th
Normal file
Binary file not shown.
|
@ -55,12 +55,12 @@ def setup_for_low_vram(sd_model, use_medvram):
|
||||||
if hasattr(sd_model.cond_stage_model, 'model'):
|
if hasattr(sd_model.cond_stage_model, 'model'):
|
||||||
sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
|
sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
|
||||||
|
|
||||||
# remove four big modules, cond, first_stage, depth (if applicable), and unet from the model and then
|
# remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then
|
||||||
# send the model to GPU. Then put modules back. the modules will be in CPU.
|
# send the model to GPU. Then put modules back. the modules will be in CPU.
|
||||||
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), sd_model.model
|
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), getattr(sd_model, 'embedder', None), sd_model.model
|
||||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = None, None, None, None
|
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None
|
||||||
sd_model.to(devices.device)
|
sd_model.to(devices.device)
|
||||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = stored
|
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored
|
||||||
|
|
||||||
# register hooks for those the first three models
|
# register hooks for those the first three models
|
||||||
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
|
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
|
||||||
|
@ -69,6 +69,8 @@ def setup_for_low_vram(sd_model, use_medvram):
|
||||||
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
|
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
|
||||||
if sd_model.depth_model:
|
if sd_model.depth_model:
|
||||||
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
|
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
|
||||||
|
if sd_model.embedder:
|
||||||
|
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
|
||||||
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
||||||
|
|
||||||
if hasattr(sd_model.cond_stage_model, 'model'):
|
if hasattr(sd_model.cond_stage_model, 'model'):
|
||||||
|
|
|
@ -78,11 +78,7 @@ def apply_overlay(image, paste_loc, index, overlays):
|
||||||
|
|
||||||
|
|
||||||
def txt2img_image_conditioning(sd_model, x, width, height):
|
def txt2img_image_conditioning(sd_model, x, width, height):
|
||||||
if sd_model.model.conditioning_key not in {'hybrid', 'concat'}:
|
if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models
|
||||||
# Dummy zero conditioning if we're not using inpainting model.
|
|
||||||
# Still takes up a bit of memory, but no encoder call.
|
|
||||||
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
|
||||||
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
|
|
||||||
|
|
||||||
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
||||||
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
||||||
|
@ -94,6 +90,16 @@ def txt2img_image_conditioning(sd_model, x, width, height):
|
||||||
|
|
||||||
return image_conditioning
|
return image_conditioning
|
||||||
|
|
||||||
|
elif sd_model.model.conditioning_key == "crossattn-adm": # UnCLIP models
|
||||||
|
|
||||||
|
return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Dummy zero conditioning if we're not using inpainting or unclip models.
|
||||||
|
# Still takes up a bit of memory, but no encoder call.
|
||||||
|
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
||||||
|
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionProcessing:
|
class StableDiffusionProcessing:
|
||||||
"""
|
"""
|
||||||
|
@ -190,6 +196,14 @@ class StableDiffusionProcessing:
|
||||||
|
|
||||||
return conditioning_image
|
return conditioning_image
|
||||||
|
|
||||||
|
def unclip_image_conditioning(self, source_image):
|
||||||
|
c_adm = self.sd_model.embedder(source_image)
|
||||||
|
if self.sd_model.noise_augmentor is not None:
|
||||||
|
noise_level = 0 # TODO: Allow other noise levels?
|
||||||
|
c_adm, noise_level_emb = self.sd_model.noise_augmentor(c_adm, noise_level=repeat(torch.tensor([noise_level]).to(c_adm.device), '1 -> b', b=c_adm.shape[0]))
|
||||||
|
c_adm = torch.cat((c_adm, noise_level_emb), 1)
|
||||||
|
return c_adm
|
||||||
|
|
||||||
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
|
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
|
||||||
self.is_using_inpainting_conditioning = True
|
self.is_using_inpainting_conditioning = True
|
||||||
|
|
||||||
|
@ -241,6 +255,9 @@ class StableDiffusionProcessing:
|
||||||
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
||||||
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
|
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
|
||||||
|
|
||||||
|
if self.sampler.conditioning_key == "crossattn-adm":
|
||||||
|
return self.unclip_image_conditioning(source_image)
|
||||||
|
|
||||||
# Dummy zero conditioning if we're not using inpainting or depth model.
|
# Dummy zero conditioning if we're not using inpainting or depth model.
|
||||||
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
|
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
|
||||||
|
|
||||||
|
|
|
@ -383,6 +383,14 @@ def repair_config(sd_config):
|
||||||
elif shared.cmd_opts.upcast_sampling:
|
elif shared.cmd_opts.upcast_sampling:
|
||||||
sd_config.model.params.unet_config.params.use_fp16 = True
|
sd_config.model.params.unet_config.params.use_fp16 = True
|
||||||
|
|
||||||
|
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
|
||||||
|
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
|
||||||
|
|
||||||
|
# For UnCLIP-L, override the hardcoded karlo directory
|
||||||
|
if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"):
|
||||||
|
karlo_path = os.path.join(paths.models_path, 'karlo')
|
||||||
|
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
|
||||||
|
|
||||||
|
|
||||||
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
||||||
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
||||||
|
|
|
@ -14,6 +14,8 @@ config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
|
||||||
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
||||||
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
||||||
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
||||||
|
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
|
||||||
|
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
|
||||||
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
|
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
|
||||||
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
|
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
|
||||||
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
|
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
|
||||||
|
@ -65,9 +67,14 @@ def is_using_v_parameterization_for_sd2(state_dict):
|
||||||
def guess_model_config_from_state_dict(sd, filename):
|
def guess_model_config_from_state_dict(sd, filename):
|
||||||
sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
|
sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
|
||||||
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
|
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
|
||||||
|
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
|
||||||
|
|
||||||
if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
||||||
return config_depth_model
|
return config_depth_model
|
||||||
|
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
|
||||||
|
return config_unclip
|
||||||
|
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024:
|
||||||
|
return config_unopenclip
|
||||||
|
|
||||||
if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
|
if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
|
||||||
if diffusion_model_input.shape[1] == 9:
|
if diffusion_model_input.shape[1] == 9:
|
||||||
|
|
|
@ -70,7 +70,12 @@ class VanillaStableDiffusionSampler:
|
||||||
|
|
||||||
# Have to unwrap the inpainting conditioning here to perform pre-processing
|
# Have to unwrap the inpainting conditioning here to perform pre-processing
|
||||||
image_conditioning = None
|
image_conditioning = None
|
||||||
|
uc_image_conditioning = None
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
|
if self.conditioning_key == "crossattn-adm":
|
||||||
|
image_conditioning = cond["c_adm"]
|
||||||
|
uc_image_conditioning = unconditional_conditioning["c_adm"]
|
||||||
|
else:
|
||||||
image_conditioning = cond["c_concat"][0]
|
image_conditioning = cond["c_concat"][0]
|
||||||
cond = cond["c_crossattn"][0]
|
cond = cond["c_crossattn"][0]
|
||||||
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
|
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
|
||||||
|
@ -98,6 +103,10 @@ class VanillaStableDiffusionSampler:
|
||||||
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
|
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
|
||||||
# Note that they need to be lists because it just concatenates them later.
|
# Note that they need to be lists because it just concatenates them later.
|
||||||
if image_conditioning is not None:
|
if image_conditioning is not None:
|
||||||
|
if self.conditioning_key == "crossattn-adm":
|
||||||
|
cond = {"c_adm": image_conditioning, "c_crossattn": [cond]}
|
||||||
|
unconditional_conditioning = {"c_adm": uc_image_conditioning, "c_crossattn": [unconditional_conditioning]}
|
||||||
|
else:
|
||||||
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
|
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
|
||||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||||
|
|
||||||
|
@ -176,6 +185,10 @@ class VanillaStableDiffusionSampler:
|
||||||
|
|
||||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||||
if image_conditioning is not None:
|
if image_conditioning is not None:
|
||||||
|
if self.conditioning_key == "crossattn-adm":
|
||||||
|
conditioning = {"c_adm": image_conditioning, "c_crossattn": [conditioning]}
|
||||||
|
unconditional_conditioning = {"c_adm": torch.zeros_like(image_conditioning), "c_crossattn": [unconditional_conditioning]}
|
||||||
|
else:
|
||||||
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
||||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||||
|
|
||||||
|
@ -195,6 +208,10 @@ class VanillaStableDiffusionSampler:
|
||||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||||
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
|
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
|
||||||
if image_conditioning is not None:
|
if image_conditioning is not None:
|
||||||
|
if self.conditioning_key == "crossattn-adm":
|
||||||
|
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_adm": image_conditioning}
|
||||||
|
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_adm": torch.zeros_like(image_conditioning)}
|
||||||
|
else:
|
||||||
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
|
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
|
||||||
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
|
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
|
||||||
|
|
||||||
|
|
|
@ -92,14 +92,21 @@ class CFGDenoiser(torch.nn.Module):
|
||||||
batch_size = len(conds_list)
|
batch_size = len(conds_list)
|
||||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||||
|
|
||||||
|
if shared.sd_model.model.conditioning_key == "crossattn-adm":
|
||||||
|
image_uncond = torch.zeros_like(image_cond)
|
||||||
|
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
|
||||||
|
else:
|
||||||
|
image_uncond = image_cond
|
||||||
|
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
|
||||||
|
|
||||||
if not is_edit_model:
|
if not is_edit_model:
|
||||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
|
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
|
||||||
else:
|
else:
|
||||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
|
||||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
||||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)])
|
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
|
||||||
|
|
||||||
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
|
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
|
||||||
cfg_denoiser_callback(denoiser_params)
|
cfg_denoiser_callback(denoiser_params)
|
||||||
|
@ -116,13 +123,13 @@ class CFGDenoiser(torch.nn.Module):
|
||||||
cond_in = torch.cat([tensor, uncond, uncond])
|
cond_in = torch.cat([tensor, uncond, uncond])
|
||||||
|
|
||||||
if shared.batch_cond_uncond:
|
if shared.batch_cond_uncond:
|
||||||
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
|
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in))
|
||||||
else:
|
else:
|
||||||
x_out = torch.zeros_like(x_in)
|
x_out = torch.zeros_like(x_in)
|
||||||
for batch_offset in range(0, x_out.shape[0], batch_size):
|
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||||
a = batch_offset
|
a = batch_offset
|
||||||
b = a + batch_size
|
b = a + batch_size
|
||||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict([cond_in[a:b]], image_cond_in[a:b]))
|
||||||
else:
|
else:
|
||||||
x_out = torch.zeros_like(x_in)
|
x_out = torch.zeros_like(x_in)
|
||||||
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
||||||
|
@ -135,9 +142,9 @@ class CFGDenoiser(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
c_crossattn = torch.cat([tensor[a:b]], uncond)
|
c_crossattn = torch.cat([tensor[a:b]], uncond)
|
||||||
|
|
||||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": c_crossattn, "c_concat": [image_cond_in[a:b]]})
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
|
||||||
|
|
||||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
|
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
|
||||||
|
|
||||||
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
|
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
|
||||||
cfg_denoised_callback(denoised_params)
|
cfg_denoised_callback(denoised_params)
|
||||||
|
|
Loading…
Reference in a new issue