aboutsummaryrefslogtreecommitdiff
path: root/modules/processing.py
diff options
context:
space:
mode:
authorAngelBottomless <aria1th@naver.com>2023-11-16 18:43:16 +0900
committerAngelBottomless <aria1th@naver.com>2023-11-16 18:43:16 +0900
commitbcfaf3979a9f93e37c418b58c75b02d9570b4354 (patch)
tree07e5ff505eda5b94c63f75438460a50ae9d954fb /modules/processing.py
parentaf45872fdb8a66ffd6a405d99120e0bacbb4a170 (diff)
convert/add hypertile options
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py21
1 files changed, 11 insertions, 10 deletions
diff --git a/modules/processing.py b/modules/processing.py
index e19a09a3..c622ff33 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -24,7 +24,7 @@ from modules.shared import opts, cmd_opts, state
import modules.shared as shared
import modules.paths as paths
import modules.face_restoration
-from modules.hypertile import split_attention, set_hypertile_seed, largest_tile_size_available
+from modules.hypertile import set_hypertile_seed, largest_tile_size_available, hypertile_context_unet, hypertile_context_vae
import modules.images as images
import modules.styles
import modules.sd_models as sd_models
@@ -874,7 +874,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
else:
if opts.sd_vae_decode_method != 'Full':
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
- with split_attention(p.sd_model.first_stage_model, aspect_ratio = p.width / p.height, tile_size=min(largest_tile_size_available(p.width, p.height), 128), disable=not shared.opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl):
+ with hypertile_context_unet(p.sd_model.first_stage_model, aspect_ratio=p.width / p.height, tile_size=largest_tile_size_available(p.width, p.height), is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts):
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
x_samples_ddim = torch.stack(x_samples_ddim).float()
@@ -1144,8 +1144,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
aspect_ratio = self.width / self.height
x = self.rng.next()
tile_size = largest_tile_size_available(self.width, self.height)
- with split_attention(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 128), swap_size=1, disable=not shared.opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl):
- with split_attention(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 256), swap_size=2, disable=not shared.opts.hypertile_split_unet_attn, is_sdxl=shared.sd_model.is_sdxl):
+ with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
+ with hypertile_context_unet(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts):
devices.torch_gc()
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
del x
@@ -1153,7 +1153,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return samples
if self.latent_scale_mode is None:
- with split_attention(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 256), swap_size=1, disable=not shared.opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl):
+ with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
else:
decoded_samples = None
@@ -1245,15 +1245,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if self.scripts is not None:
self.scripts.before_hr(self)
tile_size = largest_tile_size_available(target_width, target_height)
- with split_attention(self.sd_model.first_stage_model, aspect_ratio=target_width / target_height, tile_size=min(tile_size, 256), swap_size=1, disable=not opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl):
- with split_attention(self.sd_model.model, aspect_ratio=target_width / target_height, tile_size=min(tile_size, 256), swap_size=3, max_depth=1,scale_depth=True, disable=not opts.hypertile_split_unet_attn, is_sdxl=shared.sd_model.is_sdxl):
+ aspect_ratio = self.width / self.height
+ with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
+ with hypertile_context_unet(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts):
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
self.sampler = None
devices.torch_gc()
- with split_attention(self.sd_model.first_stage_model, aspect_ratio=target_width / target_height, tile_size=min(tile_size, 256), swap_size=1, disable=not opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl):
+ with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
self.is_hr_pass = False
@@ -1533,8 +1534,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
x *= self.initial_noise_multiplier
aspect_ratio = self.width / self.height
tile_size = largest_tile_size_available(self.width, self.height)
- with split_attention(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 128), swap_size=1, disable=not shared.opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl):
- with split_attention(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 256), swap_size=2, disable=not shared.opts.hypertile_split_unet_attn, is_sdxl=shared.sd_model.is_sdxl):
+ with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
+ with hypertile_context_unet(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts):
devices.torch_gc()
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)