diff options
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r-- | modules/sd_hijack.py | 7 |
1 files changed, 2 insertions, 5 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index b824b5bf..95a17093 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -17,6 +17,7 @@ from modules.sd_hijack_optimizations import invokeAI_mps_available import ldm.modules.attention
import ldm.modules.diffusionmodules.model
+import ldm.modules.diffusionmodules.openaimodel
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
import ldm.modules.encoders.modules
@@ -182,11 +183,7 @@ def register_buffer(self, name, attr): if type(attr) == torch.Tensor:
if attr.device != devices.device:
-
- if devices.has_mps():
- attr = attr.to(device="mps", dtype=torch.float32)
- else:
- attr = attr.to(devices.device)
+ attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
setattr(self, name, attr)
|