aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_hijack.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r--modules/sd_hijack.py16
1 files changed, 10 insertions, 6 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index cfa5f0eb..9ad98199 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -2,11 +2,10 @@ import torch
from torch.nn.functional import silu
from types import MethodType
-import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
-from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
+from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, sd_hijack_inpainting
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
@@ -30,8 +29,12 @@ ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.Cros
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
# silence new console spam from SD2
-ldm.modules.attention.print = lambda *args: None
-ldm.modules.diffusionmodules.model.print = lambda *args: None
+ldm.modules.attention.print = shared.ldm_print
+ldm.modules.diffusionmodules.model.print = shared.ldm_print
+ldm.util.print = shared.ldm_print
+ldm.models.diffusion.ddpm.print = shared.ldm_print
+
+sd_hijack_inpainting.do_inpainting_hijack()
optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None
@@ -164,12 +167,13 @@ class StableDiffusionModelHijack:
clip = None
optimization_method = None
- embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
-
def __init__(self):
+ import modules.textual_inversion.textual_inversion
+
self.extra_generation_params = {}
self.comments = []
+ self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
def apply_optimizations(self, option=None):