aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_hijack_inpainting.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-11-11 18:20:18 +0300
committerAUTOMATIC <16777216c@gmail.com>2022-11-11 18:20:18 +0300
commit7ba3923d5b494b7756d0b12f33acb3716d830b9a (patch)
treef3fba1191df7b3c99975f3e9047cdc06a036e7e6 /modules/sd_hijack_inpainting.py
parentbb2e2c82ce886843f11339571f9a70d4c5f2a09d (diff)
move DDIM/PLMS fix for OSX out of the file with inpainting code.
Diffstat (limited to 'modules/sd_hijack_inpainting.py')
-rw-r--r--modules/sd_hijack_inpainting.py18
1 files changed, 1 insertions, 17 deletions
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
index 202b42cf..46714a4f 100644
--- a/modules/sd_hijack_inpainting.py
+++ b/modules/sd_hijack_inpainting.py
@@ -1,5 +1,4 @@
import torch
-import modules.devices as devices
from einops import repeat
from omegaconf import ListConfig
@@ -315,20 +314,6 @@ class LatentInpaintDiffusion(LatentDiffusion):
self.masked_image_key = masked_image_key
assert self.masked_image_key in concat_keys
self.concat_keys = concat_keys
-
-
-# =================================================================================================
-# Fix register buffer bug for Mac OS, Viktor Tabori, viktor.doklist.com/start-here
-# =================================================================================================
-def register_buffer(self, name, attr):
- if type(attr) == torch.Tensor:
- optimal_type = devices.get_optimal_device()
- if attr.device != optimal_type:
- if getattr(torch, 'has_mps', False):
- attr = attr.to(device="mps", dtype=torch.float32)
- else:
- attr = attr.to(optimal_type)
- setattr(self, name, attr)
def should_hijack_inpainting(checkpoint_info):
@@ -341,8 +326,7 @@ def do_inpainting_hijack():
ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
- ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
- ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
+