aboutsummaryrefslogtreecommitdiff
path: root/modules/processing.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py46
1 files changed, 37 insertions, 9 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 213a2879..045c7d79 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -179,6 +179,7 @@ class StableDiffusionProcessing:
token_merging_ratio = 0
token_merging_ratio_hr = 0
disable_extra_networks: bool = False
+ firstpass_image: Image = None
scripts_value: scripts.ScriptRunner = field(default=None, init=False)
script_args_value: list = field(default=None, init=False)
@@ -1238,18 +1239,45 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
- x = self.rng.next()
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
- del x
+ if self.firstpass_image is not None and self.enable_hr:
+ # here we don't need to generate image, we just take self.firstpass_image and prepare it for hires fix
- if not self.enable_hr:
- return samples
- devices.torch_gc()
+ if self.latent_scale_mode is None:
+ image = np.array(self.firstpass_image).astype(np.float32) / 255.0 * 2.0 - 1.0
+ image = np.moveaxis(image, 2, 0)
+
+ samples = None
+ decoded_samples = torch.asarray(np.expand_dims(image, 0))
+
+ else:
+ image = np.array(self.firstpass_image).astype(np.float32) / 255.0
+ image = np.moveaxis(image, 2, 0)
+ image = torch.from_numpy(np.expand_dims(image, axis=0))
+ image = image.to(shared.device, dtype=devices.dtype_vae)
+
+ if opts.sd_vae_encode_method != 'Full':
+ self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
+
+ samples = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
+ decoded_samples = None
+ devices.torch_gc()
- if self.latent_scale_mode is None:
- decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
else:
- decoded_samples = None
+ # here we generate an image normally
+
+ x = self.rng.next()
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
+ del x
+
+ if not self.enable_hr:
+ return samples
+
+ devices.torch_gc()
+
+ if self.latent_scale_mode is None:
+ decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
+ else:
+ decoded_samples = None
with sd_models.SkipWritingToConfig():
sd_models.reload_model_weights(info=self.hr_checkpoint_info)