Merge pull request #5194 from brkirch/autocast-and-mps-randn-fixes
Use devices.autocast() and fix MPS randn issues
This commit is contained in:
commit
a2feaa95fc
8 changed files with 29 additions and 31 deletions
|
@ -66,24 +66,15 @@ dtype_vae = torch.float16
|
||||||
|
|
||||||
|
|
||||||
def randn(seed, shape):
|
def randn(seed, shape):
|
||||||
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
|
||||||
if device.type == 'mps':
|
|
||||||
generator = torch.Generator(device=cpu)
|
|
||||||
generator.manual_seed(seed)
|
|
||||||
noise = torch.randn(shape, generator=generator, device=cpu).to(device)
|
|
||||||
return noise
|
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
if device.type == 'mps':
|
||||||
|
return torch.randn(shape, device=cpu).to(device)
|
||||||
return torch.randn(shape, device=device)
|
return torch.randn(shape, device=device)
|
||||||
|
|
||||||
|
|
||||||
def randn_without_seed(shape):
|
def randn_without_seed(shape):
|
||||||
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
|
||||||
if device.type == 'mps':
|
if device.type == 'mps':
|
||||||
generator = torch.Generator(device=cpu)
|
return torch.randn(shape, device=cpu).to(device)
|
||||||
noise = torch.randn(shape, generator=generator, device=cpu).to(device)
|
|
||||||
return noise
|
|
||||||
|
|
||||||
return torch.randn(shape, device=device)
|
return torch.randn(shape, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -495,7 +495,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
|
||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
break
|
break
|
||||||
|
|
||||||
with torch.autocast("cuda"):
|
with devices.autocast():
|
||||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||||
if tag_drop_out != 0 or shuffle_tags:
|
if tag_drop_out != 0 or shuffle_tags:
|
||||||
shared.sd_model.cond_stage_model.to(devices.device)
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
|
|
|
@ -148,8 +148,7 @@ class InterrogateModels:
|
||||||
|
|
||||||
clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
|
clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
|
||||||
|
|
||||||
precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
|
with torch.no_grad(), devices.autocast():
|
||||||
with torch.no_grad(), precision_scope("cuda"):
|
|
||||||
image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
|
image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
|
||||||
|
|
||||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||||
|
|
|
@ -183,11 +183,7 @@ def register_buffer(self, name, attr):
|
||||||
|
|
||||||
if type(attr) == torch.Tensor:
|
if type(attr) == torch.Tensor:
|
||||||
if attr.device != devices.device:
|
if attr.device != devices.device:
|
||||||
|
attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
|
||||||
if devices.has_mps():
|
|
||||||
attr = attr.to(device="mps", dtype=torch.float32)
|
|
||||||
else:
|
|
||||||
attr = attr.to(devices.device)
|
|
||||||
|
|
||||||
setattr(self, name, attr)
|
setattr(self, name, attr)
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ import tqdm
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import inspect
|
import inspect
|
||||||
import k_diffusion.sampling
|
import k_diffusion.sampling
|
||||||
|
import torchsde._brownian.brownian_interval
|
||||||
import ldm.models.diffusion.ddim
|
import ldm.models.diffusion.ddim
|
||||||
import ldm.models.diffusion.plms
|
import ldm.models.diffusion.plms
|
||||||
from modules import prompt_parser, devices, processing, images
|
from modules import prompt_parser, devices, processing, images
|
||||||
|
@ -364,7 +365,23 @@ class TorchHijack:
|
||||||
if noise.shape == x.shape:
|
if noise.shape == x.shape:
|
||||||
return noise
|
return noise
|
||||||
|
|
||||||
return torch.randn_like(x)
|
if x.device.type == 'mps':
|
||||||
|
return torch.randn_like(x, device=devices.cpu).to(x.device)
|
||||||
|
else:
|
||||||
|
return torch.randn_like(x)
|
||||||
|
|
||||||
|
|
||||||
|
# MPS fix for randn in torchsde
|
||||||
|
def torchsde_randn(size, dtype, device, seed):
|
||||||
|
if device.type == 'mps':
|
||||||
|
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
|
||||||
|
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
|
||||||
|
else:
|
||||||
|
generator = torch.Generator(device).manual_seed(int(seed))
|
||||||
|
return torch.randn(size, dtype=dtype, device=device, generator=generator)
|
||||||
|
|
||||||
|
|
||||||
|
torchsde._brownian.brownian_interval._randn = torchsde_randn
|
||||||
|
|
||||||
|
|
||||||
class KDiffusionSampler:
|
class KDiffusionSampler:
|
||||||
|
@ -415,8 +432,7 @@ class KDiffusionSampler:
|
||||||
self.model_wrap.step = 0
|
self.model_wrap.step = 0
|
||||||
self.eta = p.eta or opts.eta_ancestral
|
self.eta = p.eta or opts.eta_ancestral
|
||||||
|
|
||||||
if self.sampler_noises is not None:
|
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
||||||
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises)
|
|
||||||
|
|
||||||
extra_params_kwargs = {}
|
extra_params_kwargs = {}
|
||||||
for param_name in self.extra_params:
|
for param_name in self.extra_params:
|
||||||
|
|
|
@ -13,10 +13,6 @@ from modules.swinir_model_arch import SwinIR as net
|
||||||
from modules.swinir_model_arch_v2 import Swin2SR as net2
|
from modules.swinir_model_arch_v2 import Swin2SR as net2
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
|
|
||||||
precision_scope = (
|
|
||||||
torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class UpscalerSwinIR(Upscaler):
|
class UpscalerSwinIR(Upscaler):
|
||||||
def __init__(self, dirname):
|
def __init__(self, dirname):
|
||||||
|
@ -112,7 +108,7 @@ def upscale(
|
||||||
img = np.moveaxis(img, 2, 0) / 255
|
img = np.moveaxis(img, 2, 0) / 255
|
||||||
img = torch.from_numpy(img).float()
|
img = torch.from_numpy(img).float()
|
||||||
img = img.unsqueeze(0).to(devices.device_swinir)
|
img = img.unsqueeze(0).to(devices.device_swinir)
|
||||||
with torch.no_grad(), precision_scope("cuda"):
|
with torch.no_grad(), devices.autocast():
|
||||||
_, _, h_old, w_old = img.size()
|
_, _, h_old, w_old = img.size()
|
||||||
h_pad = (h_old // window_size + 1) * window_size - h_old
|
h_pad = (h_old // window_size + 1) * window_size - h_old
|
||||||
w_pad = (w_old // window_size + 1) * window_size - w_old
|
w_pad = (w_old // window_size + 1) * window_size - w_old
|
||||||
|
|
|
@ -82,7 +82,7 @@ class PersonalizedBase(Dataset):
|
||||||
torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
|
torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
|
||||||
latent_sample = None
|
latent_sample = None
|
||||||
|
|
||||||
with torch.autocast("cuda"):
|
with devices.autocast():
|
||||||
latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
|
latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
|
||||||
|
|
||||||
if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
|
if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
|
||||||
|
@ -101,7 +101,7 @@ class PersonalizedBase(Dataset):
|
||||||
entry.cond_text = self.create_text(filename_text)
|
entry.cond_text = self.create_text(filename_text)
|
||||||
|
|
||||||
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
|
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
|
||||||
with torch.autocast("cuda"):
|
with devices.autocast():
|
||||||
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
|
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
|
||||||
|
|
||||||
self.dataset.append(entry)
|
self.dataset.append(entry)
|
||||||
|
|
|
@ -316,7 +316,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
|
||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
break
|
break
|
||||||
|
|
||||||
with torch.autocast("cuda"):
|
with devices.autocast():
|
||||||
# c = stack_conds(batch.cond).to(devices.device)
|
# c = stack_conds(batch.cond).to(devices.device)
|
||||||
# mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
|
# mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
|
||||||
# print(mask)
|
# print(mask)
|
||||||
|
|
Loading…
Reference in a new issue