aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/cmd_args.py3
-rw-r--r--modules/processing.py44
-rw-r--r--modules/sd_models.py29
-rw-r--r--modules/shared.py6
4 files changed, 49 insertions, 33 deletions
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index 4314f97b..8e5a7fab 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -103,5 +103,4 @@ parser.add_argument("--no-hashing", action='store_true', help="disable sha256 ha
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
# token merging / tomesd
-parser.add_argument("--token-merging", action='store_true', help="Provides generation speedup by merging redundant tokens. (compatible with --xformers)", default=False)
-parser.add_argument("--token-merging-ratio", type=float, help="Adjusts ratio of merged to untouched tokens. Range: (0.0-1.0], Defaults to 0.5", default=0.5)
+parser.add_argument("--token-merging", action='store_true', help="Provides speed and memory improvements by merging redundant tokens. This has a more pronounced effect on higher resolutions.", default=False)
diff --git a/modules/processing.py b/modules/processing.py
index 55735572..670a7a28 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -501,26 +501,16 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if k == 'sd_vae':
sd_vae.reload_vae_weights()
- if opts.token_merging and not opts.token_merging_hr_only:
- print("applying token merging to all passes")
- tomesd.apply_patch(
- p.sd_model,
- ratio=opts.token_merging_ratio,
- max_downsample=opts.token_merging_maximum_down_sampling,
- sx=opts.token_merging_stride_x,
- sy=opts.token_merging_stride_y,
- use_rand=opts.token_merging_random,
- merge_attn=opts.token_merging_merge_attention,
- merge_crossattn=opts.token_merging_merge_cross_attention,
- merge_mlp=opts.token_merging_merge_mlp
- )
+ if (opts.token_merging or cmd_opts.token_merging) and not opts.token_merging_hr_only:
+ print("\nApplying token merging\n")
+ sd_models.apply_token_merging(sd_model=p.sd_model, hr=False)
res = process_images_inner(p)
finally:
# undo model optimizations made by tomesd
- if opts.token_merging:
- print('removing token merging model optimizations')
+ if opts.token_merging or cmd_opts.token_merging:
+ print('\nRemoving token merging model optimizations\n')
tomesd.remove_patch(p.sd_model)
# restore opts to original state
@@ -959,20 +949,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
devices.torch_gc()
# apply token merging optimizations from tomesd for high-res pass
- # check if hr_only so we don't redundantly apply patch
- if opts.token_merging and opts.token_merging_hr_only:
- print("applying token merging for high-res pass")
- tomesd.apply_patch(
- self.sd_model,
- ratio=opts.token_merging_ratio,
- max_downsample=opts.token_merging_maximum_down_sampling,
- sx=opts.token_merging_stride_x,
- sy=opts.token_merging_stride_y,
- use_rand=opts.token_merging_random,
- merge_attn=opts.token_merging_merge_attention,
- merge_crossattn=opts.token_merging_merge_cross_attention,
- merge_mlp=opts.token_merging_merge_mlp
- )
+ # check if hr_only so we are not redundantly patching
+ if (cmd_opts.token_merging or opts.token_merging) and (opts.token_merging_hr_only or opts.token_merging_ratio_hr != opts.token_merging_ratio):
+ # case where user wants to use separate merge ratios
+ if not opts.token_merging_hr_only:
+ # clean patch done by first pass. (clobbering the first patch might be fine? this might be excessive)
+ print('Temporarily reverting token merging optimizations in preparation for next pass')
+ tomesd.remove_patch(self.sd_model)
+
+ print("\nApplying token merging for high-res pass\n")
+ sd_models.apply_token_merging(sd_model=self.sd_model, hr=True)
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 87c49b83..696a2333 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -16,6 +16,7 @@ from modules import paths, shared, modelloader, devices, script_callbacks, sd_va
from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
+import tomesd
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
@@ -545,4 +546,30 @@ def unload_model_weights(sd_model=None, info=None):
print(f"Unloaded weights {timer.summary()}.")
- return sd_model \ No newline at end of file
+ return sd_model
+
+
+def apply_token_merging(sd_model, hr: bool):
+ """
+ Applies speed and memory optimizations from tomesd.
+
+ Args:
+ hr (bool): True if called in the context of a high-res pass
+ """
+
+ ratio = shared.opts.token_merging_ratio
+ if hr:
+ ratio = shared.opts.token_merging_ratio_hr
+ print("effective hr pass merge ratio is "+str(ratio))
+
+ tomesd.apply_patch(
+ sd_model,
+ ratio=ratio,
+ max_downsample=shared.opts.token_merging_maximum_down_sampling,
+ sx=shared.opts.token_merging_stride_x,
+ sy=shared.opts.token_merging_stride_y,
+ use_rand=shared.opts.token_merging_random,
+ merge_attn=shared.opts.token_merging_merge_attention,
+ merge_crossattn=shared.opts.token_merging_merge_cross_attention,
+ merge_mlp=shared.opts.token_merging_merge_mlp
+ )
diff --git a/modules/shared.py b/modules/shared.py
index d7379e24..c7572e98 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -429,7 +429,7 @@ options_templates.update(options_section((None, "Hidden options"), {
options_templates.update(options_section(('token_merging', 'Token Merging'), {
"token_merging": OptionInfo(
- False, "Enable redundant token merging via tomesd. (currently incompatible with controlnet extension)",
+ 0.5, "Enable redundant token merging via tomesd. This can provide significant speed and memory improvements.",
gr.Checkbox
),
"token_merging_ratio": OptionInfo(
@@ -440,6 +440,10 @@ options_templates.update(options_section(('token_merging', 'Token Merging'), {
True, "Apply only to high-res fix pass. Disabling can yield a ~20-35% speedup on contemporary resolutions.",
gr.Checkbox
),
+ "token_merging_ratio_hr": OptionInfo(
+ 0.5, "Merging Ratio (high-res pass) - If 'Apply only to high-res' is enabled, this will always be the ratio used.",
+ gr.Slider, {"minimum": 0, "maximum": 0.9, "step": 0.1}
+ ),
# More advanced/niche settings:
"token_merging_random": OptionInfo(
True, "Use random perturbations - Disabling might help with certain samplers",