aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py40
-rw-r--r--modules/api/models.py6
-rw-r--r--modules/extensions.py10
-rw-r--r--modules/launch_utils.py8
-rw-r--r--modules/processing.py41
-rw-r--r--modules/prompt_parser.py2
-rw-r--r--modules/scripts.py34
-rw-r--r--modules/sd_hijack.py2
-rw-r--r--modules/sd_hijack_clip.py9
-rw-r--r--modules/ui_extra_networks.py2
10 files changed, 103 insertions, 51 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 2a4cd8a2..606db179 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -333,14 +333,16 @@ class Api:
p.outpath_grids = opts.outdir_txt2img_grids
p.outpath_samples = opts.outdir_txt2img_samples
- shared.state.begin(job="scripts_txt2img")
- if selectable_scripts is not None:
- p.script_args = script_args
- processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
- else:
- p.script_args = tuple(script_args) # Need to pass args as tuple here
- processed = process_images(p)
- shared.state.end()
+ try:
+ shared.state.begin(job="scripts_txt2img")
+ if selectable_scripts is not None:
+ p.script_args = script_args
+ processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
+ else:
+ p.script_args = tuple(script_args) # Need to pass args as tuple here
+ processed = process_images(p)
+ finally:
+ shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
@@ -390,14 +392,16 @@ class Api:
p.outpath_grids = opts.outdir_img2img_grids
p.outpath_samples = opts.outdir_img2img_samples
- shared.state.begin(job="scripts_img2img")
- if selectable_scripts is not None:
- p.script_args = script_args
- processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
- else:
- p.script_args = tuple(script_args) # Need to pass args as tuple here
- processed = process_images(p)
- shared.state.end()
+ try:
+ shared.state.begin(job="scripts_img2img")
+ if selectable_scripts is not None:
+ p.script_args = script_args
+ processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
+ else:
+ p.script_args = tuple(script_args) # Need to pass args as tuple here
+ processed = process_images(p)
+ finally:
+ shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
@@ -720,9 +724,9 @@ class Api:
cuda = {'error': f'{err}'}
return models.MemoryResponse(ram=ram, cuda=cuda)
- def launch(self, server_name, port):
+ def launch(self, server_name, port, root_path):
self.app.include_router(self.router)
- uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive)
+ uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)
def kill_webui(self):
restart.stop_program()
diff --git a/modules/api/models.py b/modules/api/models.py
index bf97b1a3..800c9b93 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -208,11 +208,9 @@ class PreprocessResponse(BaseModel):
fields = {}
for key, metadata in opts.data_labels.items():
value = opts.data.get(key)
- optType = opts.typemap.get(type(metadata.default), type(metadata.default))
+ optType = opts.typemap.get(type(metadata.default), type(metadata.default)) if metadata.default else Any
- if metadata.default is None:
- pass
- elif metadata is not None:
+ if metadata is not None:
fields.update({key: (Optional[optType], Field(default=metadata.default, description=metadata.label))})
else:
fields.update({key: (Optional[optType], Field())})
diff --git a/modules/extensions.py b/modules/extensions.py
index c561159a..3ad5ed53 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -56,10 +56,12 @@ class Extension:
self.do_read_info_from_repo()
return self.to_dict()
-
- d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
- self.from_dict(d)
- self.status = 'unknown'
+ try:
+ d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
+ self.from_dict(d)
+ except FileNotFoundError:
+ pass
+ self.status = 'unknown' if self.status == '' else self.status
def do_read_info_from_repo(self):
repo = None
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index 03552bc2..e1c9cfbe 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -196,7 +196,7 @@ def run_extension_installer(extension_dir):
try:
env = os.environ.copy()
- env['PYTHONPATH'] = os.path.abspath(".")
+ env['PYTHONPATH'] = f"{os.path.abspath('.')}{os.pathsep}{env.get('PYTHONPATH', '')}"
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
except Exception as e:
@@ -233,7 +233,7 @@ def run_extensions_installers(settings_file):
re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
-def requrements_met(requirements_file):
+def requirements_met(requirements_file):
"""
Does a simple parse of a requirements.txt file to determine if all rerqirements in it
are already installed. Returns True if so, False if not installed or parsing fails.
@@ -293,7 +293,7 @@ def prepare_environment():
try:
# the existance of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
os.remove(os.path.join(script_path, "tmp", "restart"))
- os.environ.setdefault('SD_WEBUI_RESTARTING ', '1')
+ os.environ.setdefault('SD_WEBUI_RESTARTING', '1')
except OSError:
pass
@@ -354,7 +354,7 @@ def prepare_environment():
if not os.path.isfile(requirements_file):
requirements_file = os.path.join(script_path, requirements_file)
- if not requrements_met(requirements_file):
+ if not requirements_met(requirements_file):
run_pip(f"install -r \"{requirements_file}\"", "requirements")
run_extensions_installers(settings_file=args.ui_settings_file)
diff --git a/modules/processing.py b/modules/processing.py
index a74a5302..b0992ee1 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -600,8 +600,12 @@ def program_version():
return res
-def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False):
- index = position_in_batch + iteration * p.batch_size
+def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False, index=None, all_negative_prompts=None):
+ if index is None:
+ index = position_in_batch + iteration * p.batch_size
+
+ if all_negative_prompts is None:
+ all_negative_prompts = p.all_negative_prompts
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
enable_hr = getattr(p, 'enable_hr', False)
@@ -617,12 +621,12 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Sampler": p.sampler_name,
"CFG scale": p.cfg_scale,
"Image CFG scale": getattr(p, 'image_cfg_scale', None),
- "Seed": all_seeds[index],
+ "Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra),
- "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
+ "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None),
@@ -642,7 +646,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
prompt_text = p.prompt if use_main_prompt else all_prompts[index]
- negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
+ negative_prompt_text = f"\nNegative prompt: {all_negative_prompts[index]}" if all_negative_prompts[index] else ""
return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
@@ -716,9 +720,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
else:
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
- def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):
- return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
-
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()
@@ -806,6 +807,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None:
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
+ p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+ p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+
+ batch_params = scripts.PostprocessBatchListArgs(list(x_samples_ddim))
+ p.scripts.postprocess_batch_list(p, batch_params, batch_number=n)
+ x_samples_ddim = batch_params.images
+
+ def infotext(index=0, use_main_prompt=False):
+ return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts)
+
for i, x_sample in enumerate(x_samples_ddim):
p.batch_index = i
@@ -814,7 +825,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.restore_faces:
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
- images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
+ images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-face-restoration")
devices.torch_gc()
@@ -831,15 +842,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.color_corrections is not None and i < len(p.color_corrections):
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
- images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
+ images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
if opts.samples_save and not p.do_not_save_samples:
- images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p)
+ images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
- text = infotext(n, i)
+ text = infotext(i)
infotexts.append(text)
if opts.enable_pnginfo:
image.info["parameters"] = text
@@ -850,10 +861,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
if opts.save_mask:
- images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
+ images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
if opts.save_mask_composite:
- images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask-composite")
+ images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
if opts.return_mask:
output_images.append(image_mask)
@@ -894,7 +905,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p,
images_list=output_images,
seed=p.all_seeds[0],
- info=infotext(),
+ info=infotexts[0],
comments="".join(f"{comment}\n" for comment in comments),
subseed=p.all_subseeds[0],
index_of_first_image=index_of_first_image,
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index b29d079d..203ae1ac 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -178,7 +178,7 @@ def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
re_AND = re.compile(r"\bAND\b")
-re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
+re_weight = re.compile(r"^((?:\s|.)*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
def get_multicond_prompt_list(prompts: SdConditioning | list[str]):
diff --git a/modules/scripts.py b/modules/scripts.py
index 7d9dd59f..5b4edcac 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -16,6 +16,11 @@ class PostprocessImageArgs:
self.image = image
+class PostprocessBatchListArgs:
+ def __init__(self, images):
+ self.images = images
+
+
class Script:
name = None
"""script's internal name derived from title"""
@@ -119,7 +124,7 @@ class Script:
def after_extra_networks_activate(self, p, *args, **kwargs):
"""
- Calledafter extra networks activation, before conds calculation
+ Called after extra networks activation, before conds calculation
allow modification of the network after extra networks activation been applied
won't be call if p.disable_extra_networks
@@ -156,6 +161,25 @@ class Script:
pass
+ def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs):
+ """
+ Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor.
+ This is useful when you want to update the entire batch instead of individual images.
+
+ You can modify the postprocessing object (pp) to update the images in the batch, remove images, add images, etc.
+ If the number of images is different from the batch size when returning,
+ then the script has the responsibility to also update the following attributes in the processing object (p):
+ - p.prompts
+ - p.negative_prompts
+ - p.seeds
+ - p.subseeds
+
+ **kwargs will have same items as process_batch, and also:
+ - batch_number - index of current batch, from 0 to number of batches-1
+ """
+
+ pass
+
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
"""
Called for every image after it has been generated.
@@ -536,6 +560,14 @@ class ScriptRunner:
except Exception:
errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)
+ def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.postprocess_batch_list(p, pp, *script_args, **kwargs)
+ except Exception:
+ errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
+
def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts:
try:
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index f5615967..c8fdd4f1 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -243,7 +243,7 @@ class StableDiffusionModelHijack:
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
def undo_hijack(self, m):
- if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
+ if type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index 5443e609..16a5500e 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -270,12 +270,17 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
z = self.encode_with_transformers(tokens)
+ pooled = getattr(z, 'pooled', None)
+
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
original_mean = z.mean()
- z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
+ z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
- z *= (original_mean / new_mean)
+ z = z * (original_mean / new_mean)
+
+ if pooled is not None:
+ z.pooled = pooled
return z
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 49612298..f2752f10 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -253,7 +253,7 @@ class ExtraNetworksPage:
"prompt": item.get("prompt", None),
"tabname": quote_js(tabname),
"local_preview": quote_js(item["local_preview"]),
- "name": item["name"],
+ "name": html.escape(item["name"]),
"description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""),
"card_clicked": onclick,
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"',