aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/errors.py26
-rw-r--r--modules/generation_parameters_copypaste.py8
-rw-r--r--modules/img2img.py32
-rw-r--r--modules/shared.py13
-rw-r--r--modules/styles.py67
-rw-r--r--modules/sysinfo.py162
-rw-r--r--modules/ui.py17
-rw-r--r--modules/ui_settings.py28
-rw-r--r--modules/ui_tempdir.py2
-rw-r--r--modules/upscaler.py6
10 files changed, 345 insertions, 16 deletions
diff --git a/modules/errors.py b/modules/errors.py
index e408f500..5271a9fe 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -3,10 +3,30 @@ import textwrap
import traceback
+exception_records = []
+
+
+def record_exception():
+ _, e, tb = sys.exc_info()
+ if e is None:
+ return
+
+ if exception_records and exception_records[-1] == e:
+ return
+
+ exception_records.append((e, tb))
+
+ if len(exception_records) > 5:
+ exception_records.pop(0)
+
+
def report(message: str, *, exc_info: bool = False) -> None:
"""
Print an error message to stderr, with optional traceback.
"""
+
+ record_exception()
+
for line in message.splitlines():
print("***", line, file=sys.stderr)
if exc_info:
@@ -15,6 +35,8 @@ def report(message: str, *, exc_info: bool = False) -> None:
def print_error_explanation(message):
+ record_exception()
+
lines = message.strip().split("\n")
max_len = max([len(x) for x in lines])
@@ -25,6 +47,8 @@ def print_error_explanation(message):
def display(e: Exception, task, *, full_traceback=False):
+ record_exception()
+
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
te = traceback.TracebackException.from_exception(e)
if full_traceback:
@@ -44,6 +68,8 @@ already_displayed = {}
def display_once(e: Exception, task):
+ record_exception()
+
if task in already_displayed:
return
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 071bd9ea..4c420e5f 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -265,6 +265,14 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
else:
prompt += ("" if prompt == "" else "\n") + line
+ if shared.opts.infotext_styles != "Ignore":
+ found_styles, prompt, negative_prompt = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt)
+
+ if shared.opts.infotext_styles == "Apply":
+ res["Styles array"] = found_styles
+ elif shared.opts.infotext_styles == "Apply if any" and found_styles:
+ res["Styles array"] = found_styles
+
res["Prompt"] = prompt
res["Negative prompt"] = negative_prompt
diff --git a/modules/img2img.py b/modules/img2img.py
index 4c12c2c5..b240e593 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -1,4 +1,5 @@
import os
+from pathlib import Path
import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
@@ -13,7 +14,7 @@ from modules.ui import plaintext_to_html
import modules.scripts
-def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
+def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0):
processing.fix_seed(p)
images = shared.listfiles(input_dir)
@@ -49,14 +50,31 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
continue
# Use the EXIF orientation of photos taken by smartphones.
img = ImageOps.exif_transpose(img)
+
+ if to_scale:
+ p.width = int(img.width * scale_by)
+ p.height = int(img.height * scale_by)
+
p.init_images = [img] * p.batch_size
+ image_path = Path(image)
if is_inpaint_batch:
# try to find corresponding mask for an image using simple filename matching
- mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
- # if not found use first one ("same mask for all images" use-case)
- if mask_image_path not in inpaint_masks:
+ if len(inpaint_masks) == 1:
mask_image_path = inpaint_masks[0]
+ else:
+ # try to find corresponding mask for an image using simple filename matching
+ mask_image_dir = Path(inpaint_mask_dir)
+ masks_found = list(mask_image_dir.glob(f"{image_path.stem}.*"))
+
+ if len(masks_found) == 0:
+ print(f"Warning: mask is not found for {image_path} in {mask_image_dir}. Skipping it.")
+ continue
+
+ # it should contain only 1 matching mask
+ # otherwise user has many masks with the same name but different extensions
+ mask_image_path = masks_found[0]
+
mask_image = Image.open(mask_image_path)
p.image_mask = mask_image
@@ -65,7 +83,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
proc = process_images(p)
for n, processed_image in enumerate(proc.images):
- filename = os.path.basename(image)
+ filename = image_path.name
if n > 0:
left, right = os.path.splitext(filename)
@@ -115,7 +133,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
if image is not None:
image = ImageOps.exif_transpose(image)
- if selected_scale_tab == 1:
+ if selected_scale_tab == 1 and not is_batch:
assert image, "Can't scale by because no image is selected"
width = int(image.width * scale_by)
@@ -170,7 +188,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
if is_batch:
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
- process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args)
+ process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by)
processed = Processed(p, [], p.seed, "")
else:
diff --git a/modules/shared.py b/modules/shared.py
index c4c719ad..7d056a4d 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -260,6 +260,10 @@ class OptionInfo:
self.comment_after += f"<span class='info'>({info})</span>"
return self
+ def html(self, html):
+ self.comment_after += html
+ return self
+
def needs_restart(self):
self.comment_after += " <span class='info'>(requires restart)</span>"
return self
@@ -488,7 +492,14 @@ options_templates.update(options_section(('infotext', "Infotext"), {
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
- "disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
+ "disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
+ "infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
+<li>Ignore: keep prompt and styles dropdown as it is.</li>
+<li>Apply: remove style text from prompt, always replace styles dropdown value with found styles (even if none are found).</li>
+<li>Discard: remove style text from prompt, keep styles dropdown as it is.</li>
+<li>Apply if any: remove style text from prompt; if any styles are found in prompt, put them into styles dropdown, otherwise keep it as it is.</li>
+</ul>"""),
+
}))
options_templates.update(options_section(('ui', "Live previews"), {
diff --git a/modules/styles.py b/modules/styles.py
index 34e1b5e1..ec0e1bc5 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -1,6 +1,7 @@
import csv
import os
import os.path
+import re
import typing
import shutil
@@ -28,6 +29,44 @@ def apply_styles_to_prompt(prompt, styles):
return prompt
+re_spaces = re.compile(" +")
+
+
+def extract_style_text_from_prompt(style_text, prompt):
+ stripped_prompt = re.sub(re_spaces, " ", prompt.strip())
+ stripped_style_text = re.sub(re_spaces, " ", style_text.strip())
+ if "{prompt}" in stripped_style_text:
+ left, right = stripped_style_text.split("{prompt}", 2)
+ if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
+ prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
+ return True, prompt
+ else:
+ if stripped_prompt.endswith(stripped_style_text):
+ prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]
+
+ if prompt.endswith(', '):
+ prompt = prompt[:-2]
+
+ return True, prompt
+
+ return False, prompt
+
+
+def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt):
+ if not style.prompt and not style.negative_prompt:
+ return False, prompt, negative_prompt
+
+ match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)
+ if not match_positive:
+ return False, prompt, negative_prompt
+
+ match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)
+ if not match_negative:
+ return False, prompt, negative_prompt
+
+ return True, extracted_positive, extracted_negative
+
+
class StyleDatabase:
def __init__(self, path: str):
self.no_style = PromptStyle("None", "", "")
@@ -67,10 +106,34 @@ class StyleDatabase:
if os.path.exists(path):
shutil.copy(path, f"{path}.bak")
- fd = os.open(path, os.O_RDWR|os.O_CREAT)
+ fd = os.open(path, os.O_RDWR | os.O_CREAT)
with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
# _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
# and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
writer.writeheader()
- writer.writerows(style._asdict() for k, style in self.styles.items())
+ writer.writerows(style._asdict() for k, style in self.styles.items())
+
+ def extract_styles_from_prompt(self, prompt, negative_prompt):
+ extracted = []
+
+ applicable_styles = list(self.styles.values())
+
+ while True:
+ found_style = None
+
+ for style in applicable_styles:
+ is_match, new_prompt, new_neg_prompt = extract_style_from_prompts(style, prompt, negative_prompt)
+ if is_match:
+ found_style = style
+ prompt = new_prompt
+ negative_prompt = new_neg_prompt
+ break
+
+ if not found_style:
+ break
+
+ applicable_styles.remove(found_style)
+ extracted.append(found_style.name)
+
+ return list(reversed(extracted)), prompt, negative_prompt
diff --git a/modules/sysinfo.py b/modules/sysinfo.py
new file mode 100644
index 00000000..5f15ac4f
--- /dev/null
+++ b/modules/sysinfo.py
@@ -0,0 +1,162 @@
+import json
+import os
+import sys
+import traceback
+
+import platform
+import hashlib
+import pkg_resources
+import psutil
+import re
+
+import launch
+from modules import paths_internal, timer
+
+checksum_token = "DontStealMyGamePlz__WINNERS_DONT_USE_DRUGS__DONT_COPY_THAT_FLOPPY"
+environment_whitelist = {
+ "GIT",
+ "INDEX_URL",
+ "WEBUI_LAUNCH_LIVE_OUTPUT",
+ "GRADIO_ANALYTICS_ENABLED",
+ "PYTHONPATH",
+ "TORCH_INDEX_URL",
+ "TORCH_COMMAND",
+ "REQS_FILE",
+ "XFORMERS_PACKAGE",
+ "GFPGAN_PACKAGE",
+ "CLIP_PACKAGE",
+ "OPENCLIP_PACKAGE",
+ "STABLE_DIFFUSION_REPO",
+ "K_DIFFUSION_REPO",
+ "CODEFORMER_REPO",
+ "BLIP_REPO",
+ "STABLE_DIFFUSION_COMMIT_HASH",
+ "K_DIFFUSION_COMMIT_HASH",
+ "CODEFORMER_COMMIT_HASH",
+ "BLIP_COMMIT_HASH",
+ "COMMANDLINE_ARGS",
+ "IGNORE_CMD_ARGS_ERRORS",
+}
+
+
+def pretty_bytes(num, suffix="B"):
+ for unit in ["", "K", "M", "G", "T", "P", "E", "Z", "Y"]:
+ if abs(num) < 1024 or unit == 'Y':
+ return f"{num:.0f}{unit}{suffix}"
+ num /= 1024
+
+
+def get():
+ res = get_dict()
+
+ text = json.dumps(res, ensure_ascii=False, indent=4)
+
+ h = hashlib.sha256(text.encode("utf8"))
+ text = text.replace(checksum_token, h.hexdigest())
+
+ return text
+
+
+re_checksum = re.compile(r'"Checksum": "([0-9a-fA-F]{64})"')
+
+
+def check(x):
+ m = re.search(re_checksum, x)
+ if not m:
+ return False
+
+ replaced = re.sub(re_checksum, f'"Checksum": "{checksum_token}"', x)
+
+ h = hashlib.sha256(replaced.encode("utf8"))
+ return h.hexdigest() == m.group(1)
+
+
+def get_dict():
+ ram = psutil.virtual_memory()
+
+ res = {
+ "Platform": platform.platform(),
+ "Python": platform.python_version(),
+ "Version": launch.git_tag(),
+ "Commit": launch.commit_hash(),
+ "Script path": paths_internal.script_path,
+ "Data path": paths_internal.data_path,
+ "Extensions dir": paths_internal.extensions_dir,
+ "Checksum": checksum_token,
+ "Commandline": sys.argv,
+ "Torch env info": get_torch_sysinfo(),
+ "Exceptions": get_exceptions(),
+ "CPU": {
+ "model": platform.processor(),
+ "count logical": psutil.cpu_count(logical=True),
+ "count physical": psutil.cpu_count(logical=False),
+ },
+ "RAM": {
+ x: pretty_bytes(getattr(ram, x, 0)) for x in ["total", "used", "free", "active", "inactive", "buffers", "cached", "shared"] if getattr(ram, x, 0) != 0
+ },
+ "Extensions": get_extensions(enabled=True),
+ "Inactive extensions": get_extensions(enabled=False),
+ "Environment": get_environment(),
+ "Config": get_config(),
+ "Startup": timer.startup_record,
+ "Packages": sorted([f"{pkg.key}=={pkg.version}" for pkg in pkg_resources.working_set]),
+ }
+
+ return res
+
+
+def format_traceback(tb):
+ return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
+
+
+def get_exceptions():
+ try:
+ from modules import errors
+
+ return [{"exception": str(e), "traceback": format_traceback(tb)} for e, tb in reversed(errors.exception_records)]
+ except Exception as e:
+ return str(e)
+
+
+def get_environment():
+ return {k: os.environ[k] for k in sorted(os.environ) if k in environment_whitelist}
+
+
+re_newline = re.compile(r"\r*\n")
+
+
+def get_torch_sysinfo():
+ try:
+ import torch.utils.collect_env
+ info = torch.utils.collect_env.get_env_info()._asdict()
+
+ return {k: re.split(re_newline, str(v)) if "\n" in str(v) else v for k, v in info.items()}
+ except Exception as e:
+ return str(e)
+
+
+def get_extensions(*, enabled):
+
+ try:
+ from modules import extensions
+
+ def to_json(x: extensions.Extension):
+ return {
+ "name": x.name,
+ "path": x.path,
+ "version": x.version,
+ "branch": x.branch,
+ "remote": x.remote,
+ }
+
+ return [to_json(x) for x in extensions.extensions if not x.is_builtin and x.enabled == enabled]
+ except Exception as e:
+ return str(e)
+
+
+def get_config():
+ try:
+ from modules import shared
+ return shared.opts.data
+ except Exception as e:
+ return str(e)
diff --git a/modules/ui.py b/modules/ui.py
index 988b2003..4c0fd4d5 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1,3 +1,4 @@
+import datetime
import json
import mimetypes
import os
@@ -11,7 +12,7 @@ import numpy as np
from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
-from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer
+from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path
from modules.ui_common import create_refresh_button
@@ -621,6 +622,7 @@ def create_ui():
(subseed_strength, "Variation seed strength"),
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
+ (txt2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
(denoising_strength, "Denoising strength"),
(enable_hr, lambda d: "Denoising strength" in d),
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
@@ -1036,6 +1038,7 @@ def create_ui():
(subseed_strength, "Variation seed strength"),
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
+ (img2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
(denoising_strength, "Denoising strength"),
(mask_blur, "Mask blur"),
*modules.scripts.scripts_img2img.infotext_fields
@@ -1598,3 +1601,15 @@ def setup_ui_api(app):
app.add_api_route("/internal/ping", lambda: {}, methods=["GET"])
app.add_api_route("/internal/profile-startup", lambda: timer.startup_record, methods=["GET"])
+
+ def download_sysinfo(attachment=False):
+ from fastapi.responses import PlainTextResponse
+
+ text = sysinfo.get()
+ filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.txt"
+
+ return PlainTextResponse(text, headers={'Content-Disposition': f'{"attachment" if attachment else "inline"}; filename="{filename}"'})
+
+ app.add_api_route("/internal/sysinfo", download_sysinfo, methods=["GET"])
+ app.add_api_route("/internal/sysinfo-download", lambda: download_sysinfo(attachment=True), methods=["GET"])
+
diff --git a/modules/ui_settings.py b/modules/ui_settings.py
index 7874298e..892c5e1a 100644
--- a/modules/ui_settings.py
+++ b/modules/ui_settings.py
@@ -1,6 +1,6 @@
import gradio as gr
-from modules import ui_common, shared, script_callbacks, scripts, sd_models
+from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo
from modules.call_queue import wrap_gradio_call
from modules.shared import opts
from modules.ui_components import FormRow
@@ -157,6 +157,17 @@ class UiSettings:
with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"):
loadsave.create_ui()
+ with gr.TabItem("Sysinfo", id="sysinfo", elem_id="settings_tab_sysinfo"):
+ gr.HTML('<a href="./internal/sysinfo-download" class="sysinfo_big_link" download>Download system info</a><br /><a href="./internal/sysinfo">(or open as text in a new page)</a>', elem_id="sysinfo_download")
+
+ with gr.Row():
+ with gr.Column(scale=1):
+ sysinfo_check_file = gr.File(label="Check system info for validity", type='binary')
+ with gr.Column(scale=1):
+ sysinfo_check_output = gr.HTML("", elem_id="sysinfo_validity")
+ with gr.Column(scale=100):
+ pass
+
with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"):
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
@@ -215,6 +226,21 @@ class UiSettings:
outputs=[],
)
+ def check_file(x):
+ if x is None:
+ return ''
+
+ if sysinfo.check(x.decode('utf8', errors='ignore')):
+ return 'Valid'
+
+ return 'Invalid'
+
+ sysinfo_check_file.change(
+ fn=check_file,
+ inputs=[sysinfo_check_file],
+ outputs=[sysinfo_check_output],
+ )
+
self.interface = settings_interface
def add_quicksettings(self):
diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py
index 9fc7d764..fb75137e 100644
--- a/modules/ui_tempdir.py
+++ b/modules/ui_tempdir.py
@@ -31,7 +31,7 @@ def check_tmp_file(gradio, filename):
return False
-def save_pil_to_file(self, pil_image, dir=None):
+def save_pil_to_file(self, pil_image, dir=None, format="png"):
already_saved_as = getattr(pil_image, 'already_saved_as', None)
if already_saved_as and os.path.isfile(already_saved_as):
register_tmp_file(shared.demo, already_saved_as)
diff --git a/modules/upscaler.py b/modules/upscaler.py
index 3c82861d..e682bbaa 100644
--- a/modules/upscaler.py
+++ b/modules/upscaler.py
@@ -53,8 +53,8 @@ class Upscaler:
def upscale(self, img: PIL.Image, scale, selected_model: str = None):
self.scale = scale
- dest_w = round((img.width * scale - 4) / 8) * 8
- dest_h = round((img.height * scale - 4) / 8) * 8
+ dest_w = int((img.width * scale) // 8 * 8)
+ dest_h = int((img.height * scale) // 8 * 8)
for _ in range(3):
shape = (img.width, img.height)
@@ -77,7 +77,7 @@ class Upscaler:
pass
def find_models(self, ext_filter=None) -> list:
- return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path)
+ return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path, ext_filter=ext_filter)
def update_status(self, prompt):
print(f"\nextras: {prompt}", file=shared.progress_print_out)