use selected device instead of always cuda for UniPC sampler

This commit is contained in:
AUTOMATIC 2023-03-11 11:56:05 +03:00
parent a11ce2b96c
commit f261a4a53c

View file

@ -3,7 +3,8 @@
import torch import torch
from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
from modules import shared from modules import shared, devices
class UniPCSampler(object): class UniPCSampler(object):
def __init__(self, model, **kwargs): def __init__(self, model, **kwargs):
@ -16,8 +17,8 @@ class UniPCSampler(object):
def register_buffer(self, name, attr): def register_buffer(self, name, attr):
if type(attr) == torch.Tensor: if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"): if attr.device != devices.device:
attr = attr.to(torch.device("cuda")) attr = attr.to(devices.device)
setattr(self, name, attr) setattr(self, name, attr)
def set_hooks(self, before_sample, after_sample, after_update): def set_hooks(self, before_sample, after_sample, after_update):