Make DDIM and PLMS work on Mac OS
Fix register_buffer error on Mac OS
This commit is contained in:
parent
c2465f67db
commit
86b7fc6e5e
1 changed files with 18 additions and 1 deletions
|
@ -1,4 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
|
import modules.devices as devices
|
||||||
|
|
||||||
from einops import repeat
|
from einops import repeat
|
||||||
from omegaconf import ListConfig
|
from omegaconf import ListConfig
|
||||||
|
@ -314,6 +315,20 @@ class LatentInpaintDiffusion(LatentDiffusion):
|
||||||
self.masked_image_key = masked_image_key
|
self.masked_image_key = masked_image_key
|
||||||
assert self.masked_image_key in concat_keys
|
assert self.masked_image_key in concat_keys
|
||||||
self.concat_keys = concat_keys
|
self.concat_keys = concat_keys
|
||||||
|
|
||||||
|
|
||||||
|
# =================================================================================================
|
||||||
|
# Fix register buffer bug for Mac OS, Viktor Tabori, viktor.doklist.com/start-here
|
||||||
|
# =================================================================================================
|
||||||
|
def register_buffer(self, name, attr):
|
||||||
|
if type(attr) == torch.Tensor:
|
||||||
|
optimal_type = devices.get_optimal_device()
|
||||||
|
if attr.device != optimal_type:
|
||||||
|
if getattr(torch, 'has_mps', False):
|
||||||
|
attr = attr.to(device="mps", dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
attr = attr.to(optimal_type)
|
||||||
|
setattr(self, name, attr)
|
||||||
|
|
||||||
|
|
||||||
def should_hijack_inpainting(checkpoint_info):
|
def should_hijack_inpainting(checkpoint_info):
|
||||||
|
@ -326,6 +341,8 @@ def do_inpainting_hijack():
|
||||||
|
|
||||||
ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
|
ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
|
||||||
ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
|
ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
|
||||||
|
ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
|
||||||
|
|
||||||
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
||||||
ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
|
ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
|
||||||
|
ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
|
||||||
|
|
Loading…
Reference in a new issue