Apply hijacks in ddpm_edit for upcast sampling

To avoid import errors, ddpm_edit hijacks are done after an instruct pix2pix model is loaded.
This commit is contained in:
brkirch 2023-02-07 00:05:54 -05:00
parent 4738486d8f
commit 2016733814
2 changed files with 14 additions and 0 deletions

View file

@ -104,6 +104,9 @@ class StableDiffusionModelHijack:
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self) m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
if m.cond_stage_key == "edit":
sd_hijack_unet.hijack_ddpm_edit()
self.optimization_method = apply_optimizations() self.optimization_method = apply_optimizations()
self.clip = m.cond_stage_model self.clip = m.cond_stage_model

View file

@ -44,6 +44,7 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
with devices.autocast(): with devices.autocast():
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float() return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
class GELUHijack(torch.nn.GELU, torch.nn.Module): class GELUHijack(torch.nn.GELU, torch.nn.Module):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
torch.nn.GELU.__init__(self, *args, **kwargs) torch.nn.GELU.__init__(self, *args, **kwargs)
@ -53,6 +54,16 @@ class GELUHijack(torch.nn.GELU, torch.nn.Module):
else: else:
return torch.nn.GELU.forward(self, x) return torch.nn.GELU.forward(self, x)
ddpm_edit_hijack = None
def hijack_ddpm_edit():
global ddpm_edit_hijack
if not ddpm_edit_hijack:
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)