aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorKohakuBlueleaf <apolloyeh0123@gmail.com>2024-01-09 22:39:39 +0800
committerKohakuBlueleaf <apolloyeh0123@gmail.com>2024-01-09 22:39:39 +0800
commit42e6df723c68af775b73c9fa4f43f99345348689 (patch)
tree26f55dcda9cba2d1522001ad25d336e17a50e7bb /modules
parent209c26a1cb9e4be357ab3c5e7613caf3cbc26183 (diff)
Fix bugs when arg dtype doesn't match
Diffstat (limited to 'modules')
-rw-r--r--modules/devices.py25
1 files changed, 10 insertions, 15 deletions
diff --git a/modules/devices.py b/modules/devices.py
index 6edfb127..e0574052 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -134,24 +134,19 @@ patch_module_list = [
def manual_cast_forward(target_dtype):
def forward_wrapper(self, *args, **kwargs):
+ if any(
+ isinstance(arg, torch.Tensor) and arg.dtype != target_dtype
+ for arg in args
+ ):
+ args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
+ kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
+
org_dtype = torch_utils.get_param(self).dtype
- if not target_dtype == org_dtype == dtype_inference:
+ if org_dtype != target_dtype:
self.to(target_dtype)
- args = [
- arg.to(target_dtype)
- if isinstance(arg, torch.Tensor)
- else arg
- for arg in args
- ]
- kwargs = {
- k: v.to(target_dtype)
- if isinstance(v, torch.Tensor)
- else v
- for k, v in kwargs.items()
- }
-
result = self.org_forward(*args, **kwargs)
- self.to(org_dtype)
+ if org_dtype != target_dtype:
+ self.to(org_dtype)
if target_dtype != dtype_inference:
if isinstance(result, tuple):