aboutsummaryrefslogtreecommitdiff
path: root/modules/processing.py
diff options
context:
space:
mode:
authorw-e-w <40751091+w-e-w@users.noreply.github.com>2023-08-08 11:39:34 +0900
committerGitHub <noreply@github.com>2023-08-08 11:39:34 +0900
commitf17c8c2eff63210f5e96e1e2b049b46ba9cfa389 (patch)
tree701056aec9ae11bc45df9b39b176a54fa4d34e19 /modules/processing.py
parentc75bda867be5345bf959daf23bdc19eadc90841a (diff)
parent01997f45ba089af24b03a5f614147bb0f9d8d824 (diff)
Merge branch 'dev' into auro-autolaunch
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py188
1 files changed, 116 insertions, 72 deletions
diff --git a/modules/processing.py b/modules/processing.py
index b0992ee1..31745006 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -16,6 +16,7 @@ from typing import Any, Dict, List
import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors
from modules.sd_hijack import model_hijack
+from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
import modules.paths as paths
@@ -83,7 +84,7 @@ def txt2img_image_conditioning(sd_model, x, width, height):
# The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
- image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
+ image_conditioning = images_tensor_to_samples(image_conditioning, approximation_indexes.get(opts.sd_vae_encode_method))
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
@@ -109,7 +110,7 @@ class StableDiffusionProcessing:
cached_uc = [None, None]
cached_c = [None, None]
- def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = None, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
if sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
@@ -147,8 +148,8 @@ class StableDiffusionProcessing:
self.s_min_uncond = s_min_uncond or opts.s_min_uncond
self.s_churn = s_churn or opts.s_churn
self.s_tmin = s_tmin or opts.s_tmin
- self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
- self.s_noise = s_noise or opts.s_noise
+ self.s_tmax = (s_tmax if s_tmax is not None else opts.s_tmax) or float('inf')
+ self.s_noise = s_noise if s_noise is not None else opts.s_noise
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
self.override_settings_restore_afterwards = override_settings_restore_afterwards
self.is_using_inpainting_conditioning = False
@@ -202,7 +203,7 @@ class StableDiffusionProcessing:
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
+ conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
conditioning = torch.nn.functional.interpolate(
self.sd_model.depth_model(midas_in),
size=conditioning_image.shape[2:],
@@ -215,7 +216,7 @@ class StableDiffusionProcessing:
return conditioning
def edit_image_conditioning(self, source_image):
- conditioning_image = self.sd_model.encode_first_stage(source_image).mode()
+ conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
return conditioning_image
@@ -294,7 +295,7 @@ class StableDiffusionProcessing:
self.sampler = None
self.c = None
self.uc = None
- if not opts.experimental_persistent_cond_cache:
+ if not opts.persistent_cond_cache:
StableDiffusionProcessing.cached_c = [None, None]
StableDiffusionProcessing.cached_uc = [None, None]
@@ -318,6 +319,21 @@ class StableDiffusionProcessing:
self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
+ def cached_params(self, required_prompts, steps, extra_network_data):
+ """Returns parameters that invalidate the cond cache if changed"""
+
+ return (
+ required_prompts,
+ steps,
+ opts.CLIP_stop_at_last_layers,
+ shared.sd_model.sd_checkpoint_info,
+ extra_network_data,
+ opts.sdxl_crop_left,
+ opts.sdxl_crop_top,
+ self.width,
+ self.height,
+ )
+
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data):
"""
Returns the result of calling function(shared.sd_model, required_prompts, steps)
@@ -331,17 +347,7 @@ class StableDiffusionProcessing:
caches is a list with items described above.
"""
- cached_params = (
- required_prompts,
- steps,
- opts.CLIP_stop_at_last_layers,
- shared.sd_model.sd_checkpoint_info,
- extra_network_data,
- opts.sdxl_crop_left,
- opts.sdxl_crop_top,
- self.width,
- self.height,
- )
+ cached_params = self.cached_params(required_prompts, steps, extra_network_data)
for cache in caches:
if cache[0] is not None and cached_params == cache[0]:
@@ -367,6 +373,10 @@ class StableDiffusionProcessing:
def parse_extra_network_prompts(self):
self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
+ def save_samples(self) -> bool:
+ """Returns whether generated images need to be written to disk"""
+ return opts.samples_save and not self.do_not_save_samples and (opts.save_incomplete_images or not state.interrupted and not state.skipped)
+
class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
@@ -492,7 +502,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
subnoise = None
- if subseeds is not None:
+ if subseeds is not None and subseed_strength != 0:
subseed = 0 if i >= len(subseeds) else subseeds[i]
subnoise = devices.randn(subseed, noise_shape)
@@ -524,7 +534,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
cnt = p.sampler.number_of_needed_noises(p)
if eta_noise_seed_delta > 0:
- torch.manual_seed(seed + eta_noise_seed_delta)
+ devices.manual_seed(seed + eta_noise_seed_delta)
for j in range(cnt):
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
@@ -538,8 +548,12 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
return x
+class DecodedSamples(list):
+ already_decoded = True
+
+
def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
- samples = []
+ samples = DecodedSamples()
for i in range(batch.shape[0]):
sample = decode_first_stage(model, batch[i:i + 1])[0]
@@ -572,12 +586,6 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
return samples
-def decode_first_stage(model, x):
- x = model.decode_first_stage(x.to(devices.dtype_vae))
-
- return x
-
-
def get_fixed_seed(seed):
if seed is None or seed == '' or seed == -1:
return int(random.randrange(4294967294))
@@ -636,7 +644,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
"Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
"Init image hash": getattr(p, 'init_img_hash', None),
- "RNG": opts.randn_source if opts.randn_source != "GPU" else None,
+ "RNG": opts.randn_source if opts.randn_source != "GPU" and opts.randn_source != "NV" else None,
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
**p.extra_generation_params,
"Version": program_version() if opts.add_version_to_infotext else None,
@@ -793,7 +801,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
- x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
+ if getattr(samples_ddim, 'already_decoded', False):
+ x_samples_ddim = samples_ddim
+ else:
+ if opts.sd_vae_decode_method != 'Full':
+ p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
+
+ x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
+
x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@@ -817,6 +832,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
def infotext(index=0, use_main_prompt=False):
return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts)
+ save_samples = p.save_samples()
+
for i, x_sample in enumerate(x_samples_ddim):
p.batch_index = i
@@ -824,7 +841,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
x_sample = x_sample.astype(np.uint8)
if p.restore_faces:
- if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
+ if save_samples and opts.save_images_before_face_restoration:
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-face-restoration")
devices.torch_gc()
@@ -838,16 +855,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
pp = scripts.PostprocessImageArgs(image)
p.scripts.postprocess_image(p, pp)
image = pp.image
-
if p.color_corrections is not None and i < len(p.color_corrections):
- if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
+ if save_samples and opts.save_images_before_color_correction:
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
- if opts.samples_save and not p.do_not_save_samples:
+ if save_samples:
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
text = infotext(i)
@@ -855,8 +871,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.enable_pnginfo:
image.info["parameters"] = text
output_images.append(image)
-
- if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
+ if save_samples and hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
image_mask = p.mask_for_overlay.convert('RGB')
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
@@ -892,7 +907,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
grid.info["parameters"] = text
output_images.insert(0, grid)
index_of_first_image = 1
-
if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
@@ -935,7 +949,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
cached_hr_uc = [None, None]
cached_hr_c = [None, None]
- def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
+ def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_checkpoint_name: str = None, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
super().__init__(**kwargs)
self.enable_hr = enable_hr
self.denoising_strength = denoising_strength
@@ -946,11 +960,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.hr_resize_y = hr_resize_y
self.hr_upscale_to_x = hr_resize_x
self.hr_upscale_to_y = hr_resize_y
+ self.hr_checkpoint_name = hr_checkpoint_name
+ self.hr_checkpoint_info = None
self.hr_sampler_name = hr_sampler_name
self.hr_prompt = hr_prompt
self.hr_negative_prompt = hr_negative_prompt
self.all_hr_prompts = None
self.all_hr_negative_prompts = None
+ self.latent_scale_mode = None
if firstphase_width != 0 or firstphase_height != 0:
self.hr_upscale_to_x = self.width
@@ -973,6 +990,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
+ if self.hr_checkpoint_name:
+ self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
+
+ if self.hr_checkpoint_info is None:
+ raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}')
+
+ self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title
+
if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
@@ -982,6 +1007,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
+ self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
+ if self.enable_hr and self.latent_scale_mode is None:
+ if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
+ raise Exception(f"could not find upscaler named {self.hr_upscaler}")
+
if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
self.hr_resize_x = self.width
self.hr_resize_y = self.height
@@ -1020,14 +1050,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
- # special case: the user has chosen to do nothing
- if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height:
- self.enable_hr = False
- self.denoising_strength = None
- self.extra_generation_params.pop("Hires upscale", None)
- self.extra_generation_params.pop("Hires resize", None)
- return
-
if not state.processing_has_refined_job_count:
if state.job_count == -1:
state.job_count = self.n_iter
@@ -1045,17 +1067,32 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
- latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
- if self.enable_hr and latent_scale_mode is None:
- if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
- raise Exception(f"could not find upscaler named {self.hr_upscaler}")
-
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
+ del x
if not self.enable_hr:
return samples
+ if self.latent_scale_mode is None:
+ decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
+ else:
+ decoded_samples = None
+
+ current = shared.sd_model.sd_checkpoint_info
+ try:
+ if self.hr_checkpoint_info is not None:
+ self.sampler = None
+ sd_models.reload_model_weights(info=self.hr_checkpoint_info)
+ devices.torch_gc()
+
+ return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
+ finally:
+ self.sampler = None
+ sd_models.reload_model_weights(info=current)
+ devices.torch_gc()
+
+ def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
self.is_hr_pass = True
target_width = self.hr_upscale_to_x
@@ -1064,7 +1101,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
def save_intermediate(image, index):
"""saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
- if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
+ if not self.save_samples() or not opts.save_images_before_highres_fix:
return
if not isinstance(image, Image.Image):
@@ -1073,11 +1110,18 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
- if latent_scale_mode is not None:
+ img2img_sampler_name = self.hr_sampler_name or self.sampler_name
+
+ if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
+ img2img_sampler_name = 'DDIM'
+
+ self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
+
+ if self.latent_scale_mode is not None:
for i in range(samples.shape[0]):
save_intermediate(samples, i)
- samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
+ samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=self.latent_scale_mode["mode"], antialias=self.latent_scale_mode["antialias"])
# Avoid making the inpainting conditioning unless necessary as
# this does need some extra compute to decode / encode the image again.
@@ -1086,7 +1130,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else:
image_conditioning = self.txt2img_image_conditioning(samples)
else:
- decoded_samples = decode_first_stage(self.sd_model, samples)
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
batch_images = []
@@ -1103,28 +1146,21 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
batch_images.append(image)
decoded_samples = torch.from_numpy(np.array(batch_images))
- decoded_samples = decoded_samples.to(shared.device)
- decoded_samples = 2. * decoded_samples - 1.
+ decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)
- samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
+ if opts.sd_vae_encode_method != 'Full':
+ self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
+ samples = images_tensor_to_samples(decoded_samples, approximation_indexes.get(opts.sd_vae_encode_method))
image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
shared.state.nextjob()
- img2img_sampler_name = self.hr_sampler_name or self.sampler_name
-
- if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
- img2img_sampler_name = 'DDIM'
-
- self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
-
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
# GC now before running the next img2img to prevent running out of memory
- x = None
devices.torch_gc()
if not self.disable_extra_networks:
@@ -1143,15 +1179,17 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
+ decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
+
self.is_hr_pass = False
- return samples
+ return decoded_samples
def close(self):
super().close()
self.hr_c = None
self.hr_uc = None
- if not opts.experimental_persistent_cond_cache:
+ if not opts.persistent_cond_cache:
StableDiffusionProcessingTxt2Img.cached_hr_uc = [None, None]
StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None]
@@ -1184,8 +1222,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if self.hr_c is not None:
return
- self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
- self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
+ hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
+ hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)
+
+ self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
+ self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
def setup_conds(self):
super().setup_conds()
@@ -1193,7 +1234,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.hr_uc = None
self.hr_c = None
- if self.enable_hr:
+ if self.enable_hr and self.hr_checkpoint_info is None:
if shared.opts.hires_fix_use_firstpass_conds:
self.calculate_hr_conds()
@@ -1344,10 +1385,13 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
image = torch.from_numpy(batch_images)
- image = 2. * image - 1.
image = image.to(shared.device, dtype=devices.dtype_vae)
- self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
+ if opts.sd_vae_encode_method != 'Full':
+ self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
+
+ self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
+ devices.torch_gc()
if self.resize_mode == 3:
self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")