aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/processing.py35
-rw-r--r--modules/sd_models.py7
-rw-r--r--modules/shared.py44
3 files changed, 79 insertions, 7 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 6d9c6a8d..e115aadd 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -29,6 +29,7 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
from einops import repeat, rearrange
from blendmodes.blend import blendLayers, BlendType
+import tomesd
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
@@ -500,9 +501,28 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if k == 'sd_vae':
sd_vae.reload_vae_weights()
+ if opts.token_merging:
+
+ if p.hr_second_pass_steps < 1 and not opts.token_merging_hr_only:
+ tomesd.apply_patch(
+ p.sd_model,
+ ratio=opts.token_merging_ratio,
+ max_downsample=opts.token_merging_maximum_down_sampling,
+ sx=opts.token_merging_stride_x,
+ sy=opts.token_merging_stride_y,
+ use_rand=opts.token_merging_random,
+ merge_attn=opts.token_merging_merge_attention,
+ merge_crossattn=opts.token_merging_merge_cross_attention,
+ merge_mlp=opts.token_merging_merge_mlp
+ )
+
res = process_images_inner(p)
finally:
+ # undo model optimizations made by tomesd
+ if opts.token_merging:
+ tomesd.remove_patch(p.sd_model)
+
# restore opts to original state
if p.override_settings_restore_afterwards:
for k, v in stored_opts.items():
@@ -938,6 +958,21 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = None
devices.torch_gc()
+ # apply token merging optimizations from tomesd for high-res pass
+ # check if hr_only so we don't redundantly apply patch
+ if opts.token_merging and opts.token_merging_hr_only:
+ tomesd.apply_patch(
+ self.sd_model,
+ ratio=opts.token_merging_ratio,
+ max_downsample=opts.token_merging_maximum_down_sampling,
+ sx=opts.token_merging_stride_x,
+ sy=opts.token_merging_stride_y,
+ use_rand=opts.token_merging_random,
+ merge_attn=opts.token_merging_merge_attention,
+ merge_crossattn=opts.token_merging_merge_cross_attention,
+ merge_mlp=opts.token_merging_merge_mlp
+ )
+
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
return samples
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 2c05ec17..87c49b83 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -431,13 +431,6 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
sd_model = instantiate_from_config(sd_config.model)
- if shared.cmd_opts.token_merging:
- import tomesd
- ratio = shared.cmd_opts.token_merging_ratio
-
- tomesd.apply_patch(sd_model, ratio=ratio)
- print(f"Model accelerated using {(ratio * 100)}% token merging via tomesd.")
- timer.record("token merging")
except Exception as e:
pass
diff --git a/modules/shared.py b/modules/shared.py
index 5fd0eecb..d7379e24 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -427,6 +427,50 @@ options_templates.update(options_section((None, "Hidden options"), {
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
}))
+options_templates.update(options_section(('token_merging', 'Token Merging'), {
+ "token_merging": OptionInfo(
+ False, "Enable redundant token merging via tomesd. (currently incompatible with controlnet extension)",
+ gr.Checkbox
+ ),
+ "token_merging_ratio": OptionInfo(
+ 0.5, "Merging Ratio",
+ gr.Slider, {"minimum": 0, "maximum": 0.9, "step": 0.1}
+ ),
+ "token_merging_hr_only": OptionInfo(
+ True, "Apply only to high-res fix pass. Disabling can yield a ~20-35% speedup on contemporary resolutions.",
+ gr.Checkbox
+ ),
+ # More advanced/niche settings:
+ "token_merging_random": OptionInfo(
+ True, "Use random perturbations - Disabling might help with certain samplers",
+ gr.Checkbox
+ ),
+ "token_merging_merge_attention": OptionInfo(
+ True, "Merge attention",
+ gr.Checkbox
+ ),
+ "token_merging_merge_cross_attention": OptionInfo(
+ False, "Merge cross attention",
+ gr.Checkbox
+ ),
+ "token_merging_merge_mlp": OptionInfo(
+ False, "Merge mlp",
+ gr.Checkbox
+ ),
+ "token_merging_maximum_down_sampling": OptionInfo(
+ 1, "Maximum down sampling",
+ gr.Dropdown, lambda: {"choices": ["1", "2", "4", "8"]}
+ ),
+ "token_merging_stride_x": OptionInfo(
+ 2, "Stride - X",
+ gr.Slider, {"minimum": 2, "maximum": 8, "step": 2}
+ ),
+ "token_merging_stride_y": OptionInfo(
+ 2, "Stride - Y",
+ gr.Slider, {"minimum": 2, "maximum": 8, "step": 2}
+ )
+}))
+
options_templates.update()