aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/sd_hijack.py9
-rw-r--r--modules/sd_models.py6
2 files changed, 14 insertions, 1 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 2848a251..5945b7c2 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -245,6 +245,7 @@ class StableDiffusionModelHijack:
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
+
self.clip = m.cond_stage_model
if cmd_opts.opt_split_attention_v1:
@@ -263,6 +264,14 @@ class StableDiffusionModelHijack:
self.layers = flatten(m)
+ def undo_hijack(self, m):
+ if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords:
+ m.cond_stage_model = m.cond_stage_model.wrapped
+
+ model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
+ if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
+ model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
+
def apply_circular(self, enable):
if self.circular_enabled == enable:
return
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 7a5edced..eb21e498 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -137,7 +137,7 @@ def load_model():
def reload_model_weights(sd_model, info=None):
- from modules import lowvram, devices
+ from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint()
if sd_model.sd_model_checkpint == checkpoint_info.filename:
@@ -148,8 +148,12 @@ def reload_model_weights(sd_model, info=None):
else:
sd_model.to(devices.cpu)
+ sd_hijack.model_hijack.undo_hijack(sd_model)
+
load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
+ sd_hijack.model_hijack.hijack(sd_model)
+
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device)