aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorpapuSpartan <macabeg@icloud.com>2023-04-04 02:26:44 -0500
committerpapuSpartan <macabeg@icloud.com>2023-04-04 02:26:44 -0500
commit5c8e53d5e98da0eabf384318955c57842d612c07 (patch)
tree62686f3d064381bb606624f0fc53ea97b5f4e9b4 /modules
parentc707b7df95a61b66a05be94e805e1be9a432e294 (diff)
Allow different merge ratios to be used for each pass. Make toggle cmd flag work again. Remove ratio flag. Remove warning about controlnet being incompatible
Diffstat (limited to 'modules')
-rw-r--r--modules/cmd_args.py3
-rw-r--r--modules/processing.py44
-rw-r--r--modules/sd_models.py29
-rw-r--r--modules/shared.py6
4 files changed, 49 insertions, 33 deletions
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index 4314f97b..8e5a7fab 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -103,5 +103,4 @@ parser.add_argument("--no-hashing", action='store_true', help="disable sha256 ha
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
# token merging / tomesd
-parser.add_argument("--token-merging", action='store_true', help="Provides generation speedup by merging redundant tokens. (compatible with --xformers)", default=False)
-parser.add_argument("--token-merging-ratio", type=float, help="Adjusts ratio of merged to untouched tokens. Range: (0.0-1.0], Defaults to 0.5", default=0.5)
+parser.add_argument("--token-merging", action='store_true', help="Provides speed and memory improvements by merging redundant tokens. This has a more pronounced effect on higher resolutions.", default=False)
diff --git a/modules/processing.py b/modules/processing.py
index 55735572..670a7a28 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -501,26 +501,16 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if k == 'sd_vae':
sd_vae.reload_vae_weights()
- if opts.token_merging and not opts.token_merging_hr_only:
- print("applying token merging to all passes")
- tomesd.apply_patch(
- p.sd_model,
- ratio=opts.token_merging_ratio,
- max_downsample=opts.token_merging_maximum_down_sampling,
- sx=opts.token_merging_stride_x,
- sy=opts.token_merging_stride_y,
- use_rand=opts.token_merging_random,
- merge_attn=opts.token_merging_merge_attention,
- merge_crossattn=opts.token_merging_merge_cross_attention,
- merge_mlp=opts.token_merging_merge_mlp
- )
+ if (opts.token_merging or cmd_opts.token_merging) and not opts.token_merging_hr_only:
+ print("\nApplying token merging\n")
+ sd_models.apply_token_merging(sd_model=p.sd_model, hr=False)
res = process_images_inner(p)
finally:
# undo model optimizations made by tomesd
- if opts.token_merging:
- print('removing token merging model optimizations')
+ if opts.token_merging or cmd_opts.token_merging:
+ print('\nRemoving token merging model optimizations\n')
tomesd.remove_patch(p.sd_model)
# restore opts to original state
@@ -959,20 +949,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
devices.torch_gc()
# apply token merging optimizations from tomesd for high-res pass
- # check if hr_only so we don't redundantly apply patch
- if opts.token_merging and opts.token_merging_hr_only:
- print("applying token merging for high-res pass")
- tomesd.apply_patch(
- self.sd_model,
- ratio=opts.token_merging_ratio,
- max_downsample=opts.token_merging_maximum_down_sampling,
- sx=opts.token_merging_stride_x,
- sy=opts.token_merging_stride_y,
- use_rand=opts.token_merging_random,
- merge_attn=opts.token_merging_merge_attention,
- merge_crossattn=opts.token_merging_merge_cross_attention,
- merge_mlp=opts.token_merging_merge_mlp
- )
+ # check if hr_only so we are not redundantly patching
+ if (cmd_opts.token_merging or opts.token_merging) and (opts.token_merging_hr_only or opts.token_merging_ratio_hr != opts.token_merging_ratio):
+ # case where user wants to use separate merge ratios
+ if not opts.token_merging_hr_only:
+ # clean patch done by first pass. (clobbering the first patch might be fine? this might be excessive)
+ print('Temporarily reverting token merging optimizations in preparation for next pass')
+ tomesd.remove_patch(self.sd_model)
+
+ print("\nApplying token merging for high-res pass\n")
+ sd_models.apply_token_merging(sd_model=self.sd_model, hr=True)
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 87c49b83..696a2333 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -16,6 +16,7 @@ from modules import paths, shared, modelloader, devices, script_callbacks, sd_va
from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
+import tomesd
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
@@ -545,4 +546,30 @@ def unload_model_weights(sd_model=None, info=None):
print(f"Unloaded weights {timer.summary()}.")
- return sd_model \ No newline at end of file
+ return sd_model
+
+
+def apply_token_merging(sd_model, hr: bool):
+ """
+ Applies speed and memory optimizations from tomesd.
+
+ Args:
+ hr (bool): True if called in the context of a high-res pass
+ """
+
+ ratio = shared.opts.token_merging_ratio
+ if hr:
+ ratio = shared.opts.token_merging_ratio_hr
+ print("effective hr pass merge ratio is "+str(ratio))
+
+ tomesd.apply_patch(
+ sd_model,
+ ratio=ratio,
+ max_downsample=shared.opts.token_merging_maximum_down_sampling,
+ sx=shared.opts.token_merging_stride_x,
+ sy=shared.opts.token_merging_stride_y,
+ use_rand=shared.opts.token_merging_random,
+ merge_attn=shared.opts.token_merging_merge_attention,
+ merge_crossattn=shared.opts.token_merging_merge_cross_attention,
+ merge_mlp=shared.opts.token_merging_merge_mlp
+ )
diff --git a/modules/shared.py b/modules/shared.py
index d7379e24..c7572e98 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -429,7 +429,7 @@ options_templates.update(options_section((None, "Hidden options"), {
options_templates.update(options_section(('token_merging', 'Token Merging'), {
"token_merging": OptionInfo(
- False, "Enable redundant token merging via tomesd. (currently incompatible with controlnet extension)",
+ 0.5, "Enable redundant token merging via tomesd. This can provide significant speed and memory improvements.",
gr.Checkbox
),
"token_merging_ratio": OptionInfo(
@@ -440,6 +440,10 @@ options_templates.update(options_section(('token_merging', 'Token Merging'), {
True, "Apply only to high-res fix pass. Disabling can yield a ~20-35% speedup on contemporary resolutions.",
gr.Checkbox
),
+ "token_merging_ratio_hr": OptionInfo(
+ 0.5, "Merging Ratio (high-res pass) - If 'Apply only to high-res' is enabled, this will always be the ratio used.",
+ gr.Slider, {"minimum": 0, "maximum": 0.9, "step": 0.1}
+ ),
# More advanced/niche settings:
"token_merging_random": OptionInfo(
True, "Use random perturbations - Disabling might help with certain samplers",