aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/devices.py42
1 files changed, 33 insertions, 9 deletions
diff --git a/modules/devices.py b/modules/devices.py
index ad36f656..9e1f207c 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -132,6 +132,21 @@ patch_module_list = [
]
+def cast_output(result):
+ if isinstance(result, tuple):
+ result = tuple(i.to(dtype_inference) if isinstance(i, torch.Tensor) else i for i in result)
+ elif isinstance(result, torch.Tensor):
+ result = result.to(dtype_inference)
+ return result
+
+
+def autocast_with_cast_output(self, *args, **kwargs):
+ result = self.org_forward(*args, **kwargs)
+ if dtype_inference != dtype:
+ result = cast_output(result)
+ return result
+
+
def manual_cast_forward(target_dtype):
def forward_wrapper(self, *args, **kwargs):
if any(
@@ -149,15 +164,7 @@ def manual_cast_forward(target_dtype):
self.to(org_dtype)
if target_dtype != dtype_inference:
- if isinstance(result, tuple):
- result = tuple(
- i.to(dtype_inference)
- if isinstance(i, torch.Tensor)
- else i
- for i in result
- )
- elif isinstance(result, torch.Tensor):
- result = result.to(dtype_inference)
+ result = cast_output(result)
return result
return forward_wrapper
@@ -178,6 +185,20 @@ def manual_cast(target_dtype):
module_type.forward = module_type.org_forward
+@contextlib.contextmanager
+def precision_full_with_autocast(autocast_ctx):
+ for module_type in patch_module_list:
+ org_forward = module_type.forward
+ module_type.forward = autocast_with_cast_output
+ module_type.org_forward = org_forward
+ try:
+ with autocast_ctx:
+ yield None
+ finally:
+ for module_type in patch_module_list:
+ module_type.forward = module_type.org_forward
+
+
def autocast(disable=False):
if disable:
return contextlib.nullcontext()
@@ -191,6 +212,9 @@ def autocast(disable=False):
if has_xpu() or has_mps() or cuda_no_autocast():
return manual_cast(dtype_inference)
+ if dtype_inference == torch.float32 and dtype != torch.float32:
+ return precision_full_with_autocast(torch.autocast("cuda"))
+
return torch.autocast("cuda")