aboutsummaryrefslogtreecommitdiff
path: root/modules/processing.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py27
1 files changed, 24 insertions, 3 deletions
diff --git a/modules/processing.py b/modules/processing.py
index b0e240a4..e2309534 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -799,6 +799,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
infotexts = []
output_images = []
+ unet_object = p.sd_model.model
+ vae_model = p.sd_model.first_stage_model
+ try:
+ from hyper_tile import split_attention, flush
+ except (ImportError, ModuleNotFoundError): # pip install git+https://github.com/tfernd/HyperTile@2ef64b2800d007d305755c33550537410310d7df
+ split_attention = lambda *args, **kwargs: lambda x: x # return a no-op context manager
+ flush = lambda: None
+ import random
+ saved_rng_state = random.getstate()
+ random.seed(p.seed) # hyper_tile uses random, so we need to seed it
with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast():
@@ -866,15 +876,25 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
- samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
+ # get largest tile size available, which is 2^x which is factor of gcd of p.width and p.height
+ gcd = math.gcd(p.width, p.height)
+ largest_tile_size_available = 1
+ while gcd % (largest_tile_size_available * 2) == 0:
+ largest_tile_size_available *= 2
+ aspect_ratio = p.width / p.height
+ with split_attention(vae_model, aspect_ratio=aspect_ratio, tile_size=min(largest_tile_size_available, 128), disable=not shared.opts.hypertile_split_vae_attn):
+ with split_attention(unet_object, aspect_ratio=aspect_ratio, tile_size=min(largest_tile_size_available, 256), swap_size=2, disable=not shared.opts.hypertile_split_unet_attn):
+ flush()
+ samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
if getattr(samples_ddim, 'already_decoded', False):
x_samples_ddim = samples_ddim
else:
if opts.sd_vae_decode_method != 'Full':
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
-
- x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
+ with split_attention(vae_model, aspect_ratio=aspect_ratio, tile_size=min(largest_tile_size_available, 128), disable=not shared.opts.hypertile_split_vae_attn):
+ flush()
+ x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@@ -980,6 +1000,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
+ random.setstate(saved_rng_state)
if not p.disable_extra_networks and p.extra_network_data:
extra_networks.deactivate(p, p.extra_network_data)