move DDIM/PLMS fix for OSX out of the file with inpainting code.
This commit is contained in:
parent
bb2e2c82ce
commit
7ba3923d5b
2 changed files with 24 additions and 17 deletions
|
@ -14,6 +14,8 @@ from modules.sd_hijack_optimizations import invokeAI_mps_available
|
||||||
|
|
||||||
import ldm.modules.attention
|
import ldm.modules.attention
|
||||||
import ldm.modules.diffusionmodules.model
|
import ldm.modules.diffusionmodules.model
|
||||||
|
import ldm.models.diffusion.ddim
|
||||||
|
import ldm.models.diffusion.plms
|
||||||
|
|
||||||
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
||||||
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
||||||
|
@ -406,3 +408,24 @@ def add_circular_option_to_conv_2d():
|
||||||
|
|
||||||
|
|
||||||
model_hijack = StableDiffusionModelHijack()
|
model_hijack = StableDiffusionModelHijack()
|
||||||
|
|
||||||
|
|
||||||
|
def register_buffer(self, name, attr):
|
||||||
|
"""
|
||||||
|
Fix register buffer bug for Mac OS.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if type(attr) == torch.Tensor:
|
||||||
|
if attr.device != devices.device:
|
||||||
|
|
||||||
|
# would this not break cuda when torch adds has_mps() to main version?
|
||||||
|
if getattr(torch, 'has_mps', False):
|
||||||
|
attr = attr.to(device="mps", dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
attr = attr.to(devices.device)
|
||||||
|
|
||||||
|
setattr(self, name, attr)
|
||||||
|
|
||||||
|
|
||||||
|
ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
|
||||||
|
ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import torch
|
import torch
|
||||||
import modules.devices as devices
|
|
||||||
|
|
||||||
from einops import repeat
|
from einops import repeat
|
||||||
from omegaconf import ListConfig
|
from omegaconf import ListConfig
|
||||||
|
@ -317,20 +316,6 @@ class LatentInpaintDiffusion(LatentDiffusion):
|
||||||
self.concat_keys = concat_keys
|
self.concat_keys = concat_keys
|
||||||
|
|
||||||
|
|
||||||
# =================================================================================================
|
|
||||||
# Fix register buffer bug for Mac OS, Viktor Tabori, viktor.doklist.com/start-here
|
|
||||||
# =================================================================================================
|
|
||||||
def register_buffer(self, name, attr):
|
|
||||||
if type(attr) == torch.Tensor:
|
|
||||||
optimal_type = devices.get_optimal_device()
|
|
||||||
if attr.device != optimal_type:
|
|
||||||
if getattr(torch, 'has_mps', False):
|
|
||||||
attr = attr.to(device="mps", dtype=torch.float32)
|
|
||||||
else:
|
|
||||||
attr = attr.to(optimal_type)
|
|
||||||
setattr(self, name, attr)
|
|
||||||
|
|
||||||
|
|
||||||
def should_hijack_inpainting(checkpoint_info):
|
def should_hijack_inpainting(checkpoint_info):
|
||||||
return str(checkpoint_info.filename).endswith("inpainting.ckpt") and not checkpoint_info.config.endswith("inpainting.yaml")
|
return str(checkpoint_info.filename).endswith("inpainting.ckpt") and not checkpoint_info.config.endswith("inpainting.yaml")
|
||||||
|
|
||||||
|
@ -341,8 +326,7 @@ def do_inpainting_hijack():
|
||||||
|
|
||||||
ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
|
ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
|
||||||
ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
|
ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
|
||||||
ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
|
|
||||||
|
|
||||||
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
||||||
ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
|
ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
|
||||||
ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
|
|
||||||
|
|
Loading…
Reference in a new issue