aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_hijack.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-08-05 08:01:38 +0300
committerAUTOMATIC1111 <16777216c@gmail.com>2023-08-05 08:01:38 +0300
commitef1698fd6dbd6387341a1eeeded068ff1476ee50 (patch)
treeddaa0cf76e8cf95b93f63909a026ae3d5eab460a /modules/sd_hijack.py
parent0fae47e97445df4e7de4d85538a80917fc2a2457 (diff)
parentc613416af375092f55b9bc8649c949e95d250c44 (diff)
Merge branch 'dev' into extra-networks-always-visible
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r--modules/sd_hijack.py62
1 files changed, 53 insertions, 9 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 6b5aae4b..9ad98199 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -2,11 +2,10 @@ import torch
from torch.nn.functional import silu
from types import MethodType
-import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
-from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
+from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, sd_hijack_inpainting
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
@@ -15,6 +14,11 @@ import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
import ldm.modules.encoders.modules
+import sgm.modules.attention
+import sgm.modules.diffusionmodules.model
+import sgm.modules.diffusionmodules.openaimodel
+import sgm.modules.encoders.modules
+
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
@@ -25,8 +29,12 @@ ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.Cros
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
# silence new console spam from SD2
-ldm.modules.attention.print = lambda *args: None
-ldm.modules.diffusionmodules.model.print = lambda *args: None
+ldm.modules.attention.print = shared.ldm_print
+ldm.modules.diffusionmodules.model.print = shared.ldm_print
+ldm.util.print = shared.ldm_print
+ldm.models.diffusion.ddpm.print = shared.ldm_print
+
+sd_hijack_inpainting.do_inpainting_hijack()
optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None
@@ -56,6 +64,9 @@ def apply_optimizations(option=None):
ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
+ sgm.modules.diffusionmodules.model.nonlinearity = silu
+ sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
+
if current_optimizer is not None:
current_optimizer.undo()
current_optimizer = None
@@ -89,6 +100,10 @@ def undo_optimizations():
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
+ sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
+ sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
+
def fix_checkpoint():
"""checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
@@ -152,12 +167,13 @@ class StableDiffusionModelHijack:
clip = None
optimization_method = None
- embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
-
def __init__(self):
+ import modules.textual_inversion.textual_inversion
+
self.extra_generation_params = {}
self.comments = []
+ self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
def apply_optimizations(self, option=None):
@@ -168,6 +184,32 @@ class StableDiffusionModelHijack:
undo_optimizations()
def hijack(self, m):
+ conditioner = getattr(m, 'conditioner', None)
+ if conditioner:
+ text_cond_models = []
+
+ for i in range(len(conditioner.embedders)):
+ embedder = conditioner.embedders[i]
+ typename = type(embedder).__name__
+ if typename == 'FrozenOpenCLIPEmbedder':
+ embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
+ conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
+ text_cond_models.append(conditioner.embedders[i])
+ if typename == 'FrozenCLIPEmbedder':
+ model_embeddings = embedder.transformer.text_model.embeddings
+ model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
+ conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
+ text_cond_models.append(conditioner.embedders[i])
+ if typename == 'FrozenOpenCLIPEmbedder2':
+ embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self, textual_inversion_key='clip_g')
+ conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
+ text_cond_models.append(conditioner.embedders[i])
+
+ if len(text_cond_models) == 1:
+ m.cond_stage_model = text_cond_models[0]
+ else:
+ m.cond_stage_model = conditioner
+
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
@@ -205,7 +247,7 @@ class StableDiffusionModelHijack:
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
def undo_hijack(self, m):
- if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
+ if type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
@@ -254,10 +296,11 @@ class StableDiffusionModelHijack:
class EmbeddingsWithFixes(torch.nn.Module):
- def __init__(self, wrapped, embeddings):
+ def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
super().__init__()
self.wrapped = wrapped
self.embeddings = embeddings
+ self.textual_inversion_key = textual_inversion_key
def forward(self, input_ids):
batch_fixes = self.embeddings.fixes
@@ -271,7 +314,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
vecs = []
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes:
- emb = devices.cond_cast_unet(embedding.vec)
+ vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
+ emb = devices.cond_cast_unet(vec)
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])