aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/xpu_specific.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py
index d933c790..ec1ad100 100644
--- a/modules/xpu_specific.py
+++ b/modules/xpu_specific.py
@@ -48,3 +48,6 @@ if has_xpu:
CondFunc('torch.nn.modules.conv.Conv2d.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
+ CondFunc('torch.bmm',
+ lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out),
+ lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype)