aboutsummaryrefslogtreecommitdiff
path: root/modules/processing.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py189
1 files changed, 125 insertions, 64 deletions
diff --git a/modules/processing.py b/modules/processing.py
index c61bbfbd..2168208c 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -77,9 +77,8 @@ def get_correct_sampler(p):
class StableDiffusionProcessing():
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
-
"""
- def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str="", styles: List[str]=None, seed: int=-1, subseed: int=-1, subseed_strength: float=0, seed_resize_from_h: int=-1, seed_resize_from_w: int=-1, seed_enable_extras: bool=True, sampler_index: int=0, batch_size: int=1, n_iter: int=1, steps:int =50, cfg_scale:float=7.0, width:int=512, height:int=512, restore_faces:bool=False, tiling:bool=False, do_not_save_samples:bool=False, do_not_save_grid:bool=False, extra_generation_params: Dict[Any,Any]=None, overlay_images: Any=None, negative_prompt: str=None, eta: float =None, do_not_reload_embeddings: bool=False, denoising_strength: float = 0, ddim_discretize: str = "uniform", s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0):
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_index: int = 0, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None):
self.sd_model = sd_model
self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids
@@ -109,13 +108,14 @@ class StableDiffusionProcessing():
self.do_not_reload_embeddings = do_not_reload_embeddings
self.paste_to = None
self.color_corrections = None
- self.denoising_strength: float = 0
+ self.denoising_strength: float = denoising_strength
self.sampler_noise_scheduler_override = None
- self.ddim_discretize = opts.ddim_discretize
+ self.ddim_discretize = ddim_discretize or opts.ddim_discretize
self.s_churn = s_churn or opts.s_churn
self.s_tmin = s_tmin or opts.s_tmin
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
self.s_noise = s_noise or opts.s_noise
+ self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
if not seed_enable_extras:
self.subseed = -1
@@ -129,13 +129,75 @@ class StableDiffusionProcessing():
self.all_seeds = None
self.all_subseeds = None
+ def txt2img_image_conditioning(self, x, width=None, height=None):
+ if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
+ # Dummy zero conditioning if we're not using inpainting model.
+ # Still takes up a bit of memory, but no encoder call.
+ # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
+ return x.new_zeros(x.shape[0], 5, 1, 1)
+
+ height = height or self.height
+ width = width or self.width
+
+ # The "masked-image" in this case will just be all zeros since the entire image is masked.
+ image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
+ image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
+
+ # Add the fake full 1s mask to the first dimension.
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
+ image_conditioning = image_conditioning.to(x.dtype)
+
+ return image_conditioning
+
+ def img2img_image_conditioning(self, source_image, latent_image, image_mask = None):
+ if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
+ # Dummy zero conditioning if we're not using inpainting model.
+ return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
+
+ # Handle the different mask inputs
+ if image_mask is not None:
+ if torch.is_tensor(image_mask):
+ conditioning_mask = image_mask
+ else:
+ conditioning_mask = np.array(image_mask.convert("L"))
+ conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
+ conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
+
+ # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
+ conditioning_mask = torch.round(conditioning_mask)
+ else:
+ conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
+
+ # Create another latent image, this time with a masked version of the original input.
+ # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
+ conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype)
+ conditioning_image = torch.lerp(
+ source_image,
+ source_image * (1.0 - conditioning_mask),
+ getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
+ )
+
+ # Encode the new masked image using first stage of network.
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
+
+ # Create the concatenated conditioning tensor to be fed to `c_concat`
+ conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
+ conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
+ image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
+ image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)
+
+ return image_conditioning
def init(self, all_prompts, all_seeds, all_subseeds):
pass
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
raise NotImplementedError()
+ def close(self):
+ self.sd_model = None
+ self.sampler = None
+
class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
@@ -330,6 +392,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
+ "Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
@@ -351,6 +414,22 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
def process_images(p: StableDiffusionProcessing) -> Processed:
+ stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
+
+ try:
+ for k, v in p.override_settings.items():
+ opts.data[k] = v # we don't call onchange for simplicity which makes changing model, hypernet impossible
+
+ res = process_images_inner(p)
+
+ finally:
+ for k, v in stored_opts.items():
+ opts.data[k] = v
+
+ return res
+
+
+def process_images_inner(p: StableDiffusionProcessing) -> Processed:
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
if type(p.prompt) == list:
@@ -396,7 +475,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
model_hijack.embedding_db.load_textual_inversion_embeddings()
if p.scripts is not None:
- p.scripts.run_alwayson_scripts(p)
+ p.scripts.process(p)
infotexts = []
output_images = []
@@ -419,7 +498,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
- if (len(prompts) == 0):
+ if len(prompts) == 0:
break
with devices.autocast():
@@ -434,7 +513,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
with devices.autocast():
- samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
+ samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
samples_ddim = samples_ddim.to(devices.dtype_vae)
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
@@ -508,7 +587,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
devices.torch_gc()
- return Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
+
+ res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
+
+ if p.scripts is not None:
+ p.scripts.postprocess(p, res)
+
+ return res
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
@@ -556,43 +641,41 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
- def create_dummy_mask(self, x, width=None, height=None):
- if self.sampler.conditioning_key in {'hybrid', 'concat'}:
- height = height or self.height
- width = width or self.width
-
- # The "masked-image" in this case will just be all zeros since the entire image is masked.
- image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
- image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
-
- # Add the fake full 1s mask to the first dimension.
- image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
- image_conditioning = image_conditioning.to(x.dtype)
-
- else:
- # Dummy zero conditioning if we're not using inpainting model.
- # Still takes up a bit of memory, but no encoder call.
- # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
- image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
-
- return image_conditioning
-
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x))
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
return samples
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, self.firstphase_width, self.firstphase_height))
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height))
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
+ """saves image before applying hires fix, if enabled in options; takes as an arguyment either an image or batch with latent space images"""
+ def save_intermediate(image, index):
+ if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
+ return
+
+ if not isinstance(image, Image.Image):
+ image = sd_samplers.sample_to_image(image, index)
+
+ images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
+
if opts.use_scale_latent_for_hires_fix:
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+
+ # Avoid making the inpainting conditioning unless necessary as
+ # this does need some extra compute to decode / encode the image again.
+ if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
+ image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
+ else:
+ image_conditioning = self.txt2img_image_conditioning(samples)
+ for i in range(samples.shape[0]):
+ save_intermediate(samples, i)
else:
decoded_samples = decode_first_stage(self.sd_model, samples)
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
@@ -602,6 +685,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
image = Image.fromarray(x_sample)
+
+ save_intermediate(image, i)
+
image = images.resize_image(0, image, self.width, self.height)
image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
@@ -613,6 +699,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
+ image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
+
shared.state.nextjob()
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
@@ -623,7 +711,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = None
devices.torch_gc()
- samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=self.create_dummy_mask(samples))
+ samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning)
return samples
@@ -755,36 +843,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
- if self.sampler.conditioning_key in {'hybrid', 'concat'}:
- if self.image_mask is not None:
- conditioning_mask = np.array(self.image_mask.convert("L"))
- conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
- conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
+ self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask)
- # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
- conditioning_mask = torch.round(conditioning_mask)
- else:
- conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
-
- # Create another latent image, this time with a masked version of the original input.
- conditioning_mask = conditioning_mask.to(image.device)
- conditioning_image = image * (1.0 - conditioning_mask)
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
-
- # Create the concatenated conditioning tensor to be fed to `c_concat`
- conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
- conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
- self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
- self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
- else:
- self.image_conditioning = torch.zeros(
- self.init_latent.shape[0], 5, 1, 1,
- dtype=self.init_latent.dtype,
- device=self.init_latent.device
- )
-
-
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
@@ -795,4 +856,4 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
del x
devices.torch_gc()
- return samples \ No newline at end of file
+ return samples