Add a check and explanation for tensor with all NaNs.

This commit is contained in:
AUTOMATIC 2023-01-16 22:59:46 +03:00
parent 52f6e94338
commit 9991967f40
3 changed files with 33 additions and 0 deletions

View file

@ -106,6 +106,33 @@ def autocast(disable=False):
return torch.autocast("cuda") return torch.autocast("cuda")
class NansException(Exception):
pass
def test_for_nans(x, where):
from modules import shared
if not torch.all(torch.isnan(x)).item():
return
if where == "unet":
message = "A tensor with all NaNs was produced in Unet."
if not shared.cmd_opts.no_half:
message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try using --no-half commandline argument to fix this."
elif where == "vae":
message = "A tensor with all NaNs was produced in VAE."
if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae:
message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this."
else:
message = "A tensor with all NaNs was produced."
raise NansException(message)
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383 # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
orig_tensor_to = torch.Tensor.to orig_tensor_to = torch.Tensor.to
def tensor_to_fix(self, *args, **kwargs): def tensor_to_fix(self, *args, **kwargs):
@ -156,3 +183,4 @@ if has_mps():
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) ) torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
orig_narrow = torch.narrow orig_narrow = torch.narrow
torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() ) torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() )

View file

@ -608,6 +608,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts) samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))] x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
for x in x_samples_ddim:
devices.test_for_nans(x, "vae")
x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

View file

@ -351,6 +351,8 @@ class CFGDenoiser(torch.nn.Module):
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
devices.test_for_nans(x_out, "unet")
if opts.live_preview_content == "Prompt": if opts.live_preview_content == "Prompt":
store_latent(x_out[0:uncond.shape[0]]) store_latent(x_out[0:uncond.shape[0]])
elif opts.live_preview_content == "Negative prompt": elif opts.live_preview_content == "Negative prompt":