aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/devices.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/modules/devices.py b/modules/devices.py
index c05f2b35..d7c905c2 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -121,6 +121,8 @@ def manual_autocast():
def manual_cast_forward(self, *args, **kwargs):
org_dtype = next(self.parameters()).dtype
self.to(dtype)
+ args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
+ kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
result = self.org_forward(*args, **kwargs)
self.to(org_dtype)
return result
@@ -136,7 +138,6 @@ def manual_autocast():
def autocast(disable=False):
- print(fp8, dtype, shared.cmd_opts.precision, device)
if disable:
return contextlib.nullcontext()
@@ -146,6 +147,9 @@ def autocast(disable=False):
if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()):
return manual_autocast()
+ if has_mps() and shared.cmd_opts.precision != "full":
+ return manual_autocast()
+
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
return contextlib.nullcontext()