aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorwangshuai09 <391746016@qq.com>2024-01-30 19:15:41 +0800
committerGitHub <noreply@github.com>2024-01-30 19:15:41 +0800
commit74ff85a1a1ee4cce432b1c7d33c1eda831f68d48 (patch)
tree99b70e0fef8422c8f603bf7faa1a393091cb2a8b /modules
parentec124607f47371a6cfd61a795f86a7f1cbd44651 (diff)
parentce168ab5dbc8b54b7245f352a2eaa55a37019b91 (diff)
Merge branch 'dev' into npu_support
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py153
-rw-r--r--modules/api/models.py4
-rw-r--r--modules/cache.py17
-rw-r--r--modules/call_queue.py1
-rw-r--r--modules/cmd_args.py6
-rw-r--r--modules/codeformer/codeformer_arch.py276
-rw-r--r--modules/codeformer/vqgan_arch.py435
-rw-r--r--modules/codeformer_model.py158
-rw-r--r--modules/dat_model.py79
-rw-r--r--modules/devices.py98
-rw-r--r--modules/errors.py4
-rw-r--r--modules/esrgan_model.py199
-rw-r--r--modules/esrgan_model_arch.py465
-rw-r--r--modules/extensions.py16
-rw-r--r--modules/extra_networks.py5
-rw-r--r--modules/face_restoration_utils.py180
-rw-r--r--modules/gfpgan_model.py166
-rw-r--r--modules/hat_model.py43
-rw-r--r--modules/images.py16
-rw-r--r--modules/img2img.py7
-rw-r--r--modules/infotext_utils.py (renamed from modules/generation_parameters_copypaste.py)124
-rw-r--r--modules/infotext_versions.py42
-rw-r--r--modules/initialize.py5
-rw-r--r--modules/initialize_util.py2
-rw-r--r--modules/interrogate.py6
-rw-r--r--modules/launch_utils.py29
-rw-r--r--modules/logging_config.py63
-rw-r--r--modules/masking.py43
-rw-r--r--modules/modelloader.py92
-rw-r--r--modules/options.py35
-rw-r--r--modules/paths.py1
-rw-r--r--modules/paths_internal.py1
-rw-r--r--modules/postprocessing.py15
-rw-r--r--modules/processing.py265
-rw-r--r--modules/processing_scripts/refiner.py7
-rw-r--r--modules/processing_scripts/seed.py32
-rw-r--r--modules/progress.py22
-rw-r--r--modules/realesrgan_model.py158
-rw-r--r--modules/script_callbacks.py5
-rw-r--r--modules/scripts.py131
-rw-r--r--modules/sd_hijack_utils.py12
-rw-r--r--modules/sd_models.py61
-rw-r--r--modules/sd_models_config.py6
-rw-r--r--modules/sd_models_xl.py11
-rw-r--r--modules/sd_samplers.py3
-rw-r--r--modules/sd_samplers_cfg_denoiser.py93
-rw-r--r--modules/sd_samplers_common.py7
-rw-r--r--modules/sd_samplers_kdiffusion.py6
-rw-r--r--modules/sd_samplers_lcm.py104
-rw-r--r--modules/sd_samplers_timesteps.py9
-rw-r--r--modules/sd_vae.py3
-rw-r--r--modules/shared.py3
-rw-r--r--modules/shared_gradio_themes.py4
-rw-r--r--modules/shared_init.py5
-rw-r--r--modules/shared_items.py9
-rw-r--r--modules/shared_options.py37
-rw-r--r--modules/shared_state.py7
-rw-r--r--modules/styles.py132
-rw-r--r--modules/sysinfo.py4
-rw-r--r--modules/textual_inversion/textual_inversion.py10
-rw-r--r--modules/torch_utils.py17
-rw-r--r--modules/txt2img.py65
-rw-r--r--modules/ui.py177
-rw-r--r--modules/ui_common.py132
-rw-r--r--modules/ui_extra_networks.py582
-rw-r--r--modules/ui_extra_networks_checkpoints.py8
-rw-r--r--modules/ui_extra_networks_hypernets.py6
-rw-r--r--modules/ui_extra_networks_textual_inversion.py5
-rw-r--r--modules/ui_extra_networks_user_metadata.py6
-rw-r--r--modules/ui_gradio_extensions.py19
-rw-r--r--modules/ui_loadsave.py7
-rw-r--r--modules/ui_postprocessing.py15
-rw-r--r--modules/ui_prompt_styles.py9
-rw-r--r--modules/ui_toprow.py13
-rw-r--r--modules/upscaler.py3
-rw-r--r--modules/upscaler_utils.py189
-rw-r--r--modules/util.py88
-rw-r--r--modules/xlmr.py5
-rw-r--r--modules/xlmr_m18.py4
-rw-r--r--modules/xpu_specific.py91
80 files changed, 2944 insertions, 2429 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index b3d74e51..4e656082 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -17,7 +17,7 @@ from fastapi.encoders import jsonable_encoder
from secrets import compare_digest
import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models
from modules.api import models
from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
@@ -31,7 +31,7 @@ from typing import Any
import piexif
import piexif.helper
from contextlib import closing
-
+from modules.progress import create_task_id, add_task_to_queue, start_task, finish_task, current_task
def script_name_to_index(name, scripts):
try:
@@ -230,6 +230,7 @@ class Api:
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem])
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem])
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
+ self.add_api_route("/sdapi/v1/refresh-embeddings", self.refresh_embeddings, methods=["POST"])
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
@@ -251,6 +252,24 @@ class Api:
self.default_script_arg_txt2img = []
self.default_script_arg_img2img = []
+ txt2img_script_runner = scripts.scripts_txt2img
+ img2img_script_runner = scripts.scripts_img2img
+
+ if not txt2img_script_runner.scripts or not img2img_script_runner.scripts:
+ ui.create_ui()
+
+ if not txt2img_script_runner.scripts:
+ txt2img_script_runner.initialize_scripts(False)
+ if not self.default_script_arg_txt2img:
+ self.default_script_arg_txt2img = self.init_default_script_args(txt2img_script_runner)
+
+ if not img2img_script_runner.scripts:
+ img2img_script_runner.initialize_scripts(True)
+ if not self.default_script_arg_img2img:
+ self.default_script_arg_img2img = self.init_default_script_args(img2img_script_runner)
+
+
+
def add_api_route(self, path: str, endpoint, **kwargs):
if shared.cmd_opts.api_auth:
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
@@ -312,8 +331,13 @@ class Api:
script_args[script.args_from:script.args_to] = ui_default_values
return script_args
- def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner):
+ def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner, *, input_script_args=None):
script_args = default_script_args.copy()
+
+ if input_script_args is not None:
+ for index, value in input_script_args.items():
+ script_args[index] = value
+
# position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
if selectable_scripts:
script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
@@ -335,13 +359,83 @@ class Api:
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
return script_args
+ def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_script_args=None):
+ """Processes `infotext` field from the `request`, and sets other fields of the `request` accoring to what's in infotext.
+
+ If request already has a field set, and that field is encountered in infotext too, the value from infotext is ignored.
+
+ Additionally, fills `mentioned_script_args` dict with index: value pairs for script arguments read from infotext.
+ """
+
+ if not request.infotext:
+ return {}
+
+ possible_fields = infotext_utils.paste_fields[tabname]["fields"]
+ set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have differenrt names for this
+ params = infotext_utils.parse_generation_parameters(request.infotext)
+
+ def get_field_value(field, params):
+ value = field.function(params) if field.function else params.get(field.label)
+ if value is None:
+ return None
+
+ if field.api in request.__fields__:
+ target_type = request.__fields__[field.api].type_
+ else:
+ target_type = type(field.component.value)
+
+ if target_type == type(None):
+ return None
+
+ if isinstance(value, dict) and value.get('__type__') == 'generic_update': # this is a gradio.update rather than a value
+ value = value.get('value')
+
+ if value is not None and not isinstance(value, target_type):
+ value = target_type(value)
+
+ return value
+
+ for field in possible_fields:
+ if not field.api:
+ continue
+
+ if field.api in set_fields:
+ continue
+
+ value = get_field_value(field, params)
+ if value is not None:
+ setattr(request, field.api, value)
+
+ if request.override_settings is None:
+ request.override_settings = {}
+
+ overriden_settings = infotext_utils.get_override_settings(params)
+ for _, setting_name, value in overriden_settings:
+ if setting_name not in request.override_settings:
+ request.override_settings[setting_name] = value
+
+ if script_runner is not None and mentioned_script_args is not None:
+ indexes = {v: i for i, v in enumerate(script_runner.inputs)}
+ script_fields = ((field, indexes[field.component]) for field in possible_fields if field.component in indexes)
+
+ for field, index in script_fields:
+ value = get_field_value(field, params)
+
+ if value is None:
+ continue
+
+ mentioned_script_args[index] = value
+
+ return params
+
def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
+ task_id = txt2imgreq.force_task_id or create_task_id("txt2img")
+
script_runner = scripts.scripts_txt2img
- if not script_runner.scripts:
- script_runner.initialize_scripts(False)
- ui.create_ui()
- if not self.default_script_arg_txt2img:
- self.default_script_arg_txt2img = self.init_default_script_args(script_runner)
+
+ infotext_script_args = {}
+ self.apply_infotext(txt2imgreq, "txt2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
+
selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)
populate = txt2imgreq.copy(update={ # Override __init__ params
@@ -356,12 +450,15 @@ class Api:
args.pop('script_name', None)
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
args.pop('alwayson_scripts', None)
+ args.pop('infotext', None)
- script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner)
+ script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)
send_images = args.pop('send_images', True)
args.pop('save_images', None)
+ add_task_to_queue(task_id)
+
with self.queue_lock:
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
p.is_api = True
@@ -371,12 +468,14 @@ class Api:
try:
shared.state.begin(job="scripts_txt2img")
+ start_task(task_id)
if selectable_scripts is not None:
p.script_args = script_args
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
else:
p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p)
+ finish_task(task_id)
finally:
shared.state.end()
shared.total_tqdm.clear()
@@ -386,6 +485,8 @@ class Api:
return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
+ task_id = img2imgreq.force_task_id or create_task_id("img2img")
+
init_images = img2imgreq.init_images
if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found")
@@ -395,11 +496,10 @@ class Api:
mask = decode_base64_to_image(mask)
script_runner = scripts.scripts_img2img
- if not script_runner.scripts:
- script_runner.initialize_scripts(True)
- ui.create_ui()
- if not self.default_script_arg_img2img:
- self.default_script_arg_img2img = self.init_default_script_args(script_runner)
+
+ infotext_script_args = {}
+ self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
+
selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
populate = img2imgreq.copy(update={ # Override __init__ params
@@ -416,12 +516,15 @@ class Api:
args.pop('script_name', None)
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
args.pop('alwayson_scripts', None)
+ args.pop('infotext', None)
- script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner)
+ script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)
send_images = args.pop('send_images', True)
args.pop('save_images', None)
+ add_task_to_queue(task_id)
+
with self.queue_lock:
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
p.init_images = [decode_base64_to_image(x) for x in init_images]
@@ -432,12 +535,14 @@ class Api:
try:
shared.state.begin(job="scripts_img2img")
+ start_task(task_id)
if selectable_scripts is not None:
p.script_args = script_args
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
else:
p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p)
+ finish_task(task_id)
finally:
shared.state.end()
shared.total_tqdm.clear()
@@ -480,7 +585,7 @@ class Api:
if geninfo is None:
geninfo = ""
- params = generation_parameters_copypaste.parse_generation_parameters(geninfo)
+ params = infotext_utils.parse_generation_parameters(geninfo)
script_callbacks.infotext_pasted_callback(geninfo, params)
return models.PNGInfoResponse(info=geninfo, items=items, parameters=params)
@@ -511,7 +616,7 @@ class Api:
if shared.state.current_image and not req.skip_current_image:
current_image = encode_pil_to_base64(shared.state.current_image)
- return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
+ return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo, current_task=current_task)
def interrogateapi(self, interrogatereq: models.InterrogateRequest):
image_b64 = interrogatereq.image
@@ -643,6 +748,10 @@ class Api:
"skipped": convert_embeddings(db.skipped_embeddings),
}
+ def refresh_embeddings(self):
+ with self.queue_lock:
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
+
def refresh_checkpoints(self):
with self.queue_lock:
shared.refresh_checkpoints()
@@ -775,7 +884,15 @@ class Api:
def launch(self, server_name, port, root_path):
self.app.include_router(self.router)
- uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)
+ uvicorn.run(
+ self.app,
+ host=server_name,
+ port=port,
+ timeout_keep_alive=shared.cmd_opts.timeout_keep_alive,
+ root_path=root_path,
+ ssl_keyfile=shared.cmd_opts.tls_keyfile,
+ ssl_certfile=shared.cmd_opts.tls_certfile
+ )
def kill_webui(self):
restart.stop_program()
diff --git a/modules/api/models.py b/modules/api/models.py
index 33894b3e..16edf11c 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -107,6 +107,8 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
{"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}},
+ {"key": "force_task_id", "type": str, "default": None},
+ {"key": "infotext", "type": str, "default": None},
]
).generate_model()
@@ -124,6 +126,8 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
{"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}},
+ {"key": "force_task_id", "type": str, "default": None},
+ {"key": "infotext", "type": str, "default": None},
]
).generate_model()
diff --git a/modules/cache.py b/modules/cache.py
index 2d37e7b9..a9822a0e 100644
--- a/modules/cache.py
+++ b/modules/cache.py
@@ -62,16 +62,15 @@ def cache(subsection):
if cache_data is None:
with cache_lock:
if cache_data is None:
- if not os.path.isfile(cache_filename):
+ try:
+ with open(cache_filename, "r", encoding="utf8") as file:
+ cache_data = json.load(file)
+ except FileNotFoundError:
+ cache_data = {}
+ except Exception:
+ os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json"))
+ print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache')
cache_data = {}
- else:
- try:
- with open(cache_filename, "r", encoding="utf8") as file:
- cache_data = json.load(file)
- except Exception:
- os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json"))
- print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache')
- cache_data = {}
s = cache_data.get(subsection, {})
cache_data[subsection] = s
diff --git a/modules/call_queue.py b/modules/call_queue.py
index ddf0d573..bcd7c546 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -78,6 +78,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
shared.state.skipped = False
shared.state.interrupted = False
+ shared.state.stopping_generation = False
shared.state.job_count = 0
if not add_stats:
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index da93eb26..f1251b6c 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -77,7 +77,9 @@ parser.add_argument("--port", type=int, help="launch gradio with given server po
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json'))
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
-parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
+parser.add_argument("--freeze-settings", action='store_true', help="disable editing of all settings globally", default=False)
+parser.add_argument("--freeze-settings-in-sections", type=str, help='disable editing settings in specific sections of the settings page by specifying a comma-delimited list such like "saving-images,upscaling". The list of setting names can be found in the modules/shared_options.py file', default=None)
+parser.add_argument("--freeze-specific-settings", type=str, help='disable editing of individual settings by specifying a comma-delimited list like "samples_save,samples_format". The list of setting names can be found in the config.json file', default=None)
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
@@ -86,7 +88,7 @@ parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anythin
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it", default=[data_path])
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
-parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
+parser.add_argument("--styles-file", type=str, action='append', help="path or wildcard path of styles files, allow multiple entries.", default=[])
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
diff --git a/modules/codeformer/codeformer_arch.py b/modules/codeformer/codeformer_arch.py
deleted file mode 100644
index 12db6814..00000000
--- a/modules/codeformer/codeformer_arch.py
+++ /dev/null
@@ -1,276 +0,0 @@
-# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
-
-import math
-import torch
-from torch import nn, Tensor
-import torch.nn.functional as F
-from typing import Optional
-
-from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock
-from basicsr.utils.registry import ARCH_REGISTRY
-
-def calc_mean_std(feat, eps=1e-5):
- """Calculate mean and std for adaptive_instance_normalization.
-
- Args:
- feat (Tensor): 4D tensor.
- eps (float): A small value added to the variance to avoid
- divide-by-zero. Default: 1e-5.
- """
- size = feat.size()
- assert len(size) == 4, 'The input feature should be 4D tensor.'
- b, c = size[:2]
- feat_var = feat.view(b, c, -1).var(dim=2) + eps
- feat_std = feat_var.sqrt().view(b, c, 1, 1)
- feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
- return feat_mean, feat_std
-
-
-def adaptive_instance_normalization(content_feat, style_feat):
- """Adaptive instance normalization.
-
- Adjust the reference features to have the similar color and illuminations
- as those in the degradate features.
-
- Args:
- content_feat (Tensor): The reference feature.
- style_feat (Tensor): The degradate features.
- """
- size = content_feat.size()
- style_mean, style_std = calc_mean_std(style_feat)
- content_mean, content_std = calc_mean_std(content_feat)
- normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
- return normalized_feat * style_std.expand(size) + style_mean.expand(size)
-
-
-class PositionEmbeddingSine(nn.Module):
- """
- This is a more standard version of the position embedding, very similar to the one
- used by the Attention is all you need paper, generalized to work on images.
- """
-
- def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
- super().__init__()
- self.num_pos_feats = num_pos_feats
- self.temperature = temperature
- self.normalize = normalize
- if scale is not None and normalize is False:
- raise ValueError("normalize should be True if scale is passed")
- if scale is None:
- scale = 2 * math.pi
- self.scale = scale
-
- def forward(self, x, mask=None):
- if mask is None:
- mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
- not_mask = ~mask
- y_embed = not_mask.cumsum(1, dtype=torch.float32)
- x_embed = not_mask.cumsum(2, dtype=torch.float32)
- if self.normalize:
- eps = 1e-6
- y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
- x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
-
- dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
- dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
-
- pos_x = x_embed[:, :, :, None] / dim_t
- pos_y = y_embed[:, :, :, None] / dim_t
- pos_x = torch.stack(
- (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
- ).flatten(3)
- pos_y = torch.stack(
- (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
- ).flatten(3)
- pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
- return pos
-
-def _get_activation_fn(activation):
- """Return an activation function given a string"""
- if activation == "relu":
- return F.relu
- if activation == "gelu":
- return F.gelu
- if activation == "glu":
- return F.glu
- raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
-
-
-class TransformerSALayer(nn.Module):
- def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
- super().__init__()
- self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
- # Implementation of Feedforward model - MLP
- self.linear1 = nn.Linear(embed_dim, dim_mlp)
- self.dropout = nn.Dropout(dropout)
- self.linear2 = nn.Linear(dim_mlp, embed_dim)
-
- self.norm1 = nn.LayerNorm(embed_dim)
- self.norm2 = nn.LayerNorm(embed_dim)
- self.dropout1 = nn.Dropout(dropout)
- self.dropout2 = nn.Dropout(dropout)
-
- self.activation = _get_activation_fn(activation)
-
- def with_pos_embed(self, tensor, pos: Optional[Tensor]):
- return tensor if pos is None else tensor + pos
-
- def forward(self, tgt,
- tgt_mask: Optional[Tensor] = None,
- tgt_key_padding_mask: Optional[Tensor] = None,
- query_pos: Optional[Tensor] = None):
-
- # self attention
- tgt2 = self.norm1(tgt)
- q = k = self.with_pos_embed(tgt2, query_pos)
- tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
- key_padding_mask=tgt_key_padding_mask)[0]
- tgt = tgt + self.dropout1(tgt2)
-
- # ffn
- tgt2 = self.norm2(tgt)
- tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
- tgt = tgt + self.dropout2(tgt2)
- return tgt
-
-class Fuse_sft_block(nn.Module):
- def __init__(self, in_ch, out_ch):
- super().__init__()
- self.encode_enc = ResBlock(2*in_ch, out_ch)
-
- self.scale = nn.Sequential(
- nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
- nn.LeakyReLU(0.2, True),
- nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
-
- self.shift = nn.Sequential(
- nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
- nn.LeakyReLU(0.2, True),
- nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
-
- def forward(self, enc_feat, dec_feat, w=1):
- enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
- scale = self.scale(enc_feat)
- shift = self.shift(enc_feat)
- residual = w * (dec_feat * scale + shift)
- out = dec_feat + residual
- return out
-
-
-@ARCH_REGISTRY.register()
-class CodeFormer(VQAutoEncoder):
- def __init__(self, dim_embd=512, n_head=8, n_layers=9,
- codebook_size=1024, latent_size=256,
- connect_list=('32', '64', '128', '256'),
- fix_modules=('quantize', 'generator')):
- super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
-
- if fix_modules is not None:
- for module in fix_modules:
- for param in getattr(self, module).parameters():
- param.requires_grad = False
-
- self.connect_list = connect_list
- self.n_layers = n_layers
- self.dim_embd = dim_embd
- self.dim_mlp = dim_embd*2
-
- self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
- self.feat_emb = nn.Linear(256, self.dim_embd)
-
- # transformer
- self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
- for _ in range(self.n_layers)])
-
- # logits_predict head
- self.idx_pred_layer = nn.Sequential(
- nn.LayerNorm(dim_embd),
- nn.Linear(dim_embd, codebook_size, bias=False))
-
- self.channels = {
- '16': 512,
- '32': 256,
- '64': 256,
- '128': 128,
- '256': 128,
- '512': 64,
- }
-
- # after second residual block for > 16, before attn layer for ==16
- self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
- # after first residual block for > 16, before attn layer for ==16
- self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
-
- # fuse_convs_dict
- self.fuse_convs_dict = nn.ModuleDict()
- for f_size in self.connect_list:
- in_ch = self.channels[f_size]
- self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
-
- def _init_weights(self, module):
- if isinstance(module, (nn.Linear, nn.Embedding)):
- module.weight.data.normal_(mean=0.0, std=0.02)
- if isinstance(module, nn.Linear) and module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.LayerNorm):
- module.bias.data.zero_()
- module.weight.data.fill_(1.0)
-
- def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
- # ################### Encoder #####################
- enc_feat_dict = {}
- out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
- for i, block in enumerate(self.encoder.blocks):
- x = block(x)
- if i in out_list:
- enc_feat_dict[str(x.shape[-1])] = x.clone()
-
- lq_feat = x
- # ################# Transformer ###################
- # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
- pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
- # BCHW -> BC(HW) -> (HW)BC
- feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
- query_emb = feat_emb
- # Transformer encoder
- for layer in self.ft_layers:
- query_emb = layer(query_emb, query_pos=pos_emb)
-
- # output logits
- logits = self.idx_pred_layer(query_emb) # (hw)bn
- logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
-
- if code_only: # for training stage II
- # logits doesn't need softmax before cross_entropy loss
- return logits, lq_feat
-
- # ################# Quantization ###################
- # if self.training:
- # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
- # # b(hw)c -> bc(hw) -> bchw
- # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
- # ------------
- soft_one_hot = F.softmax(logits, dim=2)
- _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
- quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
- # preserve gradients
- # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
-
- if detach_16:
- quant_feat = quant_feat.detach() # for training stage III
- if adain:
- quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
-
- # ################## Generator ####################
- x = quant_feat
- fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
-
- for i, block in enumerate(self.generator.blocks):
- x = block(x)
- if i in fuse_list: # fuse after i-th block
- f_size = str(x.shape[-1])
- if w>0:
- x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
- out = x
- # logits doesn't need softmax before cross_entropy loss
- return out, logits, lq_feat
diff --git a/modules/codeformer/vqgan_arch.py b/modules/codeformer/vqgan_arch.py
deleted file mode 100644
index 09ee6660..00000000
--- a/modules/codeformer/vqgan_arch.py
+++ /dev/null
@@ -1,435 +0,0 @@
-# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
-
-'''
-VQGAN code, adapted from the original created by the Unleashing Transformers authors:
-https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
-
-'''
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from basicsr.utils import get_root_logger
-from basicsr.utils.registry import ARCH_REGISTRY
-
-def normalize(in_channels):
- return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
-
-
-@torch.jit.script
-def swish(x):
- return x*torch.sigmoid(x)
-
-
-# Define VQVAE classes
-class VectorQuantizer(nn.Module):
- def __init__(self, codebook_size, emb_dim, beta):
- super(VectorQuantizer, self).__init__()
- self.codebook_size = codebook_size # number of embeddings
- self.emb_dim = emb_dim # dimension of embedding
- self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
- self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
- self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
-
- def forward(self, z):
- # reshape z -> (batch, height, width, channel) and flatten
- z = z.permute(0, 2, 3, 1).contiguous()
- z_flattened = z.view(-1, self.emb_dim)
-
- # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
- d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
- 2 * torch.matmul(z_flattened, self.embedding.weight.t())
-
- mean_distance = torch.mean(d)
- # find closest encodings
- # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
- min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
- # [0-1], higher score, higher confidence
- min_encoding_scores = torch.exp(-min_encoding_scores/10)
-
- min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
- min_encodings.scatter_(1, min_encoding_indices, 1)
-
- # get quantized latent vectors
- z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
- # compute loss for embedding
- loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
- # preserve gradients
- z_q = z + (z_q - z).detach()
-
- # perplexity
- e_mean = torch.mean(min_encodings, dim=0)
- perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
- # reshape back to match original input shape
- z_q = z_q.permute(0, 3, 1, 2).contiguous()
-
- return z_q, loss, {
- "perplexity": perplexity,
- "min_encodings": min_encodings,
- "min_encoding_indices": min_encoding_indices,
- "min_encoding_scores": min_encoding_scores,
- "mean_distance": mean_distance
- }
-
- def get_codebook_feat(self, indices, shape):
- # input indices: batch*token_num -> (batch*token_num)*1
- # shape: batch, height, width, channel
- indices = indices.view(-1,1)
- min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
- min_encodings.scatter_(1, indices, 1)
- # get quantized latent vectors
- z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
-
- if shape is not None: # reshape back to match original input shape
- z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
-
- return z_q
-
-
-class GumbelQuantizer(nn.Module):
- def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
- super().__init__()
- self.codebook_size = codebook_size # number of embeddings
- self.emb_dim = emb_dim # dimension of embedding
- self.straight_through = straight_through
- self.temperature = temp_init
- self.kl_weight = kl_weight
- self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
- self.embed = nn.Embedding(codebook_size, emb_dim)
-
- def forward(self, z):
- hard = self.straight_through if self.training else True
-
- logits = self.proj(z)
-
- soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
-
- z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
-
- # + kl divergence to the prior loss
- qy = F.softmax(logits, dim=1)
- diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
- min_encoding_indices = soft_one_hot.argmax(dim=1)
-
- return z_q, diff, {
- "min_encoding_indices": min_encoding_indices
- }
-
-
-class Downsample(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
-
- def forward(self, x):
- pad = (0, 1, 0, 1)
- x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
- x = self.conv(x)
- return x
-
-
-class Upsample(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
-
- def forward(self, x):
- x = F.interpolate(x, scale_factor=2.0, mode="nearest")
- x = self.conv(x)
-
- return x
-
-
-class ResBlock(nn.Module):
- def __init__(self, in_channels, out_channels=None):
- super(ResBlock, self).__init__()
- self.in_channels = in_channels
- self.out_channels = in_channels if out_channels is None else out_channels
- self.norm1 = normalize(in_channels)
- self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
- self.norm2 = normalize(out_channels)
- self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
- if self.in_channels != self.out_channels:
- self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
-
- def forward(self, x_in):
- x = x_in
- x = self.norm1(x)
- x = swish(x)
- x = self.conv1(x)
- x = self.norm2(x)
- x = swish(x)
- x = self.conv2(x)
- if self.in_channels != self.out_channels:
- x_in = self.conv_out(x_in)
-
- return x + x_in
-
-
-class AttnBlock(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.in_channels = in_channels
-
- self.norm = normalize(in_channels)
- self.q = torch.nn.Conv2d(
- in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0
- )
- self.k = torch.nn.Conv2d(
- in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0
- )
- self.v = torch.nn.Conv2d(
- in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0
- )
- self.proj_out = torch.nn.Conv2d(
- in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0
- )
-
- def forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q = self.q(h_)
- k = self.k(h_)
- v = self.v(h_)
-
- # compute attention
- b, c, h, w = q.shape
- q = q.reshape(b, c, h*w)
- q = q.permute(0, 2, 1)
- k = k.reshape(b, c, h*w)
- w_ = torch.bmm(q, k)
- w_ = w_ * (int(c)**(-0.5))
- w_ = F.softmax(w_, dim=2)
-
- # attend to values
- v = v.reshape(b, c, h*w)
- w_ = w_.permute(0, 2, 1)
- h_ = torch.bmm(v, w_)
- h_ = h_.reshape(b, c, h, w)
-
- h_ = self.proj_out(h_)
-
- return x+h_
-
-
-class Encoder(nn.Module):
- def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
- super().__init__()
- self.nf = nf
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- self.resolution = resolution
- self.attn_resolutions = attn_resolutions
-
- curr_res = self.resolution
- in_ch_mult = (1,)+tuple(ch_mult)
-
- blocks = []
- # initial convultion
- blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
-
- # residual and downsampling blocks, with attention on smaller res (16x16)
- for i in range(self.num_resolutions):
- block_in_ch = nf * in_ch_mult[i]
- block_out_ch = nf * ch_mult[i]
- for _ in range(self.num_res_blocks):
- blocks.append(ResBlock(block_in_ch, block_out_ch))
- block_in_ch = block_out_ch
- if curr_res in attn_resolutions:
- blocks.append(AttnBlock(block_in_ch))
-
- if i != self.num_resolutions - 1:
- blocks.append(Downsample(block_in_ch))
- curr_res = curr_res // 2
-
- # non-local attention block
- blocks.append(ResBlock(block_in_ch, block_in_ch))
- blocks.append(AttnBlock(block_in_ch))
- blocks.append(ResBlock(block_in_ch, block_in_ch))
-
- # normalise and convert to latent size
- blocks.append(normalize(block_in_ch))
- blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
- self.blocks = nn.ModuleList(blocks)
-
- def forward(self, x):
- for block in self.blocks:
- x = block(x)
-
- return x
-
-
-class Generator(nn.Module):
- def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
- super().__init__()
- self.nf = nf
- self.ch_mult = ch_mult
- self.num_resolutions = len(self.ch_mult)
- self.num_res_blocks = res_blocks
- self.resolution = img_size
- self.attn_resolutions = attn_resolutions
- self.in_channels = emb_dim
- self.out_channels = 3
- block_in_ch = self.nf * self.ch_mult[-1]
- curr_res = self.resolution // 2 ** (self.num_resolutions-1)
-
- blocks = []
- # initial conv
- blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
-
- # non-local attention block
- blocks.append(ResBlock(block_in_ch, block_in_ch))
- blocks.append(AttnBlock(block_in_ch))
- blocks.append(ResBlock(block_in_ch, block_in_ch))
-
- for i in reversed(range(self.num_resolutions)):
- block_out_ch = self.nf * self.ch_mult[i]
-
- for _ in range(self.num_res_blocks):
- blocks.append(ResBlock(block_in_ch, block_out_ch))
- block_in_ch = block_out_ch
-
- if curr_res in self.attn_resolutions:
- blocks.append(AttnBlock(block_in_ch))
-
- if i != 0:
- blocks.append(Upsample(block_in_ch))
- curr_res = curr_res * 2
-
- blocks.append(normalize(block_in_ch))
- blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
-
- self.blocks = nn.ModuleList(blocks)
-
-
- def forward(self, x):
- for block in self.blocks:
- x = block(x)
-
- return x
-
-
-@ARCH_REGISTRY.register()
-class VQAutoEncoder(nn.Module):
- def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
- beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
- super().__init__()
- logger = get_root_logger()
- self.in_channels = 3
- self.nf = nf
- self.n_blocks = res_blocks
- self.codebook_size = codebook_size
- self.embed_dim = emb_dim
- self.ch_mult = ch_mult
- self.resolution = img_size
- self.attn_resolutions = attn_resolutions or [16]
- self.quantizer_type = quantizer
- self.encoder = Encoder(
- self.in_channels,
- self.nf,
- self.embed_dim,
- self.ch_mult,
- self.n_blocks,
- self.resolution,
- self.attn_resolutions
- )
- if self.quantizer_type == "nearest":
- self.beta = beta #0.25
- self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
- elif self.quantizer_type == "gumbel":
- self.gumbel_num_hiddens = emb_dim
- self.straight_through = gumbel_straight_through
- self.kl_weight = gumbel_kl_weight
- self.quantize = GumbelQuantizer(
- self.codebook_size,
- self.embed_dim,
- self.gumbel_num_hiddens,
- self.straight_through,
- self.kl_weight
- )
- self.generator = Generator(
- self.nf,
- self.embed_dim,
- self.ch_mult,
- self.n_blocks,
- self.resolution,
- self.attn_resolutions
- )
-
- if model_path is not None:
- chkpt = torch.load(model_path, map_location='cpu')
- if 'params_ema' in chkpt:
- self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
- logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
- elif 'params' in chkpt:
- self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
- logger.info(f'vqgan is loaded from: {model_path} [params]')
- else:
- raise ValueError('Wrong params!')
-
-
- def forward(self, x):
- x = self.encoder(x)
- quant, codebook_loss, quant_stats = self.quantize(x)
- x = self.generator(quant)
- return x, codebook_loss, quant_stats
-
-
-
-# patch based discriminator
-@ARCH_REGISTRY.register()
-class VQGANDiscriminator(nn.Module):
- def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
- super().__init__()
-
- layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
- ndf_mult = 1
- ndf_mult_prev = 1
- for n in range(1, n_layers): # gradually increase the number of filters
- ndf_mult_prev = ndf_mult
- ndf_mult = min(2 ** n, 8)
- layers += [
- nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
- nn.BatchNorm2d(ndf * ndf_mult),
- nn.LeakyReLU(0.2, True)
- ]
-
- ndf_mult_prev = ndf_mult
- ndf_mult = min(2 ** n_layers, 8)
-
- layers += [
- nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(ndf * ndf_mult),
- nn.LeakyReLU(0.2, True)
- ]
-
- layers += [
- nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
- self.main = nn.Sequential(*layers)
-
- if model_path is not None:
- chkpt = torch.load(model_path, map_location='cpu')
- if 'params_d' in chkpt:
- self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
- elif 'params' in chkpt:
- self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
- else:
- raise ValueError('Wrong params!')
-
- def forward(self, x):
- return self.main(x)
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index da42b5e9..44b84618 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -1,132 +1,64 @@
-import os
+from __future__ import annotations
-import cv2
-import torch
-
-import modules.face_restoration
-import modules.shared
-from modules import shared, devices, modelloader, errors
-from modules.paths import models_path
-
-# codeformer people made a choice to include modified basicsr library to their project which makes
-# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
-# I am making a choice to include some files from codeformer to work around this issue.
-model_dir = "Codeformer"
-model_path = os.path.join(models_path, model_dir)
-model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
-
-codeformer = None
-
-
-def setup_model(dirname):
- os.makedirs(model_path, exist_ok=True)
-
- path = modules.paths.paths.get("CodeFormer", None)
- if path is None:
- return
-
- try:
- from torchvision.transforms.functional import normalize
- from modules.codeformer.codeformer_arch import CodeFormer
- from basicsr.utils import img2tensor, tensor2img
- from facelib.utils.face_restoration_helper import FaceRestoreHelper
- from facelib.detection.retinaface import retinaface
-
- net_class = CodeFormer
-
- class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
- def name(self):
- return "CodeFormer"
-
- def __init__(self, dirname):
- self.net = None
- self.face_helper = None
- self.cmd_dir = dirname
+import logging
- def create_models(self):
-
- if self.net is not None and self.face_helper is not None:
- self.net.to(devices.device_codeformer)
- return self.net, self.face_helper
- model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth'])
- if len(model_paths) != 0:
- ckpt_path = model_paths[0]
- else:
- print("Unable to load codeformer model.")
- return None, None
- net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
- checkpoint = torch.load(ckpt_path)['params_ema']
- net.load_state_dict(checkpoint)
- net.eval()
-
- if hasattr(retinaface, 'device'):
- retinaface.device = devices.device_codeformer
- face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
-
- self.net = net
- self.face_helper = face_helper
-
- return net, face_helper
-
- def send_model_to(self, device):
- self.net.to(device)
- self.face_helper.face_det.to(device)
- self.face_helper.face_parse.to(device)
-
- def restore(self, np_image, w=None):
- np_image = np_image[:, :, ::-1]
-
- original_resolution = np_image.shape[0:2]
+import torch
- self.create_models()
- if self.net is None or self.face_helper is None:
- return np_image
+from modules import (
+ devices,
+ errors,
+ face_restoration,
+ face_restoration_utils,
+ modelloader,
+ shared,
+)
- self.send_model_to(devices.device_codeformer)
+logger = logging.getLogger(__name__)
- self.face_helper.clean_all()
- self.face_helper.read_image(np_image)
- self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
- self.face_helper.align_warp_face()
+model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
+model_download_name = 'codeformer-v0.1.0.pth'
- for cropped_face in self.face_helper.cropped_faces:
- cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
- normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
- cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
+# used by e.g. postprocessing_codeformer.py
+codeformer: face_restoration.FaceRestoration | None = None
- try:
- with torch.no_grad():
- output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
- restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
- del output
- devices.torch_gc()
- except Exception:
- errors.report('Failed inference for CodeFormer', exc_info=True)
- restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
- restored_face = restored_face.astype('uint8')
- self.face_helper.add_restored_face(restored_face)
+class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
+ def name(self):
+ return "CodeFormer"
- self.face_helper.get_inverse_affine(None)
+ def load_net(self) -> torch.Module:
+ for model_path in modelloader.load_models(
+ model_path=self.model_path,
+ model_url=model_url,
+ command_path=self.model_path,
+ download_name=model_download_name,
+ ext_filter=['.pth'],
+ ):
+ return modelloader.load_spandrel_model(
+ model_path,
+ device=devices.device_codeformer,
+ expected_architecture='CodeFormer',
+ ).model
+ raise ValueError("No codeformer model found")
- restored_img = self.face_helper.paste_faces_to_input_image()
- restored_img = restored_img[:, :, ::-1]
+ def get_device(self):
+ return devices.device_codeformer
- if original_resolution != restored_img.shape[0:2]:
- restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
+ def restore(self, np_image, w: float | None = None):
+ if w is None:
+ w = getattr(shared.opts, "code_former_weight", 0.5)
- self.face_helper.clean_all()
+ def restore_face(cropped_face_t):
+ assert self.net is not None
+ return self.net(cropped_face_t, w=w, adain=True)[0]
- if shared.opts.face_restoration_unload:
- self.send_model_to(devices.cpu)
+ return self.restore_with_helper(np_image, restore_face)
- return restored_img
- global codeformer
+def setup_model(dirname: str) -> None:
+ global codeformer
+ try:
codeformer = FaceRestorerCodeFormer(dirname)
shared.face_restorers.append(codeformer)
-
except Exception:
errors.report("Error setting up CodeFormer", exc_info=True)
-
- # sys.path = stored_sys_path
diff --git a/modules/dat_model.py b/modules/dat_model.py
new file mode 100644
index 00000000..495d5f49
--- /dev/null
+++ b/modules/dat_model.py
@@ -0,0 +1,79 @@
+import os
+
+from modules import modelloader, errors
+from modules.shared import cmd_opts, opts
+from modules.upscaler import Upscaler, UpscalerData
+from modules.upscaler_utils import upscale_with_model
+
+
+class UpscalerDAT(Upscaler):
+ def __init__(self, user_path):
+ self.name = "DAT"
+ self.user_path = user_path
+ self.scalers = []
+ super().__init__()
+
+ for file in self.find_models(ext_filter=[".pt", ".pth"]):
+ name = modelloader.friendly_name(file)
+ scaler_data = UpscalerData(name, file, upscaler=self, scale=None)
+ self.scalers.append(scaler_data)
+
+ for model in get_dat_models(self):
+ if model.name in opts.dat_enabled_models:
+ self.scalers.append(model)
+
+ def do_upscale(self, img, path):
+ try:
+ info = self.load_model(path)
+ except Exception:
+ errors.report(f"Unable to load DAT model {path}", exc_info=True)
+ return img
+
+ model_descriptor = modelloader.load_spandrel_model(
+ info.local_data_path,
+ device=self.device,
+ prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
+ expected_architecture="DAT",
+ )
+ return upscale_with_model(
+ model_descriptor,
+ img,
+ tile_size=opts.DAT_tile,
+ tile_overlap=opts.DAT_tile_overlap,
+ )
+
+ def load_model(self, path):
+ for scaler in self.scalers:
+ if scaler.data_path == path:
+ if scaler.local_data_path.startswith("http"):
+ scaler.local_data_path = modelloader.load_file_from_url(
+ scaler.data_path,
+ model_dir=self.model_download_path,
+ )
+ if not os.path.exists(scaler.local_data_path):
+ raise FileNotFoundError(f"DAT data missing: {scaler.local_data_path}")
+ return scaler
+ raise ValueError(f"Unable to find model info: {path}")
+
+
+def get_dat_models(scaler):
+ return [
+ UpscalerData(
+ name="DAT x2",
+ path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x2.pth",
+ scale=2,
+ upscaler=scaler,
+ ),
+ UpscalerData(
+ name="DAT x3",
+ path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x3.pth",
+ scale=3,
+ upscaler=scaler,
+ ),
+ UpscalerData(
+ name="DAT x4",
+ path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x4.pth",
+ scale=4,
+ upscaler=scaler,
+ ),
+ ]
diff --git a/modules/devices.py b/modules/devices.py
index f1e56501..c737162a 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -23,6 +23,23 @@ def has_mps() -> bool:
return mac_specific.has_mps
+def cuda_no_autocast(device_id=None) -> bool:
+ if device_id is None:
+ device_id = get_cuda_device_id()
+ return (
+ torch.cuda.get_device_capability(device_id) == (7, 5)
+ and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
+ )
+
+
+def get_cuda_device_id():
+ return (
+ int(shared.cmd_opts.device_id)
+ if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
+ else 0
+ ) or torch.cuda.current_device()
+
+
def get_cuda_device_string():
if shared.cmd_opts.device_id is not None:
return f"cuda:{shared.cmd_opts.device_id}"
@@ -79,8 +96,7 @@ def enable_tf32():
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
- device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device()
- if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"):
+ if cuda_no_autocast():
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@@ -90,6 +106,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
cpu: torch.device = torch.device("cpu")
+fp8: bool = False
device: torch.device = None
device_interrogate: torch.device = None
device_gfpgan: torch.device = None
@@ -98,6 +115,7 @@ device_codeformer: torch.device = None
dtype: torch.dtype = torch.float16
dtype_vae: torch.dtype = torch.float16
dtype_unet: torch.dtype = torch.float16
+dtype_inference: torch.dtype = torch.float16
unet_needs_upcast = False
@@ -110,15 +128,89 @@ def cond_cast_float(input):
nv_rng = None
+patch_module_list = [
+ torch.nn.Linear,
+ torch.nn.Conv2d,
+ torch.nn.MultiheadAttention,
+ torch.nn.GroupNorm,
+ torch.nn.LayerNorm,
+]
+
+
+def manual_cast_forward(target_dtype):
+ def forward_wrapper(self, *args, **kwargs):
+ if any(
+ isinstance(arg, torch.Tensor) and arg.dtype != target_dtype
+ for arg in args
+ ):
+ args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
+ kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
+
+ org_dtype = target_dtype
+ for param in self.parameters():
+ if param.dtype != target_dtype:
+ org_dtype = param.dtype
+ break
+
+ if org_dtype != target_dtype:
+ self.to(target_dtype)
+ result = self.org_forward(*args, **kwargs)
+ if org_dtype != target_dtype:
+ self.to(org_dtype)
+
+ if target_dtype != dtype_inference:
+ if isinstance(result, tuple):
+ result = tuple(
+ i.to(dtype_inference)
+ if isinstance(i, torch.Tensor)
+ else i
+ for i in result
+ )
+ elif isinstance(result, torch.Tensor):
+ result = result.to(dtype_inference)
+ return result
+ return forward_wrapper
+
+
+@contextlib.contextmanager
+def manual_cast(target_dtype):
+ applied = False
+ for module_type in patch_module_list:
+ if hasattr(module_type, "org_forward"):
+ continue
+ applied = True
+ org_forward = module_type.forward
+ if module_type == torch.nn.MultiheadAttention:
+ module_type.forward = manual_cast_forward(torch.float32)
+ else:
+ module_type.forward = manual_cast_forward(target_dtype)
+ module_type.org_forward = org_forward
+ try:
+ yield None
+ finally:
+ if applied:
+ for module_type in patch_module_list:
+ if hasattr(module_type, "org_forward"):
+ module_type.forward = module_type.org_forward
+ delattr(module_type, "org_forward")
def autocast(disable=False):
if disable:
return contextlib.nullcontext()
- if dtype == torch.float32 or shared.cmd_opts.precision == "full":
+ if fp8 and device==cpu:
+ return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
+
+ if fp8 and dtype_inference == torch.float32:
+ return manual_cast(dtype)
+
+ if dtype == torch.float32 or dtype_inference == torch.float32:
return contextlib.nullcontext()
+ if has_xpu() or has_mps() or cuda_no_autocast():
+ return manual_cast(dtype)
+
return torch.autocast("cuda")
diff --git a/modules/errors.py b/modules/errors.py
index eb234a83..48aa13a1 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -107,8 +107,8 @@ def check_versions():
import torch
import gradio
- expected_torch_version = "2.0.0"
- expected_xformers_version = "0.0.20"
+ expected_torch_version = "2.1.2"
+ expected_xformers_version = "0.0.23.post1"
expected_gradio_version = "3.41.2"
if version.parse(torch.__version__) < version.parse(expected_torch_version):
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index 02a1727d..70041ab0 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -1,121 +1,7 @@
-import sys
-
-import numpy as np
-import torch
-from PIL import Image
-
-import modules.esrgan_model_arch as arch
-from modules import modelloader, images, devices
+from modules import modelloader, devices, errors
from modules.shared import opts
from modules.upscaler import Upscaler, UpscalerData
-
-
-def mod2normal(state_dict):
- # this code is copied from https://github.com/victorca25/iNNfer
- if 'conv_first.weight' in state_dict:
- crt_net = {}
- items = list(state_dict)
-
- crt_net['model.0.weight'] = state_dict['conv_first.weight']
- crt_net['model.0.bias'] = state_dict['conv_first.bias']
-
- for k in items.copy():
- if 'RDB' in k:
- ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
- if '.weight' in k:
- ori_k = ori_k.replace('.weight', '.0.weight')
- elif '.bias' in k:
- ori_k = ori_k.replace('.bias', '.0.bias')
- crt_net[ori_k] = state_dict[k]
- items.remove(k)
-
- crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
- crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
- crt_net['model.3.weight'] = state_dict['upconv1.weight']
- crt_net['model.3.bias'] = state_dict['upconv1.bias']
- crt_net['model.6.weight'] = state_dict['upconv2.weight']
- crt_net['model.6.bias'] = state_dict['upconv2.bias']
- crt_net['model.8.weight'] = state_dict['HRconv.weight']
- crt_net['model.8.bias'] = state_dict['HRconv.bias']
- crt_net['model.10.weight'] = state_dict['conv_last.weight']
- crt_net['model.10.bias'] = state_dict['conv_last.bias']
- state_dict = crt_net
- return state_dict
-
-
-def resrgan2normal(state_dict, nb=23):
- # this code is copied from https://github.com/victorca25/iNNfer
- if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
- re8x = 0
- crt_net = {}
- items = list(state_dict)
-
- crt_net['model.0.weight'] = state_dict['conv_first.weight']
- crt_net['model.0.bias'] = state_dict['conv_first.bias']
-
- for k in items.copy():
- if "rdb" in k:
- ori_k = k.replace('body.', 'model.1.sub.')
- ori_k = ori_k.replace('.rdb', '.RDB')
- if '.weight' in k:
- ori_k = ori_k.replace('.weight', '.0.weight')
- elif '.bias' in k:
- ori_k = ori_k.replace('.bias', '.0.bias')
- crt_net[ori_k] = state_dict[k]
- items.remove(k)
-
- crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
- crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
- crt_net['model.3.weight'] = state_dict['conv_up1.weight']
- crt_net['model.3.bias'] = state_dict['conv_up1.bias']
- crt_net['model.6.weight'] = state_dict['conv_up2.weight']
- crt_net['model.6.bias'] = state_dict['conv_up2.bias']
-
- if 'conv_up3.weight' in state_dict:
- # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
- re8x = 3
- crt_net['model.9.weight'] = state_dict['conv_up3.weight']
- crt_net['model.9.bias'] = state_dict['conv_up3.bias']
-
- crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
- crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
- crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
- crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
-
- state_dict = crt_net
- return state_dict
-
-
-def infer_params(state_dict):
- # this code is copied from https://github.com/victorca25/iNNfer
- scale2x = 0
- scalemin = 6
- n_uplayer = 0
- plus = False
-
- for block in list(state_dict):
- parts = block.split(".")
- n_parts = len(parts)
- if n_parts == 5 and parts[2] == "sub":
- nb = int(parts[3])
- elif n_parts == 3:
- part_num = int(parts[1])
- if (part_num > scalemin
- and parts[0] == "model"
- and parts[2] == "weight"):
- scale2x += 1
- if part_num > n_uplayer:
- n_uplayer = part_num
- out_nc = state_dict[block].shape[0]
- if not plus and "conv1x1" in block:
- plus = True
-
- nf = state_dict["model.0.weight"].shape[0]
- in_nc = state_dict["model.0.weight"].shape[1]
- out_nc = out_nc
- scale = 2 ** scale2x
-
- return in_nc, out_nc, nf, nb, plus, scale
+from modules.upscaler_utils import upscale_with_model
class UpscalerESRGAN(Upscaler):
@@ -143,12 +29,11 @@ class UpscalerESRGAN(Upscaler):
def do_upscale(self, img, selected_model):
try:
model = self.load_model(selected_model)
- except Exception as e:
- print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
+ except Exception:
+ errors.report(f"Unable to load ESRGAN model {selected_model}", exc_info=True)
return img
model.to(devices.device_esrgan)
- img = esrgan_upscale(model, img)
- return img
+ return esrgan_upscale(model, img)
def load_model(self, path: str):
if path.startswith("http"):
@@ -161,69 +46,17 @@ class UpscalerESRGAN(Upscaler):
else:
filename = path
- state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
-
- if "params_ema" in state_dict:
- state_dict = state_dict["params_ema"]
- elif "params" in state_dict:
- state_dict = state_dict["params"]
- num_conv = 16 if "realesr-animevideov3" in filename else 32
- model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
- model.load_state_dict(state_dict)
- model.eval()
- return model
-
- if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
- nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
- state_dict = resrgan2normal(state_dict, nb)
- elif "conv_first.weight" in state_dict:
- state_dict = mod2normal(state_dict)
- elif "model.0.weight" not in state_dict:
- raise Exception("The file is not a recognized ESRGAN model.")
-
- in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
-
- model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
- model.load_state_dict(state_dict)
- model.eval()
-
- return model
-
-
-def upscale_without_tiling(model, img):
- img = np.array(img)
- img = img[:, :, ::-1]
- img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
- img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(devices.device_esrgan)
- with torch.no_grad():
- output = model(img)
- output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
- output = 255. * np.moveaxis(output, 0, 2)
- output = output.astype(np.uint8)
- output = output[:, :, ::-1]
- return Image.fromarray(output, 'RGB')
+ return modelloader.load_spandrel_model(
+ filename,
+ device=('cpu' if devices.device_esrgan.type == 'mps' else None),
+ expected_architecture='ESRGAN',
+ )
def esrgan_upscale(model, img):
- if opts.ESRGAN_tile == 0:
- return upscale_without_tiling(model, img)
-
- grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
- newtiles = []
- scale_factor = 1
-
- for y, h, row in grid.tiles:
- newrow = []
- for tiledata in row:
- x, w, tile = tiledata
-
- output = upscale_without_tiling(model, tile)
- scale_factor = output.width // tile.width
-
- newrow.append([x * scale_factor, w * scale_factor, output])
- newtiles.append([y * scale_factor, h * scale_factor, newrow])
-
- newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
- output = images.combine_grid(newgrid)
- return output
+ return upscale_with_model(
+ model,
+ img,
+ tile_size=opts.ESRGAN_tile,
+ tile_overlap=opts.ESRGAN_tile_overlap,
+ )
diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py
deleted file mode 100644
index 2b9888ba..00000000
--- a/modules/esrgan_model_arch.py
+++ /dev/null
@@ -1,465 +0,0 @@
-# this file is adapted from https://github.com/victorca25/iNNfer
-
-from collections import OrderedDict
-import math
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-####################
-# RRDBNet Generator
-####################
-
-class RRDBNet(nn.Module):
- def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
- act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
- finalact=None, gaussian_noise=False, plus=False):
- super(RRDBNet, self).__init__()
- n_upscale = int(math.log(upscale, 2))
- if upscale == 3:
- n_upscale = 1
-
- self.resrgan_scale = 0
- if in_nc % 16 == 0:
- self.resrgan_scale = 1
- elif in_nc != 4 and in_nc % 4 == 0:
- self.resrgan_scale = 2
-
- fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
- rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
- norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
- gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
- LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
-
- if upsample_mode == 'upconv':
- upsample_block = upconv_block
- elif upsample_mode == 'pixelshuffle':
- upsample_block = pixelshuffle_block
- else:
- raise NotImplementedError(f'upsample mode [{upsample_mode}] is not found')
- if upscale == 3:
- upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
- else:
- upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
- HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
- HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
-
- outact = act(finalact) if finalact else None
-
- self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
- *upsampler, HR_conv0, HR_conv1, outact)
-
- def forward(self, x, outm=None):
- if self.resrgan_scale == 1:
- feat = pixel_unshuffle(x, scale=4)
- elif self.resrgan_scale == 2:
- feat = pixel_unshuffle(x, scale=2)
- else:
- feat = x
-
- return self.model(feat)
-
-
-class RRDB(nn.Module):
- """
- Residual in Residual Dense Block
- (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
- """
-
- def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
- norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
- spectral_norm=False, gaussian_noise=False, plus=False):
- super(RRDB, self).__init__()
- # This is for backwards compatibility with existing models
- if nr == 3:
- self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
- norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
- gaussian_noise=gaussian_noise, plus=plus)
- self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
- norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
- gaussian_noise=gaussian_noise, plus=plus)
- self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
- norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
- gaussian_noise=gaussian_noise, plus=plus)
- else:
- RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
- norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
- gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
- self.RDBs = nn.Sequential(*RDB_list)
-
- def forward(self, x):
- if hasattr(self, 'RDB1'):
- out = self.RDB1(x)
- out = self.RDB2(out)
- out = self.RDB3(out)
- else:
- out = self.RDBs(x)
- return out * 0.2 + x
-
-
-class ResidualDenseBlock_5C(nn.Module):
- """
- Residual Dense Block
- The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
- Modified options that can be used:
- - "Partial Convolution based Padding" arXiv:1811.11718
- - "Spectral normalization" arXiv:1802.05957
- - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
- {Rakotonirina} and A. {Rasoanaivo}
- """
-
- def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
- norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
- spectral_norm=False, gaussian_noise=False, plus=False):
- super(ResidualDenseBlock_5C, self).__init__()
-
- self.noise = GaussianNoise() if gaussian_noise else None
- self.conv1x1 = conv1x1(nf, gc) if plus else None
-
- self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
- norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
- spectral_norm=spectral_norm)
- self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
- norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
- spectral_norm=spectral_norm)
- self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
- norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
- spectral_norm=spectral_norm)
- self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
- norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
- spectral_norm=spectral_norm)
- if mode == 'CNA':
- last_act = None
- else:
- last_act = act_type
- self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
- norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
- spectral_norm=spectral_norm)
-
- def forward(self, x):
- x1 = self.conv1(x)
- x2 = self.conv2(torch.cat((x, x1), 1))
- if self.conv1x1:
- x2 = x2 + self.conv1x1(x)
- x3 = self.conv3(torch.cat((x, x1, x2), 1))
- x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
- if self.conv1x1:
- x4 = x4 + x2
- x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
- if self.noise:
- return self.noise(x5.mul(0.2) + x)
- else:
- return x5 * 0.2 + x
-
-
-####################
-# ESRGANplus
-####################
-
-class GaussianNoise(nn.Module):
- def __init__(self, sigma=0.1, is_relative_detach=False):
- super().__init__()
- self.sigma = sigma
- self.is_relative_detach = is_relative_detach
- self.noise = torch.tensor(0, dtype=torch.float)
-
- def forward(self, x):
- if self.training and self.sigma != 0:
- self.noise = self.noise.to(x.device)
- scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
- sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
- x = x + sampled_noise
- return x
-
-def conv1x1(in_planes, out_planes, stride=1):
- return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
-
-
-####################
-# SRVGGNetCompact
-####################
-
-class SRVGGNetCompact(nn.Module):
- """A compact VGG-style network structure for super-resolution.
- This class is copied from https://github.com/xinntao/Real-ESRGAN
- """
-
- def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
- super(SRVGGNetCompact, self).__init__()
- self.num_in_ch = num_in_ch
- self.num_out_ch = num_out_ch
- self.num_feat = num_feat
- self.num_conv = num_conv
- self.upscale = upscale
- self.act_type = act_type
-
- self.body = nn.ModuleList()
- # the first conv
- self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
- # the first activation
- if act_type == 'relu':
- activation = nn.ReLU(inplace=True)
- elif act_type == 'prelu':
- activation = nn.PReLU(num_parameters=num_feat)
- elif act_type == 'leakyrelu':
- activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
- self.body.append(activation)
-
- # the body structure
- for _ in range(num_conv):
- self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
- # activation
- if act_type == 'relu':
- activation = nn.ReLU(inplace=True)
- elif act_type == 'prelu':
- activation = nn.PReLU(num_parameters=num_feat)
- elif act_type == 'leakyrelu':
- activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
- self.body.append(activation)
-
- # the last conv
- self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
- # upsample
- self.upsampler = nn.PixelShuffle(upscale)
-
- def forward(self, x):
- out = x
- for i in range(0, len(self.body)):
- out = self.body[i](out)
-
- out = self.upsampler(out)
- # add the nearest upsampled image, so that the network learns the residual
- base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
- out += base
- return out
-
-
-####################
-# Upsampler
-####################
-
-class Upsample(nn.Module):
- r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
- The input data is assumed to be of the form
- `minibatch x channels x [optional depth] x [optional height] x width`.
- """
-
- def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
- super(Upsample, self).__init__()
- if isinstance(scale_factor, tuple):
- self.scale_factor = tuple(float(factor) for factor in scale_factor)
- else:
- self.scale_factor = float(scale_factor) if scale_factor else None
- self.mode = mode
- self.size = size
- self.align_corners = align_corners
-
- def forward(self, x):
- return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
-
- def extra_repr(self):
- if self.scale_factor is not None:
- info = f'scale_factor={self.scale_factor}'
- else:
- info = f'size={self.size}'
- info += f', mode={self.mode}'
- return info
-
-
-def pixel_unshuffle(x, scale):
- """ Pixel unshuffle.
- Args:
- x (Tensor): Input feature with shape (b, c, hh, hw).
- scale (int): Downsample ratio.
- Returns:
- Tensor: the pixel unshuffled feature.
- """
- b, c, hh, hw = x.size()
- out_channel = c * (scale**2)
- assert hh % scale == 0 and hw % scale == 0
- h = hh // scale
- w = hw // scale
- x_view = x.view(b, c, h, scale, w, scale)
- return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
-
-
-def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
- pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
- """
- Pixel shuffle layer
- (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
- Neural Network, CVPR17)
- """
- conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
- pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
- pixel_shuffle = nn.PixelShuffle(upscale_factor)
-
- n = norm(norm_type, out_nc) if norm_type else None
- a = act(act_type) if act_type else None
- return sequential(conv, pixel_shuffle, n, a)
-
-
-def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
- pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
- """ Upconv layer """
- upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
- upsample = Upsample(scale_factor=upscale_factor, mode=mode)
- conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
- pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
- return sequential(upsample, conv)
-
-
-
-
-
-
-
-
-####################
-# Basic blocks
-####################
-
-
-def make_layer(basic_block, num_basic_block, **kwarg):
- """Make layers by stacking the same blocks.
- Args:
- basic_block (nn.module): nn.module class for basic block. (block)
- num_basic_block (int): number of blocks. (n_layers)
- Returns:
- nn.Sequential: Stacked blocks in nn.Sequential.
- """
- layers = []
- for _ in range(num_basic_block):
- layers.append(basic_block(**kwarg))
- return nn.Sequential(*layers)
-
-
-def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
- """ activation helper """
- act_type = act_type.lower()
- if act_type == 'relu':
- layer = nn.ReLU(inplace)
- elif act_type in ('leakyrelu', 'lrelu'):
- layer = nn.LeakyReLU(neg_slope, inplace)
- elif act_type == 'prelu':
- layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
- elif act_type == 'tanh': # [-1, 1] range output
- layer = nn.Tanh()
- elif act_type == 'sigmoid': # [0, 1] range output
- layer = nn.Sigmoid()
- else:
- raise NotImplementedError(f'activation layer [{act_type}] is not found')
- return layer
-
-
-class Identity(nn.Module):
- def __init__(self, *kwargs):
- super(Identity, self).__init__()
-
- def forward(self, x, *kwargs):
- return x
-
-
-def norm(norm_type, nc):
- """ Return a normalization layer """
- norm_type = norm_type.lower()
- if norm_type == 'batch':
- layer = nn.BatchNorm2d(nc, affine=True)
- elif norm_type == 'instance':
- layer = nn.InstanceNorm2d(nc, affine=False)
- elif norm_type == 'none':
- def norm_layer(x): return Identity()
- else:
- raise NotImplementedError(f'normalization layer [{norm_type}] is not found')
- return layer
-
-
-def pad(pad_type, padding):
- """ padding layer helper """
- pad_type = pad_type.lower()
- if padding == 0:
- return None
- if pad_type == 'reflect':
- layer = nn.ReflectionPad2d(padding)
- elif pad_type == 'replicate':
- layer = nn.ReplicationPad2d(padding)
- elif pad_type == 'zero':
- layer = nn.ZeroPad2d(padding)
- else:
- raise NotImplementedError(f'padding layer [{pad_type}] is not implemented')
- return layer
-
-
-def get_valid_padding(kernel_size, dilation):
- kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
- padding = (kernel_size - 1) // 2
- return padding
-
-
-class ShortcutBlock(nn.Module):
- """ Elementwise sum the output of a submodule to its input """
- def __init__(self, submodule):
- super(ShortcutBlock, self).__init__()
- self.sub = submodule
-
- def forward(self, x):
- output = x + self.sub(x)
- return output
-
- def __repr__(self):
- return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
-
-
-def sequential(*args):
- """ Flatten Sequential. It unwraps nn.Sequential. """
- if len(args) == 1:
- if isinstance(args[0], OrderedDict):
- raise NotImplementedError('sequential does not support OrderedDict input.')
- return args[0] # No sequential is needed.
- modules = []
- for module in args:
- if isinstance(module, nn.Sequential):
- for submodule in module.children():
- modules.append(submodule)
- elif isinstance(module, nn.Module):
- modules.append(module)
- return nn.Sequential(*modules)
-
-
-def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
- pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
- spectral_norm=False):
- """ Conv layer with padding, normalization, activation """
- assert mode in ['CNA', 'NAC', 'CNAC'], f'Wrong conv mode [{mode}]'
- padding = get_valid_padding(kernel_size, dilation)
- p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
- padding = padding if pad_type == 'zero' else 0
-
- if convtype=='PartialConv2D':
- from torchvision.ops import PartialConv2d # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer
- c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
- dilation=dilation, bias=bias, groups=groups)
- elif convtype=='DeformConv2D':
- from torchvision.ops import DeformConv2d # not tested
- c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
- dilation=dilation, bias=bias, groups=groups)
- elif convtype=='Conv3D':
- c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
- dilation=dilation, bias=bias, groups=groups)
- else:
- c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
- dilation=dilation, bias=bias, groups=groups)
-
- if spectral_norm:
- c = nn.utils.spectral_norm(c)
-
- a = act(act_type) if act_type else None
- if 'CNA' in mode:
- n = norm(norm_type, out_nc) if norm_type else None
- return sequential(p, c, n, a)
- elif mode == 'NAC':
- if norm_type is None and act_type is not None:
- a = act(act_type, inplace=False)
- n = norm(norm_type, in_nc) if norm_type else None
- return sequential(n, a, p, c)
diff --git a/modules/extensions.py b/modules/extensions.py
index 1899cd52..04bda297 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -32,11 +32,12 @@ class ExtensionMetadata:
self.config = configparser.ConfigParser()
filepath = os.path.join(path, self.filename)
- if os.path.isfile(filepath):
- try:
- self.config.read(filepath)
- except Exception:
- errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True)
+ # `self.config.read()` will quietly swallow OSErrors (which FileNotFoundError is),
+ # so no need to check whether the file exists beforehand.
+ try:
+ self.config.read(filepath)
+ except Exception:
+ errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True)
self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name)
self.canonical_name = canonical_name.lower().strip()
@@ -223,13 +224,16 @@ def list_extensions():
# check for requirements
for extension in extensions:
+ if not extension.enabled:
+ continue
+
for req in extension.metadata.requires:
required_extension = loaded_extensions.get(req)
if required_extension is None:
errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False)
continue
- if not extension.enabled:
+ if not required_extension.enabled:
errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False)
continue
diff --git a/modules/extra_networks.py b/modules/extra_networks.py
index b9533677..04249dff 100644
--- a/modules/extra_networks.py
+++ b/modules/extra_networks.py
@@ -206,7 +206,7 @@ def parse_prompts(prompts):
return res, extra_data
-def get_user_metadata(filename):
+def get_user_metadata(filename, lister=None):
if filename is None:
return {}
@@ -215,7 +215,8 @@ def get_user_metadata(filename):
metadata = {}
try:
- if os.path.isfile(metadata_filename):
+ exists = lister.exists(metadata_filename) if lister else os.path.exists(metadata_filename)
+ if exists:
with open(metadata_filename, "r", encoding="utf8") as file:
metadata = json.load(file)
except Exception as e:
diff --git a/modules/face_restoration_utils.py b/modules/face_restoration_utils.py
new file mode 100644
index 00000000..1cbac236
--- /dev/null
+++ b/modules/face_restoration_utils.py
@@ -0,0 +1,180 @@
+from __future__ import annotations
+
+import logging
+import os
+from functools import cached_property
+from typing import TYPE_CHECKING, Callable
+
+import cv2
+import numpy as np
+import torch
+
+from modules import devices, errors, face_restoration, shared
+
+if TYPE_CHECKING:
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
+
+logger = logging.getLogger(__name__)
+
+
+def bgr_image_to_rgb_tensor(img: np.ndarray) -> torch.Tensor:
+ """Convert a BGR NumPy image in [0..1] range to a PyTorch RGB float32 tensor."""
+ assert img.shape[2] == 3, "image must be RGB"
+ if img.dtype == "float64":
+ img = img.astype("float32")
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ return torch.from_numpy(img.transpose(2, 0, 1)).float()
+
+
+def rgb_tensor_to_bgr_image(tensor: torch.Tensor, *, min_max=(0.0, 1.0)) -> np.ndarray:
+ """
+ Convert a PyTorch RGB tensor in range `min_max` to a BGR NumPy image in [0..1] range.
+ """
+ tensor = tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
+ tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])
+ assert tensor.dim() == 3, "tensor must be RGB"
+ img_np = tensor.numpy().transpose(1, 2, 0)
+ if img_np.shape[2] == 1: # gray image, no RGB/BGR required
+ return np.squeeze(img_np, axis=2)
+ return cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
+
+
+def create_face_helper(device) -> FaceRestoreHelper:
+ from facexlib.detection import retinaface
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
+ if hasattr(retinaface, 'device'):
+ retinaface.device = device
+ return FaceRestoreHelper(
+ upscale_factor=1,
+ face_size=512,
+ crop_ratio=(1, 1),
+ det_model='retinaface_resnet50',
+ save_ext='png',
+ use_parse=True,
+ device=device,
+ )
+
+
+def restore_with_face_helper(
+ np_image: np.ndarray,
+ face_helper: FaceRestoreHelper,
+ restore_face: Callable[[torch.Tensor], torch.Tensor],
+) -> np.ndarray:
+ """
+ Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image.
+
+ `restore_face` should take a cropped face image and return a restored face image.
+ """
+ from torchvision.transforms.functional import normalize
+ np_image = np_image[:, :, ::-1]
+ original_resolution = np_image.shape[0:2]
+
+ try:
+ logger.debug("Detecting faces...")
+ face_helper.clean_all()
+ face_helper.read_image(np_image)
+ face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
+ face_helper.align_warp_face()
+ logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces))
+ for cropped_face in face_helper.cropped_faces:
+ cropped_face_t = bgr_image_to_rgb_tensor(cropped_face / 255.0)
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
+
+ try:
+ with torch.no_grad():
+ cropped_face_t = restore_face(cropped_face_t)
+ devices.torch_gc()
+ except Exception:
+ errors.report('Failed face-restoration inference', exc_info=True)
+
+ restored_face = rgb_tensor_to_bgr_image(cropped_face_t, min_max=(-1, 1))
+ restored_face = (restored_face * 255.0).astype('uint8')
+ face_helper.add_restored_face(restored_face)
+
+ logger.debug("Merging restored faces into image")
+ face_helper.get_inverse_affine(None)
+ img = face_helper.paste_faces_to_input_image()
+ img = img[:, :, ::-1]
+ if original_resolution != img.shape[0:2]:
+ img = cv2.resize(
+ img,
+ (0, 0),
+ fx=original_resolution[1] / img.shape[1],
+ fy=original_resolution[0] / img.shape[0],
+ interpolation=cv2.INTER_LINEAR,
+ )
+ logger.debug("Face restoration complete")
+ finally:
+ face_helper.clean_all()
+ return img
+
+
+class CommonFaceRestoration(face_restoration.FaceRestoration):
+ net: torch.Module | None
+ model_url: str
+ model_download_name: str
+
+ def __init__(self, model_path: str):
+ super().__init__()
+ self.net = None
+ self.model_path = model_path
+ os.makedirs(model_path, exist_ok=True)
+
+ @cached_property
+ def face_helper(self) -> FaceRestoreHelper:
+ return create_face_helper(self.get_device())
+
+ def send_model_to(self, device):
+ if self.net:
+ logger.debug("Sending %s to %s", self.net, device)
+ self.net.to(device)
+ if self.face_helper:
+ logger.debug("Sending face helper to %s", device)
+ self.face_helper.face_det.to(device)
+ self.face_helper.face_parse.to(device)
+
+ def get_device(self):
+ raise NotImplementedError("get_device must be implemented by subclasses")
+
+ def load_net(self) -> torch.Module:
+ raise NotImplementedError("load_net must be implemented by subclasses")
+
+ def restore_with_helper(
+ self,
+ np_image: np.ndarray,
+ restore_face: Callable[[torch.Tensor], torch.Tensor],
+ ) -> np.ndarray:
+ try:
+ if self.net is None:
+ self.net = self.load_net()
+ except Exception:
+ logger.warning("Unable to load face-restoration model", exc_info=True)
+ return np_image
+
+ try:
+ self.send_model_to(self.get_device())
+ return restore_with_face_helper(np_image, self.face_helper, restore_face)
+ finally:
+ if shared.opts.face_restoration_unload:
+ self.send_model_to(devices.cpu)
+
+
+def patch_facexlib(dirname: str) -> None:
+ import facexlib.detection
+ import facexlib.parsing
+
+ det_facex_load_file_from_url = facexlib.detection.load_file_from_url
+ par_facex_load_file_from_url = facexlib.parsing.load_file_from_url
+
+ def update_kwargs(kwargs):
+ return dict(kwargs, save_dir=dirname, model_dir=None)
+
+ def facex_load_file_from_url(**kwargs):
+ return det_facex_load_file_from_url(**update_kwargs(kwargs))
+
+ def facex_load_file_from_url2(**kwargs):
+ return par_facex_load_file_from_url(**update_kwargs(kwargs))
+
+ facexlib.detection.load_file_from_url = facex_load_file_from_url
+ facexlib.parsing.load_file_from_url = facex_load_file_from_url2
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index 01d668ec..445b0409 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -1,125 +1,71 @@
+from __future__ import annotations
+
+import logging
import os
-import facexlib
-import gfpgan
+import torch
-import modules.face_restoration
-from modules import paths, shared, devices, modelloader, errors
+from modules import (
+ devices,
+ errors,
+ face_restoration,
+ face_restoration_utils,
+ modelloader,
+ shared,
+)
-model_dir = "GFPGAN"
-user_path = None
-model_path = os.path.join(paths.models_path, model_dir)
-model_file_path = None
+logger = logging.getLogger(__name__)
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
-have_gfpgan = False
-loaded_gfpgan_model = None
-
-
-def gfpgann():
- global loaded_gfpgan_model
- global model_path
- global model_file_path
- if loaded_gfpgan_model is not None:
- loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
- return loaded_gfpgan_model
-
- if gfpgan_constructor is None:
- return None
-
- models = modelloader.load_models(model_path, model_url, user_path, ext_filter=['.pth'])
-
- if len(models) == 1 and models[0].startswith("http"):
- model_file = models[0]
- elif len(models) != 0:
- gfp_models = []
- for item in models:
- if 'GFPGAN' in os.path.basename(item):
- gfp_models.append(item)
- latest_file = max(gfp_models, key=os.path.getctime)
- model_file = latest_file
- else:
- print("Unable to load gfpgan model!")
- return None
-
- if hasattr(facexlib.detection.retinaface, 'device'):
- facexlib.detection.retinaface.device = devices.device_gfpgan
- model_file_path = model_file
- model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
- loaded_gfpgan_model = model
-
- return model
-
-
-def send_model_to(model, device):
- model.gfpgan.to(device)
- model.face_helper.face_det.to(device)
- model.face_helper.face_parse.to(device)
+model_download_name = "GFPGANv1.4.pth"
+gfpgan_face_restorer: face_restoration.FaceRestoration | None = None
+
+
+class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
+ def name(self):
+ return "GFPGAN"
+
+ def get_device(self):
+ return devices.device_gfpgan
+
+ def load_net(self) -> torch.Module:
+ for model_path in modelloader.load_models(
+ model_path=self.model_path,
+ model_url=model_url,
+ command_path=self.model_path,
+ download_name=model_download_name,
+ ext_filter=['.pth'],
+ ):
+ if 'GFPGAN' in os.path.basename(model_path):
+ model = modelloader.load_spandrel_model(
+ model_path,
+ device=self.get_device(),
+ expected_architecture='GFPGAN',
+ ).model
+ model.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81
+ return model
+ raise ValueError("No GFPGAN model found")
+
+ def restore(self, np_image):
+ def restore_face(cropped_face_t):
+ assert self.net is not None
+ return self.net(cropped_face_t, return_rgb=False)[0]
+
+ return self.restore_with_helper(np_image, restore_face)
def gfpgan_fix_faces(np_image):
- model = gfpgann()
- if model is None:
- return np_image
-
- send_model_to(model, devices.device_gfpgan)
-
- np_image_bgr = np_image[:, :, ::-1]
- cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
- np_image = gfpgan_output_bgr[:, :, ::-1]
-
- model.face_helper.clean_all()
-
- if shared.opts.face_restoration_unload:
- send_model_to(model, devices.cpu)
-
+ if gfpgan_face_restorer:
+ return gfpgan_face_restorer.restore(np_image)
+ logger.warning("GFPGAN face restorer not set up")
return np_image
-gfpgan_constructor = None
+def setup_model(dirname: str) -> None:
+ global gfpgan_face_restorer
-
-def setup_model(dirname):
try:
- os.makedirs(model_path, exist_ok=True)
- from gfpgan import GFPGANer
- from facexlib import detection, parsing # noqa: F401
- global user_path
- global have_gfpgan
- global gfpgan_constructor
- global model_file_path
-
- facexlib_path = model_path
-
- if dirname is not None:
- facexlib_path = dirname
-
- load_file_from_url_orig = gfpgan.utils.load_file_from_url
- facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
- facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
-
- def my_load_file_from_url(**kwargs):
- return load_file_from_url_orig(**dict(kwargs, model_dir=model_file_path))
-
- def facex_load_file_from_url(**kwargs):
- return facex_load_file_from_url_orig(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
-
- def facex_load_file_from_url2(**kwargs):
- return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
-
- gfpgan.utils.load_file_from_url = my_load_file_from_url
- facexlib.detection.load_file_from_url = facex_load_file_from_url
- facexlib.parsing.load_file_from_url = facex_load_file_from_url2
- user_path = dirname
- have_gfpgan = True
- gfpgan_constructor = GFPGANer
-
- class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
- def name(self):
- return "GFPGAN"
-
- def restore(self, np_image):
- return gfpgan_fix_faces(np_image)
-
- shared.face_restorers.append(FaceRestorerGFPGAN())
+ face_restoration_utils.patch_facexlib(dirname)
+ gfpgan_face_restorer = FaceRestorerGFPGAN(model_path=dirname)
+ shared.face_restorers.append(gfpgan_face_restorer)
except Exception:
errors.report("Error setting up GFPGAN", exc_info=True)
diff --git a/modules/hat_model.py b/modules/hat_model.py
new file mode 100644
index 00000000..7f2abb41
--- /dev/null
+++ b/modules/hat_model.py
@@ -0,0 +1,43 @@
+import os
+import sys
+
+from modules import modelloader, devices
+from modules.shared import opts
+from modules.upscaler import Upscaler, UpscalerData
+from modules.upscaler_utils import upscale_with_model
+
+
+class UpscalerHAT(Upscaler):
+ def __init__(self, dirname):
+ self.name = "HAT"
+ self.scalers = []
+ self.user_path = dirname
+ super().__init__()
+ for file in self.find_models(ext_filter=[".pt", ".pth"]):
+ name = modelloader.friendly_name(file)
+ scale = 4 # TODO: scale might not be 4, but we can't know without loading the model
+ scaler_data = UpscalerData(name, file, upscaler=self, scale=scale)
+ self.scalers.append(scaler_data)
+
+ def do_upscale(self, img, selected_model):
+ try:
+ model = self.load_model(selected_model)
+ except Exception as e:
+ print(f"Unable to load HAT model {selected_model}: {e}", file=sys.stderr)
+ return img
+ model.to(devices.device_esrgan) # TODO: should probably be device_hat
+ return upscale_with_model(
+ model,
+ img,
+ tile_size=opts.ESRGAN_tile, # TODO: should probably be HAT_tile
+ tile_overlap=opts.ESRGAN_tile_overlap, # TODO: should probably be HAT_tile_overlap
+ )
+
+ def load_model(self, path: str):
+ if not os.path.isfile(path):
+ raise FileNotFoundError(f"Model file {path} not found")
+ return modelloader.load_spandrel_model(
+ path,
+ device=devices.device_esrgan, # TODO: should probably be device_hat
+ expected_architecture='HAT',
+ )
diff --git a/modules/images.py b/modules/images.py
index daf4eebe..b6f2358c 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -61,12 +61,17 @@ def image_grid(imgs, batch_size=1, rows=None):
return grid
-Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
+class Grid(namedtuple("_Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])):
+ @property
+ def tile_count(self) -> int:
+ """
+ The total number of tiles in the grid.
+ """
+ return sum(len(row[2]) for row in self.tiles)
-def split_grid(image, tile_w=512, tile_h=512, overlap=64):
- w = image.width
- h = image.height
+def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid:
+ w, h = image.size
non_overlap_width = tile_w - overlap
non_overlap_height = tile_h - overlap
@@ -316,7 +321,7 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None):
return res
-invalid_filename_chars = '<>:"/\\|?*\n\r\t'
+invalid_filename_chars = '#<>:"/\\|?*\n\r\t'
invalid_filename_prefix = ' '
invalid_filename_postfix = ' .'
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
@@ -791,3 +796,4 @@ def flatten(img, bgcolor):
img = background
return img.convert('RGB')
+
diff --git a/modules/img2img.py b/modules/img2img.py
index c583290a..f81405df 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -7,7 +7,7 @@ from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageErr
import gradio as gr
from modules import images as imgutil
-from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
+from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
from modules.sd_models import get_closet_checkpoint_match
@@ -51,7 +51,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
if state.skipped:
state.skipped = False
- if state.interrupted:
+ if state.interrupted or state.stopping_generation:
break
try:
@@ -222,9 +222,6 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
if shared.opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
- if mask:
- p.extra_generation_params["Mask blur"] = mask_blur
-
with closing(p):
if is_batch:
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
diff --git a/modules/generation_parameters_copypaste.py b/modules/infotext_utils.py
index 4efe53e0..1049c6c3 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/infotext_utils.py
@@ -4,12 +4,15 @@ import io
import json
import os
import re
+import sys
import gradio as gr
from modules.paths import data_path
-from modules import shared, ui_tempdir, script_callbacks, processing
+from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions
from PIL import Image
+sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name
+
re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
@@ -28,6 +31,19 @@ class ParamBinding:
self.paste_field_names = paste_field_names or []
+class PasteField(tuple):
+ def __new__(cls, component, target, *, api=None):
+ return super().__new__(cls, (component, target))
+
+ def __init__(self, component, target, *, api=None):
+ super().__init__()
+
+ self.api = api
+ self.component = component
+ self.label = target if isinstance(target, str) else None
+ self.function = target if callable(target) else None
+
+
paste_fields: dict[str, dict] = {}
registered_param_bindings: list[ParamBinding] = []
@@ -84,6 +100,12 @@ def image_from_url_text(filedata):
def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
+
+ if fields:
+ for i in range(len(fields)):
+ if not isinstance(fields[i], PasteField):
+ fields[i] = PasteField(*fields[i])
+
paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
# backwards compatibility for existing extensions
@@ -208,7 +230,7 @@ def restore_old_hires_fix_params(res):
res['Hires resize-2'] = height
-def parse_generation_parameters(x: str):
+def parse_generation_parameters(x: str, skip_fields: list[str] | None = None):
"""parses generation parameters string, the one you see in text field under the picture in UI:
```
girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
@@ -218,6 +240,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
returns a dict with field values
"""
+ if skip_fields is None:
+ skip_fields = shared.opts.infotext_skip_pasting
res = {}
@@ -290,6 +314,18 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
if "Hires negative prompt" not in res:
res["Hires negative prompt"] = ""
+ if "Mask mode" not in res:
+ res["Mask mode"] = "Inpaint masked"
+
+ if "Masked content" not in res:
+ res["Masked content"] = 'original'
+
+ if "Inpaint area" not in res:
+ res["Inpaint area"] = "Whole picture"
+
+ if "Masked area padding" not in res:
+ res["Masked area padding"] = 32
+
restore_old_hires_fix_params(res)
# Missing RNG means the default was set, which is GPU RNG
@@ -314,8 +350,16 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
if "VAE Decoder" not in res:
res["VAE Decoder"] = "Full"
- skip = set(shared.opts.infotext_skip_pasting)
- res = {k: v for k, v in res.items() if k not in skip}
+ if "FP8 weight" not in res:
+ res["FP8 weight"] = "Disable"
+
+ if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable":
+ res["Cache FP16 weight for LoRA"] = False
+
+ infotext_versions.backcompat(res)
+
+ for key in skip_fields:
+ res.pop(key, None)
return res
@@ -365,13 +409,57 @@ def create_override_settings_dict(text_pairs):
return res
+def get_override_settings(params, *, skip_fields=None):
+ """Returns a list of settings overrides from the infotext parameters dictionary.
+
+ This function checks the `params` dictionary for any keys that correspond to settings in `shared.opts` and returns
+ a list of tuples containing the parameter name, setting name, and new value cast to correct type.
+
+ It checks for conditions before adding an override:
+ - ignores settings that match the current value
+ - ignores parameter keys present in skip_fields argument.
+
+ Example input:
+ {"Clip skip": "2"}
+
+ Example output:
+ [("Clip skip", "CLIP_stop_at_last_layers", 2)]
+ """
+
+ res = []
+
+ mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
+ for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
+ if param_name in (skip_fields or {}):
+ continue
+
+ v = params.get(param_name, None)
+ if v is None:
+ continue
+
+ if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
+ continue
+
+ v = shared.opts.cast_value(setting_name, v)
+ current_value = getattr(shared.opts, setting_name, None)
+
+ if v == current_value:
+ continue
+
+ res.append((param_name, setting_name, v))
+
+ return res
+
+
def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
def paste_func(prompt):
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
filename = os.path.join(data_path, "params.txt")
- if os.path.exists(filename):
+ try:
with open(filename, "r", encoding="utf8") as file:
prompt = file.read()
+ except OSError:
+ pass
params = parse_generation_parameters(prompt)
script_callbacks.infotext_pasted_callback(prompt, params)
@@ -393,6 +481,8 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
if valtype == bool and v == "False":
val = False
+ elif valtype == int:
+ val = float(v)
else:
val = valtype(v)
@@ -406,29 +496,9 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
already_handled_fields = {key: 1 for _, key in paste_fields}
def paste_settings(params):
- vals = {}
-
- mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
- for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
- if param_name in already_handled_fields:
- continue
-
- v = params.get(param_name, None)
- if v is None:
- continue
-
- if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
- continue
-
- v = shared.opts.cast_value(setting_name, v)
- current_value = getattr(shared.opts, setting_name, None)
-
- if v == current_value:
- continue
-
- vals[param_name] = v
+ vals = get_override_settings(params, skip_fields=already_handled_fields)
- vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
+ vals_pairs = [f"{infotext_text}: {value}" for infotext_text, setting_name, value in vals]
return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs))
diff --git a/modules/infotext_versions.py b/modules/infotext_versions.py
new file mode 100644
index 00000000..23b45c3f
--- /dev/null
+++ b/modules/infotext_versions.py
@@ -0,0 +1,42 @@
+from modules import shared
+from packaging import version
+import re
+
+
+v160 = version.parse("1.6.0")
+v170_tsnr = version.parse("v1.7.0-225")
+
+
+def parse_version(text):
+ if text is None:
+ return None
+
+ m = re.match(r'([^-]+-[^-]+)-.*', text)
+ if m:
+ text = m.group(1)
+
+ try:
+ return version.parse(text)
+ except Exception:
+ return None
+
+
+def backcompat(d):
+ """Checks infotext Version field, and enables backwards compatibility options according to it."""
+
+ if not shared.opts.auto_backcompat:
+ return
+
+ ver = parse_version(d.get("Version"))
+ if ver is None:
+ return
+
+ if ver < v160 and '[' in d.get('Prompt', ''):
+ d["Old prompt editing timelines"] = True
+
+ if ver < v160 and d.get('Sampler', '') in ('DDIM', 'PLMS'):
+ d["Pad conds v0"] = True
+
+ if ver < v170_tsnr:
+ d["Downcast alphas_cumprod"] = True
+
diff --git a/modules/initialize.py b/modules/initialize.py
index 3285cc3c..cc34fd6f 100644
--- a/modules/initialize.py
+++ b/modules/initialize.py
@@ -1,5 +1,6 @@
import importlib
import logging
+import os
import sys
import warnings
from threading import Thread
@@ -18,6 +19,7 @@ def imports():
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
+ os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False')
import gradio # noqa: F401
startup_timer.record("import gradio")
@@ -54,9 +56,6 @@ def initialize():
initialize_util.configure_sigint_handler()
initialize_util.configure_opts_onchange()
- from modules import modelloader
- modelloader.cleanup_models()
-
from modules import sd_models
sd_models.setup_model()
startup_timer.record("setup SD model")
diff --git a/modules/initialize_util.py b/modules/initialize_util.py
index 2e9b6d89..b6767138 100644
--- a/modules/initialize_util.py
+++ b/modules/initialize_util.py
@@ -177,6 +177,8 @@ def configure_opts_onchange():
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
+ shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
+ shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False)
startup_timer.record("opts onchange")
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 3045560d..c93e7aa8 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -10,14 +10,14 @@ import torch.hub
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
-from modules import devices, paths, shared, lowvram, modelloader, errors
+from modules import devices, paths, shared, lowvram, modelloader, errors, torch_utils
blip_image_eval_size = 384
clip_model_name = 'ViT-L/14'
Category = namedtuple("Category", ["name", "topn", "items"])
-re_topn = re.compile(r"\.top(\d+)\.")
+re_topn = re.compile(r"\.top(\d+)$")
def category_types():
return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
@@ -131,7 +131,7 @@ class InterrogateModels:
self.clip_model = self.clip_model.to(devices.device_interrogate)
- self.dtype = next(self.clip_model.parameters()).dtype
+ self.dtype = torch_utils.get_param(self.clip_model).dtype
def send_clip_to_ram(self):
if not shared.opts.interrogate_keep_models_in_memory:
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index 29506f24..3ff4576a 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -27,8 +27,7 @@ dir_repos = "repositories"
# Whether to default to printing command output
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
-if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
- os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
+os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False')
def check_python_version():
@@ -245,11 +244,13 @@ def list_extensions(settings_file):
settings = {}
try:
- if os.path.isfile(settings_file):
- with open(settings_file, "r", encoding="utf8") as file:
- settings = json.load(file)
+ with open(settings_file, "r", encoding="utf8") as file:
+ settings = json.load(file)
+ except FileNotFoundError:
+ pass
except Exception:
- errors.report("Could not load settings", exc_info=True)
+ errors.report(f'\nCould not load settings\nThe config file "{settings_file}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True)
+ os.replace(settings_file, os.path.join(script_path, "tmp", "config.json"))
disabled_extensions = set(settings.get('disabled_extensions', []))
disable_all_extensions = settings.get('disable_all_extensions', 'none')
@@ -314,8 +315,8 @@ def requirements_met(requirements_file):
def prepare_environment():
- torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
- torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
+ torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121")
+ torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.2 torchvision==0.16.2 --extra-index-url {torch_index_url}")
if args.use_ipex:
if platform.system() == "Windows":
# The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main
@@ -338,20 +339,20 @@ def prepare_environment():
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
- xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
+ xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.23.post1')
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
+ assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git")
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
- codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
+ assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917")
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
- codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
try:
@@ -405,18 +406,14 @@ def prepare_environment():
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
+ git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash)
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
- git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
startup_timer.record("clone repositores")
- if not is_installed("lpips"):
- run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
- startup_timer.record("install CodeFormer requirements")
-
if not os.path.isfile(requirements_file):
requirements_file = os.path.join(script_path, requirements_file)
diff --git a/modules/logging_config.py b/modules/logging_config.py
index 79269875..8e31d8c9 100644
--- a/modules/logging_config.py
+++ b/modules/logging_config.py
@@ -1,41 +1,58 @@
-import os
import logging
+import os
try:
- from tqdm.auto import tqdm
+ from tqdm import tqdm
+
class TqdmLoggingHandler(logging.Handler):
- def __init__(self, level=logging.INFO):
- super().__init__(level)
+ def __init__(self, fallback_handler: logging.Handler):
+ super().__init__()
+ self.fallback_handler = fallback_handler
def emit(self, record):
try:
- msg = self.format(record)
- tqdm.write(msg)
- self.flush()
+ # If there are active tqdm progress bars,
+ # attempt to not interfere with them.
+ if tqdm._instances:
+ tqdm.write(self.format(record))
+ else:
+ self.fallback_handler.emit(record)
except Exception:
- self.handleError(record)
+ self.fallback_handler.emit(record)
- TQDM_IMPORTED = True
except ImportError:
- # tqdm does not exist before first launch
- # I will import once the UI finishes seting up the enviroment and reloads.
- TQDM_IMPORTED = False
+ TqdmLoggingHandler = None
+
def setup_logging(loglevel):
if loglevel is None:
loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL")
- loghandlers = []
+ if not loglevel:
+ return
+
+ if logging.root.handlers:
+ # Already configured, do not interfere
+ return
+
+ formatter = logging.Formatter(
+ '%(asctime)s %(levelname)s [%(name)s] %(message)s',
+ '%Y-%m-%d %H:%M:%S',
+ )
+
+ if os.environ.get("SD_WEBUI_RICH_LOG"):
+ from rich.logging import RichHandler
+ handler = RichHandler()
+ else:
+ handler = logging.StreamHandler()
+ handler.setFormatter(formatter)
+
+ if TqdmLoggingHandler:
+ handler = TqdmLoggingHandler(handler)
- if TQDM_IMPORTED:
- loghandlers.append(TqdmLoggingHandler())
+ handler.setFormatter(formatter)
- if loglevel:
- log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
- logging.basicConfig(
- level=log_level,
- format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
- datefmt='%Y-%m-%d %H:%M:%S',
- handlers=loghandlers
- )
+ log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
+ logging.root.setLevel(log_level)
+ logging.root.addHandler(handler)
diff --git a/modules/masking.py b/modules/masking.py
index be9f84c7..29a39452 100644
--- a/modules/masking.py
+++ b/modules/masking.py
@@ -3,40 +3,15 @@ from PIL import Image, ImageFilter, ImageOps
def get_crop_region(mask, pad=0):
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
- For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
-
- h, w = mask.shape
-
- crop_left = 0
- for i in range(w):
- if not (mask[:, i] == 0).all():
- break
- crop_left += 1
-
- crop_right = 0
- for i in reversed(range(w)):
- if not (mask[:, i] == 0).all():
- break
- crop_right += 1
-
- crop_top = 0
- for i in range(h):
- if not (mask[i] == 0).all():
- break
- crop_top += 1
-
- crop_bottom = 0
- for i in reversed(range(h)):
- if not (mask[i] == 0).all():
- break
- crop_bottom += 1
-
- return (
- int(max(crop_left-pad, 0)),
- int(max(crop_top-pad, 0)),
- int(min(w - crop_right + pad, w)),
- int(min(h - crop_bottom + pad, h))
- )
+ For example, if a user has painted the top-right part of a 512x512 image, the result may be (256, 0, 512, 256)"""
+ mask_img = mask if isinstance(mask, Image.Image) else Image.fromarray(mask)
+ box = mask_img.getbbox()
+ if box:
+ x1, y1, x2, y2 = box
+ else: # when no box is found
+ x1, y1 = mask_img.size
+ x2 = y2 = 0
+ return max(x1 - pad, 0), max(y1 - pad, 0), min(x2 + pad, mask_img.size[0]), min(y2 + pad, mask_img.size[1])
def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height):
diff --git a/modules/modelloader.py b/modules/modelloader.py
index 098bcb79..e100bb24 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -1,13 +1,20 @@
from __future__ import annotations
-import os
-import shutil
import importlib
+import logging
+import os
+from typing import TYPE_CHECKING
from urllib.parse import urlparse
+import torch
+
from modules import shared
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
-from modules.paths import script_path, models_path
+
+if TYPE_CHECKING:
+ import spandrel
+
+logger = logging.getLogger(__name__)
def load_file_from_url(
@@ -90,54 +97,6 @@ def friendly_name(file: str):
return model_name
-def cleanup_models():
- # This code could probably be more efficient if we used a tuple list or something to store the src/destinations
- # and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
- # somehow auto-register and just do these things...
- root_path = script_path
- src_path = models_path
- dest_path = os.path.join(models_path, "Stable-diffusion")
- move_files(src_path, dest_path, ".ckpt")
- move_files(src_path, dest_path, ".safetensors")
- src_path = os.path.join(root_path, "ESRGAN")
- dest_path = os.path.join(models_path, "ESRGAN")
- move_files(src_path, dest_path)
- src_path = os.path.join(models_path, "BSRGAN")
- dest_path = os.path.join(models_path, "ESRGAN")
- move_files(src_path, dest_path, ".pth")
- src_path = os.path.join(root_path, "gfpgan")
- dest_path = os.path.join(models_path, "GFPGAN")
- move_files(src_path, dest_path)
- src_path = os.path.join(root_path, "SwinIR")
- dest_path = os.path.join(models_path, "SwinIR")
- move_files(src_path, dest_path)
- src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
- dest_path = os.path.join(models_path, "LDSR")
- move_files(src_path, dest_path)
-
-
-def move_files(src_path: str, dest_path: str, ext_filter: str = None):
- try:
- os.makedirs(dest_path, exist_ok=True)
- if os.path.exists(src_path):
- for file in os.listdir(src_path):
- fullpath = os.path.join(src_path, file)
- if os.path.isfile(fullpath):
- if ext_filter is not None:
- if ext_filter not in file:
- continue
- print(f"Moving {file} from {src_path} to {dest_path}.")
- try:
- shutil.move(fullpath, dest_path)
- except Exception:
- pass
- if len(os.listdir(src_path)) == 0:
- print(f"Removing empty folder: {src_path}")
- shutil.rmtree(src_path, True)
- except Exception:
- pass
-
-
def load_upscalers():
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
# so we'll try to import any _model.py files before looking in __subclasses__
@@ -177,3 +136,34 @@ def load_upscalers():
# Special case for UpscalerNone keeps it at the beginning of the list.
key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
)
+
+
+def load_spandrel_model(
+ path: str | os.PathLike,
+ *,
+ device: str | torch.device | None,
+ prefer_half: bool = False,
+ dtype: str | torch.dtype | None = None,
+ expected_architecture: str | None = None,
+) -> spandrel.ModelDescriptor:
+ import spandrel
+ model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path))
+ if expected_architecture and model_descriptor.architecture != expected_architecture:
+ logger.warning(
+ f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})",
+ )
+ half = False
+ if prefer_half:
+ if model_descriptor.supports_half:
+ model_descriptor.model.half()
+ half = True
+ else:
+ logger.info("Model %s does not support half precision, ignoring --half", path)
+ if dtype:
+ model_descriptor.model.to(dtype=dtype)
+ model_descriptor.model.eval()
+ logger.debug(
+ "Loaded %s from %s (device=%s, half=%s, dtype=%s)",
+ model_descriptor, path, device, half, dtype,
+ )
+ return model_descriptor
diff --git a/modules/options.py b/modules/options.py
index 4fead690..35ccade2 100644
--- a/modules/options.py
+++ b/modules/options.py
@@ -1,3 +1,4 @@
+import os
import json
import sys
from dataclasses import dataclass
@@ -6,6 +7,7 @@ import gradio as gr
from modules import errors
from modules.shared_cmd_options import cmd_opts
+from modules.paths_internal import script_path
class OptionInfo:
@@ -91,18 +93,35 @@ class Options:
if self.data is not None:
if key in self.data or key in self.data_labels:
+
+ # Check that settings aren't globally frozen
assert not cmd_opts.freeze_settings, "changing settings is disabled"
+ # Get the info related to the setting being changed
info = self.data_labels.get(key, None)
if info.do_not_save:
return
+ # Restrict component arguments
comp_args = info.component_args if info else None
if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
- raise RuntimeError(f"not possible to set {key} because it is restricted")
+ raise RuntimeError(f"not possible to set '{key}' because it is restricted")
+
+ # Check that this section isn't frozen
+ if cmd_opts.freeze_settings_in_sections is not None:
+ frozen_sections = list(map(str.strip, cmd_opts.freeze_settings_in_sections.split(','))) # Trim whitespace from section names
+ section_key = info.section[0]
+ section_name = info.section[1]
+ assert section_key not in frozen_sections, f"not possible to set '{key}' because settings in section '{section_name}' ({section_key}) are frozen with --freeze-settings-in-sections"
+
+ # Check that this section of the settings isn't frozen
+ if cmd_opts.freeze_specific_settings is not None:
+ frozen_keys = list(map(str.strip, cmd_opts.freeze_specific_settings.split(','))) # Trim whitespace from setting keys
+ assert key not in frozen_keys, f"not possible to set '{key}' because this setting is frozen with --freeze-specific-settings"
+ # Check shorthand option which disables editing options in "saving-paths"
if cmd_opts.hide_ui_dir_config and key in self.restricted_opts:
- raise RuntimeError(f"not possible to set {key} because it is restricted")
+ raise RuntimeError(f"not possible to set '{key}' because it is restricted with --hide_ui_dir_config")
self.data[key] = value
return
@@ -176,9 +195,15 @@ class Options:
return type_x == type_y
def load(self, filename):
- with open(filename, "r", encoding="utf8") as file:
- self.data = json.load(file)
-
+ try:
+ with open(filename, "r", encoding="utf8") as file:
+ self.data = json.load(file)
+ except FileNotFoundError:
+ self.data = {}
+ except Exception:
+ errors.report(f'\nCould not load settings\nThe config file "{filename}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True)
+ os.replace(filename, os.path.join(script_path, "tmp", "config.json"))
+ self.data = {}
# 1.6.0 VAE defaults
if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None:
self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default')
diff --git a/modules/paths.py b/modules/paths.py
index 187b9496..03064651 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -38,7 +38,6 @@ mute_sdxl_imports()
path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion', []),
(os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
- (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
]
diff --git a/modules/paths_internal.py b/modules/paths_internal.py
index 89131a54..b86ecd7f 100644
--- a/modules/paths_internal.py
+++ b/modules/paths_internal.py
@@ -28,5 +28,6 @@ models_path = os.path.join(data_path, "models")
extensions_dir = os.path.join(data_path, "extensions")
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
config_states_dir = os.path.join(script_path, "config_states")
+default_output_dir = os.path.join(data_path, "output")
roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf')
diff --git a/modules/postprocessing.py b/modules/postprocessing.py
index 0c59fad4..f1488232 100644
--- a/modules/postprocessing.py
+++ b/modules/postprocessing.py
@@ -2,7 +2,7 @@ import os
from PIL import Image
-from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste
+from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, infotext_utils
from modules.shared import opts
@@ -62,8 +62,6 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
else:
image_data = image_placeholder
- shared.state.assign_current_image(image_data)
-
parameters, existing_pnginfo = images.read_info_from_image(image_data)
if parameters:
existing_pnginfo["parameters"] = parameters
@@ -86,22 +84,25 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
basename = ''
forced_filename = None
- infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])
+ infotext = ", ".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in pp.info.items() if v is not None])
if opts.enable_pnginfo:
pp.image.info = existing_pnginfo
pp.image.info["postprocessing"] = infotext
+ shared.state.assign_current_image(pp.image)
+
if save_output:
fullfn, _ = images.save_image(pp.image, path=outpath, basename=basename, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=forced_filename, suffix=suffix)
if pp.caption:
caption_filename = os.path.splitext(fullfn)[0] + ".txt"
- if os.path.isfile(caption_filename):
+ existing_caption = ""
+ try:
with open(caption_filename, encoding="utf8") as file:
existing_caption = file.read().strip()
- else:
- existing_caption = ""
+ except FileNotFoundError:
+ pass
action = shared.opts.postprocessing_existing_caption_action
if action == 'Prepend' and existing_caption:
diff --git a/modules/processing.py b/modules/processing.py
index 6f01c95f..52f00bfb 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -16,7 +16,7 @@ from skimage import exposure
from typing import Any
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
from modules.rng import slerp # noqa: F401
from modules.sd_hijack import model_hijack
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
@@ -62,18 +62,22 @@ def apply_color_correction(correction, original_image):
return image.convert('RGB')
-def apply_overlay(image, paste_loc, index, overlays):
- if overlays is None or index >= len(overlays):
- return image
+def uncrop(image, dest_size, paste_loc):
+ x, y, w, h = paste_loc
+ base_image = Image.new('RGBA', dest_size)
+ image = images.resize_image(1, image, w, h)
+ base_image.paste(image, (x, y))
+ image = base_image
+
+ return image
- overlay = overlays[index]
+
+def apply_overlay(image, paste_loc, overlay):
+ if overlay is None:
+ return image
if paste_loc is not None:
- x, y, w, h = paste_loc
- base_image = Image.new('RGBA', (overlay.width, overlay.height))
- image = images.resize_image(1, image, w, h)
- base_image.paste(image, (x, y))
- image = base_image
+ image = uncrop(image, (overlay.width, overlay.height), paste_loc)
image = image.convert('RGBA')
image.alpha_composite(overlay)
@@ -81,9 +85,12 @@ def apply_overlay(image, paste_loc, index, overlays):
return image
-def create_binary_mask(image):
+def create_binary_mask(image, round=True):
if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):
- image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
+ if round:
+ image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
+ else:
+ image = image.split()[-1].convert("L")
else:
image = image.convert('L')
return image
@@ -106,6 +113,21 @@ def txt2img_image_conditioning(sd_model, x, width, height):
return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
else:
+ sd = sd_model.model.state_dict()
+ diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
+ if diffusion_model_input is not None:
+ if diffusion_model_input.shape[1] == 9:
+ # The "masked-image" in this case will just be all 0.5 since the entire image is masked.
+ image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
+ image_conditioning = images_tensor_to_samples(image_conditioning,
+ approximation_indexes.get(opts.sd_vae_encode_method))
+
+ # Add the fake full 1s mask to the first dimension.
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
+ image_conditioning = image_conditioning.to(x.dtype)
+
+ return image_conditioning
+
# Dummy zero conditioning if we're not using inpainting or unclip models.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
@@ -157,6 +179,7 @@ class StableDiffusionProcessing:
token_merging_ratio = 0
token_merging_ratio_hr = 0
disable_extra_networks: bool = False
+ firstpass_image: Image = None
scripts_value: scripts.ScriptRunner = field(default=None, init=False)
script_args_value: list = field(default=None, init=False)
@@ -308,7 +331,7 @@ class StableDiffusionProcessing:
c_adm = torch.cat((c_adm, noise_level_emb), 1)
return c_adm
- def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
+ def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
self.is_using_inpainting_conditioning = True
# Handle the different mask inputs
@@ -320,8 +343,10 @@ class StableDiffusionProcessing:
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
- # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
- conditioning_mask = torch.round(conditioning_mask)
+ if round_image_mask:
+ # Caller is requesting a discretized mask as input, so we round to either 1.0 or 0.0
+ conditioning_mask = torch.round(conditioning_mask)
+
else:
conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
@@ -345,7 +370,7 @@ class StableDiffusionProcessing:
return image_conditioning
- def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
+ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
source_image = devices.cond_cast_float(source_image)
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
@@ -357,11 +382,17 @@ class StableDiffusionProcessing:
return self.edit_image_conditioning(source_image)
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
- return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
+ return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask, round_image_mask=round_image_mask)
if self.sampler.conditioning_key == "crossattn-adm":
return self.unclip_image_conditioning(source_image)
+ sd = self.sampler.model_wrap.inner_model.model.state_dict()
+ diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
+ if diffusion_model_input is not None:
+ if diffusion_model_input.shape[1] == 9:
+ return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
+
# Dummy zero conditioning if we're not using inpainting or depth model.
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
@@ -422,6 +453,8 @@ class StableDiffusionProcessing:
opts.sdxl_crop_top,
self.width,
self.height,
+ opts.fp8_storage,
+ opts.cache_fp16_weight,
)
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
@@ -596,20 +629,33 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
sample = decode_first_stage(model, batch[i:i + 1])[0]
if check_for_nans:
+
try:
devices.test_for_nans(sample, "vae")
except devices.NansException as e:
- if devices.dtype_vae == torch.float32 or not shared.opts.auto_vae_precision:
+ if shared.opts.auto_vae_precision_bfloat16:
+ autofix_dtype = torch.bfloat16
+ autofix_dtype_text = "bfloat16"
+ autofix_dtype_setting = "Automatically convert VAE to bfloat16"
+ autofix_dtype_comment = ""
+ elif shared.opts.auto_vae_precision:
+ autofix_dtype = torch.float32
+ autofix_dtype_text = "32-bit float"
+ autofix_dtype_setting = "Automatically revert VAE to 32-bit floats"
+ autofix_dtype_comment = "\nTo always start with 32-bit VAE, use --no-half-vae commandline flag."
+ else:
+ raise e
+
+ if devices.dtype_vae == autofix_dtype:
raise e
errors.print_error_explanation(
"A tensor with all NaNs was produced in VAE.\n"
- "Web UI will now convert VAE into 32-bit float and retry.\n"
- "To disable this behavior, disable the 'Automatically revert VAE to 32-bit floats' setting.\n"
- "To always start with 32-bit VAE, use --no-half-vae commandline flag."
+ f"Web UI will now convert VAE into {autofix_dtype_text} and retry.\n"
+ f"To disable this behavior, disable the '{autofix_dtype_setting}' setting.{autofix_dtype_comment}"
)
- devices.dtype_vae = torch.float32
+ devices.dtype_vae = autofix_dtype
model.first_stage_model.to(devices.dtype_vae)
batch = batch.to(devices.dtype_vae)
@@ -679,12 +725,14 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Size": f"{p.width}x{p.height}",
"Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
"Model": p.sd_model_name if opts.add_model_name_to_info else None,
+ "FP8 weight": opts.fp8_storage if devices.fp8 else None,
+ "Cache FP16 weight for LoRA": opts.cache_fp16_weight if devices.fp8 else None,
"VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None,
"VAE": p.sd_vae_name if opts.add_vae_name_to_info else None,
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
- "Denoising strength": getattr(p, 'denoising_strength', None),
+ "Denoising strength": p.extra_generation_params.get("Denoising strength"),
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
"Clip skip": None if clip_skip <= 1 else clip_skip,
"ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
@@ -699,7 +747,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"User": p.user if opts.add_user_name_to_info else None,
}
- generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
+ generation_params_text = ", ".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in generation_params.items() if v is not None])
prompt_text = p.main_prompt if use_main_prompt else all_prompts[index]
negative_prompt_text = f"\nNegative prompt: {p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]}" if all_negative_prompts[index] else ""
@@ -818,7 +866,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if state.skipped:
state.skipped = False
- if state.interrupted:
+ if state.interrupted or state.stopping_generation:
break
sd_models.reload_model_weights() # model can be changed for example by refiner
@@ -864,9 +912,42 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
+ def rescale_zero_terminal_snr_abar(alphas_cumprod):
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= (alphas_bar_sqrt_T)
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+ alphas_bar[-1] = 4.8973451890853435e-08
+ return alphas_bar
+
+ if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'):
+ p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device)
+
+ if opts.use_downcasted_alpha_bar:
+ p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
+ p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device)
+ if opts.sd_noise_schedule == "Zero Terminal SNR":
+ p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
+ p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device)
+
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
+ if p.scripts is not None:
+ ps = scripts.PostSampleArgs(samples_ddim)
+ p.scripts.post_sample(p, ps)
+ samples_ddim = ps.samples
+
if getattr(samples_ddim, 'already_decoded', False):
x_samples_ddim = samples_ddim
else:
@@ -922,13 +1003,42 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
pp = scripts.PostprocessImageArgs(image)
p.scripts.postprocess_image(p, pp)
image = pp.image
+
+ mask_for_overlay = getattr(p, "mask_for_overlay", None)
+
+ if not shared.opts.overlay_inpaint:
+ overlay_image = None
+ elif getattr(p, "overlay_images", None) is not None and i < len(p.overlay_images):
+ overlay_image = p.overlay_images[i]
+ else:
+ overlay_image = None
+
+ if p.scripts is not None:
+ ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image)
+ p.scripts.postprocess_maskoverlay(p, ppmo)
+ mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image
+
if p.color_corrections is not None and i < len(p.color_corrections):
if save_samples and opts.save_images_before_color_correction:
- image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
+ image_without_cc = apply_overlay(image, p.paste_to, overlay_image)
images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)
- image = apply_overlay(image, p.paste_to, i, p.overlay_images)
+ # If the intention is to show the output from the model
+ # that is being composited over the original image,
+ # we need to keep the original image around
+ # and use it in the composite step.
+ original_denoised_image = image.copy()
+
+ if p.paste_to is not None:
+ original_denoised_image = uncrop(original_denoised_image, (overlay_image.width, overlay_image.height), p.paste_to)
+
+ image = apply_overlay(image, p.paste_to, overlay_image)
+
+ if p.scripts is not None:
+ pp = scripts.PostprocessImageArgs(image)
+ p.scripts.postprocess_image_after_composite(p, pp)
+ image = pp.image
if save_samples:
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
@@ -938,16 +1048,17 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.enable_pnginfo:
image.info["parameters"] = text
output_images.append(image)
- if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
+
+ if mask_for_overlay is not None:
if opts.return_mask or opts.save_mask:
- image_mask = p.mask_for_overlay.convert('RGB')
+ image_mask = mask_for_overlay.convert('RGB')
if save_samples and opts.save_mask:
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
if opts.return_mask:
output_images.append(image_mask)
if opts.return_mask_composite or opts.save_mask_composite:
- image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
+ image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
if save_samples and opts.save_mask_composite:
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
if opts.return_mask_composite:
@@ -1025,6 +1136,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
hr_sampler_name: str = None
hr_prompt: str = ''
hr_negative_prompt: str = ''
+ force_task_id: str = None
cached_hr_uc = [None, None]
cached_hr_c = [None, None]
@@ -1097,7 +1209,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
- if self.hr_checkpoint_name:
+ self.extra_generation_params["Denoising strength"] = self.denoising_strength
+
+ if self.hr_checkpoint_name and self.hr_checkpoint_name != 'Use same checkpoint':
self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
if self.hr_checkpoint_info is None:
@@ -1124,8 +1238,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if not state.processing_has_refined_job_count:
if state.job_count == -1:
state.job_count = self.n_iter
-
- shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count)
+ if getattr(self, 'txt2img_upscale', False):
+ total_steps = (self.hr_second_pass_steps or self.steps) * state.job_count
+ else:
+ total_steps = (self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count
+ shared.total_tqdm.updateTotal(total_steps)
state.job_count = state.job_count * 2
state.processing_has_refined_job_count = True
@@ -1138,18 +1255,45 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
- x = self.rng.next()
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
- del x
+ if self.firstpass_image is not None and self.enable_hr:
+ # here we don't need to generate image, we just take self.firstpass_image and prepare it for hires fix
- if not self.enable_hr:
- return samples
- devices.torch_gc()
+ if self.latent_scale_mode is None:
+ image = np.array(self.firstpass_image).astype(np.float32) / 255.0 * 2.0 - 1.0
+ image = np.moveaxis(image, 2, 0)
+
+ samples = None
+ decoded_samples = torch.asarray(np.expand_dims(image, 0))
+
+ else:
+ image = np.array(self.firstpass_image).astype(np.float32) / 255.0
+ image = np.moveaxis(image, 2, 0)
+ image = torch.from_numpy(np.expand_dims(image, axis=0))
+ image = image.to(shared.device, dtype=devices.dtype_vae)
+
+ if opts.sd_vae_encode_method != 'Full':
+ self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
+
+ samples = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
+ decoded_samples = None
+ devices.torch_gc()
- if self.latent_scale_mode is None:
- decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
else:
- decoded_samples = None
+ # here we generate an image normally
+
+ x = self.rng.next()
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
+ del x
+
+ if not self.enable_hr:
+ return samples
+
+ devices.torch_gc()
+
+ if self.latent_scale_mode is None:
+ decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
+ else:
+ decoded_samples = None
with sd_models.SkipWritingToConfig():
sd_models.reload_model_weights(info=self.hr_checkpoint_info)
@@ -1351,12 +1495,14 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
mask_blur_x: int = 4
mask_blur_y: int = 4
mask_blur: int = None
+ mask_round: bool = True
inpainting_fill: int = 0
inpaint_full_res: bool = True
inpaint_full_res_padding: int = 0
inpainting_mask_invert: int = 0
initial_noise_multiplier: float = None
latent_mask: Image = None
+ force_task_id: str = None
image_mask: Any = field(default=None, init=False)
@@ -1386,6 +1532,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.mask_blur_y = value
def init(self, all_prompts, all_seeds, all_subseeds):
+ self.extra_generation_params["Denoising strength"] = self.denoising_strength
+
self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
@@ -1396,10 +1544,11 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if image_mask is not None:
# image_mask is passed in as RGBA by Gradio to support alpha masks,
# but we still want to support binary masks.
- image_mask = create_binary_mask(image_mask)
+ image_mask = create_binary_mask(image_mask, round=self.mask_round)
if self.inpainting_mask_invert:
image_mask = ImageOps.invert(image_mask)
+ self.extra_generation_params["Mask mode"] = "Inpaint not masked"
if self.mask_blur_x > 0:
np_mask = np.array(image_mask)
@@ -1413,16 +1562,22 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
image_mask = Image.fromarray(np_mask)
+ if self.mask_blur_x > 0 or self.mask_blur_y > 0:
+ self.extra_generation_params["Mask blur"] = self.mask_blur
+
if self.inpaint_full_res:
self.mask_for_overlay = image_mask
mask = image_mask.convert('L')
- crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
+ crop_region = masking.get_crop_region(mask, self.inpaint_full_res_padding)
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
x1, y1, x2, y2 = crop_region
mask = mask.crop(crop_region)
image_mask = images.resize_image(2, mask, self.width, self.height)
self.paste_to = (x1, y1, x2-x1, y2-y1)
+
+ self.extra_generation_params["Inpaint area"] = "Only masked"
+ self.extra_generation_params["Masked area padding"] = self.inpaint_full_res_padding
else:
image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
np_mask = np.array(image_mask)
@@ -1442,7 +1597,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
# Save init image
if opts.save_init_img:
self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
- images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)
+ images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False, existing_info=img.info)
image = images.flatten(img, opts.img2img_background_color)
@@ -1464,6 +1619,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.inpainting_fill != 1:
image = masking.fill(image, latent_mask)
+ if self.inpainting_fill == 0:
+ self.extra_generation_params["Masked content"] = 'fill'
+
if add_color_corrections:
self.color_corrections.append(setup_color_correction(image))
@@ -1503,7 +1661,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
latmask = latmask[0]
- latmask = np.around(latmask)
+ if self.mask_round:
+ latmask = np.around(latmask)
latmask = np.tile(latmask[None], (4, 1, 1))
self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
@@ -1512,10 +1671,13 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
# this needs to be fixed to be done in sample() using actual seeds for batches
if self.inpainting_fill == 2:
self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
+ self.extra_generation_params["Masked content"] = 'latent noise'
+
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
+ self.extra_generation_params["Masked content"] = 'latent nothing'
- self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask)
+ self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.mask_round)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
x = self.rng.next()
@@ -1527,7 +1689,14 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
if self.mask is not None:
- samples = samples * self.nmask + self.init_latent * self.mask
+ blended_samples = samples * self.nmask + self.init_latent * self.mask
+
+ if self.scripts is not None:
+ mba = scripts.MaskBlendArgs(samples, self.nmask, self.init_latent, self.mask, blended_samples)
+ self.scripts.on_mask_blend(self, mba)
+ blended_samples = mba.blended_latent
+
+ samples = blended_samples
del x
devices.torch_gc()
diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py
index 29ccb78f..ba33d8a4 100644
--- a/modules/processing_scripts/refiner.py
+++ b/modules/processing_scripts/refiner.py
@@ -1,6 +1,7 @@
import gradio as gr
from modules import scripts, sd_models
+from modules.infotext_utils import PasteField
from modules.ui_common import create_refresh_button
from modules.ui_components import InputAccordion
@@ -31,9 +32,9 @@ class ScriptRefiner(scripts.ScriptBuiltinUI):
return None if info is None else info.title
self.infotext_fields = [
- (enable_refiner, lambda d: 'Refiner' in d),
- (refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner'))),
- (refiner_switch_at, 'Refiner switch at'),
+ PasteField(enable_refiner, lambda d: 'Refiner' in d),
+ PasteField(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner')), api="refiner_checkpoint"),
+ PasteField(refiner_switch_at, 'Refiner switch at', api="refiner_switch_at"),
]
return enable_refiner, refiner_checkpoint, refiner_switch_at
diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py
index dc9c2da5..7a4c0159 100644
--- a/modules/processing_scripts/seed.py
+++ b/modules/processing_scripts/seed.py
@@ -3,8 +3,10 @@ import json
import gradio as gr
from modules import scripts, ui, errors
+from modules.infotext_utils import PasteField
from modules.shared import cmd_opts
from modules.ui_components import ToolButton
+from modules import infotext_utils
class ScriptSeed(scripts.ScriptBuiltinUI):
@@ -51,12 +53,12 @@ class ScriptSeed(scripts.ScriptBuiltinUI):
seed_checkbox.change(lambda x: gr.update(visible=x), show_progress=False, inputs=[seed_checkbox], outputs=[seed_extras])
self.infotext_fields = [
- (self.seed, "Seed"),
- (seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
- (subseed, "Variation seed"),
- (subseed_strength, "Variation seed strength"),
- (seed_resize_from_w, "Seed resize from-1"),
- (seed_resize_from_h, "Seed resize from-2"),
+ PasteField(self.seed, "Seed", api="seed"),
+ PasteField(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
+ PasteField(subseed, "Variation seed", api="subseed"),
+ PasteField(subseed_strength, "Variation seed strength", api="subseed_strength"),
+ PasteField(seed_resize_from_w, "Seed resize from-1", api="seed_resize_from_h"),
+ PasteField(seed_resize_from_h, "Seed resize from-2", api="seed_resize_from_w"),
]
self.on_after_component(lambda x: connect_reuse_seed(self.seed, reuse_seed, x.component, False), elem_id=f'generation_info_{self.tabname}')
@@ -76,7 +78,6 @@ class ScriptSeed(scripts.ScriptBuiltinUI):
p.seed_resize_from_h = seed_resize_from_h
-
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, is_subseed):
""" Connects a 'reuse (sub)seed' button's click event so that it copies last used
(sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
@@ -84,21 +85,14 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
def copy_seed(gen_info_string: str, index):
res = -1
-
try:
gen_info = json.loads(gen_info_string)
- index -= gen_info.get('index_of_first_image', 0)
-
- if is_subseed and gen_info.get('subseed_strength', 0) > 0:
- all_subseeds = gen_info.get('all_subseeds', [-1])
- res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
- else:
- all_seeds = gen_info.get('all_seeds', [-1])
- res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
-
- except json.decoder.JSONDecodeError:
+ infotext = gen_info.get('infotexts')[index]
+ gen_parameters = infotext_utils.parse_generation_parameters(infotext, [])
+ res = int(gen_parameters.get('Variation seed' if is_subseed else 'Seed', -1))
+ except Exception:
if gen_info_string:
- errors.report(f"Error parsing JSON generation info: {gen_info_string}")
+ errors.report(f"Error retrieving seed from generation info: {gen_info_string}", exc_info=True)
return [res, gr.update()]
diff --git a/modules/progress.py b/modules/progress.py
index 69921de7..85255e82 100644
--- a/modules/progress.py
+++ b/modules/progress.py
@@ -8,10 +8,13 @@ from pydantic import BaseModel, Field
from modules.shared import opts
import modules.shared as shared
-
+from collections import OrderedDict
+import string
+import random
+from typing import List
current_task = None
-pending_tasks = {}
+pending_tasks = OrderedDict()
finished_tasks = []
recorded_results = []
recorded_results_limit = 2
@@ -34,6 +37,11 @@ def finish_task(id_task):
if len(finished_tasks) > 16:
finished_tasks.pop(0)
+def create_task_id(task_type):
+ N = 7
+ res = ''.join(random.choices(string.ascii_uppercase +
+ string.digits, k=N))
+ return f"task({task_type}-{res})"
def record_results(id_task, res):
recorded_results.append((id_task, res))
@@ -44,6 +52,9 @@ def record_results(id_task, res):
def add_task_to_queue(id_job):
pending_tasks[id_job] = time.time()
+class PendingTasksResponse(BaseModel):
+ size: int = Field(title="Pending task size")
+ tasks: List[str] = Field(title="Pending task ids")
class ProgressRequest(BaseModel):
id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
@@ -63,9 +74,16 @@ class ProgressResponse(BaseModel):
def setup_progress_api(app):
+ app.add_api_route("/internal/pending-tasks", get_pending_tasks, methods=["GET"])
return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
+def get_pending_tasks():
+ pending_tasks_ids = list(pending_tasks)
+ pending_len = len(pending_tasks_ids)
+ return PendingTasksResponse(size=pending_len, tasks=pending_tasks_ids)
+
+
def progressapi(req: ProgressRequest):
active = req.id_task == current_task
queued = req.id_task in pending_tasks
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index 02841c30..ff9d8ac0 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -1,12 +1,9 @@
import os
-import numpy as np
-from PIL import Image
-from realesrgan import RealESRGANer
-
-from modules.upscaler import Upscaler, UpscalerData
-from modules.shared import cmd_opts, opts
from modules import modelloader, errors
+from modules.shared import cmd_opts, opts
+from modules.upscaler import Upscaler, UpscalerData
+from modules.upscaler_utils import upscale_with_model
class UpscalerRealESRGAN(Upscaler):
@@ -14,29 +11,20 @@ class UpscalerRealESRGAN(Upscaler):
self.name = "RealESRGAN"
self.user_path = path
super().__init__()
- try:
- from basicsr.archs.rrdbnet_arch import RRDBNet # noqa: F401
- from realesrgan import RealESRGANer # noqa: F401
- from realesrgan.archs.srvgg_arch import SRVGGNetCompact # noqa: F401
- self.enable = True
- self.scalers = []
- scalers = self.load_models(path)
+ self.enable = True
+ self.scalers = []
+ scalers = get_realesrgan_models(self)
- local_model_paths = self.find_models(ext_filter=[".pth"])
- for scaler in scalers:
- if scaler.local_data_path.startswith("http"):
- filename = modelloader.friendly_name(scaler.local_data_path)
- local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")]
- if local_model_candidates:
- scaler.local_data_path = local_model_candidates[0]
+ local_model_paths = self.find_models(ext_filter=[".pth"])
+ for scaler in scalers:
+ if scaler.local_data_path.startswith("http"):
+ filename = modelloader.friendly_name(scaler.local_data_path)
+ local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")]
+ if local_model_candidates:
+ scaler.local_data_path = local_model_candidates[0]
- if scaler.name in opts.realesrgan_enabled_models:
- self.scalers.append(scaler)
-
- except Exception:
- errors.report("Error importing Real-ESRGAN", exc_info=True)
- self.enable = False
- self.scalers = []
+ if scaler.name in opts.realesrgan_enabled_models:
+ self.scalers.append(scaler)
def do_upscale(self, img, path):
if not self.enable:
@@ -48,20 +36,19 @@ class UpscalerRealESRGAN(Upscaler):
errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
return img
- upsampler = RealESRGANer(
- scale=info.scale,
- model_path=info.local_data_path,
- model=info.model(),
- half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
- tile=opts.ESRGAN_tile,
- tile_pad=opts.ESRGAN_tile_overlap,
+ model_descriptor = modelloader.load_spandrel_model(
+ info.local_data_path,
device=self.device,
+ prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
+ expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel
+ )
+ return upscale_with_model(
+ model_descriptor,
+ img,
+ tile_size=opts.ESRGAN_tile,
+ tile_overlap=opts.ESRGAN_tile_overlap,
+ # TODO: `outscale`?
)
-
- upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
-
- image = Image.fromarray(upsampled)
- return image
def load_model(self, path):
for scaler in self.scalers:
@@ -76,58 +63,43 @@ class UpscalerRealESRGAN(Upscaler):
return scaler
raise ValueError(f"Unable to find model info: {path}")
- def load_models(self, _):
- return get_realesrgan_models(self)
-
-def get_realesrgan_models(scaler):
- try:
- from basicsr.archs.rrdbnet_arch import RRDBNet
- from realesrgan.archs.srvgg_arch import SRVGGNetCompact
- models = [
- UpscalerData(
- name="R-ESRGAN General 4xV3",
- path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
- scale=4,
- upscaler=scaler,
- model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
- ),
- UpscalerData(
- name="R-ESRGAN General WDN 4xV3",
- path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
- scale=4,
- upscaler=scaler,
- model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
- ),
- UpscalerData(
- name="R-ESRGAN AnimeVideo",
- path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
- scale=4,
- upscaler=scaler,
- model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
- ),
- UpscalerData(
- name="R-ESRGAN 4x+",
- path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
- scale=4,
- upscaler=scaler,
- model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
- ),
- UpscalerData(
- name="R-ESRGAN 4x+ Anime6B",
- path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
- scale=4,
- upscaler=scaler,
- model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
- ),
- UpscalerData(
- name="R-ESRGAN 2x+",
- path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
- scale=2,
- upscaler=scaler,
- model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
- ),
- ]
- return models
- except Exception:
- errors.report("Error making Real-ESRGAN models list", exc_info=True)
+def get_realesrgan_models(scaler: UpscalerRealESRGAN):
+ return [
+ UpscalerData(
+ name="R-ESRGAN General 4xV3",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
+ scale=4,
+ upscaler=scaler,
+ ),
+ UpscalerData(
+ name="R-ESRGAN General WDN 4xV3",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
+ scale=4,
+ upscaler=scaler,
+ ),
+ UpscalerData(
+ name="R-ESRGAN AnimeVideo",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
+ scale=4,
+ upscaler=scaler,
+ ),
+ UpscalerData(
+ name="R-ESRGAN 4x+",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
+ scale=4,
+ upscaler=scaler,
+ ),
+ UpscalerData(
+ name="R-ESRGAN 4x+ Anime6B",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
+ scale=4,
+ upscaler=scaler,
+ ),
+ UpscalerData(
+ name="R-ESRGAN 2x+",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
+ scale=2,
+ upscaler=scaler,
+ ),
+ ]
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 9ed7ad21..a54cb3eb 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -41,7 +41,7 @@ class ExtraNoiseParams:
class CFGDenoiserParams:
- def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
+ def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond, denoiser=None):
self.x = x
"""Latent image representation in the process of being denoised"""
@@ -63,6 +63,9 @@ class CFGDenoiserParams:
self.text_uncond = text_uncond
""" Encoder hidden states of text conditioning from negative prompt"""
+ self.denoiser = denoiser
+ """Current CFGDenoiser object with processing parameters"""
+
class CFGDenoisedParams:
def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
diff --git a/modules/scripts.py b/modules/scripts.py
index 7f9454eb..94690a22 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -11,11 +11,31 @@ from modules import shared, paths, script_callbacks, extensions, script_loading,
AlwaysVisible = object()
+class MaskBlendArgs:
+ def __init__(self, current_latent, nmask, init_latent, mask, blended_latent, denoiser=None, sigma=None):
+ self.current_latent = current_latent
+ self.nmask = nmask
+ self.init_latent = init_latent
+ self.mask = mask
+ self.blended_latent = blended_latent
+
+ self.denoiser = denoiser
+ self.is_final_blend = denoiser is None
+ self.sigma = sigma
+
+class PostSampleArgs:
+ def __init__(self, samples):
+ self.samples = samples
class PostprocessImageArgs:
def __init__(self, image):
self.image = image
+class PostProcessMaskOverlayArgs:
+ def __init__(self, index, mask_for_overlay, overlay_image):
+ self.index = index
+ self.mask_for_overlay = mask_for_overlay
+ self.overlay_image = overlay_image
class PostprocessBatchListArgs:
def __init__(self, images):
@@ -71,6 +91,9 @@ class Script:
setup_for_ui_only = False
"""If true, the script setup will only be run in Gradio UI, not in API"""
+ controls = None
+ """A list of controls retured by the ui()."""
+
def title(self):
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
@@ -206,6 +229,25 @@ class Script:
pass
+ def on_mask_blend(self, p, mba: MaskBlendArgs, *args):
+ """
+ Called in inpainting mode when the original content is blended with the inpainted content.
+ This is called at every step in the denoising process and once at the end.
+ If is_final_blend is true, this is called for the final blending stage.
+ Otherwise, denoiser and sigma are defined and may be used to inform the procedure.
+ """
+
+ pass
+
+ def post_sample(self, p, ps: PostSampleArgs, *args):
+ """
+ Called after the samples have been generated,
+ but before they have been decoded by the VAE, if applicable.
+ Check getattr(samples, 'already_decoded', False) to test if the images are decoded.
+ """
+
+ pass
+
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
"""
Called for every image after it has been generated.
@@ -213,6 +255,22 @@ class Script:
pass
+ def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs, *args):
+ """
+ Called for every image after it has been generated.
+ """
+
+ pass
+
+ def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs, *args):
+ """
+ Called for every image after it has been generated.
+ Same as postprocess_image but after inpaint_full_res composite
+ So that it operates on the full image instead of the inpaint_full_res crop region.
+ """
+
+ pass
+
def postprocess(self, p, processed, *args):
"""
This function is called after processing ends for AlwaysVisible scripts.
@@ -520,7 +578,12 @@ class ScriptRunner:
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
for script_data in auto_processing_scripts + scripts_data:
- script = script_data.script_class()
+ try:
+ script = script_data.script_class()
+ except Exception:
+ errors.report(f"Error # failed to initialize Script {script_data.module}: ", exc_info=True)
+ continue
+
script.filename = script_data.path
script.is_txt2img = not is_img2img
script.is_img2img = is_img2img
@@ -573,6 +636,7 @@ class ScriptRunner:
import modules.api.models as api_models
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
+ script.controls = controls
if controls is None:
return
@@ -645,6 +709,8 @@ class ScriptRunner:
self.setup_ui_for_section(None, self.selectable_scripts)
def select_script(script_index):
+ if script_index is None:
+ script_index = 0
selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
@@ -688,7 +754,7 @@ class ScriptRunner:
def run(self, p, *args):
script_index = args[0]
- if script_index == 0:
+ if script_index == 0 or script_index is None:
return None
script = self.selectable_scripts[script_index-1]
@@ -767,6 +833,22 @@ class ScriptRunner:
except Exception:
errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
+ def post_sample(self, p, ps: PostSampleArgs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.post_sample(p, ps, *script_args)
+ except Exception:
+ errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
+
+ def on_mask_blend(self, p, mba: MaskBlendArgs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.on_mask_blend(p, mba, *script_args)
+ except Exception:
+ errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
+
def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts:
try:
@@ -775,6 +857,22 @@ class ScriptRunner:
except Exception:
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
+ def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.postprocess_maskoverlay(p, ppmo, *script_args)
+ except Exception:
+ errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
+
+ def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.postprocess_image_after_composite(p, pp, *script_args)
+ except Exception:
+ errors.report(f"Error running postprocess_image_after_composite: {script.filename}", exc_info=True)
+
def before_component(self, component, **kwargs):
for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
try:
@@ -841,6 +939,35 @@ class ScriptRunner:
except Exception:
errors.report(f"Error running setup: {script.filename}", exc_info=True)
+ def set_named_arg(self, args, script_name, arg_elem_id, value, fuzzy=False):
+ """Locate an arg of a specific script in script_args and set its value
+ Args:
+ args: all script args of process p, p.script_args
+ script_name: the name target script name to
+ arg_elem_id: the elem_id of the target arg
+ value: the value to set
+ fuzzy: if True, arg_elem_id can be a substring of the control.elem_id else exact match
+ Returns:
+ Updated script args
+ when script_name in not found or arg_elem_id is not found in script controls, raise RuntimeError
+ """
+ script = next((x for x in self.scripts if x.name == script_name), None)
+ if script is None:
+ raise RuntimeError(f"script {script_name} not found")
+
+ for i, control in enumerate(script.controls):
+ if arg_elem_id in control.elem_id if fuzzy else arg_elem_id == control.elem_id:
+ index = script.args_from + i
+
+ if isinstance(args, tuple):
+ return args[:index] + (value,) + args[index + 1:]
+ elif isinstance(args, list):
+ args[index] = value
+ return args
+ else:
+ raise RuntimeError(f"args is not a list or tuple, but {type(args)}")
+ raise RuntimeError(f"arg_elem_id {arg_elem_id} not found in script {script_name}")
+
scripts_txt2img: ScriptRunner = None
scripts_img2img: ScriptRunner = None
diff --git a/modules/sd_hijack_utils.py b/modules/sd_hijack_utils.py
index f8684475..79bf6e46 100644
--- a/modules/sd_hijack_utils.py
+++ b/modules/sd_hijack_utils.py
@@ -11,10 +11,14 @@ class CondFunc:
break
except ImportError:
pass
- for attr_name in func_path[i:-1]:
- resolved_obj = getattr(resolved_obj, attr_name)
- orig_func = getattr(resolved_obj, func_path[-1])
- setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
+ try:
+ for attr_name in func_path[i:-1]:
+ resolved_obj = getattr(resolved_obj, attr_name)
+ orig_func = getattr(resolved_obj, func_path[-1])
+ setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
+ except AttributeError:
+ print(f"Warning: Failed to resolve {orig_func} for CondFunc hijack")
+ pass
self.__init__(orig_func, sub_func, cond_func)
return lambda *args, **kwargs: self(*args, **kwargs)
def __init__(self, orig_func, sub_func, cond_func):
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 9355f1e1..2c045771 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -348,10 +348,28 @@ class SkipWritingToConfig:
SkipWritingToConfig.skip = self.previous
+def check_fp8(model):
+ if model is None:
+ return None
+ if devices.get_optimal_device_name() == "mps":
+ enable_fp8 = False
+ elif shared.opts.fp8_storage == "Enable":
+ enable_fp8 = True
+ elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL":
+ enable_fp8 = True
+ else:
+ enable_fp8 = False
+ return enable_fp8
+
+
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
+ if devices.fp8:
+ # prevent model to load state dict in fp8
+ model.half()
+
if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
@@ -383,6 +401,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if shared.cmd_opts.no_half:
model.float()
+ model.alphas_cumprod_original = model.alphas_cumprod
devices.dtype_unet = torch.float32
timer.record("apply float()")
else:
@@ -396,7 +415,11 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if shared.cmd_opts.upcast_sampling and depth_model:
model.depth_model = None
+ alphas_cumprod = model.alphas_cumprod
+ model.alphas_cumprod = None
model.half()
+ model.alphas_cumprod = alphas_cumprod
+ model.alphas_cumprod_original = alphas_cumprod
model.first_stage_model = vae
if depth_model:
model.depth_model = depth_model
@@ -404,6 +427,28 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
devices.dtype_unet = torch.float16
timer.record("apply half()")
+ for module in model.modules():
+ if hasattr(module, 'fp16_weight'):
+ del module.fp16_weight
+ if hasattr(module, 'fp16_bias'):
+ del module.fp16_bias
+
+ if check_fp8(model):
+ devices.fp8 = True
+ first_stage = model.first_stage_model
+ model.first_stage_model = None
+ for module in model.modules():
+ if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
+ if shared.opts.cache_fp16_weight:
+ module.fp16_weight = module.weight.data.clone().cpu().half()
+ if module.bias is not None:
+ module.fp16_bias = module.bias.data.clone().cpu().half()
+ module.to(torch.float8_e4m3fn)
+ model.first_stage_model = first_stage
+ timer.record("apply fp8")
+ else:
+ devices.fp8 = False
+
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae)
@@ -651,6 +696,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
else:
weight_dtype_conversion = {
'first_stage_model': None,
+ 'alphas_cumprod': None,
'': torch.float16,
}
@@ -746,7 +792,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
return None
-def reload_model_weights(sd_model=None, info=None):
+def reload_model_weights(sd_model=None, info=None, forced_reload=False):
checkpoint_info = info or select_checkpoint()
timer = Timer()
@@ -758,11 +804,14 @@ def reload_model_weights(sd_model=None, info=None):
current_checkpoint_info = None
else:
current_checkpoint_info = sd_model.sd_checkpoint_info
- if sd_model.sd_model_checkpoint == checkpoint_info.filename:
+ if check_fp8(sd_model) != devices.fp8:
+ # load from state dict again to prevent extra numerical errors
+ forced_reload = True
+ elif sd_model.sd_model_checkpoint == checkpoint_info.filename and not forced_reload:
return sd_model
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
- if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
+ if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
return sd_model
if sd_model is not None:
@@ -793,13 +842,13 @@ def reload_model_weights(sd_model=None, info=None):
sd_hijack.model_hijack.hijack(sd_model)
timer.record("hijack")
- script_callbacks.model_loaded_callback(sd_model)
- timer.record("script callbacks")
-
if not sd_model.lowvram:
sd_model.to(devices.device)
timer.record("move model to device")
+ script_callbacks.model_loaded_callback(sd_model)
+ timer.record("script callbacks")
+
print(f"Weights loaded in {timer.summary()}.")
model_data.set_sd_model(sd_model)
diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py
index deab2f6e..b38137eb 100644
--- a/modules/sd_models_config.py
+++ b/modules/sd_models_config.py
@@ -15,6 +15,7 @@ config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
+config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml")
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
@@ -71,7 +72,10 @@ def guess_model_config_from_state_dict(sd, filename):
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
- return config_sdxl
+ if diffusion_model_input.shape[1] == 9:
+ return config_sdxl_inpainting
+ else:
+ return config_sdxl
if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
return config_sdxl_refiner
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py
index 01123321..0de17af3 100644
--- a/modules/sd_models_xl.py
+++ b/modules/sd_models_xl.py
@@ -6,6 +6,7 @@ import sgm.models.diffusion
import sgm.modules.diffusionmodules.denoiser_scaling
import sgm.modules.diffusionmodules.discretizer
from modules import devices, shared, prompt_parser
+from modules import torch_utils
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
@@ -34,6 +35,12 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
+ sd = self.model.state_dict()
+ diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
+ if diffusion_model_input is not None:
+ if diffusion_model_input.shape[1] == 9:
+ x = torch.cat([x] + cond['c_concat'], dim=1)
+
return self.model(x, t, cond)
@@ -84,7 +91,7 @@ sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt
def extend_sdxl(model):
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
- dtype = next(model.model.diffusion_model.parameters()).dtype
+ dtype = torch_utils.get_param(model.model.diffusion_model).dtype
model.model.diffusion_model.dtype = dtype
model.model.conditioning_key = 'crossattn'
model.cond_stage_key = 'txt'
@@ -93,7 +100,7 @@ def extend_sdxl(model):
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
- model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
+ model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32)
model.conditioner.wrapped = torch.nn.Module()
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 45faae62..a58528a0 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -1,4 +1,4 @@
-from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, shared
+from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared
# imports for functions that previously were here and are used by other modules
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
@@ -6,6 +6,7 @@ from modules.sd_samplers_common import samples_to_image_grid, sample_to_image #
all_samplers = [
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
*sd_samplers_timesteps.samplers_data_timesteps,
+ *sd_samplers_lcm.samplers_data_lcm,
]
all_samplers_map = {x.name: x for x in all_samplers}
diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py
index b8101d38..941dff4b 100644
--- a/modules/sd_samplers_cfg_denoiser.py
+++ b/modules/sd_samplers_cfg_denoiser.py
@@ -53,9 +53,13 @@ class CFGDenoiser(torch.nn.Module):
self.step = 0
self.image_cfg_scale = None
self.padded_cond_uncond = False
+ self.padded_cond_uncond_v0 = False
self.sampler = sampler
self.model_wrap = None
self.p = None
+
+ # NOTE: masking before denoising can cause the original latents to be oversmoothed
+ # as the original latents do not have noise
self.mask_before_denoising = False
@property
@@ -88,6 +92,62 @@ class CFGDenoiser(torch.nn.Module):
self.sampler.sampler_extra_args['cond'] = c
self.sampler.sampler_extra_args['uncond'] = uc
+ def pad_cond_uncond(self, cond, uncond):
+ empty = shared.sd_model.cond_stage_model_empty_prompt
+ num_repeats = (cond.shape[1] - uncond.shape[1]) // empty.shape[1]
+
+ if num_repeats < 0:
+ cond = pad_cond(cond, -num_repeats, empty)
+ self.padded_cond_uncond = True
+ elif num_repeats > 0:
+ uncond = pad_cond(uncond, num_repeats, empty)
+ self.padded_cond_uncond = True
+
+ return cond, uncond
+
+ def pad_cond_uncond_v0(self, cond, uncond):
+ """
+ Pads the 'uncond' tensor to match the shape of the 'cond' tensor.
+
+ If 'uncond' is a dictionary, it is assumed that the 'crossattn' key holds the tensor to be padded.
+ If 'uncond' is a tensor, it is padded directly.
+
+ If the number of columns in 'uncond' is less than the number of columns in 'cond', the last column of 'uncond'
+ is repeated to match the number of columns in 'cond'.
+
+ If the number of columns in 'uncond' is greater than the number of columns in 'cond', 'uncond' is truncated
+ to match the number of columns in 'cond'.
+
+ Args:
+ cond (torch.Tensor or DictWithShape): The condition tensor to match the shape of 'uncond'.
+ uncond (torch.Tensor or DictWithShape): The tensor to be padded, or a dictionary containing the tensor to be padded.
+
+ Returns:
+ tuple: A tuple containing the 'cond' tensor and the padded 'uncond' tensor.
+
+ Note:
+ This is the padding that was always used in DDIM before version 1.6.0
+ """
+
+ is_dict_cond = isinstance(uncond, dict)
+ uncond_vec = uncond['crossattn'] if is_dict_cond else uncond
+
+ if uncond_vec.shape[1] < cond.shape[1]:
+ last_vector = uncond_vec[:, -1:]
+ last_vector_repeated = last_vector.repeat([1, cond.shape[1] - uncond_vec.shape[1], 1])
+ uncond_vec = torch.hstack([uncond_vec, last_vector_repeated])
+ self.padded_cond_uncond_v0 = True
+ elif uncond_vec.shape[1] > cond.shape[1]:
+ uncond_vec = uncond_vec[:, :cond.shape[1]]
+ self.padded_cond_uncond_v0 = True
+
+ if is_dict_cond:
+ uncond['crossattn'] = uncond_vec
+ else:
+ uncond = uncond_vec
+
+ return cond, uncond
+
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
@@ -105,8 +165,21 @@ class CFGDenoiser(torch.nn.Module):
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
+ # If we use masks, blending between the denoised and original latent images occurs here.
+ def apply_blend(current_latent):
+ blended_latent = current_latent * self.nmask + self.init_latent * self.mask
+
+ if self.p.scripts is not None:
+ from modules import scripts
+ mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma)
+ self.p.scripts.on_mask_blend(self.p, mba)
+ blended_latent = mba.blended_latent
+
+ return blended_latent
+
+ # Blend in the original latents (before)
if self.mask_before_denoising and self.mask is not None:
- x = self.init_latent * self.mask + self.nmask * x
+ x = apply_blend(x)
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
@@ -130,7 +203,7 @@ class CFGDenoiser(torch.nn.Module):
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
- denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
+ denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond, self)
cfg_denoiser_callback(denoiser_params)
x_in = denoiser_params.x
image_cond_in = denoiser_params.image_cond
@@ -146,16 +219,11 @@ class CFGDenoiser(torch.nn.Module):
sigma_in = sigma_in[:-batch_size]
self.padded_cond_uncond = False
+ self.padded_cond_uncond_v0 = False
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
- empty = shared.sd_model.cond_stage_model_empty_prompt
- num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
-
- if num_repeats < 0:
- tensor = pad_cond(tensor, -num_repeats, empty)
- self.padded_cond_uncond = True
- elif num_repeats > 0:
- uncond = pad_cond(uncond, num_repeats, empty)
- self.padded_cond_uncond = True
+ tensor, uncond = self.pad_cond_uncond(tensor, uncond)
+ elif shared.opts.pad_cond_uncond_v0 and tensor.shape[1] != uncond.shape[1]:
+ tensor, uncond = self.pad_cond_uncond_v0(tensor, uncond)
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
if is_edit_model:
@@ -207,8 +275,9 @@ class CFGDenoiser(torch.nn.Module):
else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
+ # Blend in the original latents (after)
if not self.mask_before_denoising and self.mask is not None:
- denoised = self.init_latent * self.mask + self.nmask * denoised
+ denoised = apply_blend(denoised)
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index 58efcad2..6bd38e12 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -335,3 +335,10 @@ class Sampler:
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
raise NotImplementedError()
+
+ def add_infotext(self, p):
+ if self.model_wrap_cfg.padded_cond_uncond:
+ p.extra_generation_params["Pad conds"] = True
+
+ if self.model_wrap_cfg.padded_cond_uncond_v0:
+ p.extra_generation_params["Pad conds v0"] = True
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 8a8c87e0..337106c0 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -187,8 +187,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
- if self.model_wrap_cfg.padded_cond_uncond:
- p.extra_generation_params["Pad conds"] = True
+ self.add_infotext(p)
return samples
@@ -234,8 +233,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
- if self.model_wrap_cfg.padded_cond_uncond:
- p.extra_generation_params["Pad conds"] = True
+ self.add_infotext(p)
return samples
diff --git a/modules/sd_samplers_lcm.py b/modules/sd_samplers_lcm.py
new file mode 100644
index 00000000..59839b72
--- /dev/null
+++ b/modules/sd_samplers_lcm.py
@@ -0,0 +1,104 @@
+import torch
+
+from k_diffusion import utils, sampling
+from k_diffusion.external import DiscreteEpsDDPMDenoiser
+from k_diffusion.sampling import default_noise_sampler, trange
+
+from modules import shared, sd_samplers_cfg_denoiser, sd_samplers_kdiffusion, sd_samplers_common
+
+
+class LCMCompVisDenoiser(DiscreteEpsDDPMDenoiser):
+ def __init__(self, model):
+ timesteps = 1000
+ original_timesteps = 50 # LCM Original Timesteps (default=50, for current version of LCM)
+ self.skip_steps = timesteps // original_timesteps
+
+ alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32, device=model.device)
+ for x in range(original_timesteps):
+ alphas_cumprod_valid[original_timesteps - 1 - x] = model.alphas_cumprod[timesteps - 1 - x * self.skip_steps]
+
+ super().__init__(model, alphas_cumprod_valid, quantize=None)
+
+
+ def get_sigmas(self, n=None,):
+ if n is None:
+ return sampling.append_zero(self.sigmas.flip(0))
+
+ start = self.sigma_to_t(self.sigma_max)
+ end = self.sigma_to_t(self.sigma_min)
+
+ t = torch.linspace(start, end, n, device=shared.sd_model.device)
+
+ return sampling.append_zero(self.t_to_sigma(t))
+
+
+ def sigma_to_t(self, sigma, quantize=None):
+ log_sigma = sigma.log()
+ dists = log_sigma - self.log_sigmas[:, None]
+ return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)
+
+
+ def t_to_sigma(self, timestep):
+ t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
+ return super().t_to_sigma(t)
+
+
+ def get_eps(self, *args, **kwargs):
+ return self.inner_model.apply_model(*args, **kwargs)
+
+
+ def get_scaled_out(self, sigma, output, input):
+ sigma_data = 0.5
+ scaled_timestep = utils.append_dims(self.sigma_to_t(sigma), output.ndim) * 10.0
+
+ c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
+ c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
+
+ return c_out * output + c_skip * input
+
+
+ def forward(self, input, sigma, **kwargs):
+ c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
+ eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
+ return self.get_scaled_out(sigma, input + eps * c_out, input)
+
+
+def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
+ extra_args = {} if extra_args is None else extra_args
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
+ s_in = x.new_ones([x.shape[0]])
+
+ for i in trange(len(sigmas) - 1, disable=disable):
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+
+ x = denoised
+ if sigmas[i + 1] > 0:
+ x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
+ return x
+
+
+class CFGDenoiserLCM(sd_samplers_cfg_denoiser.CFGDenoiser):
+ @property
+ def inner_model(self):
+ if self.model_wrap is None:
+ denoiser = LCMCompVisDenoiser
+ self.model_wrap = denoiser(shared.sd_model)
+
+ return self.model_wrap
+
+
+class LCMSampler(sd_samplers_kdiffusion.KDiffusionSampler):
+ def __init__(self, funcname, sd_model, options=None):
+ super().__init__(funcname, sd_model, options)
+ self.model_wrap_cfg = CFGDenoiserLCM(self)
+ self.model_wrap = self.model_wrap_cfg.inner_model
+
+
+samplers_lcm = [('LCM', sample_lcm, ['k_lcm'], {})]
+samplers_data_lcm = [
+ sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: LCMSampler(funcname, model), aliases, options)
+ for label, funcname, aliases, options in samplers_lcm
+]
diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py
index b17a8f93..8cc7d384 100644
--- a/modules/sd_samplers_timesteps.py
+++ b/modules/sd_samplers_timesteps.py
@@ -36,7 +36,7 @@ class CompVisTimestepsVDenoiser(torch.nn.Module):
self.inner_model = model
def predict_eps_from_z_and_v(self, x_t, t, v):
- return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t
+ return torch.sqrt(self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * v + torch.sqrt(1 - self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * x_t
def forward(self, input, timesteps, **kwargs):
model_output = self.inner_model.apply_model(input, timesteps, **kwargs)
@@ -80,6 +80,7 @@ class CompVisSampler(sd_samplers_common.Sampler):
self.eta_default = 0.0
self.model_wrap_cfg = CFGDenoiserTimesteps(self)
+ self.model_wrap = self.model_wrap_cfg.inner_model
def get_timesteps(self, p, steps):
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
@@ -132,8 +133,7 @@ class CompVisSampler(sd_samplers_common.Sampler):
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
- if self.model_wrap_cfg.padded_cond_uncond:
- p.extra_generation_params["Pad conds"] = True
+ self.add_infotext(p)
return samples
@@ -157,8 +157,7 @@ class CompVisSampler(sd_samplers_common.Sampler):
}
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
- if self.model_wrap_cfg.padded_cond_uncond:
- p.extra_generation_params["Pad conds"] = True
+ self.add_infotext(p)
return samples
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index 31306d8b..43687e48 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -273,10 +273,11 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
load_vae(sd_model, vae_file, vae_source)
sd_hijack.model_hijack.hijack(sd_model)
- script_callbacks.model_loaded_callback(sd_model)
if not sd_model.lowvram:
sd_model.to(devices.device)
+ script_callbacks.model_loaded_callback(sd_model)
+
print("VAE weights loaded.")
return sd_model
diff --git a/modules/shared.py b/modules/shared.py
index 63661939..ccdca4e7 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -1,3 +1,4 @@
+import os
import sys
import gradio as gr
@@ -11,7 +12,7 @@ parser = shared_cmd_options.parser
batch_cond_uncond = True # old field, unused now in favor of shared.opts.batch_cond_uncond
parallel_processing_allowed = True
-styles_filename = cmd_opts.styles_file
+styles_filename = cmd_opts.styles_file = cmd_opts.styles_file if len(cmd_opts.styles_file) > 0 else [os.path.join(data_path, 'styles.csv')]
config_filename = cmd_opts.ui_settings_file
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
diff --git a/modules/shared_gradio_themes.py b/modules/shared_gradio_themes.py
index 822db0a9..b6dc3145 100644
--- a/modules/shared_gradio_themes.py
+++ b/modules/shared_gradio_themes.py
@@ -65,3 +65,7 @@ def reload_gradio_theme(theme_name=None):
except Exception as e:
errors.display(e, "changing gradio theme")
shared.gradio_theme = gr.themes.Default(**default_theme_args)
+
+ # append additional values gradio_theme
+ shared.gradio_theme.sd_webui_modal_lightbox_toolbar_opacity = shared.opts.sd_webui_modal_lightbox_toolbar_opacity
+ shared.gradio_theme.sd_webui_modal_lightbox_icon_opacity = shared.opts.sd_webui_modal_lightbox_icon_opacity
diff --git a/modules/shared_init.py b/modules/shared_init.py
index d3fb687e..935e3a21 100644
--- a/modules/shared_init.py
+++ b/modules/shared_init.py
@@ -18,8 +18,10 @@ def initialize():
shared.options_templates = shared_options.options_templates
shared.opts = options.Options(shared_options.options_templates, shared_options.restricted_opts)
shared.restricted_opts = shared_options.restricted_opts
- if os.path.exists(shared.config_filename):
+ try:
shared.opts.load(shared.config_filename)
+ except FileNotFoundError:
+ pass
from modules import devices
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
@@ -27,6 +29,7 @@ def initialize():
devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
+ devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype
shared.device = devices.device
shared.weight_load_location = None if cmd_opts.lowram else "cpu"
diff --git a/modules/shared_items.py b/modules/shared_items.py
index 991971ad..88f63645 100644
--- a/modules/shared_items.py
+++ b/modules/shared_items.py
@@ -8,6 +8,11 @@ def realesrgan_models_names():
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
+def dat_models_names():
+ import modules.dat_model
+ return [x.name for x in modules.dat_model.get_dat_models(None)]
+
+
def postprocessing_scripts():
import modules.scripts
@@ -67,14 +72,14 @@ def reload_hypernetworks():
def get_infotext_names():
- from modules import generation_parameters_copypaste, shared
+ from modules import infotext_utils, shared
res = {}
for info in shared.opts.data_labels.values():
if info.infotext:
res[info.infotext] = 1
- for tab_data in generation_parameters_copypaste.paste_fields.values():
+ for tab_data in infotext_utils.paste_fields.values():
for _, name in tab_data.get("fields") or []:
if isinstance(name, str):
res[name] = 1
diff --git a/modules/shared_options.py b/modules/shared_options.py
index d2e86ff1..bdd066c4 100644
--- a/modules/shared_options.py
+++ b/modules/shared_options.py
@@ -1,7 +1,8 @@
+import os
import gradio as gr
-from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes
-from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
+from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes, util
+from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir, default_output_dir # noqa: F401
from modules.shared_cmd_options import cmd_opts
from modules.options import options_section, OptionInfo, OptionHTML, categories
@@ -74,14 +75,14 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
options_templates.update(options_section(('saving-paths', "Paths for saving", "saving"), {
"outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
- "outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
- "outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
- "outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs),
+ "outdir_txt2img_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'txt2img-images')), 'Output directory for txt2img images', component_args=hide_dirs),
+ "outdir_img2img_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'img2img-images')), 'Output directory for img2img images', component_args=hide_dirs),
+ "outdir_extras_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'extras-images')), 'Output directory for images from extras tab', component_args=hide_dirs),
"outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs),
- "outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
- "outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
- "outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
- "outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
+ "outdir_txt2img_grids": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'txt2img-grids')), 'Output directory for txt2img grids', component_args=hide_dirs),
+ "outdir_img2img_grids": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'img2img-grids')), 'Output directory for img2img grids', component_args=hide_dirs),
+ "outdir_save": OptionInfo(util.truncate_path(os.path.join(data_path, 'log', 'images')), "Directory for saving images using the Save button", component_args=hide_dirs),
+ "outdir_init_images": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'init-images')), "Directory for saving init images when using img2img", component_args=hide_dirs),
}))
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory", "saving"), {
@@ -96,6 +97,9 @@ options_templates.update(options_section(('upscaling', "Upscaling", "postprocess
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
+ "dat_enabled_models": OptionInfo(["DAT x2", "DAT x3", "DAT x4"], "Select which DAT models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.dat_models_names()}),
+ "DAT_tile": OptionInfo(192, "Tile size for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
+ "DAT_tile_overlap": OptionInfo(8, "Tile overlap for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}),
}))
@@ -114,6 +118,7 @@ options_templates.update(options_section(('system', "System", "system"), {
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
+ "enable_upscale_progressbar": OptionInfo(True, "Show a progress bar in the console for tiled upscaling."),
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
@@ -176,6 +181,7 @@ For img2img, VAE is used to process user's input image before the sampling, and
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list, infotext='VAE').info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
"sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"),
+ "auto_vae_precision_bfloat16": OptionInfo(False, "Automatically convert VAE to bfloat16").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image; if enabled, overrides the option below"),
"auto_vae_precision": OptionInfo(True, "Automatically revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
"sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Encoder').info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"),
@@ -195,6 +201,7 @@ options_templates.update(options_section(('img2img', "img2img", "sd"), {
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
"img2img_batch_show_results_limit": OptionInfo(32, "Show the first N batch img2img results in UI", gr.Slider, {"minimum": -1, "maximum": 1000, "step": 1}).info('0: disable, -1: show all images. Too many images can cause lag'),
+ "overlay_inpaint": OptionInfo(True, "Overlay original for inpaint").info("when inpainting, overlay the original image over the areas that weren't inpainted."),
}))
options_templates.update(options_section(('optimizations', "Optimizations", "sd"), {
@@ -203,12 +210,16 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd"
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio hr').info("only applies if non-zero and overrides above"),
- "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
+ "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
+ "pad_cond_uncond_v0": OptionInfo(False, "Pad prompt/negative prompt (v0)", infotext='Pad conds v0').info("alternative implementation for the above; used prior to 1.6.0 for DDIM sampler; ignored if the above is set; changes seeds"),
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
+ "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."),
+ "cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."),
}))
options_templates.update(options_section(('compatibility', "Compatibility", "sd"), {
+ "auto_backcompat": OptionInfo(True, "Automatic backward compatibility").info("automatically enable options for backwards compatibility when importing generation parameters from infotext that has program version."),
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
@@ -216,6 +227,7 @@ options_templates.update(options_section(('compatibility', "Compatibility", "sd"
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
"use_old_scheduling": OptionInfo(False, "Use old prompt editing timelines.", infotext="Old prompt editing timelines").info("For [red:green:N]; old: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"),
+ "use_downcasted_alpha_bar": OptionInfo(False, "Downcast model alphas_cumprod to fp16 before sampling. For reproducing old seeds.", infotext="Downcast alphas_cumprod")
}))
options_templates.update(options_section(('interrogate', "Interrogate"), {
@@ -244,6 +256,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks", "s
"extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
"extra_networks_card_order_field": OptionInfo("Path", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Path', 'Name', 'Date Created', 'Date Modified']}).needs_reload_ui(),
"extra_networks_card_order": OptionInfo("Ascending", "Default order for Extra Networks cards", gr.Dropdown, {"choices": ['Ascending', 'Descending']}).needs_reload_ui(),
+ "extra_networks_tree_view_default_enabled": OptionInfo(False, "Enables the Extra Networks directory tree view by default").needs_reload_ui(),
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
"textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
@@ -267,6 +280,8 @@ options_templates.update(options_section(('ui_gallery', "Gallery", "ui"), {
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Full page image viewer: show images zoomed in by default"),
"js_modal_lightbox_gamepad": OptionInfo(False, "Full page image viewer: navigate with gamepad"),
"js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Full page image viewer: gamepad repeat period").info("in milliseconds"),
+ "sd_webui_modal_lightbox_icon_opacity": OptionInfo(1, "Full page image viewer: control icon unfocused opacity", gr.Slider, {"minimum": 0.0, "maximum": 1, "step": 0.01}, onchange=shared.reload_gradio_theme).info('for mouse only').needs_reload_ui(),
+ "sd_webui_modal_lightbox_toolbar_opacity": OptionInfo(0.9, "Full page image viewer: tool bar opacity", gr.Slider, {"minimum": 0.0, "maximum": 1, "step": 0.01}, onchange=shared.reload_gradio_theme).info('for mouse only').needs_reload_ui(),
"gallery_height": OptionInfo("", "Gallery height", gr.Textbox).info("can be any valid CSS value, for example 768px or 20em").needs_reload_ui(),
}))
@@ -279,6 +294,7 @@ options_templates.update(options_section(('ui_alternatives', "UI alternatives",
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(),
"txt2img_settings_accordion": OptionInfo(False, "Settings in txt2img hidden under Accordion").needs_reload_ui(),
"img2img_settings_accordion": OptionInfo(False, "Settings in img2img hidden under Accordion").needs_reload_ui(),
+ "interrupt_after_current": OptionInfo(True, "Don't Interrupt in the middle").info("when using Interrupt button, if generating more than one image, stop after the generation of an image has finished, instead of immediately"),
}))
options_templates.update(options_section(('ui', "User interface", "ui"), {
@@ -354,6 +370,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}, infotext='UniPC skip type'),
'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, infotext='UniPC order').info("must be < sampling steps"),
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'),
+ 'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise Schedule").info("for use with zero terminal SNR trained models")
}))
options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), {
diff --git a/modules/shared_state.py b/modules/shared_state.py
index a68789cc..33996691 100644
--- a/modules/shared_state.py
+++ b/modules/shared_state.py
@@ -12,6 +12,7 @@ log = logging.getLogger(__name__)
class State:
skipped = False
interrupted = False
+ stopping_generation = False
job = ""
job_no = 0
job_count = 0
@@ -79,6 +80,10 @@ class State:
self.interrupted = True
log.info("Received interrupt request")
+ def stop_generating(self):
+ self.stopping_generation = True
+ log.info("Received stop generating request")
+
def nextjob(self):
if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1:
self.do_set_current_image()
@@ -91,6 +96,7 @@ class State:
obj = {
"skipped": self.skipped,
"interrupted": self.interrupted,
+ "stopping_generation": self.stopping_generation,
"job": self.job,
"job_count": self.job_count,
"job_timestamp": self.job_timestamp,
@@ -114,6 +120,7 @@ class State:
self.id_live_preview = 0
self.skipped = False
self.interrupted = False
+ self.stopping_generation = False
self.textinfo = None
self.job = job
devices.torch_gc()
diff --git a/modules/styles.py b/modules/styles.py
index 81d9800d..9edcc7e4 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -1,16 +1,15 @@
+from pathlib import Path
import csv
-import fnmatch
import os
-import os.path
import typing
import shutil
class PromptStyle(typing.NamedTuple):
name: str
- prompt: str
- negative_prompt: str
- path: str = None
+ prompt: str | None
+ negative_prompt: str | None
+ path: str | None = None
def merge_prompts(style_prompt: str, prompt: str) -> str:
@@ -30,38 +29,29 @@ def apply_styles_to_prompt(prompt, styles):
return prompt
-def unwrap_style_text_from_prompt(style_text, prompt):
- """
- Checks the prompt to see if the style text is wrapped around it. If so,
- returns True plus the prompt text without the style text. Otherwise, returns
- False with the original prompt.
+def extract_style_text_from_prompt(style_text, prompt):
+ """This function extracts the text from a given prompt based on a provided style text. It checks if the style text contains the placeholder {prompt} or if it appears at the end of the prompt. If a match is found, it returns True along with the extracted text. Otherwise, it returns False and the original prompt.
- Note that the "cleaned" version of the style text is only used for matching
- purposes here. It isn't returned; the original style text is not modified.
+ extract_style_text_from_prompt("masterpiece", "1girl, art by greg, masterpiece") outputs (True, "1girl, art by greg")
+ extract_style_text_from_prompt("masterpiece, {prompt}", "masterpiece, 1girl, art by greg") outputs (True, "1girl, art by greg")
+ extract_style_text_from_prompt("masterpiece, {prompt}", "exquisite, 1girl, art by greg") outputs (False, "exquisite, 1girl, art by greg")
"""
- stripped_prompt = prompt
- stripped_style_text = style_text
+
+ stripped_prompt = prompt.strip()
+ stripped_style_text = style_text.strip()
+
if "{prompt}" in stripped_style_text:
- # Work out whether the prompt is wrapped in the style text. If so, we
- # return True and the "inner" prompt text that isn't part of the style.
- try:
- left, right = stripped_style_text.split("{prompt}", 2)
- except ValueError as e:
- # If the style text has multple "{prompt}"s, we can't split it into
- # two parts. This is an error, but we can't do anything about it.
- print(f"Unable to compare style text to prompt:\n{style_text}")
- print(f"Error: {e}")
- return False, prompt
+ left, right = stripped_style_text.split("{prompt}", 2)
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
- prompt = stripped_prompt[len(left) : len(stripped_prompt) - len(right)]
+ prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
return True, prompt
else:
- # Work out whether the given prompt ends with the style text. If so, we
- # return True and the prompt text up to where the style text starts.
if stripped_prompt.endswith(stripped_style_text):
- prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)]
- if prompt.endswith(", "):
+ prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]
+
+ if prompt.endswith(', '):
prompt = prompt[:-2]
+
return True, prompt
return False, prompt
@@ -76,15 +66,11 @@ def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
if not style.prompt and not style.negative_prompt:
return False, prompt, negative_prompt
- match_positive, extracted_positive = unwrap_style_text_from_prompt(
- style.prompt, prompt
- )
+ match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)
if not match_positive:
return False, prompt, negative_prompt
- match_negative, extracted_negative = unwrap_style_text_from_prompt(
- style.negative_prompt, negative_prompt
- )
+ match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)
if not match_negative:
return False, prompt, negative_prompt
@@ -92,14 +78,19 @@ def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
class StyleDatabase:
- def __init__(self, path: str):
+ def __init__(self, paths: list[str | Path]):
self.no_style = PromptStyle("None", "", "", None)
self.styles = {}
- self.path = path
-
- folder, file = os.path.split(self.path)
- filename, _, ext = file.partition('*')
- self.default_path = os.path.join(folder, filename + ext)
+ self.paths = paths
+ self.all_styles_files: list[Path] = []
+
+ folder, file = os.path.split(self.paths[0])
+ if '*' in file or '?' in file:
+ # if the first path is a wildcard pattern, find the first match else use "folder/styles.csv" as the default path
+ self.default_path = next(Path(folder).glob(file), Path(os.path.join(folder, 'styles.csv')))
+ self.paths.insert(0, self.default_path)
+ else:
+ self.default_path = Path(self.paths[0])
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
@@ -112,33 +103,31 @@ class StyleDatabase:
"""
self.styles.clear()
- path, filename = os.path.split(self.path)
-
- if "*" in filename:
- fileglob = filename.split("*")[0] + "*.csv"
- filelist = []
- for file in os.listdir(path):
- if fnmatch.fnmatch(file, fileglob):
- filelist.append(file)
- # Add a visible divider to the style list
- half_len = round(len(file) / 2)
- divider = f"{'-' * (20 - half_len)} {file.upper()}"
- divider = f"{divider} {'-' * (40 - len(divider))}"
- self.styles[divider] = PromptStyle(
- f"{divider}", None, None, "do_not_save"
- )
- # Add styles from this CSV file
- self.load_from_csv(os.path.join(path, file))
- if len(filelist) == 0:
- print(f"No styles found in {path} matching {fileglob}")
- return
- elif not os.path.exists(self.path):
- print(f"Style database not found: {self.path}")
- return
- else:
- self.load_from_csv(self.path)
-
- def load_from_csv(self, path: str):
+ # scans for all styles files
+ all_styles_files = []
+ for pattern in self.paths:
+ folder, file = os.path.split(pattern)
+ if '*' in file or '?' in file:
+ found_files = Path(folder).glob(file)
+ [all_styles_files.append(file) for file in found_files]
+ else:
+ # if os.path.exists(pattern):
+ all_styles_files.append(Path(pattern))
+
+ # Remove any duplicate entries
+ seen = set()
+ self.all_styles_files = [s for s in all_styles_files if not (s in seen or seen.add(s))]
+
+ for styles_file in self.all_styles_files:
+ if len(all_styles_files) > 1:
+ # add divider when more than styles file
+ # '---------------- STYLES ----------------'
+ divider = f' {styles_file.stem.upper()} '.center(40, '-')
+ self.styles[divider] = PromptStyle(f"{divider}", None, None, "do_not_save")
+ if styles_file.is_file():
+ self.load_from_csv(styles_file)
+
+ def load_from_csv(self, path: str | Path):
with open(path, "r", encoding="utf-8-sig", newline="") as file:
reader = csv.DictReader(file, skipinitialspace=True)
for row in reader:
@@ -150,7 +139,7 @@ class StyleDatabase:
negative_prompt = row.get("negative_prompt", "")
# Add style to database
self.styles[row["name"]] = PromptStyle(
- row["name"], prompt, negative_prompt, path
+ row["name"], prompt, negative_prompt, str(path)
)
def get_style_paths(self) -> set:
@@ -158,11 +147,11 @@ class StyleDatabase:
# Update any styles without a path to the default path
for style in list(self.styles.values()):
if not style.path:
- self.styles[style.name] = style._replace(path=self.default_path)
+ self.styles[style.name] = style._replace(path=str(self.default_path))
# Create a list of all distinct paths, including the default path
style_paths = set()
- style_paths.add(self.default_path)
+ style_paths.add(str(self.default_path))
for _, style in self.styles.items():
if style.path:
style_paths.add(style.path)
@@ -190,7 +179,6 @@ class StyleDatabase:
def save_styles(self, path: str = None) -> None:
# The path argument is deprecated, but kept for backwards compatibility
- _ = path
style_paths = self.get_style_paths()
diff --git a/modules/sysinfo.py b/modules/sysinfo.py
index b669edd0..f336251e 100644
--- a/modules/sysinfo.py
+++ b/modules/sysinfo.py
@@ -24,13 +24,13 @@ environment_whitelist = {
"XFORMERS_PACKAGE",
"CLIP_PACKAGE",
"OPENCLIP_PACKAGE",
+ "ASSETS_REPO",
"STABLE_DIFFUSION_REPO",
"K_DIFFUSION_REPO",
- "CODEFORMER_REPO",
"BLIP_REPO",
+ "ASSETS_COMMIT_HASH",
"STABLE_DIFFUSION_COMMIT_HASH",
"K_DIFFUSION_COMMIT_HASH",
- "CODEFORMER_COMMIT_HASH",
"BLIP_COMMIT_HASH",
"COMMANDLINE_ARGS",
"IGNORE_CMD_ARGS_ERRORS",
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 9c062503..d16e3b9a 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -11,7 +11,6 @@ import safetensors.torch
import numpy as np
from PIL import Image, PngImagePlugin
-from torch.utils.tensorboard import SummaryWriter
from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
import modules.textual_inversion.dataset
@@ -348,6 +347,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
})
def tensorboard_setup(log_directory):
+ from torch.utils.tensorboard import SummaryWriter
os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
return SummaryWriter(
log_dir=os.path.join(log_directory, "tensorboard"),
@@ -452,8 +452,12 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
old_parallel_processing_allowed = shared.parallel_processing_allowed
+ tensorboard_writer = None
if shared.opts.training_enable_tensorboard:
- tensorboard_writer = tensorboard_setup(log_directory)
+ try:
+ tensorboard_writer = tensorboard_setup(log_directory)
+ except ImportError:
+ errors.report("Error initializing tensorboard", exc_info=True)
pin_memory = shared.opts.pin_memory
@@ -626,7 +630,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
- if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
+ if tensorboard_writer and shared.opts.training_tensorboard_save_images:
tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
diff --git a/modules/torch_utils.py b/modules/torch_utils.py
new file mode 100644
index 00000000..e5b52393
--- /dev/null
+++ b/modules/torch_utils.py
@@ -0,0 +1,17 @@
+from __future__ import annotations
+
+import torch.nn
+
+
+def get_param(model) -> torch.nn.Parameter:
+ """
+ Find the first parameter in a model or module.
+ """
+ if hasattr(model, "model") and hasattr(model.model, "parameters"):
+ # Unpeel a model descriptor to get at the actual Torch module.
+ model = model.model
+
+ for param in model.parameters():
+ return param
+
+ raise ValueError(f"No parameters found in model {model!r}")
diff --git a/modules/txt2img.py b/modules/txt2img.py
index e4e18ceb..fc56b8a8 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -1,17 +1,22 @@
+import json
from contextlib import closing
import modules.scripts
-from modules import processing
-from modules.generation_parameters_copypaste import create_override_settings_dict
+from modules import processing, infotext_utils
+from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters
from modules.shared import opts
import modules.shared as shared
from modules.ui import plaintext_to_html
+from PIL import Image
import gradio as gr
-def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
+def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False):
override_settings = create_override_settings_dict(override_settings_texts)
+ if force_enable_hr:
+ enable_hr = True
+
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@@ -27,7 +32,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
width=width,
height=height,
enable_hr=enable_hr,
- denoising_strength=denoising_strength if enable_hr else None,
+ denoising_strength=denoising_strength,
hr_scale=hr_scale,
hr_upscaler=hr_upscaler,
hr_second_pass_steps=hr_second_pass_steps,
@@ -48,8 +53,58 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
if shared.opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
+ return p
+
+
+def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args):
+ assert len(gallery) > 0, 'No image to upscale'
+ assert 0 <= gallery_index < len(gallery), f'Bad image index: {gallery_index}'
+
+ p = txt2img_create_processing(id_task, request, *args, force_enable_hr=True)
+ p.batch_size = 1
+ p.n_iter = 1
+ # txt2img_upscale attribute that signifies this is called by txt2img_upscale
+ p.txt2img_upscale = True
+
+ geninfo = json.loads(generation_info)
+
+ image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0]
+ p.firstpass_image = infotext_utils.image_from_url_text(image_info)
+
+ parameters = parse_generation_parameters(geninfo.get('infotexts')[gallery_index], [])
+ p.seed = parameters.get('Seed', -1)
+ p.subseed = parameters.get('Variation seed', -1)
+
+ p.override_settings['save_images_before_highres_fix'] = False
+
+ with closing(p):
+ processed = modules.scripts.scripts_txt2img.run(p, *p.script_args)
+
+ if processed is None:
+ processed = processing.process_images(p)
+
+ shared.total_tqdm.clear()
+
+ new_gallery = []
+ for i, image in enumerate(gallery):
+ if i == gallery_index:
+ geninfo["infotexts"][gallery_index: gallery_index+1] = processed.infotexts
+ new_gallery.extend(processed.images)
+ else:
+ fake_image = Image.new(mode="RGB", size=(1, 1))
+ fake_image.already_saved_as = image["name"].rsplit('?', 1)[0]
+ new_gallery.append(fake_image)
+
+ geninfo["infotexts"][gallery_index] = processed.info
+
+ return new_gallery, json.dumps(geninfo), plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
+
+
+def txt2img(id_task: str, request: gr.Request, *args):
+ p = txt2img_create_processing(id_task, request, *args)
+
with closing(p):
- processed = modules.scripts.scripts_txt2img.run(p, *args)
+ processed = modules.scripts.scripts_txt2img.run(p, *p.script_args)
if processed is None:
processed = processing.process_images(p)
diff --git a/modules/ui.py b/modules/ui.py
index d80486dd..177c6872 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -13,7 +13,7 @@ from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import gradio_extensons # noqa: F401
-from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow
+from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow, launch_utils
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow
from modules.paths import script_path
from modules.ui_common import create_refresh_button
@@ -21,14 +21,14 @@ from modules.ui_gradio_extensions import reload_javascript
from modules.shared import opts, cmd_opts
-import modules.generation_parameters_copypaste as parameters_copypaste
+import modules.infotext_utils as parameters_copypaste
import modules.hypernetworks.ui as hypernetworks_ui
import modules.textual_inversion.ui as textual_inversion_ui
import modules.textual_inversion.textual_inversion as textual_inversion
import modules.shared as shared
from modules import prompt_parser
from modules.sd_hijack import model_hijack
-from modules.generation_parameters_copypaste import image_from_url_text
+from modules.infotext_utils import image_from_url_text, PasteField
create_setting_component = ui_settings.create_setting_component
@@ -177,7 +177,6 @@ def update_negative_prompt_token_counter(text, steps):
return update_token_counter(text, steps, is_positive=False)
-
def setup_progressbar(*args, **kwargs):
pass
@@ -267,7 +266,7 @@ def create_ui():
dummy_component = gr.Label(visible=False)
- extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs")
+ extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs", elem_classes=["extra-networks"])
extra_tabs.__enter__()
with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False):
@@ -376,50 +375,60 @@ def create_ui():
show_progress=False,
)
- txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow)
+ output_panel = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow)
+
+ txt2img_inputs = [
+ dummy_component,
+ toprow.prompt,
+ toprow.negative_prompt,
+ toprow.ui_styles.dropdown,
+ steps,
+ sampler_name,
+ batch_count,
+ batch_size,
+ cfg_scale,
+ height,
+ width,
+ enable_hr,
+ denoising_strength,
+ hr_scale,
+ hr_upscaler,
+ hr_second_pass_steps,
+ hr_resize_x,
+ hr_resize_y,
+ hr_checkpoint_name,
+ hr_sampler_name,
+ hr_prompt,
+ hr_negative_prompt,
+ override_settings,
+ ] + custom_inputs
+
+ txt2img_outputs = [
+ output_panel.gallery,
+ output_panel.generation_info,
+ output_panel.infotext,
+ output_panel.html_log,
+ ]
txt2img_args = dict(
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
_js="submit",
- inputs=[
- dummy_component,
- toprow.prompt,
- toprow.negative_prompt,
- toprow.ui_styles.dropdown,
- steps,
- sampler_name,
- batch_count,
- batch_size,
- cfg_scale,
- height,
- width,
- enable_hr,
- denoising_strength,
- hr_scale,
- hr_upscaler,
- hr_second_pass_steps,
- hr_resize_x,
- hr_resize_y,
- hr_checkpoint_name,
- hr_sampler_name,
- hr_prompt,
- hr_negative_prompt,
- override_settings,
-
- ] + custom_inputs,
-
- outputs=[
- txt2img_gallery,
- generation_info,
- html_info,
- html_log,
- ],
+ inputs=txt2img_inputs,
+ outputs=txt2img_outputs,
show_progress=False,
)
toprow.prompt.submit(**txt2img_args)
toprow.submit.click(**txt2img_args)
+ output_panel.button_upscale.click(
+ fn=wrap_gradio_gpu_call(modules.txt2img.txt2img_upscale, extra_outputs=[None, '', '']),
+ _js="submit_txt2img_upscale",
+ inputs=txt2img_inputs[0:1] + [output_panel.gallery, dummy_component, output_panel.generation_info] + txt2img_inputs[1:],
+ outputs=txt2img_outputs,
+ show_progress=False,
+ )
+
res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False)
toprow.restore_progress_button.click(
@@ -427,37 +436,37 @@ def create_ui():
_js="restoreProgressTxt2img",
inputs=[dummy_component],
outputs=[
- txt2img_gallery,
- generation_info,
- html_info,
- html_log,
+ output_panel.gallery,
+ output_panel.generation_info,
+ output_panel.infotext,
+ output_panel.html_log,
],
show_progress=False,
)
txt2img_paste_fields = [
- (toprow.prompt, "Prompt"),
- (toprow.negative_prompt, "Negative prompt"),
- (steps, "Steps"),
- (sampler_name, "Sampler"),
- (cfg_scale, "CFG scale"),
- (width, "Size-1"),
- (height, "Size-2"),
- (batch_size, "Batch size"),
- (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
- (denoising_strength, "Denoising strength"),
- (enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)),
- (hr_scale, "Hires upscale"),
- (hr_upscaler, "Hires upscaler"),
- (hr_second_pass_steps, "Hires steps"),
- (hr_resize_x, "Hires resize-1"),
- (hr_resize_y, "Hires resize-2"),
- (hr_checkpoint_name, "Hires checkpoint"),
- (hr_sampler_name, "Hires sampler"),
- (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
- (hr_prompt, "Hires prompt"),
- (hr_negative_prompt, "Hires negative prompt"),
- (hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
+ PasteField(toprow.prompt, "Prompt", api="prompt"),
+ PasteField(toprow.negative_prompt, "Negative prompt", api="negative_prompt"),
+ PasteField(steps, "Steps", api="steps"),
+ PasteField(sampler_name, "Sampler", api="sampler_name"),
+ PasteField(cfg_scale, "CFG scale", api="cfg_scale"),
+ PasteField(width, "Size-1", api="width"),
+ PasteField(height, "Size-2", api="height"),
+ PasteField(batch_size, "Batch size", api="batch_size"),
+ PasteField(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update(), api="styles"),
+ PasteField(denoising_strength, "Denoising strength", api="denoising_strength"),
+ PasteField(enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d), api="enable_hr"),
+ PasteField(hr_scale, "Hires upscale", api="hr_scale"),
+ PasteField(hr_upscaler, "Hires upscaler", api="hr_upscaler"),
+ PasteField(hr_second_pass_steps, "Hires steps", api="hr_second_pass_steps"),
+ PasteField(hr_resize_x, "Hires resize-1", api="hr_resize_x"),
+ PasteField(hr_resize_y, "Hires resize-2", api="hr_resize_y"),
+ PasteField(hr_checkpoint_name, "Hires checkpoint", api="hr_checkpoint_name"),
+ PasteField(hr_sampler_name, "Hires sampler", api="hr_sampler_name"),
+ PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
+ PasteField(hr_prompt, "Hires prompt", api="hr_prompt"),
+ PasteField(hr_negative_prompt, "Hires negative prompt", api="hr_negative_prompt"),
+ PasteField(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
*scripts.scripts_txt2img.infotext_fields
]
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
@@ -480,7 +489,7 @@ def create_ui():
toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
- ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
+ ui_extra_networks.setup_ui(extra_networks_ui, output_panel.gallery)
extra_tabs.__exit__()
@@ -490,7 +499,7 @@ def create_ui():
with gr.Blocks(analytics_enabled=False) as img2img_interface:
toprow = ui_toprow.Toprow(is_img2img=True, is_compact=shared.opts.compact_prompt_box)
- extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs")
+ extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs", elem_classes=["extra-networks"])
extra_tabs.__enter__()
with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False):
@@ -523,7 +532,7 @@ def create_ui():
if category == "image":
with gr.Tabs(elem_id="mode_img2img"):
- img2img_selected_tab = gr.State(0)
+ img2img_selected_tab = gr.Number(value=0, visible=False)
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height)
@@ -604,7 +613,7 @@ def create_ui():
elif category == "dimensions":
with FormRow():
with gr.Column(elem_id="img2img_column_size", scale=4):
- selected_scale_tab = gr.State(value=0)
+ selected_scale_tab = gr.Number(value=0, visible=False)
with gr.Tabs():
with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to:
@@ -711,7 +720,7 @@ def create_ui():
outputs=[inpaint_controls, mask_alpha],
)
- img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples, toprow)
+ output_panel = create_output_panel("img2img", opts.outdir_img2img_samples, toprow)
img2img_args = dict(
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
@@ -756,10 +765,10 @@ def create_ui():
img2img_batch_png_info_dir,
] + custom_inputs,
outputs=[
- img2img_gallery,
- generation_info,
- html_info,
- html_log,
+ output_panel.gallery,
+ output_panel.generation_info,
+ output_panel.infotext,
+ output_panel.html_log,
],
show_progress=False,
)
@@ -797,10 +806,10 @@ def create_ui():
_js="restoreProgressImg2img",
inputs=[dummy_component],
outputs=[
- img2img_gallery,
- generation_info,
- html_info,
- html_log,
+ output_panel.gallery,
+ output_panel.generation_info,
+ output_panel.infotext,
+ output_panel.html_log,
],
show_progress=False,
)
@@ -831,6 +840,10 @@ def create_ui():
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
(denoising_strength, "Denoising strength"),
(mask_blur, "Mask blur"),
+ (inpainting_mask_invert, 'Mask mode'),
+ (inpainting_fill, 'Masked content'),
+ (inpaint_full_res, 'Inpaint area'),
+ (inpaint_full_res_padding, 'Masked area padding'),
*scripts.scripts_img2img.infotext_fields
]
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
@@ -840,7 +853,7 @@ def create_ui():
))
extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img')
- ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
+ ui_extra_networks.setup_ui(extra_networks_ui_img2img, output_panel.gallery)
extra_tabs.__exit__()
@@ -1086,6 +1099,7 @@ def create_ui():
)
loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)
+ ui_settings_from_file = loadsave.ui_settings.copy()
settings = ui_settings.UiSettings()
settings.create_ui(loadsave, dummy_component)
@@ -1146,7 +1160,8 @@ def create_ui():
modelmerger_ui.setup_ui(dummy_component=dummy_component, sd_model_checkpoint_component=settings.component_dict['sd_model_checkpoint'])
- loadsave.dump_defaults()
+ if ui_settings_from_file != loadsave.ui_settings:
+ loadsave.dump_defaults()
demo.ui_loadsave = loadsave
return demo
@@ -1208,3 +1223,5 @@ def setup_ui_api(app):
app.add_api_route("/internal/sysinfo", download_sysinfo, methods=["GET"])
app.add_api_route("/internal/sysinfo-download", lambda: download_sysinfo(attachment=True), methods=["GET"])
+ import fastapi.staticfiles
+ app.mount("/webui-assets", fastapi.staticfiles.StaticFiles(directory=launch_utils.repo_dir('stable-diffusion-webui-assets')), name="webui-assets")
diff --git a/modules/ui_common.py b/modules/ui_common.py
index 032ec4af..29fe7d0e 100644
--- a/modules/ui_common.py
+++ b/modules/ui_common.py
@@ -1,3 +1,5 @@
+import csv
+import dataclasses
import json
import html
import os
@@ -8,10 +10,10 @@ import gradio as gr
import subprocess as sp
from modules import call_queue, shared
-from modules.generation_parameters_copypaste import image_from_url_text
+from modules.infotext_utils import image_from_url_text
import modules.images
from modules.ui_components import ToolButton
-import modules.generation_parameters_copypaste as parameters_copypaste
+import modules.infotext_utils as parameters_copypaste
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
@@ -35,12 +37,38 @@ def plaintext_to_html(text, classname=None):
return f"<p class='{classname}'>{content}</p>" if classname else f"<p>{content}</p>"
+def update_logfile(logfile_path, fields):
+ """Update a logfile from old format to new format to maintain CSV integrity."""
+ with open(logfile_path, "r", encoding="utf8", newline="") as file:
+ reader = csv.reader(file)
+ rows = list(reader)
+
+ # blank file: leave it as is
+ if not rows:
+ return
+
+ # file is already synced, do nothing
+ if len(rows[0]) == len(fields):
+ return
+
+ rows[0] = fields
+
+ # append new fields to each row as empty values
+ for row in rows[1:]:
+ while len(row) < len(fields):
+ row.append("")
+
+ with open(logfile_path, "w", encoding="utf8", newline="") as file:
+ writer = csv.writer(file)
+ writer.writerows(rows)
+
+
def save_files(js_data, images, do_make_zip, index):
- import csv
filenames = []
fullfns = []
+ parsed_infotexts = []
- #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
+ # quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
class MyObject:
def __init__(self, d=None):
if d is not None:
@@ -48,35 +76,55 @@ def save_files(js_data, images, do_make_zip, index):
setattr(self, key, value)
data = json.loads(js_data)
-
p = MyObject(data)
+
path = shared.opts.outdir_save
save_to_dirs = shared.opts.use_save_to_dirs_for_ui
extension: str = shared.opts.samples_format
start_index = 0
- only_one = False
if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
- only_one = True
images = [images[index]]
start_index = index
os.makedirs(shared.opts.outdir_save, exist_ok=True)
- with open(os.path.join(shared.opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
+ fields = [
+ "prompt",
+ "seed",
+ "width",
+ "height",
+ "sampler",
+ "cfgs",
+ "steps",
+ "filename",
+ "negative_prompt",
+ "sd_model_name",
+ "sd_model_hash",
+ ]
+ logfile_path = os.path.join(shared.opts.outdir_save, "log.csv")
+
+ # NOTE: ensure csv integrity when fields are added by
+ # updating headers and padding with delimeters where needed
+ if os.path.exists(logfile_path):
+ update_logfile(logfile_path, fields)
+
+ with open(logfile_path, "a", encoding="utf8", newline='') as file:
at_start = file.tell() == 0
writer = csv.writer(file)
if at_start:
- writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
+ writer.writerow(fields)
for image_index, filedata in enumerate(images, start_index):
image = image_from_url_text(filedata)
is_grid = image_index < p.index_of_first_image
- i = 0 if is_grid else (image_index - p.index_of_first_image)
p.batch_index = image_index-1
- fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
+
+ parameters = parameters_copypaste.parse_generation_parameters(data["infotexts"][image_index], [])
+ parsed_infotexts.append(parameters)
+ fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=parameters['Seed'], prompt=parameters['Prompt'], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
filename = os.path.relpath(fullfn, path)
filenames.append(filename)
@@ -85,12 +133,12 @@ def save_files(js_data, images, do_make_zip, index):
filenames.append(os.path.basename(txt_fullfn))
fullfns.append(txt_fullfn)
- writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
+ writer.writerow([parsed_infotexts[0]['Prompt'], parsed_infotexts[0]['Seed'], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], parsed_infotexts[0]['Negative prompt'], data["sd_model_name"], data["sd_model_hash"]])
# Make Zip
if do_make_zip:
- zip_fileseed = p.all_seeds[index-1] if only_one else p.all_seeds[0]
- namegen = modules.images.FilenameGenerator(p, zip_fileseed, p.all_prompts[0], image, True)
+ p.all_seeds = [parameters['Seed'] for parameters in parsed_infotexts]
+ namegen = modules.images.FilenameGenerator(p, parsed_infotexts[0]['Seed'], parsed_infotexts[0]['Prompt'], image, True)
zip_filename = namegen.apply(shared.opts.grid_zip_filename_pattern or "[datetime]_[[model_name]]_[seed]-[seed_last]")
zip_filepath = os.path.join(path, f"{zip_filename}.zip")
@@ -104,7 +152,17 @@ def save_files(js_data, images, do_make_zip, index):
return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
+@dataclasses.dataclass
+class OutputPanel:
+ gallery = None
+ generation_info = None
+ infotext = None
+ html_log = None
+ button_upscale = None
+
+
def create_output_panel(tabname, outdir, toprow=None):
+ res = OutputPanel()
def open_folder(f):
if not os.path.exists(f):
@@ -136,9 +194,8 @@ Requested path was: {f}
with gr.Column(variant='panel', elem_id=f"{tabname}_results_panel"):
with gr.Group(elem_id=f"{tabname}_gallery_container"):
- result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None)
+ res.gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None)
- generation_info = None
with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.")
@@ -152,6 +209,9 @@ Requested path was: {f}
'extras': ToolButton('📐', elem_id=f'{tabname}_send_to_extras', tooltip="Send image and generation parameters to extras tab.")
}
+ if tabname == 'txt2img':
+ res.button_upscale = ToolButton('✨', elem_id=f'{tabname}_upscale', tooltip="Create an upscaled version of the current image using hires fix settings.")
+
open_folder_button.click(
fn=lambda: open_folder(shared.opts.outdir_samples or outdir),
inputs=[],
@@ -162,17 +222,17 @@ Requested path was: {f}
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
with gr.Group():
- html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
- html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log")
+ res.infotext = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
+ res.html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log")
- generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
+ res.generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
if tabname == 'txt2img' or tabname == 'img2img':
generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
generation_info_button.click(
fn=update_generation_info,
_js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
- inputs=[generation_info, html_info, html_info],
- outputs=[html_info, html_info],
+ inputs=[res.generation_info, res.infotext, res.infotext],
+ outputs=[res.infotext, res.infotext],
show_progress=False,
)
@@ -180,14 +240,14 @@ Requested path was: {f}
fn=call_queue.wrap_gradio_call(save_files),
_js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
inputs=[
- generation_info,
- result_gallery,
- html_info,
- html_info,
+ res.generation_info,
+ res.gallery,
+ res.infotext,
+ res.infotext,
],
outputs=[
download_files,
- html_log,
+ res.html_log,
],
show_progress=False,
)
@@ -196,21 +256,21 @@ Requested path was: {f}
fn=call_queue.wrap_gradio_call(save_files),
_js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
inputs=[
- generation_info,
- result_gallery,
- html_info,
- html_info,
+ res.generation_info,
+ res.gallery,
+ res.infotext,
+ res.infotext,
],
outputs=[
download_files,
- html_log,
+ res.html_log,
]
)
else:
- html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
- html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
- html_log = gr.HTML(elem_id=f'html_log_{tabname}')
+ res.generation_info = gr.HTML(elem_id=f'html_info_x_{tabname}')
+ res.infotext = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
+ res.html_log = gr.HTML(elem_id=f'html_log_{tabname}')
paste_field_names = []
if tabname == "txt2img":
@@ -220,11 +280,11 @@ Requested path was: {f}
for paste_tabname, paste_button in buttons.items():
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
- paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery,
+ paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=res.gallery,
paste_field_names=paste_field_names
))
- return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
+ return res
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index fe5d3ba3..325d848e 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -2,23 +2,22 @@ import functools
import os.path
import urllib.parse
from pathlib import Path
+from typing import Optional, Union
+from dataclasses import dataclass
-from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks
+from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks, util
from modules.images import read_info_from_image, save_image_with_geninfo
import gradio as gr
import json
import html
from fastapi.exceptions import HTTPException
-from modules.generation_parameters_copypaste import image_from_url_text
-from modules.ui_components import ToolButton
+from modules.infotext_utils import image_from_url_text
extra_pages = []
allowed_dirs = set()
-
default_allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"]
-
@functools.cache
def allowed_preview_extensions_with_extra(extra_extensions=None):
return set(default_allowed_preview_extensions) | set(extra_extensions or [])
@@ -28,6 +27,62 @@ def allowed_preview_extensions():
return allowed_preview_extensions_with_extra((shared.opts.samples_format, ))
+@dataclass
+class ExtraNetworksItem:
+ """Wrapper for dictionaries representing ExtraNetworks items."""
+ item: dict
+
+
+def get_tree(paths: Union[str, list[str]], items: dict[str, ExtraNetworksItem]) -> dict:
+ """Recursively builds a directory tree.
+
+ Args:
+ paths: Path or list of paths to directories. These paths are treated as roots from which
+ the tree will be built.
+ items: A dictionary associating filepaths to an ExtraNetworksItem instance.
+
+ Returns:
+ The result directory tree.
+ """
+ if isinstance(paths, (str,)):
+ paths = [paths]
+
+ def _get_tree(_paths: list[str], _root: str):
+ _res = {}
+ for path in _paths:
+ relpath = os.path.relpath(path, _root)
+ if os.path.isdir(path):
+ dir_items = os.listdir(path)
+ # Ignore empty directories.
+ if not dir_items:
+ continue
+ dir_tree = _get_tree([os.path.join(path, x) for x in dir_items], _root)
+ # We only want to store non-empty folders in the tree.
+ if dir_tree:
+ _res[relpath] = dir_tree
+ else:
+ if path not in items:
+ continue
+ # Add the ExtraNetworksItem to the result.
+ _res[relpath] = items[path]
+ return _res
+
+ res = {}
+ # Handle each root directory separately.
+ # Each root WILL have a key/value at the root of the result dict though
+ # the value can be an empty dict if the directory is empty. We want these
+ # placeholders for empty dirs so we can inform the user later.
+ for path in paths:
+ root = os.path.dirname(path)
+ relpath = os.path.relpath(path, root)
+ # Wrap the path in a list since that is what the `_get_tree` expects.
+ res[relpath] = _get_tree([path], root)
+ if res[relpath]:
+ # We need to pull the inner path out one for these root dirs.
+ res[relpath] = res[relpath][relpath]
+
+ return res
+
def register_page(page):
"""registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
@@ -80,7 +135,7 @@ def get_single_card(page: str = "", tabname: str = "", name: str = ""):
item = page.items.get(name)
page.read_user_metadata(item)
- item_html = page.create_html_for_item(item, tabname)
+ item_html = page.create_item_html(tabname, item)
return JSONResponse({"html": item_html})
@@ -96,24 +151,31 @@ def quote_js(s):
s = s.replace('"', '\\"')
return f'"{s}"'
-
class ExtraNetworksPage:
def __init__(self, title):
self.title = title
self.name = title.lower()
- self.id_page = self.name.replace(" ", "_")
- self.card_page = shared.html("extra-networks-card.html")
+ # This is the actual name of the extra networks tab (not txt2img/img2img).
+ self.extra_networks_tabname = self.name.replace(" ", "_")
self.allow_prompt = True
self.allow_negative_prompt = False
self.metadata = {}
self.items = {}
+ self.lister = util.MassFileLister()
+ # HTML Templates
+ self.pane_tpl = shared.html("extra-networks-pane.html")
+ self.card_tpl = shared.html("extra-networks-card.html")
+ self.btn_tree_tpl = shared.html("extra-networks-tree-button.html")
+ self.btn_copy_path_tpl = shared.html("extra-networks-copy-path-button.html")
+ self.btn_metadata_tpl = shared.html("extra-networks-metadata-button.html")
+ self.btn_edit_item_tpl = shared.html("extra-networks-edit-item-button.html")
def refresh(self):
pass
def read_user_metadata(self, item):
filename = item.get("filename", None)
- metadata = extra_networks.get_user_metadata(filename)
+ metadata = extra_networks.get_user_metadata(filename, lister=self.lister)
desc = metadata.get("description", None)
if desc is not None:
@@ -123,117 +185,74 @@ class ExtraNetworksPage:
def link_preview(self, filename):
quoted_filename = urllib.parse.quote(filename.replace('\\', '/'))
- mtime = os.path.getmtime(filename)
+ mtime, _ = self.lister.mctime(filename)
return f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}"
def search_terms_from_path(self, filename, possible_directories=None):
abspath = os.path.abspath(filename)
-
for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()):
- parentdir = os.path.abspath(parentdir)
+ parentdir = os.path.dirname(os.path.abspath(parentdir))
if abspath.startswith(parentdir):
- return abspath[len(parentdir):].replace('\\', '/')
+ return os.path.relpath(abspath, parentdir)
return ""
- def create_html(self, tabname):
- items_html = ''
-
- self.metadata = {}
-
- subdirs = {}
- for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
- for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])):
- for dirname in sorted(dirs, key=shared.natural_sort_key):
- x = os.path.join(root, dirname)
-
- if not os.path.isdir(x):
- continue
-
- subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
-
- if shared.opts.extra_networks_dir_button_function:
- if not subdir.startswith("/"):
- subdir = "/" + subdir
- else:
- while subdir.startswith("/"):
- subdir = subdir[1:]
-
- is_empty = len(os.listdir(x)) == 0
- if not is_empty and not subdir.endswith("/"):
- subdir = subdir + "/"
-
- if ("/." in subdir or subdir.startswith(".")) and not shared.opts.extra_networks_show_hidden_directories:
- continue
-
- subdirs[subdir] = 1
-
- if subdirs:
- subdirs = {"": 1, **subdirs}
-
- subdirs_html = "".join([f"""
-<button class='lg secondary gradio-button custom-button{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_search", event)'>
-{html.escape(subdir if subdir!="" else "all")}
-</button>
-""" for subdir in subdirs])
-
- self.items = {x["name"]: x for x in self.list_items()}
- for item in self.items.values():
- metadata = item.get("metadata")
- if metadata:
- self.metadata[item["name"]] = metadata
-
- if "user_metadata" not in item:
- self.read_user_metadata(item)
-
- items_html += self.create_html_for_item(item, tabname)
-
- if items_html == '':
- dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()])
- items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
-
- self_name_id = self.name.replace(" ", "_")
-
- res = f"""
-<div id='{tabname}_{self_name_id}_subdirs' class='extra-network-subdirs extra-network-subdirs-cards'>
-{subdirs_html}
-</div>
-<div id='{tabname}_{self_name_id}_cards' class='extra-network-cards'>
-{items_html}
-</div>
-"""
-
- return res
-
- def create_item(self, name, index=None):
- raise NotImplementedError()
-
- def list_items(self):
- raise NotImplementedError()
-
- def allowed_directories_for_previews(self):
- return []
-
- def create_html_for_item(self, item, tabname):
+ def create_item_html(
+ self,
+ tabname: str,
+ item: dict,
+ template: Optional[str] = None,
+ ) -> Union[str, dict]:
+ """Generates HTML for a single ExtraNetworks Item.
+
+ Args:
+ tabname: The name of the active tab.
+ item: Dictionary containing item information.
+ template: Optional template string to use.
+
+ Returns:
+ If a template is passed: HTML string generated for this item.
+ Can be empty if the item is not meant to be shown.
+ If no template is passed: A dictionary containing the generated item's attributes.
"""
- Create HTML for card item in tab tabname; can return empty string if the item is not meant to be shown.
- """
-
preview = item.get("preview", None)
+ style_height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
+ style_width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
+ style_font_size = f"font-size: {shared.opts.extra_networks_card_text_scale*100}%;"
+ card_style = style_height + style_width + style_font_size
+ background_image = f'<img src="{html.escape(preview)}" class="preview" loading="lazy">' if preview else ''
onclick = item.get("onclick", None)
if onclick is None:
- onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
-
- height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
- width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
- background_image = f'<img src="{html.escape(preview)}" class="preview" loading="lazy">' if preview else ''
- metadata_button = ""
+ # Don't quote prompt/neg_prompt since they are stored as js strings already.
+ onclick_js_tpl = "cardClicked('{tabname}', {prompt}, {neg_prompt}, {allow_neg});"
+ onclick = onclick_js_tpl.format(
+ **{
+ "tabname": tabname,
+ "prompt": item["prompt"],
+ "neg_prompt": item.get("negative_prompt", "''"),
+ "allow_neg": str(self.allow_negative_prompt).lower(),
+ }
+ )
+ onclick = html.escape(onclick)
+
+ btn_copy_path = self.btn_copy_path_tpl.format(**{"filename": item["filename"]})
+ btn_metadata = ""
metadata = item.get("metadata")
if metadata:
- metadata_button = f"<div class='metadata-button card-button' title='Show internal metadata' onclick='extraNetworksRequestMetadata(event, {quote_js(self.name)}, {quote_js(html.escape(item['name']))})'></div>"
-
- edit_button = f"<div class='edit-button card-button' title='Edit metadata' onclick='extraNetworksEditUserMetadata(event, {quote_js(tabname)}, {quote_js(self.id_page)}, {quote_js(html.escape(item['name']))})'></div>"
+ btn_metadata = self.btn_metadata_tpl.format(
+ **{
+ "extra_networks_tabname": self.extra_networks_tabname,
+ "name": html.escape(item["name"]),
+ }
+ )
+ btn_edit_item = self.btn_edit_item_tpl.format(
+ **{
+ "tabname": tabname,
+ "extra_networks_tabname": self.extra_networks_tabname,
+ "name": html.escape(item["name"]),
+ }
+ )
local_path = ""
filename = item.get("filename", "")
@@ -253,36 +272,292 @@ class ExtraNetworksPage:
if search_only and shared.opts.extra_networks_hidden_models == "Never":
return ""
- sort_keys = " ".join([f'data-sort-{k}="{html.escape(str(v))}"' for k, v in item.get("sort_keys", {}).items()]).strip()
-
+ sort_keys = " ".join(
+ [
+ f'data-sort-{k}="{html.escape(str(v))}"'
+ for k, v in item.get("sort_keys", {}).items()
+ ]
+ ).strip()
+
+ search_terms_html = ""
+ search_term_template = "<span class='hidden {class}'>{search_term}</span>"
+ for search_term in item.get("search_terms", []):
+ search_terms_html += search_term_template.format(
+ **{
+ "class": f"search_terms{' search_only' if search_only else ''}",
+ "search_term": search_term,
+ }
+ )
+
+ # Some items here might not be used depending on HTML template used.
args = {
"background_image": background_image,
- "style": f"'display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%'",
- "prompt": item.get("prompt", None),
- "tabname": quote_js(tabname),
+ "card_clicked": onclick,
+ "copy_path_button": btn_copy_path,
+ "description": (item.get("description", "") or "" if shared.opts.extra_networks_card_show_desc else ""),
+ "edit_button": btn_edit_item,
"local_preview": quote_js(item["local_preview"]),
+ "metadata_button": btn_metadata,
"name": html.escape(item["name"]),
- "description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""),
- "card_clicked": onclick,
- "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"',
- "search_term": item.get("search_term", ""),
- "metadata_button": metadata_button,
- "edit_button": edit_button,
+ "prompt": item.get("prompt", None),
+ "save_card_preview": html.escape(f"return saveCardPreview(event, '{tabname}', '{item['local_preview']}');"),
"search_only": " search_only" if search_only else "",
+ "search_terms": search_terms_html,
"sort_keys": sort_keys,
+ "style": card_style,
+ "tabname": tabname,
+ "extra_networks_tabname": self.extra_networks_tabname,
}
- return self.card_page.format(**args)
+ if template:
+ return template.format(**args)
+ else:
+ return args
+
+ def create_tree_dir_item_html(
+ self,
+ tabname: str,
+ dir_path: str,
+ content: Optional[str] = None,
+ ) -> Optional[str]:
+ """Generates HTML for a directory item in the tree.
+
+ The generated HTML is of the format:
+ ```html
+ <li class="tree-list-item tree-list-item--has-subitem">
+ <div class="tree-list-content tree-list-content-dir"></div>
+ <ul class="tree-list tree-list--subgroup">
+ {content}
+ </ul>
+ </li>
+ ```
+
+ Args:
+ tabname: The name of the active tab.
+ dir_path: Path to the directory for this item.
+ content: Optional HTML string that will be wrapped by this <ul>.
+
+ Returns:
+ HTML formatted string.
+ """
+ if not content:
+ return None
+
+ btn = self.btn_tree_tpl.format(
+ **{
+ "search_terms": "",
+ "subclass": "tree-list-content-dir",
+ "tabname": tabname,
+ "extra_networks_tabname": self.extra_networks_tabname,
+ "onclick_extra": "",
+ "data_path": dir_path,
+ "data_hash": "",
+ "action_list_item_action_leading": "<i class='tree-list-item-action-chevron'></i>",
+ "action_list_item_visual_leading": "🗀",
+ "action_list_item_label": os.path.basename(dir_path),
+ "action_list_item_visual_trailing": "",
+ "action_list_item_action_trailing": "",
+ }
+ )
+ ul = f"<ul class='tree-list tree-list--subgroup' hidden>{content}</ul>"
+ return (
+ "<li class='tree-list-item tree-list-item--has-subitem' data-tree-entry-type='dir'>"
+ f"{btn}{ul}"
+ "</li>"
+ )
+
+ def create_tree_file_item_html(self, tabname: str, file_path: str, item: dict) -> str:
+ """Generates HTML for a file item in the tree.
+
+ The generated HTML is of the format:
+ ```html
+ <li class="tree-list-item tree-list-item--subitem">
+ <span data-filterable-item-text hidden></span>
+ <div class="tree-list-content tree-list-content-file"></div>
+ </li>
+ ```
+
+ Args:
+ tabname: The name of the active tab.
+ file_path: The path to the file for this item.
+ item: Dictionary containing the item information.
+
+ Returns:
+ HTML formatted string.
+ """
+ item_html_args = self.create_item_html(tabname, item)
+ action_buttons = "".join(
+ [
+ item_html_args["copy_path_button"],
+ item_html_args["metadata_button"],
+ item_html_args["edit_button"],
+ ]
+ )
+ action_buttons = f"<div class=\"button-row\">{action_buttons}</div>"
+ btn = self.btn_tree_tpl.format(
+ **{
+ "search_terms": "",
+ "subclass": "tree-list-content-file",
+ "tabname": tabname,
+ "extra_networks_tabname": self.extra_networks_tabname,
+ "onclick_extra": item_html_args["card_clicked"],
+ "data_path": file_path,
+ "data_hash": item["shorthash"],
+ "action_list_item_action_leading": "<i class='tree-list-item-action-chevron'></i>",
+ "action_list_item_visual_leading": "🗎",
+ "action_list_item_label": item["name"],
+ "action_list_item_visual_trailing": "",
+ "action_list_item_action_trailing": action_buttons,
+ }
+ )
+ return (
+ "<li class='tree-list-item tree-list-item--subitem' data-tree-entry-type='file'>"
+ f"{btn}"
+ "</li>"
+ )
+
+ def create_tree_view_html(self, tabname: str) -> str:
+ """Generates HTML for displaying folders in a tree view.
+
+ Args:
+ tabname: The name of the active tab.
+
+ Returns:
+ HTML string generated for this tree view.
+ """
+ res = ""
+
+ # Setup the tree dictionary.
+ roots = self.allowed_directories_for_previews()
+ tree_items = {v["filename"]: ExtraNetworksItem(v) for v in self.items.values()}
+ tree = get_tree([os.path.abspath(x) for x in roots], items=tree_items)
+
+ if not tree:
+ return res
+
+ def _build_tree(data: Optional[dict[str, ExtraNetworksItem]] = None) -> Optional[str]:
+ """Recursively builds HTML for a tree.
+
+ Args:
+ data: Dictionary representing a directory tree. Can be NoneType.
+ Data keys should be absolute paths from the root and values
+ should be subdirectory trees or an ExtraNetworksItem.
+
+ Returns:
+ If data is not None: HTML string
+ Else: None
+ """
+ if not data:
+ return None
+
+ # Lists for storing <li> items html for directories and files separately.
+ _dir_li = []
+ _file_li = []
+
+ for k, v in sorted(data.items(), key=lambda x: shared.natural_sort_key(x[0])):
+ if isinstance(v, (ExtraNetworksItem,)):
+ _file_li.append(self.create_tree_file_item_html(tabname, k, v.item))
+ else:
+ _dir_li.append(self.create_tree_dir_item_html(tabname, k, _build_tree(v)))
+
+ # Directories should always be displayed before files so we order them here.
+ return "".join(_dir_li) + "".join(_file_li)
+
+ # Add each root directory to the tree.
+ for k, v in sorted(tree.items(), key=lambda x: shared.natural_sort_key(x[0])):
+ item_html = self.create_tree_dir_item_html(tabname, k, _build_tree(v))
+ # Only add non-empty entries to the tree.
+ if item_html is not None:
+ res += item_html
+
+ return f"<ul class='tree-list tree-list--tree'>{res}</ul>"
+
+ def create_card_view_html(self, tabname: str) -> str:
+ """Generates HTML for the network Card View section for a tab.
+
+ This HTML goes into the `extra-networks-pane.html` <div> with
+ `id='{tabname}_{extra_networks_tabname}_cards`.
+
+ Args:
+ tabname: The name of the active tab.
+
+ Returns:
+ HTML formatted string.
+ """
+ res = ""
+ for item in self.items.values():
+ res += self.create_item_html(tabname, item, self.card_tpl)
+
+ if res == "":
+ dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()])
+ res = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
+
+ return res
+
+ def create_html(self, tabname):
+ """Generates an HTML string for the current pane.
+
+ The generated HTML uses `extra-networks-pane.html` as a template.
+
+ Args:
+ tabname: The name of the active tab.
+
+ Returns:
+ HTML formatted string.
+ """
+ self.lister.reset()
+ self.metadata = {}
+ self.items = {x["name"]: x for x in self.list_items()}
+ # Populate the instance metadata for each item.
+ for item in self.items.values():
+ metadata = item.get("metadata")
+ if metadata:
+ self.metadata[item["name"]] = metadata
+
+ if "user_metadata" not in item:
+ self.read_user_metadata(item)
+
+ data_sortdir = shared.opts.extra_networks_card_order
+ data_sortmode = shared.opts.extra_networks_card_order_field.lower().replace("sort", "").replace(" ", "_").rstrip("_").strip()
+ data_sortkey = f"{data_sortmode}-{data_sortdir}-{len(self.items)}"
+ tree_view_btn_extra_class = ""
+ tree_view_div_extra_class = "hidden"
+ if shared.opts.extra_networks_tree_view_default_enabled:
+ tree_view_btn_extra_class = "extra-network-control--enabled"
+ tree_view_div_extra_class = ""
+
+ return self.pane_tpl.format(
+ **{
+ "tabname": tabname,
+ "extra_networks_tabname": self.extra_networks_tabname,
+ "data_sortmode": data_sortmode,
+ "data_sortkey": data_sortkey,
+ "data_sortdir": data_sortdir,
+ "tree_view_btn_extra_class": tree_view_btn_extra_class,
+ "tree_view_div_extra_class": tree_view_div_extra_class,
+ "tree_html": self.create_tree_view_html(tabname),
+ "items_html": self.create_card_view_html(tabname),
+ }
+ )
+
+ def create_item(self, name, index=None):
+ raise NotImplementedError()
+
+ def list_items(self):
+ raise NotImplementedError()
+
+ def allowed_directories_for_previews(self):
+ return []
def get_sort_keys(self, path):
"""
List of default keys used for sorting in the UI.
"""
pth = Path(path)
- stat = pth.stat()
+ mtime, ctime = self.lister.mctime(path)
return {
- "date_created": int(stat.st_ctime or 0),
- "date_modified": int(stat.st_mtime or 0),
+ "date_created": int(mtime),
+ "date_modified": int(ctime),
"name": pth.name.lower(),
"path": str(pth.parent).lower(),
}
@@ -292,10 +567,10 @@ class ExtraNetworksPage:
Find a preview PNG for a given path (without extension) and call link_preview on it.
"""
- potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in allowed_preview_extensions()], [])
+ potential_files = sum([[f"{path}.{ext}", f"{path}.preview.{ext}"] for ext in allowed_preview_extensions()], [])
for file in potential_files:
- if os.path.isfile(file):
+ if self.lister.exists(file):
return self.link_preview(file)
return None
@@ -305,6 +580,9 @@ class ExtraNetworksPage:
Find and read a description file for a given path (without extension).
"""
for file in [f"{path}.txt", f"{path}.description.txt"]:
+ if not self.lister.exists(file):
+ continue
+
try:
with open(file, "r", encoding="utf-8", errors="replace") as f:
return f.read()
@@ -360,10 +638,7 @@ def pages_in_preferred_order(pages):
return sorted(pages, key=lambda x: tab_scores[x.name])
-
def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
- from modules.ui import switch_values_symbol
-
ui = ExtraNetworksUi()
ui.pages = []
ui.pages_contents = []
@@ -373,62 +648,53 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
related_tabs = []
+ button_refresh = gr.Button("Refresh", elem_id=f"{tabname}_extra_refresh_internal", visible=False)
+
for page in ui.stored_extra_pages:
- with gr.Tab(page.title, elem_id=f"{tabname}_{page.id_page}", elem_classes=["extra-page"]) as tab:
- with gr.Column(elem_id=f"{tabname}_{page.id_page}_prompts", elem_classes=["extra-page-prompts"]):
+ with gr.Tab(page.title, elem_id=f"{tabname}_{page.extra_networks_tabname}", elem_classes=["extra-page"]) as tab:
+ with gr.Column(elem_id=f"{tabname}_{page.extra_networks_tabname}_prompts", elem_classes=["extra-page-prompts"]):
pass
- elem_id = f"{tabname}_{page.id_page}_cards_html"
+ elem_id = f"{tabname}_{page.extra_networks_tabname}_cards_html"
page_elem = gr.HTML('Loading...', elem_id=elem_id)
ui.pages.append(page_elem)
-
- page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[])
-
editor = page.create_user_metadata_editor(ui, tabname)
editor.create_ui()
ui.user_metadata_editors.append(editor)
-
related_tabs.append(tab)
- edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True)
- dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
- button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order")
- button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False)
- checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False)
-
- ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
- ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
-
- tab_controls = [edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs]
+ ui.button_save_preview = gr.Button('Save preview', elem_id=f"{tabname}_save_preview", visible=False)
+ ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=f"{tabname}_preview_filename", visible=False)
for tab in unrelated_tabs:
- tab.select(fn=lambda: [gr.update(visible=False) for _ in tab_controls], _js='function(){ extraNetworksUrelatedTabSelected("' + tabname + '"); }', inputs=[], outputs=tab_controls, show_progress=False)
+ tab.select(fn=None, _js=f"function(){{extraNetworksUnrelatedTabSelected('{tabname}');}}", inputs=[], outputs=[], show_progress=False)
for page, tab in zip(ui.stored_extra_pages, related_tabs):
- allow_prompt = "true" if page.allow_prompt else "false"
- allow_negative_prompt = "true" if page.allow_negative_prompt else "false"
-
- jscode = 'extraNetworksTabSelected("' + tabname + '", "' + f"{tabname}_{page.id_page}_prompts" + '", ' + allow_prompt + ', ' + allow_negative_prompt + ');'
-
- tab.select(fn=lambda: [gr.update(visible=True) for _ in tab_controls], _js='function(){ ' + jscode + ' }', inputs=[], outputs=tab_controls, show_progress=False)
-
- dropdown_sort.change(fn=lambda: None, _js="function(){ applyExtraNetworkSort('" + tabname + "'); }")
+ jscode = (
+ "function(){{"
+ f"extraNetworksTabSelected('{tabname}', '{tabname}_{page.extra_networks_tabname}_prompts', {str(page.allow_prompt).lower()}, {str(page.allow_negative_prompt).lower()}, '{tabname}_{page.extra_networks_tabname}');"
+ f"applyExtraNetworkFilter('{tabname}_{page.extra_networks_tabname}');"
+ "}}"
+ )
+ tab.select(fn=None, _js=jscode, inputs=[], outputs=[], show_progress=False)
+
+ def create_html():
+ ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages]
def pages_html():
if not ui.pages_contents:
- return refresh()
-
+ create_html()
return ui.pages_contents
def refresh():
for pg in ui.stored_extra_pages:
pg.refresh()
-
- ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages]
-
+ create_html()
return ui.pages_contents
- interface.load(fn=pages_html, inputs=[], outputs=[*ui.pages])
+ interface.load(fn=pages_html, inputs=[], outputs=ui.pages)
+ # NOTE: Event is manually fired in extraNetworks.js:extraNetworksTreeRefreshOnClick()
+ # button is unused and hidden at all times. Only used in order to fire this event.
button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
return ui
@@ -478,5 +744,3 @@ def setup_ui(ui, gallery):
for editor in ui.user_metadata_editors:
editor.setup_ui(gallery)
-
-
diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py
index 1693e71f..a8c33671 100644
--- a/modules/ui_extra_networks_checkpoints.py
+++ b/modules/ui_extra_networks_checkpoints.py
@@ -2,7 +2,6 @@ import html
import os
from modules import shared, ui_extra_networks, sd_models
-from modules.ui_extra_networks import quote_js
from modules.ui_extra_networks_checkpoints_user_metadata import CheckpointUserMetadataEditor
@@ -21,14 +20,17 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
return
path, ext = os.path.splitext(checkpoint.filename)
+ search_terms = [self.search_terms_from_path(checkpoint.filename)]
+ if checkpoint.sha256:
+ search_terms.append(checkpoint.sha256)
return {
"name": checkpoint.name_for_extra,
"filename": checkpoint.filename,
"shorthash": checkpoint.shorthash,
"preview": self.find_preview(path),
"description": self.find_description(path),
- "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
- "onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"',
+ "search_terms": search_terms,
+ "onclick": html.escape(f"return selectCheckpoint('{name}');"),
"local_preview": f"{path}.{shared.opts.samples_format}",
"metadata": checkpoint.metadata,
"sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)},
diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py
index c96c4fa3..2fb4bd19 100644
--- a/modules/ui_extra_networks_hypernets.py
+++ b/modules/ui_extra_networks_hypernets.py
@@ -20,14 +20,16 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
path, ext = os.path.splitext(full_path)
sha256 = sha256_from_cache(full_path, f'hypernet/{name}')
shorthash = sha256[0:10] if sha256 else None
-
+ search_terms = [self.search_terms_from_path(path)]
+ if sha256:
+ search_terms.append(sha256)
return {
"name": name,
"filename": full_path,
"shorthash": shorthash,
"preview": self.find_preview(path),
"description": self.find_description(path),
- "search_term": self.search_terms_from_path(path) + " " + (sha256 or ""),
+ "search_terms": search_terms,
"prompt": quote_js(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + quote_js(">"),
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
"sort_keys": {'default': index, **self.get_sort_keys(path + ext)},
diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py
index 1b334fda..deb7cb87 100644
--- a/modules/ui_extra_networks_textual_inversion.py
+++ b/modules/ui_extra_networks_textual_inversion.py
@@ -18,13 +18,16 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
return
path, ext = os.path.splitext(embedding.filename)
+ search_terms = [self.search_terms_from_path(embedding.filename)]
+ if embedding.hash:
+ search_terms.append(embedding.hash)
return {
"name": name,
"filename": embedding.filename,
"shorthash": embedding.shorthash,
"preview": self.find_preview(path),
"description": self.find_description(path),
- "search_term": self.search_terms_from_path(embedding.filename) + " " + (embedding.hash or ""),
+ "search_terms": search_terms,
"prompt": quote_js(embedding.name),
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
"sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)},
diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py
index 36a807fc..2ca937fd 100644
--- a/modules/ui_extra_networks_user_metadata.py
+++ b/modules/ui_extra_networks_user_metadata.py
@@ -5,7 +5,7 @@ import os.path
import gradio as gr
-from modules import generation_parameters_copypaste, images, sysinfo, errors, ui_extra_networks
+from modules import infotext_utils, images, sysinfo, errors, ui_extra_networks
class UserMetadataEditor:
@@ -14,7 +14,7 @@ class UserMetadataEditor:
self.ui = ui
self.tabname = tabname
self.page = page
- self.id_part = f"{self.tabname}_{self.page.id_page}_edit_user_metadata"
+ self.id_part = f"{self.tabname}_{self.page.extra_networks_tabname}_edit_user_metadata"
self.box = None
@@ -181,7 +181,7 @@ class UserMetadataEditor:
index = len(gallery) - 1 if index >= len(gallery) else index
img_info = gallery[index if index >= 0 else 0]
- image = generation_parameters_copypaste.image_from_url_text(img_info)
+ image = infotext_utils.image_from_url_text(img_info)
geninfo, items = images.read_info_from_image(image)
images.save_image_with_geninfo(image, geninfo, item["local_preview"])
diff --git a/modules/ui_gradio_extensions.py b/modules/ui_gradio_extensions.py
index 0d368f8b..f5278d22 100644
--- a/modules/ui_gradio_extensions.py
+++ b/modules/ui_gradio_extensions.py
@@ -1,17 +1,12 @@
import os
import gradio as gr
-from modules import localization, shared, scripts
-from modules.paths import script_path, data_path, cwd
+from modules import localization, shared, scripts, util
+from modules.paths import script_path, data_path
def webpath(fn):
- if fn.startswith(cwd):
- web_path = os.path.relpath(fn, cwd)
- else:
- web_path = os.path.abspath(fn)
-
- return f'file={web_path}?{os.path.getmtime(fn)}'
+ return f'file={util.truncate_path(fn)}?{os.path.getmtime(fn)}'
def javascript_html():
@@ -40,13 +35,11 @@ def css_html():
return f'<link rel="stylesheet" property="stylesheet" href="{webpath(fn)}">'
for cssfile in scripts.list_files_with_name("style.css"):
- if not os.path.isfile(cssfile):
- continue
-
head += stylesheet(cssfile)
- if os.path.exists(os.path.join(data_path, "user.css")):
- head += stylesheet(os.path.join(data_path, "user.css"))
+ user_css = os.path.join(data_path, "user.css")
+ if os.path.exists(user_css):
+ head += stylesheet(user_css)
return head
diff --git a/modules/ui_loadsave.py b/modules/ui_loadsave.py
index 7826786c..2555cdb6 100644
--- a/modules/ui_loadsave.py
+++ b/modules/ui_loadsave.py
@@ -26,8 +26,9 @@ class UiLoadsave:
self.ui_defaults_review = None
try:
- if os.path.exists(self.filename):
- self.ui_settings = self.read_from_file()
+ self.ui_settings = self.read_from_file()
+ except FileNotFoundError:
+ pass
except Exception as e:
self.error_loading = True
errors.display(e, "loading settings")
@@ -144,7 +145,7 @@ class UiLoadsave:
json.dump(current_ui_settings, file, indent=4, ensure_ascii=False)
def dump_defaults(self):
- """saves default values to a file unless tjhe file is present and there was an error loading default values at start"""
+ """saves default values to a file unless the file is present and there was an error loading default values at start"""
if self.error_loading and os.path.exists(self.filename):
return
diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py
index 13d888e4..7261c2df 100644
--- a/modules/ui_postprocessing.py
+++ b/modules/ui_postprocessing.py
@@ -1,13 +1,14 @@
import gradio as gr
from modules import scripts, shared, ui_common, postprocessing, call_queue, ui_toprow
-import modules.generation_parameters_copypaste as parameters_copypaste
+import modules.infotext_utils as parameters_copypaste
+from modules.ui_components import ResizeHandleRow
def create_ui():
dummy_component = gr.Label(visible=False)
- tab_index = gr.State(value=0)
+ tab_index = gr.Number(value=0, visible=False)
- with gr.Row(equal_height=False, variant='compact'):
+ with ResizeHandleRow(equal_height=False, variant='compact'):
with gr.Column(variant='compact'):
with gr.Tabs(elem_id="mode_extras"):
with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:
@@ -28,7 +29,7 @@ def create_ui():
toprow.create_inline_toprow_image()
submit = toprow.submit
- result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples)
+ output_panel = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples)
tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index])
tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index])
@@ -48,9 +49,9 @@ def create_ui():
*script_inputs
],
outputs=[
- result_images,
- html_info_x,
- html_log,
+ output_panel.gallery,
+ output_panel.generation_info,
+ output_panel.html_log,
],
show_progress=False,
)
diff --git a/modules/ui_prompt_styles.py b/modules/ui_prompt_styles.py
index 0d74c23f..d67e3f17 100644
--- a/modules/ui_prompt_styles.py
+++ b/modules/ui_prompt_styles.py
@@ -22,9 +22,12 @@ def save_style(name, prompt, negative_prompt):
if not name:
return gr.update(visible=False)
- style = styles.PromptStyle(name, prompt, negative_prompt)
+ existing_style = shared.prompt_styles.styles.get(name)
+ path = existing_style.path if existing_style is not None else None
+
+ style = styles.PromptStyle(name, prompt, negative_prompt, path)
shared.prompt_styles.styles[style.name] = style
- shared.prompt_styles.save_styles(shared.styles_filename)
+ shared.prompt_styles.save_styles()
return gr.update(visible=True)
@@ -34,7 +37,7 @@ def delete_style(name):
return
shared.prompt_styles.styles.pop(name, None)
- shared.prompt_styles.save_styles(shared.styles_filename)
+ shared.prompt_styles.save_styles()
return '', '', ''
diff --git a/modules/ui_toprow.py b/modules/ui_toprow.py
index 88838f97..fbe705be 100644
--- a/modules/ui_toprow.py
+++ b/modules/ui_toprow.py
@@ -79,11 +79,11 @@ class Toprow:
def create_prompts(self):
with gr.Column(elem_id=f"{self.id_part}_prompt_container", elem_classes=["prompt-container-compact"] if self.is_compact else [], scale=6):
with gr.Row(elem_id=f"{self.id_part}_prompt_row", elem_classes=["prompt-row"]):
- self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
+ self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Esc to interrupt)", elem_classes=["prompt"])
self.prompt_img = gr.File(label="", elem_id=f"{self.id_part}_prompt_image", file_count="single", type="binary", visible=False)
with gr.Row(elem_id=f"{self.id_part}_neg_prompt_row", elem_classes=["prompt-row"]):
- self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
+ self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Esc to interrupt)", elem_classes=["prompt"])
self.prompt_img.change(
fn=modules.images.image_data,
@@ -106,8 +106,15 @@ class Toprow:
outputs=[],
)
+ def interrupt_function():
+ if not shared.state.stopping_generation and shared.state.job_count > 1 and shared.opts.interrupt_after_current:
+ shared.state.stop_generating()
+ gr.Info("Generation will stop after finishing this image, click again to stop immediately.")
+ else:
+ shared.state.interrupt()
+
self.interrupt.click(
- fn=lambda: shared.state.interrupt(),
+ fn=interrupt_function,
inputs=[],
outputs=[],
)
diff --git a/modules/upscaler.py b/modules/upscaler.py
index b256e085..3aee69db 100644
--- a/modules/upscaler.py
+++ b/modules/upscaler.py
@@ -98,6 +98,9 @@ class UpscalerData:
self.scale = scale
self.model = model
+ def __repr__(self):
+ return f"<UpscalerData name={self.name} path={self.data_path} scale={self.scale}>"
+
class UpscalerNone(Upscaler):
name = "None"
diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py
new file mode 100644
index 00000000..afed8b40
--- /dev/null
+++ b/modules/upscaler_utils.py
@@ -0,0 +1,189 @@
+import logging
+from typing import Callable
+
+import numpy as np
+import torch
+import tqdm
+from PIL import Image
+
+from modules import images, shared, torch_utils
+
+logger = logging.getLogger(__name__)
+
+
+def pil_image_to_torch_bgr(img: Image.Image) -> torch.Tensor:
+ img = np.array(img.convert("RGB"))
+ img = img[:, :, ::-1] # flip RGB to BGR
+ img = np.transpose(img, (2, 0, 1)) # HWC to CHW
+ img = np.ascontiguousarray(img) / 255 # Rescale to [0, 1]
+ return torch.from_numpy(img)
+
+
+def torch_bgr_to_pil_image(tensor: torch.Tensor) -> Image.Image:
+ if tensor.ndim == 4:
+ # If we're given a tensor with a batch dimension, squeeze it out
+ # (but only if it's a batch of size 1).
+ if tensor.shape[0] != 1:
+ raise ValueError(f"{tensor.shape} does not describe a BCHW tensor")
+ tensor = tensor.squeeze(0)
+ assert tensor.ndim == 3, f"{tensor.shape} does not describe a CHW tensor"
+ # TODO: is `tensor.float().cpu()...numpy()` the most efficient idiom?
+ arr = tensor.float().cpu().clamp_(0, 1).numpy() # clamp
+ arr = 255.0 * np.moveaxis(arr, 0, 2) # CHW to HWC, rescale
+ arr = arr.round().astype(np.uint8)
+ arr = arr[:, :, ::-1] # flip BGR to RGB
+ return Image.fromarray(arr, "RGB")
+
+
+def upscale_pil_patch(model, img: Image.Image) -> Image.Image:
+ """
+ Upscale a given PIL image using the given model.
+ """
+ param = torch_utils.get_param(model)
+
+ with torch.no_grad():
+ tensor = pil_image_to_torch_bgr(img).unsqueeze(0) # add batch dimension
+ tensor = tensor.to(device=param.device, dtype=param.dtype)
+ return torch_bgr_to_pil_image(model(tensor))
+
+
+def upscale_with_model(
+ model: Callable[[torch.Tensor], torch.Tensor],
+ img: Image.Image,
+ *,
+ tile_size: int,
+ tile_overlap: int = 0,
+ desc="tiled upscale",
+) -> Image.Image:
+ if tile_size <= 0:
+ logger.debug("Upscaling %s without tiling", img)
+ output = upscale_pil_patch(model, img)
+ logger.debug("=> %s", output)
+ return output
+
+ grid = images.split_grid(img, tile_size, tile_size, tile_overlap)
+ newtiles = []
+
+ with tqdm.tqdm(total=grid.tile_count, desc=desc, disable=not shared.opts.enable_upscale_progressbar) as p:
+ for y, h, row in grid.tiles:
+ newrow = []
+ for x, w, tile in row:
+ logger.debug("Tile (%d, %d) %s...", x, y, tile)
+ output = upscale_pil_patch(model, tile)
+ scale_factor = output.width // tile.width
+ logger.debug("=> %s (scale factor %s)", output, scale_factor)
+ newrow.append([x * scale_factor, w * scale_factor, output])
+ p.update(1)
+ newtiles.append([y * scale_factor, h * scale_factor, newrow])
+
+ newgrid = images.Grid(
+ newtiles,
+ tile_w=grid.tile_w * scale_factor,
+ tile_h=grid.tile_h * scale_factor,
+ image_w=grid.image_w * scale_factor,
+ image_h=grid.image_h * scale_factor,
+ overlap=grid.overlap * scale_factor,
+ )
+ return images.combine_grid(newgrid)
+
+
+def tiled_upscale_2(
+ img: torch.Tensor,
+ model,
+ *,
+ tile_size: int,
+ tile_overlap: int,
+ scale: int,
+ device: torch.device,
+ desc="Tiled upscale",
+):
+ # Alternative implementation of `upscale_with_model` originally used by
+ # SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and
+ # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
+ # Pillow space without weighting.
+
+ b, c, h, w = img.size()
+ tile_size = min(tile_size, h, w)
+
+ if tile_size <= 0:
+ logger.debug("Upscaling %s without tiling", img.shape)
+ return model(img)
+
+ stride = tile_size - tile_overlap
+ h_idx_list = list(range(0, h - tile_size, stride)) + [h - tile_size]
+ w_idx_list = list(range(0, w - tile_size, stride)) + [w - tile_size]
+ result = torch.zeros(
+ b,
+ c,
+ h * scale,
+ w * scale,
+ device=device,
+ dtype=img.dtype,
+ )
+ weights = torch.zeros_like(result)
+ logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape)
+ with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc, disable=not shared.opts.enable_upscale_progressbar) as pbar:
+ for h_idx in h_idx_list:
+ if shared.state.interrupted or shared.state.skipped:
+ break
+
+ for w_idx in w_idx_list:
+ if shared.state.interrupted or shared.state.skipped:
+ break
+
+ # Only move this patch to the device if it's not already there.
+ in_patch = img[
+ ...,
+ h_idx : h_idx + tile_size,
+ w_idx : w_idx + tile_size,
+ ].to(device=device)
+
+ out_patch = model(in_patch)
+
+ result[
+ ...,
+ h_idx * scale : (h_idx + tile_size) * scale,
+ w_idx * scale : (w_idx + tile_size) * scale,
+ ].add_(out_patch)
+
+ out_patch_mask = torch.ones_like(out_patch)
+
+ weights[
+ ...,
+ h_idx * scale : (h_idx + tile_size) * scale,
+ w_idx * scale : (w_idx + tile_size) * scale,
+ ].add_(out_patch_mask)
+
+ pbar.update(1)
+
+ output = result.div_(weights)
+
+ return output
+
+
+def upscale_2(
+ img: Image.Image,
+ model,
+ *,
+ tile_size: int,
+ tile_overlap: int,
+ scale: int,
+ desc: str,
+):
+ """
+ Convenience wrapper around `tiled_upscale_2` that handles PIL images.
+ """
+ param = torch_utils.get_param(model)
+ tensor = pil_image_to_torch_bgr(img).to(dtype=param.dtype).unsqueeze(0) # add batch dimension
+
+ with torch.no_grad():
+ output = tiled_upscale_2(
+ tensor,
+ model,
+ tile_size=tile_size,
+ tile_overlap=tile_overlap,
+ scale=scale,
+ desc=desc,
+ device=param.device,
+ )
+ return torch_bgr_to_pil_image(output)
diff --git a/modules/util.py b/modules/util.py
index 60afc067..ee373e92 100644
--- a/modules/util.py
+++ b/modules/util.py
@@ -2,7 +2,7 @@ import os
import re
from modules import shared
-from modules.paths_internal import script_path
+from modules.paths_internal import script_path, cwd
def natural_sort_key(s, regex=re.compile('([0-9]+)')):
@@ -21,11 +21,11 @@ def html_path(filename):
def html(filename):
path = html_path(filename)
- if os.path.exists(path):
+ try:
with open(path, encoding="utf8") as file:
return file.read()
-
- return ""
+ except OSError:
+ return ""
def walk_files(path, allowed_extensions=None):
@@ -56,3 +56,83 @@ def ldm_print(*args, **kwargs):
return
print(*args, **kwargs)
+
+
+def truncate_path(target_path, base_path=cwd):
+ abs_target, abs_base = os.path.abspath(target_path), os.path.abspath(base_path)
+ try:
+ if os.path.commonpath([abs_target, abs_base]) == abs_base:
+ return os.path.relpath(abs_target, abs_base)
+ except ValueError:
+ pass
+ return abs_target
+
+
+class MassFileListerCachedDir:
+ """A class that caches file metadata for a specific directory."""
+
+ def __init__(self, dirname):
+ self.files = None
+ self.files_cased = None
+ self.dirname = dirname
+
+ stats = ((x.name, x.stat(follow_symlinks=False)) for x in os.scandir(self.dirname))
+ files = [(n, s.st_mtime, s.st_ctime) for n, s in stats]
+ self.files = {x[0].lower(): x for x in files}
+ self.files_cased = {x[0]: x for x in files}
+
+
+class MassFileLister:
+ """A class that provides a way to check for the existence and mtime/ctile of files without doing more than one stat call per file."""
+
+ def __init__(self):
+ self.cached_dirs = {}
+
+ def find(self, path):
+ """
+ Find the metadata for a file at the given path.
+
+ Returns:
+ tuple or None: A tuple of (name, mtime, ctime) if the file exists, or None if it does not.
+ """
+
+ dirname, filename = os.path.split(path)
+
+ cached_dir = self.cached_dirs.get(dirname)
+ if cached_dir is None:
+ cached_dir = MassFileListerCachedDir(dirname)
+ self.cached_dirs[dirname] = cached_dir
+
+ stats = cached_dir.files_cased.get(filename)
+ if stats is not None:
+ return stats
+
+ stats = cached_dir.files.get(filename.lower())
+ if stats is None:
+ return None
+
+ try:
+ os_stats = os.stat(path, follow_symlinks=False)
+ return filename, os_stats.st_mtime, os_stats.st_ctime
+ except Exception:
+ return None
+
+ def exists(self, path):
+ """Check if a file exists at the given path."""
+
+ return self.find(path) is not None
+
+ def mctime(self, path):
+ """
+ Get the modification and creation times for a file at the given path.
+
+ Returns:
+ tuple: A tuple of (mtime, ctime) if the file exists, or (0, 0) if it does not.
+ """
+
+ stats = self.find(path)
+ return (0, 0) if stats is None else stats[1:3]
+
+ def reset(self):
+ """Clear the cache of all directories."""
+ self.cached_dirs.clear()
diff --git a/modules/xlmr.py b/modules/xlmr.py
index a407a3ca..319771b7 100644
--- a/modules/xlmr.py
+++ b/modules/xlmr.py
@@ -5,6 +5,9 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta
from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional
+from modules import torch_utils
+
+
class BertSeriesConfig(BertConfig):
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
@@ -62,7 +65,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
self.post_init()
def encode(self,c):
- device = next(self.parameters()).device
+ device = torch_utils.get_param(self).device
text = self.tokenizer(c,
truncation=True,
max_length=77,
diff --git a/modules/xlmr_m18.py b/modules/xlmr_m18.py
index a727e865..f6055504 100644
--- a/modules/xlmr_m18.py
+++ b/modules/xlmr_m18.py
@@ -4,6 +4,8 @@ import torch
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional
+from modules import torch_utils
+
class BertSeriesConfig(BertConfig):
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
@@ -68,7 +70,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
self.post_init()
def encode(self,c):
- device = next(self.parameters()).device
+ device = torch_utils.get_param(self).device
text = self.tokenizer(c,
truncation=True,
max_length=77,
diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py
index d8da94a0..2971dbc3 100644
--- a/modules/xpu_specific.py
+++ b/modules/xpu_specific.py
@@ -27,11 +27,90 @@ def torch_xpu_gc():
has_xpu = check_for_xpu()
+
+# Arc GPU cannot allocate a single block larger than 4GB: https://github.com/intel/compute-runtime/issues/627
+# Here we implement a slicing algorithm to split large batch size into smaller chunks,
+# so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT.
+# The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G,
+# which is the best trade-off between VRAM usage and performance.
+ARC_SINGLE_ALLOCATION_LIMIT = {}
+orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention
+def torch_xpu_scaled_dot_product_attention(
+ query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs
+):
+ # cast to same dtype first
+ key = key.to(query.dtype)
+ value = value.to(query.dtype)
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
+ attn_mask = attn_mask.to(query.dtype)
+
+ N = query.shape[:-2] # Batch size
+ L = query.size(-2) # Target sequence length
+ E = query.size(-1) # Embedding dimension of the query and key
+ S = key.size(-2) # Source sequence length
+ Ev = value.size(-1) # Embedding dimension of the value
+
+ total_batch_size = torch.numel(torch.empty(N))
+ device_id = query.device.index
+ if device_id not in ARC_SINGLE_ALLOCATION_LIMIT:
+ ARC_SINGLE_ALLOCATION_LIMIT[device_id] = min(torch.xpu.get_device_properties(device_id).total_memory // 8, 4 * 1024 * 1024 * 1024)
+ batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT[device_id] // (L * S * query.element_size()))
+
+ if total_batch_size <= batch_size_limit:
+ return orig_sdp_attn_func(
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ *args, **kwargs
+ )
+
+ query = torch.reshape(query, (-1, L, E))
+ key = torch.reshape(key, (-1, S, E))
+ value = torch.reshape(value, (-1, S, Ev))
+ if attn_mask is not None:
+ attn_mask = attn_mask.view(-1, L, S)
+ chunk_count = (total_batch_size + batch_size_limit - 1) // batch_size_limit
+ outputs = []
+ for i in range(chunk_count):
+ attn_mask_chunk = (
+ None
+ if attn_mask is None
+ else attn_mask[i * batch_size_limit : (i + 1) * batch_size_limit, :, :]
+ )
+ chunk_output = orig_sdp_attn_func(
+ query[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
+ key[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
+ value[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
+ attn_mask_chunk,
+ dropout_p,
+ is_causal,
+ *args, **kwargs
+ )
+ outputs.append(chunk_output)
+ result = torch.cat(outputs, dim=0)
+ return torch.reshape(result, (*N, L, Ev))
+
+
+def is_xpu_device(device: str | torch.device = None):
+ if device is None:
+ return False
+ if isinstance(device, str):
+ return device.startswith("xpu")
+ return device.type == "xpu"
+
+
if has_xpu:
- # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device
- CondFunc('torch.Generator',
- lambda orig_func, device=None: torch.xpu.Generator(device),
- lambda orig_func, device=None: device is not None and device.type == "xpu")
+ try:
+ # torch.Generator supports "xpu" device since 2.1
+ torch.Generator("xpu")
+ except RuntimeError:
+ # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for torch < 2.1)
+ CondFunc('torch.Generator',
+ lambda orig_func, device=None: torch.xpu.Generator(device),
+ lambda orig_func, device=None: is_xpu_device(device))
# W/A for some OPs that could not handle different input dtypes
CondFunc('torch.nn.functional.layer_norm',
@@ -55,5 +134,5 @@ if has_xpu:
lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out),
lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors))
CondFunc('torch.nn.functional.scaled_dot_product_attention',
- lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: orig_func(query, key.to(query.dtype), value.to(query.dtype), attn_mask, dropout_p, is_causal),
- lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype)
+ lambda orig_func, *args, **kwargs: torch_xpu_scaled_dot_product_attention(*args, **kwargs),
+ lambda orig_func, query, *args, **kwargs: query.is_xpu)