aboutsummaryrefslogtreecommitdiff
path: root/modules/processing.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-11-26 11:18:25 +0300
committerGitHub <noreply@github.com>2023-11-26 11:18:25 +0300
commitfd8674a4bcdde86854e019ec74aa3112191dfd26 (patch)
treeb5cc91eb6658173159c08c952519e071cd14b3b6 /modules/processing.py
parent8aa51f682c17d85f4585b9471860224568d25e95 (diff)
parent97431f29feb17ffc96ca95e9b3efec87be9d8b3a (diff)
Merge pull request #13948 from aria1th/hypertile-in-sample
support HyperTile optimization
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py44
1 files changed, 25 insertions, 19 deletions
diff --git a/modules/processing.py b/modules/processing.py
index b0e240a4..36c2be5e 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -24,6 +24,7 @@ from modules.shared import opts, cmd_opts, state
import modules.shared as shared
import modules.paths as paths
import modules.face_restoration
+from modules.hypertile import set_hypertile_seed, largest_tile_size_available, hypertile_context_unet, hypertile_context_vae
import modules.images as images
import modules.styles
import modules.sd_models as sd_models
@@ -799,7 +800,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
infotexts = []
output_images = []
-
with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
@@ -861,7 +861,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.comment(comment)
p.extra_generation_params.update(model_hijack.extra_generation_params)
-
+ set_hypertile_seed(p.seed)
+ # add batch size + hypertile status to information to reproduce the run
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
@@ -873,8 +874,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
else:
if opts.sd_vae_decode_method != 'Full':
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
-
- x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
+ with hypertile_context_vae(p.sd_model.first_stage_model, aspect_ratio=p.width / p.height, tile_size=largest_tile_size_available(p.width, p.height), opts=shared.opts):
+ x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@@ -1140,24 +1141,25 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
-
+ aspect_ratio = self.width / self.height
x = self.rng.next()
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
+ tile_size = largest_tile_size_available(self.width, self.height)
+ with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
+ with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts):
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
del x
-
if not self.enable_hr:
return samples
+ devices.torch_gc()
if self.latent_scale_mode is None:
- decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
+ with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
+ decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
else:
decoded_samples = None
with sd_models.SkipWritingToConfig():
sd_models.reload_model_weights(info=self.hr_checkpoint_info)
-
- devices.torch_gc()
-
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
@@ -1165,7 +1167,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return samples
self.is_hr_pass = True
-
target_width = self.hr_upscale_to_x
target_height = self.hr_upscale_to_y
@@ -1243,18 +1244,20 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if self.scripts is not None:
self.scripts.before_hr(self)
-
- samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
+ tile_size = largest_tile_size_available(target_width, target_height)
+ aspect_ratio = self.width / self.height
+ with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
+ with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts):
+ samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
self.sampler = None
devices.torch_gc()
-
- decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
+ with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
+ decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
self.is_hr_pass = False
-
return decoded_samples
def close(self):
@@ -1529,8 +1532,11 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.initial_noise_multiplier != 1.0:
self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
x *= self.initial_noise_multiplier
-
- samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
+ aspect_ratio = self.width / self.height
+ tile_size = largest_tile_size_available(self.width, self.height)
+ with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
+ with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts):
+ samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
if self.mask is not None:
samples = samples * self.nmask + self.init_latent * self.mask