aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_hijack.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r--modules/sd_hijack.py15
1 files changed, 8 insertions, 7 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 1eb0644a..92874a79 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -8,9 +8,9 @@ from torch import einsum
from torch.nn.functional import silu
import modules.textual_inversion.textual_inversion
-from modules import prompt_parser, devices, sd_hijack_optimizations, shared
+from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
from modules.hypernetworks import hypernetwork
-from modules.shared import cmd_opts
+from modules.shared import opts, device, cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip
from modules.sd_hijack_optimizations import invokeAI_mps_available
@@ -68,6 +68,10 @@ def undo_optimizations():
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
+def fix_checkpoint():
+ ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward
+ ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
+ ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
class StableDiffusionModelHijack:
fixes = None
@@ -90,6 +94,7 @@ class StableDiffusionModelHijack:
self.clip = m.cond_stage_model
apply_optimizations()
+ fix_checkpoint()
def flatten(el):
flattened = [flatten(children) for children in el.children()]
@@ -179,11 +184,7 @@ def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != devices.device:
-
- if devices.has_mps():
- attr = attr.to(device="mps", dtype=torch.float32)
- else:
- attr = attr.to(devices.device)
+ attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
setattr(self, name, attr)