aboutsummaryrefslogtreecommitdiff
path: root/modules/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/modules/devices.py b/modules/devices.py
index 03e7bdb7..c19a7f40 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -128,7 +128,7 @@ def manual_cast_forward(self, *args, **kwargs):
@contextlib.contextmanager
-def manual_autocast():
+def manual_cast():
for module_type in patch_module_list:
org_forward = module_type.forward
module_type.forward = manual_cast_forward
@@ -148,10 +148,10 @@ def autocast(disable=False):
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()):
- return manual_autocast()
+ return manual_cast()
if has_mps() and shared.cmd_opts.precision != "full":
- return manual_autocast()
+ return manual_cast()
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
return contextlib.nullcontext()