aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/sd_hijack_inpainting.py19
1 files changed, 18 insertions, 1 deletions
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
index fd92a335..202b42cf 100644
--- a/modules/sd_hijack_inpainting.py
+++ b/modules/sd_hijack_inpainting.py
@@ -1,4 +1,5 @@
import torch
+import modules.devices as devices
from einops import repeat
from omegaconf import ListConfig
@@ -314,6 +315,20 @@ class LatentInpaintDiffusion(LatentDiffusion):
self.masked_image_key = masked_image_key
assert self.masked_image_key in concat_keys
self.concat_keys = concat_keys
+
+
+# =================================================================================================
+# Fix register buffer bug for Mac OS, Viktor Tabori, viktor.doklist.com/start-here
+# =================================================================================================
+def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ optimal_type = devices.get_optimal_device()
+ if attr.device != optimal_type:
+ if getattr(torch, 'has_mps', False):
+ attr = attr.to(device="mps", dtype=torch.float32)
+ else:
+ attr = attr.to(optimal_type)
+ setattr(self, name, attr)
def should_hijack_inpainting(checkpoint_info):
@@ -326,6 +341,8 @@ def do_inpainting_hijack():
ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
+ ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
- ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms \ No newline at end of file
+ ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
+ ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer