aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_models.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py47
1 files changed, 37 insertions, 10 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 36f643e1..4bd8783e 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -15,9 +15,9 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
-from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
+import tomesd
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
@@ -87,8 +87,7 @@ class CheckpointInfo:
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
-
- from transformers import logging, CLIPModel
+ from transformers import logging, CLIPModel # noqa: F401
logging.set_verbosity_error()
except Exception:
@@ -167,7 +166,7 @@ def model_hash(filename):
def select_checkpoint():
model_checkpoint = shared.opts.sd_model_checkpoint
-
+
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
if checkpoint_info is not None:
return checkpoint_info
@@ -239,7 +238,7 @@ def read_metadata_from_safetensors(filename):
if isinstance(v, str) and v[0:1] == '{':
try:
res[k] = json.loads(v)
- except Exception as e:
+ except Exception:
pass
return res
@@ -374,7 +373,7 @@ def enable_midas_autodownload():
if not os.path.exists(path):
if not os.path.exists(midas_path):
mkdir(midas_path)
-
+
print(f"Downloading midas model weights for {model_type} to {path}")
request.urlretrieve(midas_urls[model_type], path)
print(f"{model_type} downloaded")
@@ -415,6 +414,9 @@ class SdModelData:
def get_sd_model(self):
if self.sd_model is None:
with self.lock:
+ if self.sd_model is not None:
+ return self.sd_model
+
try:
load_model()
except Exception as e:
@@ -467,7 +469,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
try:
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
sd_model = instantiate_from_config(sd_config.model)
- except Exception as e:
+ except Exception:
pass
if sd_model is None:
@@ -538,13 +540,12 @@ def reload_model_weights(sd_model=None, info=None):
if sd_model is None or checkpoint_config != sd_model.used_config:
del sd_model
- checkpoints_loaded.clear()
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
return model_data.sd_model
try:
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
- except Exception as e:
+ except Exception:
print("Failed to load checkpoint, restoring previous")
load_model_weights(sd_model, current_checkpoint_info, None, timer)
raise
@@ -565,7 +566,7 @@ def reload_model_weights(sd_model=None, info=None):
def unload_model_weights(sd_model=None, info=None):
- from modules import lowvram, devices, sd_hijack
+ from modules import devices, sd_hijack
timer = Timer()
if model_data.sd_model:
@@ -580,3 +581,29 @@ def unload_model_weights(sd_model=None, info=None):
print(f"Unloaded weights {timer.summary()}.")
return sd_model
+
+
+def apply_token_merging(sd_model, token_merging_ratio):
+ """
+ Applies speed and memory optimizations from tomesd.
+ """
+
+ current_token_merging_ratio = getattr(sd_model, 'applied_token_merged_ratio', 0)
+
+ if current_token_merging_ratio == token_merging_ratio:
+ return
+
+ if current_token_merging_ratio > 0:
+ tomesd.remove_patch(sd_model)
+
+ if token_merging_ratio > 0:
+ tomesd.apply_patch(
+ sd_model,
+ ratio=token_merging_ratio,
+ use_rand=False, # can cause issues with some samplers
+ merge_attn=True,
+ merge_crossattn=False,
+ merge_mlp=False
+ )
+
+ sd_model.applied_token_merged_ratio = token_merging_ratio