aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/cmd_args.py1
-rw-r--r--modules/generation_parameters_copypaste.py40
-rw-r--r--modules/processing.py37
-rw-r--r--modules/sd_models.py26
-rw-r--r--modules/shared.py45
-rw-r--r--requirements_versions.txt1
6 files changed, 149 insertions, 1 deletions
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index f4a4ab36..46043e33 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -102,4 +102,5 @@ parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gra
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
+parser.add_argument("--token-merging", action='store_true', help="Provides speed and memory improvements by merging redundant tokens. This has a more pronounced effect on higher resolutions.", default=False)
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index b0e945a1..fb56254f 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -282,6 +282,33 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
res["Hires resize-1"] = 0
res["Hires resize-2"] = 0
+ # Infer additional override settings for token merging
+ token_merging_ratio = res.get("Token merging ratio", None)
+ token_merging_ratio_hr = res.get("Token merging ratio hr", None)
+
+ if token_merging_ratio is not None or token_merging_ratio_hr is not None:
+ res["Token merging"] = 'True'
+
+ if token_merging_ratio is None:
+ res["Token merging hr only"] = 'True'
+ else:
+ res["Token merging hr only"] = 'False'
+
+ if res.get("Token merging random", None) is None:
+ res["Token merging random"] = 'False'
+ if res.get("Token merging merge attention", None) is None:
+ res["Token merging merge attention"] = 'True'
+ if res.get("Token merging merge cross attention", None) is None:
+ res["Token merging merge cross attention"] = 'False'
+ if res.get("Token merging merge mlp", None) is None:
+ res["Token merging merge mlp"] = 'False'
+ if res.get("Token merging stride x", None) is None:
+ res["Token merging stride x"] = '2'
+ if res.get("Token merging stride y", None) is None:
+ res["Token merging stride y"] = '2'
+ if res.get("Token merging maximum down sampling", None) is None:
+ res["Token merging maximum down sampling"] = '1'
+
restore_old_hires_fix_params(res)
# Missing RNG means the default was set, which is GPU RNG
@@ -308,8 +335,19 @@ infotext_to_setting_name_mapping = [
('UniPC skip type', 'uni_pc_skip_type'),
('UniPC order', 'uni_pc_order'),
('UniPC lower order final', 'uni_pc_lower_order_final'),
+ ('Token merging', 'token_merging'),
+ ('Token merging ratio', 'token_merging_ratio'),
+ ('Token merging hr only', 'token_merging_hr_only'),
+ ('Token merging ratio hr', 'token_merging_ratio_hr'),
+ ('Token merging random', 'token_merging_random'),
+ ('Token merging merge attention', 'token_merging_merge_attention'),
+ ('Token merging merge cross attention', 'token_merging_merge_cross_attention'),
+ ('Token merging merge mlp', 'token_merging_merge_mlp'),
+ ('Token merging maximum down sampling', 'token_merging_maximum_down_sampling'),
+ ('Token merging stride x', 'token_merging_stride_x'),
+ ('Token merging stride y', 'token_merging_stride_y'),
('RNG', 'randn_source'),
- ('NGMS', 's_min_uncond'),
+ ('NGMS', 's_min_uncond')
]
diff --git a/modules/processing.py b/modules/processing.py
index f902b9df..8ba3a96b 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -29,6 +29,13 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
from einops import repeat, rearrange
from blendmodes.blend import blendLayers, BlendType
+import tomesd
+
+# add a logger for the processing module
+logger = logging.getLogger(__name__)
+# manually set output level here since there is no option to do so yet through launch options
+# logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(name)s %(message)s')
+
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
@@ -489,6 +496,15 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
"Clip skip": None if clip_skip <= 1 else clip_skip,
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
+ "Token merging ratio": None if not (opts.token_merging or cmd_opts.token_merging) or opts.token_merging_hr_only else opts.token_merging_ratio,
+ "Token merging ratio hr": None if not (opts.token_merging or cmd_opts.token_merging) else opts.token_merging_ratio_hr,
+ "Token merging random": None if opts.token_merging_random is False else opts.token_merging_random,
+ "Token merging merge attention": None if opts.token_merging_merge_attention is True else opts.token_merging_merge_attention,
+ "Token merging merge cross attention": None if opts.token_merging_merge_cross_attention is False else opts.token_merging_merge_cross_attention,
+ "Token merging merge mlp": None if opts.token_merging_merge_mlp is False else opts.token_merging_merge_mlp,
+ "Token merging stride x": None if opts.token_merging_stride_x == 2 else opts.token_merging_stride_x,
+ "Token merging stride y": None if opts.token_merging_stride_y == 2 else opts.token_merging_stride_y,
+ "Token merging maximum down sampling": None if opts.token_merging_maximum_down_sampling == 1 else opts.token_merging_maximum_down_sampling,
"Init image hash": getattr(p, 'init_img_hash', None),
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
@@ -522,9 +538,18 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if k == 'sd_vae':
sd_vae.reload_vae_weights()
+ if (opts.token_merging or cmd_opts.token_merging) and not opts.token_merging_hr_only:
+ sd_models.apply_token_merging(sd_model=p.sd_model, hr=False)
+ logger.debug('Token merging applied')
+
res = process_images_inner(p)
finally:
+ # undo model optimizations made by tomesd
+ if opts.token_merging or cmd_opts.token_merging:
+ tomesd.remove_patch(p.sd_model)
+ logger.debug('Token merging model optimizations removed')
+
# restore opts to original state
if p.override_settings_restore_afterwards:
for k, v in stored_opts.items():
@@ -977,6 +1002,18 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = None
devices.torch_gc()
+ # apply token merging optimizations from tomesd for high-res pass
+ # check if hr_only so we are not redundantly patching
+ if (cmd_opts.token_merging or opts.token_merging) and (opts.token_merging_hr_only or opts.token_merging_ratio_hr != opts.token_merging_ratio):
+ # case where user wants to use separate merge ratios
+ if not opts.token_merging_hr_only:
+ # clean patch done by first pass. (clobbering the first patch might be fine? this might be excessive)
+ tomesd.remove_patch(self.sd_model)
+ logger.debug('Temporarily removed token merging optimizations in preparation for next pass')
+
+ sd_models.apply_token_merging(sd_model=self.sd_model, hr=True)
+ logger.debug('Applied token merging for high-res pass')
+
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
self.is_hr_pass = False
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 3316d021..4787193c 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -17,6 +17,7 @@ from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
+import tomesd
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
@@ -578,3 +579,28 @@ def unload_model_weights(sd_model=None, info=None):
print(f"Unloaded weights {timer.summary()}.")
return sd_model
+
+
+def apply_token_merging(sd_model, hr: bool):
+ """
+ Applies speed and memory optimizations from tomesd.
+
+ Args:
+ hr (bool): True if called in the context of a high-res pass
+ """
+
+ ratio = shared.opts.token_merging_ratio
+ if hr:
+ ratio = shared.opts.token_merging_ratio_hr
+
+ tomesd.apply_patch(
+ sd_model,
+ ratio=ratio,
+ max_downsample=shared.opts.token_merging_maximum_down_sampling,
+ sx=shared.opts.token_merging_stride_x,
+ sy=shared.opts.token_merging_stride_y,
+ use_rand=shared.opts.token_merging_random,
+ merge_attn=shared.opts.token_merging_merge_attention,
+ merge_crossattn=shared.opts.token_merging_merge_cross_attention,
+ merge_mlp=shared.opts.token_merging_merge_mlp
+ )
diff --git a/modules/shared.py b/modules/shared.py
index 1df1dd45..4b346585 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -458,6 +458,51 @@ options_templates.update(options_section((None, "Hidden options"), {
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
}))
+options_templates.update(options_section(('token_merging', 'Token Merging'), {
+ "token_merging": OptionInfo(
+ False, "Enable redundant token merging via tomesd. This can provide significant speed and memory improvements.",
+ gr.Checkbox
+ ),
+ "token_merging_ratio": OptionInfo(
+ 0.5, "Merging Ratio",
+ gr.Slider, {"minimum": 0, "maximum": 0.9, "step": 0.1}
+ ),
+ "token_merging_hr_only": OptionInfo(
+ True, "Apply only to high-res fix pass. Disabling can yield a ~20-35% speedup on contemporary resolutions.",
+ gr.Checkbox
+ ),
+ "token_merging_ratio_hr": OptionInfo(
+ 0.5, "Merging Ratio (high-res pass) - If 'Apply only to high-res' is enabled, this will always be the ratio used.",
+ gr.Slider, {"minimum": 0, "maximum": 0.9, "step": 0.1}
+ ),
+ # More advanced/niche settings:
+ "token_merging_random": OptionInfo(
+ False, "Use random perturbations - Can improve outputs for certain samplers. For others, it may cause visual artifacting.",
+ gr.Checkbox
+ ),
+ "token_merging_merge_attention": OptionInfo(
+ True, "Merge attention",
+ gr.Checkbox
+ ),
+ "token_merging_merge_cross_attention": OptionInfo(
+ False, "Merge cross attention",
+ gr.Checkbox
+ ),
+ "token_merging_merge_mlp": OptionInfo(
+ False, "Merge mlp",
+ gr.Checkbox
+ ),
+ "token_merging_maximum_down_sampling": OptionInfo(1, "Maximum down sampling", gr.Radio, lambda: {"choices": [1, 2, 4, 8]}),
+ "token_merging_stride_x": OptionInfo(
+ 2, "Stride - X",
+ gr.Slider, {"minimum": 2, "maximum": 8, "step": 2}
+ ),
+ "token_merging_stride_y": OptionInfo(
+ 2, "Stride - Y",
+ gr.Slider, {"minimum": 2, "maximum": 8, "step": 2}
+ )
+}))
+
options_templates.update()
diff --git a/requirements_versions.txt b/requirements_versions.txt
index df8c6861..225e2319 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -26,3 +26,4 @@ torchsde==0.2.5
safetensors==0.3.1
httpcore<=0.15
fastapi==0.94.0
+tomesd>=0.1.2 \ No newline at end of file