aboutsummaryrefslogtreecommitdiff
path: root/modules/devices.py
diff options
context:
space:
mode:
authorKohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>2023-11-19 15:56:31 +0800
committerKohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>2023-11-19 15:56:31 +0800
commit043d2edcf6a543f236f1f3cb70ac72e7b3b357b6 (patch)
treee13f5bacdbbe7c8940194b9c45cbda025da662a8 /modules/devices.py
parentf383af2729ec2d1969200218577ab19dd78f7d48 (diff)
Better naming
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/modules/devices.py b/modules/devices.py
index 03e7bdb7..c19a7f40 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -128,7 +128,7 @@ def manual_cast_forward(self, *args, **kwargs):
@contextlib.contextmanager
-def manual_autocast():
+def manual_cast():
for module_type in patch_module_list:
org_forward = module_type.forward
module_type.forward = manual_cast_forward
@@ -148,10 +148,10 @@ def autocast(disable=False):
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()):
- return manual_autocast()
+ return manual_cast()
if has_mps() and shared.cmd_opts.precision != "full":
- return manual_autocast()
+ return manual_cast()
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
return contextlib.nullcontext()