Merge pull request #4233 from thesved/patch-1
Make DDIM and PLMS work on Mac OS
This commit is contained in:
commit
bb2e2c82ce
1 changed files with 18 additions and 1 deletions
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
import modules.devices as devices
|
||||
|
||||
from einops import repeat
|
||||
from omegaconf import ListConfig
|
||||
|
@ -316,6 +317,20 @@ class LatentInpaintDiffusion(LatentDiffusion):
|
|||
self.concat_keys = concat_keys
|
||||
|
||||
|
||||
# =================================================================================================
|
||||
# Fix register buffer bug for Mac OS, Viktor Tabori, viktor.doklist.com/start-here
|
||||
# =================================================================================================
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
optimal_type = devices.get_optimal_device()
|
||||
if attr.device != optimal_type:
|
||||
if getattr(torch, 'has_mps', False):
|
||||
attr = attr.to(device="mps", dtype=torch.float32)
|
||||
else:
|
||||
attr = attr.to(optimal_type)
|
||||
setattr(self, name, attr)
|
||||
|
||||
|
||||
def should_hijack_inpainting(checkpoint_info):
|
||||
return str(checkpoint_info.filename).endswith("inpainting.ckpt") and not checkpoint_info.config.endswith("inpainting.yaml")
|
||||
|
||||
|
@ -326,6 +341,8 @@ def do_inpainting_hijack():
|
|||
|
||||
ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
|
||||
ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
|
||||
ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
|
||||
|
||||
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
||||
ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
|
||||
ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
|
||||
|
|
Loading…
Reference in a new issue