aboutsummaryrefslogtreecommitdiff
path: root/modules/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py31
1 files changed, 31 insertions, 0 deletions
diff --git a/modules/devices.py b/modules/devices.py
index ac3ae0c9..524ec7af 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -106,6 +106,36 @@ def autocast(disable=False):
return torch.autocast("cuda")
+class NansException(Exception):
+ pass
+
+
+def test_for_nans(x, where):
+ from modules import shared
+
+ if shared.cmd_opts.disable_nan_check:
+ return
+
+ if not torch.all(torch.isnan(x)).item():
+ return
+
+ if where == "unet":
+ message = "A tensor with all NaNs was produced in Unet."
+
+ if not shared.cmd_opts.no_half:
+ message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try using --no-half commandline argument to fix this."
+
+ elif where == "vae":
+ message = "A tensor with all NaNs was produced in VAE."
+
+ if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae:
+ message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this."
+ else:
+ message = "A tensor with all NaNs was produced."
+
+ raise NansException(message)
+
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
orig_tensor_to = torch.Tensor.to
def tensor_to_fix(self, *args, **kwargs):
@@ -159,3 +189,4 @@ if has_mps():
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
orig_narrow = torch.narrow
torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() )
+