aboutsummaryrefslogtreecommitdiff
path: root/modules/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py25
1 files changed, 14 insertions, 11 deletions
diff --git a/modules/devices.py b/modules/devices.py
index d7c905c2..03e7bdb7 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -20,15 +20,15 @@ def cuda_no_autocast(device_id=None) -> bool:
if device_id is None:
device_id = get_cuda_device_id()
return (
- torch.cuda.get_device_capability(device_id) == (7, 5)
+ torch.cuda.get_device_capability(device_id) == (7, 5)
and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
)
def get_cuda_device_id():
return (
- int(shared.cmd_opts.device_id)
- if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
+ int(shared.cmd_opts.device_id)
+ if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
else 0
) or torch.cuda.current_device()
@@ -116,16 +116,19 @@ patch_module_list = [
torch.nn.LayerNorm,
]
+
+def manual_cast_forward(self, *args, **kwargs):
+ org_dtype = next(self.parameters()).dtype
+ self.to(dtype)
+ args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
+ kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
+ result = self.org_forward(*args, **kwargs)
+ self.to(org_dtype)
+ return result
+
+
@contextlib.contextmanager
def manual_autocast():
- def manual_cast_forward(self, *args, **kwargs):
- org_dtype = next(self.parameters()).dtype
- self.to(dtype)
- args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
- kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
- result = self.org_forward(*args, **kwargs)
- self.to(org_dtype)
- return result
for module_type in patch_module_list:
org_forward = module_type.forward
module_type.forward = manual_cast_forward