aboutsummaryrefslogtreecommitdiff
path: root/modules/devices.py
diff options
context:
space:
mode:
authorArturo Albacete <aalbacetef@gmail.com>2024-01-20 21:15:57 +0100
committerArturo Albacete <aalbacetef@gmail.com>2024-01-20 21:15:57 +0100
commitd0b65e148bdc3d35f3f8ee38310ba55152ab4880 (patch)
tree1485c4f2df143b4c3d7e958a232dc9ddb2205280 /modules/devices.py
parent315e40a49c32438551ed6b66138acdf664ecdbc8 (diff)
parentf939bce845ae07536b1c920618743af83e0b01ec (diff)
merge dev
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/modules/devices.py b/modules/devices.py
index 0321d12c..dfffaf24 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -164,7 +164,11 @@ def manual_cast_forward(target_dtype):
@contextlib.contextmanager
def manual_cast(target_dtype):
+ applied = False
for module_type in patch_module_list:
+ if hasattr(module_type, "org_forward"):
+ continue
+ applied = True
org_forward = module_type.forward
if module_type == torch.nn.MultiheadAttention and has_xpu():
module_type.forward = manual_cast_forward(torch.float32)
@@ -174,8 +178,11 @@ def manual_cast(target_dtype):
try:
yield None
finally:
- for module_type in patch_module_list:
- module_type.forward = module_type.org_forward
+ if applied:
+ for module_type in patch_module_list:
+ if hasattr(module_type, "org_forward"):
+ module_type.forward = module_type.org_forward
+ delattr(module_type, "org_forward")
def autocast(disable=False):