aboutsummaryrefslogtreecommitdiff
path: root/modules/processing.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2022-10-15 10:47:26 +0300
committerGitHub <noreply@github.com>2022-10-15 10:47:26 +0300
commitf42e0aae6de6b9a7f8da4eaf13594a13502b4fa9 (patch)
tree472025101577ff5cbd45a3bcb524e6e4accb75ec /modules/processing.py
parent0e77ee24b0b651d6a564245243850e4fb9831e31 (diff)
parentd13ce89e203d76ab2b54a3406a93a5e4304f529e (diff)
Merge branch 'master' into master
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py233
1 files changed, 157 insertions, 76 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 1da753a2..7e2a416d 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -1,4 +1,3 @@
-import contextlib
import json
import math
import os
@@ -12,9 +11,8 @@ import cv2
from skimage import exposure
import modules.sd_hijack
-from modules import devices, prompt_parser, masking
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram
from modules.sd_hijack import model_hijack
-from modules.sd_samplers import samplers, samplers_for_img2img
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
import modules.face_restoration
@@ -48,6 +46,12 @@ def apply_color_correction(correction, image):
return image
+def get_correct_sampler(p):
+ if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img):
+ return sd_samplers.samplers
+ elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
+ return sd_samplers.samplers_for_img2img
+
class StableDiffusionProcessing:
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
self.sd_model = sd_model
@@ -56,7 +60,7 @@ class StableDiffusionProcessing:
self.prompt: str = prompt
self.prompt_for_display: str = None
self.negative_prompt: str = (negative_prompt or "")
- self.styles: str = styles
+ self.styles: list = styles or []
self.seed: int = seed
self.subseed: int = subseed
self.subseed_strength: float = subseed_strength
@@ -85,7 +89,7 @@ class StableDiffusionProcessing:
self.s_tmin = opts.s_tmin
self.s_tmax = float('inf') # not representable as a standard ui option
self.s_noise = opts.s_noise
-
+
if not seed_enable_extras:
self.subseed = -1
self.subseed_strength = 0
@@ -111,7 +115,7 @@ class Processed:
self.width = p.width
self.height = p.height
self.sampler_index = p.sampler_index
- self.sampler = samplers[p.sampler_index].name
+ self.sampler = sd_samplers.samplers[p.sampler_index].name
self.cfg_scale = p.cfg_scale
self.steps = p.steps
self.batch_size = p.batch_size
@@ -123,6 +127,9 @@ class Processed:
self.denoising_strength = getattr(p, 'denoising_strength', None)
self.extra_generation_params = p.extra_generation_params
self.index_of_first_image = index_of_first_image
+ self.styles = p.styles
+ self.job_timestamp = state.job_timestamp
+ self.clip_skip = opts.CLIP_stop_at_last_layers
self.eta = p.eta
self.ddim_discretize = p.ddim_discretize
@@ -133,7 +140,7 @@ class Processed:
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
- self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
+ self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
self.all_prompts = all_prompts or [self.prompt]
@@ -167,6 +174,9 @@ class Processed:
"extra_generation_params": self.extra_generation_params,
"index_of_first_image": self.index_of_first_image,
"infotexts": self.infotexts,
+ "styles": self.styles,
+ "job_timestamp": self.job_timestamp,
+ "clip_skip": self.clip_skip,
}
return json.dumps(obj)
@@ -197,7 +207,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
# enables the generation of additional tensors with noise that the sampler will use during its processing.
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
# produce the same images as with two batches [100], [101].
- if p is not None and p.sampler is not None and len(seeds) > 1 and opts.enable_batch_seeds:
+ if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0):
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
else:
sampler_noises = None
@@ -237,6 +247,9 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
if sampler_noises is not None:
cnt = p.sampler.number_of_needed_noises(p)
+ if opts.eta_noise_seed_delta > 0:
+ torch.manual_seed(seed + opts.eta_noise_seed_delta)
+
for j in range(cnt):
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
@@ -249,29 +262,49 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
return x
+def decode_first_stage(model, x):
+ with devices.autocast(disable=x.dtype == devices.dtype_vae):
+ x = model.decode_first_stage(x)
+
+ return x
+
+
+def get_fixed_seed(seed):
+ if seed is None or seed == '' or seed == -1:
+ return int(random.randrange(4294967294))
+
+ return seed
+
+
def fix_seed(p):
- p.seed = int(random.randrange(4294967294)) if p.seed is None or p.seed == '' or p.seed == -1 else p.seed
- p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == '' or p.subseed == -1 else p.subseed
+ p.seed = get_fixed_seed(p.seed)
+ p.subseed = get_fixed_seed(p.subseed)
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
index = position_in_batch + iteration * p.batch_size
+ clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
+
generation_params = {
"Steps": p.steps,
- "Sampler": samplers[p.sampler_index].name,
+ "Sampler": get_correct_sampler(p)[p.sampler_index].name,
"CFG scale": p.cfg_scale,
"Seed": all_seeds[index],
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
+ "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
+ "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(':', '')),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None),
- "Eta": (None if p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
+ "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
+ "Clip skip": None if clip_skip <= 1 else clip_skip,
+ "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
}
generation_params.update(p.extra_generation_params)
@@ -290,15 +323,24 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
assert(len(p.prompt) > 0)
else:
assert p.prompt is not None
-
+
+ with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
+ processed = Processed(p, [], p.seed, "")
+ file.write(processed.infotext(p, 0))
+
devices.torch_gc()
- fix_seed(p)
+ seed = get_fixed_seed(p.seed)
+ subseed = get_fixed_seed(p.subseed)
- os.makedirs(p.outpath_samples, exist_ok=True)
- os.makedirs(p.outpath_grids, exist_ok=True)
+ if p.outpath_samples is not None:
+ os.makedirs(p.outpath_samples, exist_ok=True)
+
+ if p.outpath_grids is not None:
+ os.makedirs(p.outpath_grids, exist_ok=True)
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
+ modules.sd_hijack.model_hijack.clear_comments()
comments = {}
@@ -309,33 +351,36 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
else:
all_prompts = p.batch_size * p.n_iter * [p.prompt]
- if type(p.seed) == list:
- all_seeds = p.seed
+ if type(seed) == list:
+ all_seeds = seed
else:
- all_seeds = [int(p.seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
+ all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
- if type(p.subseed) == list:
- all_subseeds = p.subseed
+ if type(subseed) == list:
+ all_subseeds = subseed
else:
- all_subseeds = [int(p.subseed) + x for x in range(len(all_prompts))]
+ all_subseeds = [int(subseed) + x for x in range(len(all_prompts))]
def infotext(iteration=0, position_in_batch=0):
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
if os.path.exists(cmd_opts.embeddings_dir):
- model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)
+ model_hijack.embedding_db.load_textual_inversion_embeddings()
infotexts = []
output_images = []
- precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
- ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
- with torch.no_grad(), precision_scope("cuda"), ema_scope():
- p.init(all_prompts, all_seeds, all_subseeds)
+
+ with torch.no_grad(), p.sd_model.ema_scope():
+ with devices.autocast():
+ p.init(all_prompts, all_seeds, all_subseeds)
if state.job_count == -1:
state.job_count = p.n_iter
for n in range(p.n_iter):
+ if state.skipped:
+ state.skipped = False
+
if state.interrupted:
break
@@ -348,8 +393,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
#uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
#c = p.sd_model.get_learned_conditioning(prompts)
- uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps)
- c = prompt_parser.get_learned_conditioning(prompts, p.steps)
+ with devices.autocast():
+ uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
+ c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
if len(model_hijack.comments) > 0:
for comment in model_hijack.comments:
@@ -358,16 +404,26 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
- samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
- if state.interrupted:
+ with devices.autocast():
+ samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
+
+ if state.interrupted or state.skipped:
- # if we are interruped, sample returns just noise
+ # if we are interrupted, sample returns just noise
# use the image collected previously in sampler loop
samples_ddim = shared.state.current_latent
- x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
+ samples_ddim = samples_ddim.to(devices.dtype_vae)
+ x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+ del samples_ddim
+
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ lowvram.send_everything_to_cpu()
+
+ devices.torch_gc()
+
if opts.filter_nsfw:
import modules.safety as safety
x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
@@ -383,6 +439,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc()
x_sample = modules.face_restoration.restore_faces(x_sample)
+ devices.torch_gc()
image = Image.fromarray(x_sample)
@@ -408,9 +465,16 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if opts.samples_save and not p.do_not_save_samples:
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
- infotexts.append(infotext(n, i))
+ text = infotext(n, i)
+ infotexts.append(text)
+ if opts.enable_pnginfo:
+ image.info["parameters"] = text
output_images.append(image)
+ del x_samples_ddim
+
+ devices.torch_gc()
+
state.nextjob()
p.color_corrections = None
@@ -421,7 +485,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
grid = images.image_grid(output_images, p.batch_size)
if opts.return_grid:
- infotexts.insert(0, infotext())
+ text = infotext()
+ infotexts.insert(0, text)
+ if opts.enable_pnginfo:
+ grid.info["parameters"] = text
output_images.insert(0, grid)
index_of_first_image = 1
@@ -434,16 +501,15 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None
- firstphase_width = 0
- firstphase_height = 0
- firstphase_width_truncated = 0
- firstphase_height_truncated = 0
- def __init__(self, enable_hr=False, scale_latent=True, denoising_strength=0.75, **kwargs):
+ def __init__(self, enable_hr=False, denoising_strength=0.75, firstphase_width=0, firstphase_height=0, **kwargs):
super().__init__(**kwargs)
self.enable_hr = enable_hr
- self.scale_latent = scale_latent
self.denoising_strength = denoising_strength
+ self.firstphase_width = firstphase_width
+ self.firstphase_height = firstphase_height
+ self.truncate_x = 0
+ self.truncate_y = 0
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
@@ -452,17 +518,34 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else:
state.job_count = state.job_count * 2
- desired_pixel_count = 512 * 512
- actual_pixel_count = self.width * self.height
- scale = math.sqrt(desired_pixel_count / actual_pixel_count)
+ if self.firstphase_width == 0 or self.firstphase_height == 0:
+ desired_pixel_count = 512 * 512
+ actual_pixel_count = self.width * self.height
+ scale = math.sqrt(desired_pixel_count / actual_pixel_count)
+ self.firstphase_width = math.ceil(scale * self.width / 64) * 64
+ self.firstphase_height = math.ceil(scale * self.height / 64) * 64
+ firstphase_width_truncated = int(scale * self.width)
+ firstphase_height_truncated = int(scale * self.height)
+
+ else:
+ self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
+
+ width_ratio = self.width / self.firstphase_width
+ height_ratio = self.height / self.firstphase_height
+
+ if width_ratio > height_ratio:
+ firstphase_width_truncated = self.firstphase_width
+ firstphase_height_truncated = self.firstphase_width * self.height / self.width
+ else:
+ firstphase_width_truncated = self.firstphase_height * self.width / self.height
+ firstphase_height_truncated = self.firstphase_height
+
+ self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
+ self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
- self.firstphase_width = math.ceil(scale * self.width / 64) * 64
- self.firstphase_height = math.ceil(scale * self.height / 64) * 64
- self.firstphase_width_truncated = int(scale * self.width)
- self.firstphase_height_truncated = int(scale * self.height)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
- self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
+ self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
@@ -472,46 +555,41 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
- truncate_x = (self.firstphase_width - self.firstphase_width_truncated) // opt_f
- truncate_y = (self.firstphase_height - self.firstphase_height_truncated) // opt_f
+ samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
- samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2]
+ decoded_samples = decode_first_stage(self.sd_model, samples)
- if self.scale_latent:
- samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+ if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
+ decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
else:
- decoded_samples = self.sd_model.decode_first_stage(samples)
+ lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
- if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
- decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
- else:
- lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
-
- batch_images = []
- for i, x_sample in enumerate(lowres_samples):
- x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
- x_sample = x_sample.astype(np.uint8)
- image = Image.fromarray(x_sample)
- image = images.resize_image(0, image, self.width, self.height)
- image = np.array(image).astype(np.float32) / 255.0
- image = np.moveaxis(image, 2, 0)
- batch_images.append(image)
+ batch_images = []
+ for i, x_sample in enumerate(lowres_samples):
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
+ x_sample = x_sample.astype(np.uint8)
+ image = Image.fromarray(x_sample)
+ image = images.resize_image(0, image, self.width, self.height)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = np.moveaxis(image, 2, 0)
+ batch_images.append(image)
- decoded_samples = torch.from_numpy(np.array(batch_images))
- decoded_samples = decoded_samples.to(shared.device)
- decoded_samples = 2. * decoded_samples - 1.
+ decoded_samples = torch.from_numpy(np.array(batch_images))
+ decoded_samples = decoded_samples.to(shared.device)
+ decoded_samples = 2. * decoded_samples - 1.
- samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
+ samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
shared.state.nextjob()
- self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
+ self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
+
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
# GC now before running the next img2img to prevent running out of memory
x = None
devices.torch_gc()
-
+
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
return samples
@@ -540,7 +618,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.nmask = None
def init(self, all_prompts, all_seeds, all_subseeds):
- self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model)
+ self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
crop_region = None
if self.image_mask is not None:
@@ -647,4 +725,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.mask is not None:
samples = samples * self.nmask + self.init_latent * self.mask
+ del x
+ devices.torch_gc()
+
return samples