aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_models.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py36
1 files changed, 20 insertions, 16 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index e612be10..4bd8783e 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -583,23 +583,27 @@ def unload_model_weights(sd_model=None, info=None):
return sd_model
-def apply_token_merging(sd_model, hr: bool):
+def apply_token_merging(sd_model, token_merging_ratio):
"""
Applies speed and memory optimizations from tomesd.
-
- Args:
- hr (bool): True if called in the context of a high-res pass
"""
- ratio = shared.opts.token_merging_ratio
- if hr:
- ratio = shared.opts.token_merging_ratio_hr
-
- tomesd.apply_patch(
- sd_model,
- ratio=ratio,
- use_rand=False, # can cause issues with some samplers
- merge_attn=True,
- merge_crossattn=False,
- merge_mlp=False
- )
+ current_token_merging_ratio = getattr(sd_model, 'applied_token_merged_ratio', 0)
+
+ if current_token_merging_ratio == token_merging_ratio:
+ return
+
+ if current_token_merging_ratio > 0:
+ tomesd.remove_patch(sd_model)
+
+ if token_merging_ratio > 0:
+ tomesd.apply_patch(
+ sd_model,
+ ratio=token_merging_ratio,
+ use_rand=False, # can cause issues with some samplers
+ merge_attn=True,
+ merge_crossattn=False,
+ merge_mlp=False
+ )
+
+ sd_model.applied_token_merged_ratio = token_merging_ratio