aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py9
-rw-r--r--modules/api/models.py3
-rw-r--r--modules/errors.py26
-rw-r--r--modules/extra_networks.py3
-rw-r--r--modules/generation_parameters_copypaste.py8
-rw-r--r--modules/img2img.py32
-rw-r--r--modules/launch_utils.py14
-rw-r--r--modules/lowvram.py6
-rw-r--r--modules/processing.py64
-rw-r--r--modules/sd_hijack_optimizations.py2
-rw-r--r--modules/shared.py23
-rw-r--r--modules/styles.py67
-rw-r--r--modules/sysinfo.py162
-rw-r--r--modules/ui.py18
-rw-r--r--modules/ui_extensions.py2
-rw-r--r--modules/ui_extra_networks.py18
-rw-r--r--modules/ui_extra_networks_checkpoints.py4
-rw-r--r--modules/ui_extra_networks_hypernets.py4
-rw-r--r--modules/ui_extra_networks_textual_inversion.py4
-rw-r--r--modules/ui_settings.py28
-rw-r--r--modules/ui_tempdir.py2
-rw-r--r--modules/upscaler.py6
22 files changed, 466 insertions, 39 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 555eefdb..2e49526e 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -189,6 +189,7 @@ class Api:
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
+ self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
@@ -540,6 +541,14 @@ class Api:
for upscaler in shared.sd_upscalers
]
+ def get_latent_upscale_modes(self):
+ return [
+ {
+ "name": upscale_mode,
+ }
+ for upscale_mode in [*(shared.latent_upscale_modes or {})]
+ ]
+
def get_sd_models(self):
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
diff --git a/modules/api/models.py b/modules/api/models.py
index 47fdede2..b3a745f0 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -241,6 +241,9 @@ class UpscalerItem(BaseModel):
model_url: Optional[str] = Field(title="URL")
scale: Optional[float] = Field(title="Scale")
+class LatentUpscalerModeItem(BaseModel):
+ name: str = Field(title="Name")
+
class SDModelItem(BaseModel):
title: str = Field(title="Title")
model_name: str = Field(title="Model Name")
diff --git a/modules/errors.py b/modules/errors.py
index e408f500..5271a9fe 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -3,10 +3,30 @@ import textwrap
import traceback
+exception_records = []
+
+
+def record_exception():
+ _, e, tb = sys.exc_info()
+ if e is None:
+ return
+
+ if exception_records and exception_records[-1] == e:
+ return
+
+ exception_records.append((e, tb))
+
+ if len(exception_records) > 5:
+ exception_records.pop(0)
+
+
def report(message: str, *, exc_info: bool = False) -> None:
"""
Print an error message to stderr, with optional traceback.
"""
+
+ record_exception()
+
for line in message.splitlines():
print("***", line, file=sys.stderr)
if exc_info:
@@ -15,6 +35,8 @@ def report(message: str, *, exc_info: bool = False) -> None:
def print_error_explanation(message):
+ record_exception()
+
lines = message.strip().split("\n")
max_len = max([len(x) for x in lines])
@@ -25,6 +47,8 @@ def print_error_explanation(message):
def display(e: Exception, task, *, full_traceback=False):
+ record_exception()
+
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
te = traceback.TracebackException.from_exception(e)
if full_traceback:
@@ -44,6 +68,8 @@ already_displayed = {}
def display_once(e: Exception, task):
+ record_exception()
+
if task in already_displayed:
return
diff --git a/modules/extra_networks.py b/modules/extra_networks.py
index f4743928..1f093df2 100644
--- a/modules/extra_networks.py
+++ b/modules/extra_networks.py
@@ -32,6 +32,9 @@ class ExtraNetworkParams:
else:
self.positional.append(item)
+ def __eq__(self, other):
+ return self.items == other.items
+
class ExtraNetwork:
def __init__(self, name):
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 237401a1..1d02ffae 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -265,6 +265,14 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
else:
prompt += ("" if prompt == "" else "\n") + line
+ if shared.opts.infotext_styles != "Ignore":
+ found_styles, prompt, negative_prompt = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt)
+
+ if shared.opts.infotext_styles == "Apply":
+ res["Styles array"] = found_styles
+ elif shared.opts.infotext_styles == "Apply if any" and found_styles:
+ res["Styles array"] = found_styles
+
res["Prompt"] = prompt
res["Negative prompt"] = negative_prompt
diff --git a/modules/img2img.py b/modules/img2img.py
index 35c4facc..3981b783 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -1,4 +1,5 @@
import os
+from pathlib import Path
import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
@@ -13,7 +14,7 @@ from modules.ui import plaintext_to_html
import modules.scripts
-def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
+def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0):
processing.fix_seed(p)
images = shared.listfiles(input_dir)
@@ -48,14 +49,31 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
continue
# Use the EXIF orientation of photos taken by smartphones.
img = ImageOps.exif_transpose(img)
+
+ if to_scale:
+ p.width = int(img.width * scale_by)
+ p.height = int(img.height * scale_by)
+
p.init_images = [img] * p.batch_size
+ image_path = Path(image)
if is_inpaint_batch:
# try to find corresponding mask for an image using simple filename matching
- mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
- # if not found use first one ("same mask for all images" use-case)
- if mask_image_path not in inpaint_masks:
+ if len(inpaint_masks) == 1:
mask_image_path = inpaint_masks[0]
+ else:
+ # try to find corresponding mask for an image using simple filename matching
+ mask_image_dir = Path(inpaint_mask_dir)
+ masks_found = list(mask_image_dir.glob(f"{image_path.stem}.*"))
+
+ if len(masks_found) == 0:
+ print(f"Warning: mask is not found for {image_path} in {mask_image_dir}. Skipping it.")
+ continue
+
+ # it should contain only 1 matching mask
+ # otherwise user has many masks with the same name but different extensions
+ mask_image_path = masks_found[0]
+
mask_image = Image.open(mask_image_path)
p.image_mask = mask_image
@@ -64,7 +82,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
proc = process_images(p)
for n, processed_image in enumerate(proc.images):
- filename = os.path.basename(image)
+ filename = image_path.name
if n > 0:
left, right = os.path.splitext(filename)
@@ -114,7 +132,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
if image is not None:
image = ImageOps.exif_transpose(image)
- if selected_scale_tab == 1:
+ if selected_scale_tab == 1 and not is_batch:
assert image, "Can't scale by because no image is selected"
width = int(image.width * scale_by)
@@ -169,7 +187,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
if is_batch:
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
- process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args)
+ process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by)
processed = Processed(p, [], p.seed, "")
else:
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index 6e9bb770..af8d8b37 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -68,7 +68,13 @@ def git_tag():
try:
return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
except Exception:
- return "<none>"
+ try:
+ from pathlib import Path
+ changelog_md = Path(__file__).parent.parent / "CHANGELOG.md"
+ with changelog_md.open(encoding="utf-8") as file:
+ return next((line.strip() for line in file if line.strip()), "<none>")
+ except Exception:
+ return "<none>"
def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
@@ -238,6 +244,12 @@ def prepare_environment():
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
+ try:
+ # the existance of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
+ os.remove(os.path.join(script_path, "tmp", "restart"))
+ except OSError:
+ pass
+
if not args.skip_python_version_check:
check_python_version()
diff --git a/modules/lowvram.py b/modules/lowvram.py
index e254cc13..d95bcfbf 100644
--- a/modules/lowvram.py
+++ b/modules/lowvram.py
@@ -15,6 +15,8 @@ def send_everything_to_cpu():
def setup_for_low_vram(sd_model, use_medvram):
+ sd_model.lowvram = True
+
parents = {}
def send_me_to_gpu(module, _):
@@ -96,3 +98,7 @@ def setup_for_low_vram(sd_model, use_medvram):
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.output_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
+
+
+def is_enabled(sd_model):
+ return getattr(sd_model, 'lowvram', False)
diff --git a/modules/processing.py b/modules/processing.py
index 9ebdb549..f8225b83 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -171,6 +171,7 @@ class StableDiffusionProcessing:
self.prompts = None
self.negative_prompts = None
+ self.extra_network_data = None
self.seeds = None
self.subseeds = None
@@ -311,7 +312,7 @@ class StableDiffusionProcessing:
self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
- def get_conds_with_caching(self, function, required_prompts, steps, cache):
+ def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data):
"""
Returns the result of calling function(shared.sd_model, required_prompts, steps)
using a cache to store the result if the same arguments have been used before.
@@ -320,27 +321,31 @@ class StableDiffusionProcessing:
representing the previously used arguments, or None if no arguments
have been used before. The second element is where the previously
computed result is stored.
+
+ caches is a list with items described above.
"""
- if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info) == cache[0]:
- return cache[1]
+
+ for cache in caches:
+ if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]:
+ return cache[1]
+
+ cache = caches[0]
with devices.autocast():
cache[1] = function(shared.sd_model, required_prompts, steps)
- cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info)
+ cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data)
return cache[1]
def setup_conds(self):
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
- self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, self.cached_uc)
- self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, self.cached_c)
+ self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
+ self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
def parse_extra_network_prompts(self):
- self.prompts, extra_network_data = extra_networks.parse_prompts(self.prompts)
-
- return extra_network_data
+ self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
class Processed:
@@ -681,7 +686,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if state.job_count == -1:
state.job_count = p.n_iter
- extra_network_data = None
for n in range(p.n_iter):
p.iteration = n
@@ -702,11 +706,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if len(p.prompts) == 0:
break
- extra_network_data = p.parse_extra_network_prompts()
+ p.parse_extra_network_prompts()
if not p.disable_extra_networks:
with devices.autocast():
- extra_networks.activate(p, extra_network_data)
+ extra_networks.activate(p, p.extra_network_data)
if p.scripts is not None:
p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
@@ -741,7 +745,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
del samples_ddim
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ if lowvram.is_enabled(shared.sd_model):
lowvram.send_everything_to_cpu()
devices.torch_gc()
@@ -828,8 +832,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
- if not p.disable_extra_networks and extra_network_data:
- extra_networks.deactivate(p, extra_network_data)
+ if not p.disable_extra_networks and p.extra_network_data:
+ extra_networks.deactivate(p, p.extra_network_data)
devices.torch_gc()
@@ -896,6 +900,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.hr_negative_prompts = None
self.hr_extra_network_data = None
+ self.cached_hr_uc = [None, None]
+ self.cached_hr_c = [None, None]
self.hr_c = None
self.hr_uc = None
@@ -1059,6 +1065,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
with devices.autocast():
extra_networks.activate(self, self.hr_extra_network_data)
+ with devices.autocast():
+ self.calculate_hr_conds()
+
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
@@ -1070,6 +1079,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return samples
def close(self):
+ self.cached_hr_uc = [None, None]
+ self.cached_hr_c = [None, None]
self.hr_c = None
self.hr_uc = None
@@ -1098,12 +1109,31 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_hr_prompts]
self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_hr_negative_prompts]
+ def calculate_hr_conds(self):
+ if self.hr_c is not None:
+ return
+
+ self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
+ self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
+
def setup_conds(self):
super().setup_conds()
+ self.hr_uc = None
+ self.hr_c = None
+
if self.enable_hr:
- self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_uc)
- self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_c)
+ if shared.opts.hires_fix_use_firstpass_conds:
+ self.calculate_hr_conds()
+
+ elif lowvram.is_enabled(shared.sd_model): # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
+ with devices.autocast():
+ extra_networks.activate(self, self.hr_extra_network_data)
+
+ self.calculate_hr_conds()
+
+ with devices.autocast():
+ extra_networks.activate(self, self.extra_network_data)
def parse_extra_network_prompts(self):
res = super().parse_extra_network_prompts()
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index b41aa419..3c71e6b5 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -46,7 +46,7 @@ class SdOptimizationXformers(SdOptimization):
priority = 100
def is_available(self):
- return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))
+ return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.cuda.is_available() and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))
def apply(self):
ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
diff --git a/modules/shared.py b/modules/shared.py
index 7025a754..2bd7c6ec 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -260,6 +260,10 @@ class OptionInfo:
self.comment_after += f"<span class='info'>({info})</span>"
return self
+ def html(self, html):
+ self.comment_after += html
+ return self
+
def needs_restart(self):
self.comment_after += " <span class='info'>(requires restart)</span>"
return self
@@ -425,6 +429,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
+ "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
}))
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
@@ -488,7 +493,14 @@ options_templates.update(options_section(('infotext', "Infotext"), {
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
- "disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
+ "disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
+ "infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
+<li>Ignore: keep prompt and styles dropdown as it is.</li>
+<li>Apply: remove style text from prompt, always replace styles dropdown value with found styles (even if none are found).</li>
+<li>Discard: remove style text from prompt, keep styles dropdown as it is.</li>
+<li>Apply if any: remove style text from prompt; if any styles are found in prompt, put them into styles dropdown, otherwise keep it as it is.</li>
+</ul>"""),
+
}))
options_templates.update(options_section(('ui', "Live previews"), {
@@ -841,3 +853,12 @@ def walk_files(path, allowed_extensions=None):
continue
yield os.path.join(root, filename)
+
+
+def restart_program():
+ """creates file tmp/restart and immediately stops the process, which webui.bat/webui.sh interpret as a command to start webui again"""
+
+ with open(os.path.join(script_path, "tmp", "restart"), "w"):
+ pass
+
+ os._exit(0)
diff --git a/modules/styles.py b/modules/styles.py
index 34e1b5e1..ec0e1bc5 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -1,6 +1,7 @@
import csv
import os
import os.path
+import re
import typing
import shutil
@@ -28,6 +29,44 @@ def apply_styles_to_prompt(prompt, styles):
return prompt
+re_spaces = re.compile(" +")
+
+
+def extract_style_text_from_prompt(style_text, prompt):
+ stripped_prompt = re.sub(re_spaces, " ", prompt.strip())
+ stripped_style_text = re.sub(re_spaces, " ", style_text.strip())
+ if "{prompt}" in stripped_style_text:
+ left, right = stripped_style_text.split("{prompt}", 2)
+ if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
+ prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
+ return True, prompt
+ else:
+ if stripped_prompt.endswith(stripped_style_text):
+ prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]
+
+ if prompt.endswith(', '):
+ prompt = prompt[:-2]
+
+ return True, prompt
+
+ return False, prompt
+
+
+def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt):
+ if not style.prompt and not style.negative_prompt:
+ return False, prompt, negative_prompt
+
+ match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)
+ if not match_positive:
+ return False, prompt, negative_prompt
+
+ match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)
+ if not match_negative:
+ return False, prompt, negative_prompt
+
+ return True, extracted_positive, extracted_negative
+
+
class StyleDatabase:
def __init__(self, path: str):
self.no_style = PromptStyle("None", "", "")
@@ -67,10 +106,34 @@ class StyleDatabase:
if os.path.exists(path):
shutil.copy(path, f"{path}.bak")
- fd = os.open(path, os.O_RDWR|os.O_CREAT)
+ fd = os.open(path, os.O_RDWR | os.O_CREAT)
with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
# _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
# and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
writer.writeheader()
- writer.writerows(style._asdict() for k, style in self.styles.items())
+ writer.writerows(style._asdict() for k, style in self.styles.items())
+
+ def extract_styles_from_prompt(self, prompt, negative_prompt):
+ extracted = []
+
+ applicable_styles = list(self.styles.values())
+
+ while True:
+ found_style = None
+
+ for style in applicable_styles:
+ is_match, new_prompt, new_neg_prompt = extract_style_from_prompts(style, prompt, negative_prompt)
+ if is_match:
+ found_style = style
+ prompt = new_prompt
+ negative_prompt = new_neg_prompt
+ break
+
+ if not found_style:
+ break
+
+ applicable_styles.remove(found_style)
+ extracted.append(found_style.name)
+
+ return list(reversed(extracted)), prompt, negative_prompt
diff --git a/modules/sysinfo.py b/modules/sysinfo.py
new file mode 100644
index 00000000..5f15ac4f
--- /dev/null
+++ b/modules/sysinfo.py
@@ -0,0 +1,162 @@
+import json
+import os
+import sys
+import traceback
+
+import platform
+import hashlib
+import pkg_resources
+import psutil
+import re
+
+import launch
+from modules import paths_internal, timer
+
+checksum_token = "DontStealMyGamePlz__WINNERS_DONT_USE_DRUGS__DONT_COPY_THAT_FLOPPY"
+environment_whitelist = {
+ "GIT",
+ "INDEX_URL",
+ "WEBUI_LAUNCH_LIVE_OUTPUT",
+ "GRADIO_ANALYTICS_ENABLED",
+ "PYTHONPATH",
+ "TORCH_INDEX_URL",
+ "TORCH_COMMAND",
+ "REQS_FILE",
+ "XFORMERS_PACKAGE",
+ "GFPGAN_PACKAGE",
+ "CLIP_PACKAGE",
+ "OPENCLIP_PACKAGE",
+ "STABLE_DIFFUSION_REPO",
+ "K_DIFFUSION_REPO",
+ "CODEFORMER_REPO",
+ "BLIP_REPO",
+ "STABLE_DIFFUSION_COMMIT_HASH",
+ "K_DIFFUSION_COMMIT_HASH",
+ "CODEFORMER_COMMIT_HASH",
+ "BLIP_COMMIT_HASH",
+ "COMMANDLINE_ARGS",
+ "IGNORE_CMD_ARGS_ERRORS",
+}
+
+
+def pretty_bytes(num, suffix="B"):
+ for unit in ["", "K", "M", "G", "T", "P", "E", "Z", "Y"]:
+ if abs(num) < 1024 or unit == 'Y':
+ return f"{num:.0f}{unit}{suffix}"
+ num /= 1024
+
+
+def get():
+ res = get_dict()
+
+ text = json.dumps(res, ensure_ascii=False, indent=4)
+
+ h = hashlib.sha256(text.encode("utf8"))
+ text = text.replace(checksum_token, h.hexdigest())
+
+ return text
+
+
+re_checksum = re.compile(r'"Checksum": "([0-9a-fA-F]{64})"')
+
+
+def check(x):
+ m = re.search(re_checksum, x)
+ if not m:
+ return False
+
+ replaced = re.sub(re_checksum, f'"Checksum": "{checksum_token}"', x)
+
+ h = hashlib.sha256(replaced.encode("utf8"))
+ return h.hexdigest() == m.group(1)
+
+
+def get_dict():
+ ram = psutil.virtual_memory()
+
+ res = {
+ "Platform": platform.platform(),
+ "Python": platform.python_version(),
+ "Version": launch.git_tag(),
+ "Commit": launch.commit_hash(),
+ "Script path": paths_internal.script_path,
+ "Data path": paths_internal.data_path,
+ "Extensions dir": paths_internal.extensions_dir,
+ "Checksum": checksum_token,
+ "Commandline": sys.argv,
+ "Torch env info": get_torch_sysinfo(),
+ "Exceptions": get_exceptions(),
+ "CPU": {
+ "model": platform.processor(),
+ "count logical": psutil.cpu_count(logical=True),
+ "count physical": psutil.cpu_count(logical=False),
+ },
+ "RAM": {
+ x: pretty_bytes(getattr(ram, x, 0)) for x in ["total", "used", "free", "active", "inactive", "buffers", "cached", "shared"] if getattr(ram, x, 0) != 0
+ },
+ "Extensions": get_extensions(enabled=True),
+ "Inactive extensions": get_extensions(enabled=False),
+ "Environment": get_environment(),
+ "Config": get_config(),
+ "Startup": timer.startup_record,
+ "Packages": sorted([f"{pkg.key}=={pkg.version}" for pkg in pkg_resources.working_set]),
+ }
+
+ return res
+
+
+def format_traceback(tb):
+ return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
+
+
+def get_exceptions():
+ try:
+ from modules import errors
+
+ return [{"exception": str(e), "traceback": format_traceback(tb)} for e, tb in reversed(errors.exception_records)]
+ except Exception as e:
+ return str(e)
+
+
+def get_environment():
+ return {k: os.environ[k] for k in sorted(os.environ) if k in environment_whitelist}
+
+
+re_newline = re.compile(r"\r*\n")
+
+
+def get_torch_sysinfo():
+ try:
+ import torch.utils.collect_env
+ info = torch.utils.collect_env.get_env_info()._asdict()
+
+ return {k: re.split(re_newline, str(v)) if "\n" in str(v) else v for k, v in info.items()}
+ except Exception as e:
+ return str(e)
+
+
+def get_extensions(*, enabled):
+
+ try:
+ from modules import extensions
+
+ def to_json(x: extensions.Extension):
+ return {
+ "name": x.name,
+ "path": x.path,
+ "version": x.version,
+ "branch": x.branch,
+ "remote": x.remote,
+ }
+
+ return [to_json(x) for x in extensions.extensions if not x.is_builtin and x.enabled == enabled]
+ except Exception as e:
+ return str(e)
+
+
+def get_config():
+ try:
+ from modules import shared
+ return shared.opts.data
+ except Exception as e:
+ return str(e)
diff --git a/modules/ui.py b/modules/ui.py
index 9a025cca..3315cc17 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1,3 +1,4 @@
+import datetime
import json
import mimetypes
import os
@@ -11,7 +12,7 @@ import numpy as np
from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
-from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer
+from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path
from modules.ui_common import create_refresh_button
@@ -79,6 +80,7 @@ extra_networks_symbol = '\U0001F3B4' # 🎴
switch_values_symbol = '\U000021C5' # ⇅
restore_progress_symbol = '\U0001F300' # 🌀
detect_image_size_symbol = '\U0001F4D0' # 📐
+up_down_symbol = '\u2195\ufe0f' # ↕️
def plaintext_to_html(text):
@@ -620,6 +622,7 @@ def create_ui():
(subseed_strength, "Variation seed strength"),
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
+ (txt2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
(denoising_strength, "Denoising strength"),
(enable_hr, lambda d: "Denoising strength" in d),
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
@@ -1035,6 +1038,7 @@ def create_ui():
(subseed_strength, "Variation seed strength"),
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
+ (img2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
(denoising_strength, "Denoising strength"),
(mask_blur, "Mask blur"),
*modules.scripts.scripts_img2img.infotext_fields
@@ -1597,3 +1601,15 @@ def setup_ui_api(app):
app.add_api_route("/internal/ping", lambda: {}, methods=["GET"])
app.add_api_route("/internal/profile-startup", lambda: timer.startup_record, methods=["GET"])
+
+ def download_sysinfo(attachment=False):
+ from fastapi.responses import PlainTextResponse
+
+ text = sysinfo.get()
+ filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.txt"
+
+ return PlainTextResponse(text, headers={'Content-Disposition': f'{"attachment" if attachment else "inline"}; filename="{filename}"'})
+
+ app.add_api_route("/internal/sysinfo", download_sysinfo, methods=["GET"])
+ app.add_api_route("/internal/sysinfo-download", lambda: download_sysinfo(attachment=True), methods=["GET"])
+
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index 65173e06..1ae516d7 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -49,7 +49,7 @@ def apply_and_restart(disable_list, update_list, disable_all):
shared.opts.disabled_extensions = disabled
shared.opts.disable_all_extensions = disable_all
shared.opts.save(shared.config_filename)
- shared.state.request_restart()
+ shared.restart_program()
def save_config_state(name):
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 19fbaae5..a7d3bc79 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -4,6 +4,7 @@ from pathlib import Path
from modules import shared
from modules.images import read_info_from_image, save_image_with_geninfo
+from modules.ui import up_down_symbol
import gradio as gr
import json
import html
@@ -185,6 +186,8 @@ class ExtraNetworksPage:
if search_only and shared.opts.extra_networks_hidden_models == "Never":
return ""
+ sort_keys = " ".join([html.escape(f'data-sort-{k}={v}') for k, v in item.get("sort_keys", {}).items()]).strip()
+
args = {
"background_image": background_image,
"style": f"'display: none; {height}{width}'",
@@ -198,10 +201,23 @@ class ExtraNetworksPage:
"search_term": item.get("search_term", ""),
"metadata_button": metadata_button,
"search_only": " search_only" if search_only else "",
+ "sort_keys": sort_keys,
}
return self.card_page.format(**args)
+ def get_sort_keys(self, path):
+ """
+ List of default keys used for sorting in the UI.
+ """
+ pth = Path(path)
+ stat = pth.stat()
+ return {
+ "date_created": int(stat.st_ctime or 0),
+ "date_modified": int(stat.st_mtime or 0),
+ "name": pth.name.lower(),
+ }
+
def find_preview(self, path):
"""
Find a preview PNG for a given path (without extension) and call link_preview on it.
@@ -296,6 +312,8 @@ def create_ui(container, button, tabname):
page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + json.dumps(tabname) + '); return []}', inputs=[], outputs=[])
gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
+ gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", multiselect=False, visible=False, show_label=False, interactive=True)
+ gr.Button(up_down_symbol, elem_id=tabname+"_extra_sortorder")
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py
index a17aa9c9..8b9ab71b 100644
--- a/modules/ui_extra_networks_checkpoints.py
+++ b/modules/ui_extra_networks_checkpoints.py
@@ -14,7 +14,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
def list_items(self):
checkpoint: sd_models.CheckpointInfo
- for name, checkpoint in sd_models.checkpoints_list.items():
+ for index, (name, checkpoint) in enumerate(sd_models.checkpoints_list.items()):
path, ext = os.path.splitext(checkpoint.filename)
yield {
"name": checkpoint.name_for_extra,
@@ -24,6 +24,8 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
"onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
"local_preview": f"{path}.{shared.opts.samples_format}",
+ "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)},
+
}
def allowed_directories_for_previews(self):
diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py
index 6187e000..7c19b532 100644
--- a/modules/ui_extra_networks_hypernets.py
+++ b/modules/ui_extra_networks_hypernets.py
@@ -12,7 +12,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
shared.reload_hypernetworks()
def list_items(self):
- for name, path in shared.hypernetworks.items():
+ for index, (name, path) in enumerate(shared.hypernetworks.items()):
path, ext = os.path.splitext(path)
yield {
@@ -23,6 +23,8 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
"search_term": self.search_terms_from_path(path),
"prompt": json.dumps(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
+ "sort_keys": {'default': index, **self.get_sort_keys(path + ext)},
+
}
def allowed_directories_for_previews(self):
diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py
index 6944d559..58a61c55 100644
--- a/modules/ui_extra_networks_textual_inversion.py
+++ b/modules/ui_extra_networks_textual_inversion.py
@@ -13,7 +13,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
def list_items(self):
- for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values():
+ for index, embedding in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings.values()):
path, ext = os.path.splitext(embedding.filename)
yield {
"name": embedding.name,
@@ -23,6 +23,8 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
"search_term": self.search_terms_from_path(embedding.filename),
"prompt": json.dumps(embedding.name),
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
+ "sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)},
+
}
def allowed_directories_for_previews(self):
diff --git a/modules/ui_settings.py b/modules/ui_settings.py
index 2688d8c2..0c560b30 100644
--- a/modules/ui_settings.py
+++ b/modules/ui_settings.py
@@ -1,6 +1,6 @@
import gradio as gr
-from modules import ui_common, shared, script_callbacks, scripts, sd_models
+from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo
from modules.call_queue import wrap_gradio_call
from modules.shared import opts
from modules.ui_components import FormRow
@@ -157,6 +157,17 @@ class UiSettings:
with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"):
loadsave.create_ui()
+ with gr.TabItem("Sysinfo", id="sysinfo", elem_id="settings_tab_sysinfo"):
+ gr.HTML('<a href="./internal/sysinfo-download" class="sysinfo_big_link" download>Download system info</a><br /><a href="./internal/sysinfo">(or open as text in a new page)</a>', elem_id="sysinfo_download")
+
+ with gr.Row():
+ with gr.Column(scale=1):
+ sysinfo_check_file = gr.File(label="Check system info for validity", type='binary')
+ with gr.Column(scale=1):
+ sysinfo_check_output = gr.HTML("", elem_id="sysinfo_validity")
+ with gr.Column(scale=100):
+ pass
+
with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"):
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
@@ -215,6 +226,21 @@ class UiSettings:
outputs=[],
)
+ def check_file(x):
+ if x is None:
+ return ''
+
+ if sysinfo.check(x.decode('utf8', errors='ignore')):
+ return 'Valid'
+
+ return 'Invalid'
+
+ sysinfo_check_file.change(
+ fn=check_file,
+ inputs=[sysinfo_check_file],
+ outputs=[sysinfo_check_output],
+ )
+
self.interface = settings_interface
def add_quicksettings(self):
diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py
index 9fc7d764..fb75137e 100644
--- a/modules/ui_tempdir.py
+++ b/modules/ui_tempdir.py
@@ -31,7 +31,7 @@ def check_tmp_file(gradio, filename):
return False
-def save_pil_to_file(self, pil_image, dir=None):
+def save_pil_to_file(self, pil_image, dir=None, format="png"):
already_saved_as = getattr(pil_image, 'already_saved_as', None)
if already_saved_as and os.path.isfile(already_saved_as):
register_tmp_file(shared.demo, already_saved_as)
diff --git a/modules/upscaler.py b/modules/upscaler.py
index 3c82861d..e682bbaa 100644
--- a/modules/upscaler.py
+++ b/modules/upscaler.py
@@ -53,8 +53,8 @@ class Upscaler:
def upscale(self, img: PIL.Image, scale, selected_model: str = None):
self.scale = scale
- dest_w = round((img.width * scale - 4) / 8) * 8
- dest_h = round((img.height * scale - 4) / 8) * 8
+ dest_w = int((img.width * scale) // 8 * 8)
+ dest_h = int((img.height * scale) // 8 * 8)
for _ in range(3):
shape = (img.width, img.height)
@@ -77,7 +77,7 @@ class Upscaler:
pass
def find_models(self, ext_filter=None) -> list:
- return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path)
+ return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path, ext_filter=ext_filter)
def update_status(self, prompt):
print(f"\nextras: {prompt}", file=shared.progress_print_out)