aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py167
-rw-r--r--modules/api/models.py14
-rw-r--r--modules/cache.py120
-rw-r--r--modules/call_queue.py23
-rw-r--r--modules/cmd_args.py3
-rw-r--r--modules/codeformer_model.py6
-rw-r--r--modules/devices.py11
-rw-r--r--modules/esrgan_model.py23
-rw-r--r--modules/extensions.py26
-rw-r--r--modules/extra_networks.py19
-rw-r--r--modules/extras.py3
-rw-r--r--modules/generation_parameters_copypaste.py29
-rw-r--r--modules/gfpgan_model.py2
-rw-r--r--modules/hashes.py33
-rw-r--r--modules/hypernetworks/hypernetwork.py32
-rw-r--r--modules/images.py76
-rw-r--r--modules/img2img.py74
-rw-r--r--modules/interrogate.py3
-rw-r--r--modules/launch_utils.py66
-rw-r--r--modules/lowvram.py53
-rw-r--r--modules/mac_specific.py39
-rw-r--r--modules/modelloader.py31
-rw-r--r--modules/paths.py39
-rw-r--r--modules/postprocessing.py7
-rw-r--r--modules/processing.py104
-rw-r--r--modules/prompt_parser.py95
-rw-r--r--modules/realesrgan_model.py33
-rw-r--r--modules/script_loading.py5
-rw-r--r--modules/scripts.py40
-rw-r--r--modules/sd_hijack.py45
-rw-r--r--modules/sd_hijack_clip.py46
-rw-r--r--modules/sd_hijack_open_clip.py36
-rw-r--r--modules/sd_hijack_optimizations.py51
-rw-r--r--modules/sd_hijack_unet.py8
-rw-r--r--modules/sd_models.py59
-rw-r--r--modules/sd_models_config.py9
-rw-r--r--modules/sd_models_xl.py99
-rw-r--r--modules/sd_samplers.py3
-rw-r--r--modules/sd_samplers_compvis.py6
-rw-r--r--modules/sd_samplers_kdiffusion.py45
-rw-r--r--modules/sd_vae_approx.py59
-rw-r--r--modules/sd_vae_taesd.py26
-rw-r--r--modules/shared.py55
-rw-r--r--modules/textual_inversion/logging.py48
-rw-r--r--modules/textual_inversion/preprocess.py2
-rw-r--r--modules/textual_inversion/textual_inversion.py17
-rw-r--r--modules/txt2img.py19
-rw-r--r--modules/ui.py16
-rw-r--r--modules/ui_common.py9
-rw-r--r--modules/ui_extensions.py40
-rw-r--r--modules/ui_extra_networks.py118
-rw-r--r--modules/ui_extra_networks_checkpoints.py32
-rw-r--r--modules/ui_extra_networks_hypernets.py33
-rw-r--r--modules/ui_extra_networks_textual_inversion.py32
-rw-r--r--modules/ui_extra_networks_user_metadata.py195
-rw-r--r--modules/ui_settings.py21
56 files changed, 1732 insertions, 573 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 2e49526e..2a4cd8a2 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,5 +1,6 @@
import base64
import io
+import os
import time
import datetime
import uvicorn
@@ -14,7 +15,7 @@ from fastapi.encoders import jsonable_encoder
from secrets import compare_digest
import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart
from modules.api import models
from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
@@ -22,7 +23,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image
-from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
+from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_aliases
from modules.sd_vae import vae_dict
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models
@@ -30,13 +31,7 @@ from modules import devices
from typing import Dict, List, Any
import piexif
import piexif.helper
-
-
-def upscaler_to_index(name: str):
- try:
- return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
- except Exception as e:
- raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e
+from contextlib import closing
def script_name_to_index(name, scripts):
@@ -84,6 +79,8 @@ def encode_pil_to_base64(image):
image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
+ if image.mode == "RGBA":
+ image = image.convert("RGB")
parameters = image.info.get('parameters', None)
exif_bytes = piexif.dump({
"Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
@@ -102,14 +99,16 @@ def encode_pil_to_base64(image):
def api_middleware(app: FastAPI):
- rich_available = True
+ rich_available = False
try:
- import anyio # importing just so it can be placed on silent list
- import starlette # importing just so it can be placed on silent list
- from rich.console import Console
- console = Console()
+ if os.environ.get('WEBUI_RICH_EXCEPTIONS', None) is not None:
+ import anyio # importing just so it can be placed on silent list
+ import starlette # importing just so it can be placed on silent list
+ from rich.console import Console
+ console = Console()
+ rich_available = True
except Exception:
- rich_available = False
+ pass
@app.middleware("http")
async def log_and_time(req: Request, call_next):
@@ -120,14 +119,14 @@ def api_middleware(app: FastAPI):
endpoint = req.scope.get('path', 'err')
if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
- t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
- code = res.status_code,
- ver = req.scope.get('http_version', '0.0'),
- cli = req.scope.get('client', ('0:0.0.0', 0))[0],
- prot = req.scope.get('scheme', 'err'),
- method = req.scope.get('method', 'err'),
- endpoint = endpoint,
- duration = duration,
+ t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
+ code=res.status_code,
+ ver=req.scope.get('http_version', '0.0'),
+ cli=req.scope.get('client', ('0:0.0.0', 0))[0],
+ prot=req.scope.get('scheme', 'err'),
+ method=req.scope.get('method', 'err'),
+ endpoint=endpoint,
+ duration=duration,
))
return res
@@ -138,7 +137,7 @@ def api_middleware(app: FastAPI):
"body": vars(e).get('body', ''),
"errors": str(e),
}
- if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
+ if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
message = f"API error: {request.method}: {request.url} {err}"
if rich_available:
print(message)
@@ -209,6 +208,11 @@ class Api:
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
+ if shared.cmd_opts.api_server_stop:
+ self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
+ self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"])
+ self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"])
+
self.default_script_arg_txt2img = []
self.default_script_arg_img2img = []
@@ -324,19 +328,19 @@ class Api:
args.pop('save_images', None)
with self.queue_lock:
- p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
- p.scripts = script_runner
- p.outpath_grids = opts.outdir_txt2img_grids
- p.outpath_samples = opts.outdir_txt2img_samples
-
- shared.state.begin()
- if selectable_scripts is not None:
- p.script_args = script_args
- processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
- else:
- p.script_args = tuple(script_args) # Need to pass args as tuple here
- processed = process_images(p)
- shared.state.end()
+ with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
+ p.scripts = script_runner
+ p.outpath_grids = opts.outdir_txt2img_grids
+ p.outpath_samples = opts.outdir_txt2img_samples
+
+ shared.state.begin(job="scripts_txt2img")
+ if selectable_scripts is not None:
+ p.script_args = script_args
+ processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
+ else:
+ p.script_args = tuple(script_args) # Need to pass args as tuple here
+ processed = process_images(p)
+ shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
@@ -380,20 +384,20 @@ class Api:
args.pop('save_images', None)
with self.queue_lock:
- p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
- p.init_images = [decode_base64_to_image(x) for x in init_images]
- p.scripts = script_runner
- p.outpath_grids = opts.outdir_img2img_grids
- p.outpath_samples = opts.outdir_img2img_samples
-
- shared.state.begin()
- if selectable_scripts is not None:
- p.script_args = script_args
- processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
- else:
- p.script_args = tuple(script_args) # Need to pass args as tuple here
- processed = process_images(p)
- shared.state.end()
+ with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
+ p.init_images = [decode_base64_to_image(x) for x in init_images]
+ p.scripts = script_runner
+ p.outpath_grids = opts.outdir_img2img_grids
+ p.outpath_samples = opts.outdir_img2img_samples
+
+ shared.state.begin(job="scripts_img2img")
+ if selectable_scripts is not None:
+ p.script_args = script_args
+ processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
+ else:
+ p.script_args = tuple(script_args) # Need to pass args as tuple here
+ processed = process_images(p)
+ shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
@@ -517,6 +521,10 @@ class Api:
return options
def set_config(self, req: Dict[str, Any]):
+ checkpoint_name = req.get("sd_model_checkpoint", None)
+ if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases:
+ raise RuntimeError(f"model {checkpoint_name!r} not found")
+
for k, v in req.items():
shared.opts.set(k, v)
@@ -593,48 +601,47 @@ class Api:
}
def refresh_checkpoints(self):
- shared.refresh_checkpoints()
+ with self.queue_lock:
+ shared.refresh_checkpoints()
def create_embedding(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="create_embedding")
filename = create_embedding(**args) # create empty embedding
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
- shared.state.end()
return models.CreateResponse(info=f"create embedding filename: {filename}")
except AssertionError as e:
- shared.state.end()
return models.TrainResponse(info=f"create embedding error: {e}")
+ finally:
+ shared.state.end()
+
def create_hypernetwork(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="create_hypernetwork")
filename = create_hypernetwork(**args) # create empty embedding
- shared.state.end()
return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
except AssertionError as e:
- shared.state.end()
return models.TrainResponse(info=f"create hypernetwork error: {e}")
+ finally:
+ shared.state.end()
def preprocess(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="preprocess")
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
shared.state.end()
- return models.PreprocessResponse(info = 'preprocess complete')
+ return models.PreprocessResponse(info='preprocess complete')
except KeyError as e:
- shared.state.end()
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
- except AssertionError as e:
- shared.state.end()
+ except Exception as e:
return models.PreprocessResponse(info=f"preprocess error: {e}")
- except FileNotFoundError as e:
+ finally:
shared.state.end()
- return models.PreprocessResponse(info=f'preprocess error: {e}')
def train_embedding(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="train_embedding")
apply_optimizations = shared.opts.training_xattention_optimizations
error = None
filename = ''
@@ -647,15 +654,15 @@ class Api:
finally:
if not apply_optimizations:
sd_hijack.apply_optimizations()
- shared.state.end()
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
- except AssertionError as msg:
- shared.state.end()
+ except Exception as msg:
return models.TrainResponse(info=f"train embedding error: {msg}")
+ finally:
+ shared.state.end()
def train_hypernetwork(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="train_hypernetwork")
shared.loaded_hypernetworks = []
apply_optimizations = shared.opts.training_xattention_optimizations
error = None
@@ -673,9 +680,10 @@ class Api:
sd_hijack.apply_optimizations()
shared.state.end()
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
- except AssertionError:
+ except Exception as exc:
+ return models.TrainResponse(info=f"train embedding error: {exc}")
+ finally:
shared.state.end()
- return models.TrainResponse(info=f"train embedding error: {error}")
def get_memory(self):
try:
@@ -714,4 +722,17 @@ class Api:
def launch(self, server_name, port):
self.app.include_router(self.router)
- uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0)
+ uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive)
+
+ def kill_webui(self):
+ restart.stop_program()
+
+ def restart_webui(self):
+ if restart.is_restartable():
+ restart.restart_program()
+ return Response(status_code=501)
+
+ def stop_webui(request):
+ shared.state.server_command = "stop"
+ return Response("Stopping.")
+
diff --git a/modules/api/models.py b/modules/api/models.py
index b3a745f0..bf97b1a3 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -1,4 +1,5 @@
import inspect
+
from pydantic import BaseModel, Field, create_model
from typing import Any, Optional
from typing_extensions import Literal
@@ -207,11 +208,12 @@ class PreprocessResponse(BaseModel):
fields = {}
for key, metadata in opts.data_labels.items():
value = opts.data.get(key)
- optType = opts.typemap.get(type(metadata.default), type(value))
+ optType = opts.typemap.get(type(metadata.default), type(metadata.default))
- if (metadata is not None):
- fields.update({key: (Optional[optType], Field(
- default=metadata.default ,description=metadata.label))})
+ if metadata.default is None:
+ pass
+ elif metadata is not None:
+ fields.update({key: (Optional[optType], Field(default=metadata.default, description=metadata.label))})
else:
fields.update({key: (Optional[optType], Field())})
@@ -274,10 +276,6 @@ class PromptStyleItem(BaseModel):
prompt: Optional[str] = Field(title="Prompt")
negative_prompt: Optional[str] = Field(title="Negative Prompt")
-class ArtistItem(BaseModel):
- name: str = Field(title="Name")
- score: float = Field(title="Score")
- category: str = Field(title="Category")
class EmbeddingItem(BaseModel):
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
diff --git a/modules/cache.py b/modules/cache.py
new file mode 100644
index 00000000..71fe6302
--- /dev/null
+++ b/modules/cache.py
@@ -0,0 +1,120 @@
+import json
+import os.path
+import threading
+import time
+
+from modules.paths import data_path, script_path
+
+cache_filename = os.path.join(data_path, "cache.json")
+cache_data = None
+cache_lock = threading.Lock()
+
+dump_cache_after = None
+dump_cache_thread = None
+
+
+def dump_cache():
+ """
+ Marks cache for writing to disk. 5 seconds after no one else flags the cache for writing, it is written.
+ """
+
+ global dump_cache_after
+ global dump_cache_thread
+
+ def thread_func():
+ global dump_cache_after
+ global dump_cache_thread
+
+ while dump_cache_after is not None and time.time() < dump_cache_after:
+ time.sleep(1)
+
+ with cache_lock:
+ with open(cache_filename, "w", encoding="utf8") as file:
+ json.dump(cache_data, file, indent=4)
+
+ dump_cache_after = None
+ dump_cache_thread = None
+
+ with cache_lock:
+ dump_cache_after = time.time() + 5
+ if dump_cache_thread is None:
+ dump_cache_thread = threading.Thread(name='cache-writer', target=thread_func)
+ dump_cache_thread.start()
+
+
+def cache(subsection):
+ """
+ Retrieves or initializes a cache for a specific subsection.
+
+ Parameters:
+ subsection (str): The subsection identifier for the cache.
+
+ Returns:
+ dict: The cache data for the specified subsection.
+ """
+
+ global cache_data
+
+ if cache_data is None:
+ with cache_lock:
+ if cache_data is None:
+ if not os.path.isfile(cache_filename):
+ cache_data = {}
+ else:
+ try:
+ with open(cache_filename, "r", encoding="utf8") as file:
+ cache_data = json.load(file)
+ except Exception:
+ os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json"))
+ print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache')
+ cache_data = {}
+
+ s = cache_data.get(subsection, {})
+ cache_data[subsection] = s
+
+ return s
+
+
+def cached_data_for_file(subsection, title, filename, func):
+ """
+ Retrieves or generates data for a specific file, using a caching mechanism.
+
+ Parameters:
+ subsection (str): The subsection of the cache to use.
+ title (str): The title of the data entry in the subsection of the cache.
+ filename (str): The path to the file to be checked for modifications.
+ func (callable): A function that generates the data if it is not available in the cache.
+
+ Returns:
+ dict or None: The cached or generated data, or None if data generation fails.
+
+ The `cached_data_for_file` function implements a caching mechanism for data stored in files.
+ It checks if the data associated with the given `title` is present in the cache and compares the
+ modification time of the file with the cached modification time. If the file has been modified,
+ the cache is considered invalid and the data is regenerated using the provided `func`.
+ Otherwise, the cached data is returned.
+
+ If the data generation fails, None is returned to indicate the failure. Otherwise, the generated
+ or cached data is returned as a dictionary.
+ """
+
+ existing_cache = cache(subsection)
+ ondisk_mtime = os.path.getmtime(filename)
+
+ entry = existing_cache.get(title)
+ if entry:
+ cached_mtime = entry.get("mtime", 0)
+ if ondisk_mtime > cached_mtime:
+ entry = None
+
+ if not entry or 'value' not in entry:
+ value = func()
+ if value is None:
+ return None
+
+ entry = {'mtime': ondisk_mtime, 'value': value}
+ existing_cache[title] = entry
+
+ dump_cache()
+
+ return entry['value']
diff --git a/modules/call_queue.py b/modules/call_queue.py
index 1b5e5273..61aa240f 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -1,3 +1,4 @@
+from functools import wraps
import html
import threading
import time
@@ -18,6 +19,7 @@ def wrap_queued_call(func):
def wrap_gradio_gpu_call(func, extra_outputs=None):
+ @wraps(func)
def f(*args, **kwargs):
# if the first argument is a string that says "task(...)", it is treated as a job id
@@ -28,7 +30,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
id_task = None
with queue_lock:
- shared.state.begin()
+ shared.state.begin(job=id_task)
progress.start_task(id_task)
try:
@@ -45,6 +47,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
+ @wraps(func)
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
if run_memmon:
@@ -82,9 +85,9 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
elapsed = time.perf_counter() - t
elapsed_m = int(elapsed // 60)
elapsed_s = elapsed % 60
- elapsed_text = f"{elapsed_s:.2f}s"
+ elapsed_text = f"{elapsed_s:.1f} sec."
if elapsed_m > 0:
- elapsed_text = f"{elapsed_m}m "+elapsed_text
+ elapsed_text = f"{elapsed_m} min. "+elapsed_text
if run_memmon:
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
@@ -92,14 +95,22 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
reserved_peak = mem_stats['reserved_peak']
sys_peak = mem_stats['system_peak']
sys_total = mem_stats['total']
- sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
+ sys_pct = sys_peak/max(sys_total, 1) * 100
- vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
+ toltip_a = "Active: peak amount of video memory used during generation (excluding cached data)"
+ toltip_r = "Reserved: total amout of video memory allocated by the Torch library "
+ toltip_sys = "System: peak amout of video memory allocated by all running programs, out of total capacity"
+
+ text_a = f"<abbr title='{toltip_a}'>A</abbr>: <span class='measurement'>{active_peak/1024:.2f} GB</span>"
+ text_r = f"<abbr title='{toltip_r}'>R</abbr>: <span class='measurement'>{reserved_peak/1024:.2f} GB</span>"
+ text_sys = f"<abbr title='{toltip_sys}'>Sys</abbr>: <span class='measurement'>{sys_peak/1024:.1f}/{sys_total/1024:g} GB</span> ({sys_pct:.1f}%)"
+
+ vram_html = f"<p class='vram'>{text_a}, <wbr>{text_r}, <wbr>{text_sys}</p>"
else:
vram_html = ''
# last item is always HTML
- res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
+ res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr><span class='measurement'>{elapsed_text}</span></p>{vram_html}</div>"
return tuple(res)
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index de905caa..e401f641 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -15,6 +15,7 @@ parser.add_argument("--update-check", action='store_true', help="launch.py argum
parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
+parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint")
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
@@ -107,3 +108,5 @@ parser.add_argument("--no-hashing", action='store_true', help="disable sha256 ha
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
+parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
+parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index d974e4b8..da42b5e9 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -15,7 +15,6 @@ model_dir = "Codeformer"
model_path = os.path.join(models_path, model_dir)
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
-have_codeformer = False
codeformer = None
@@ -100,7 +99,7 @@ def setup_model(dirname):
output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output
- torch.cuda.empty_cache()
+ devices.torch_gc()
except Exception:
errors.report('Failed inference for CodeFormer', exc_info=True)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
@@ -123,9 +122,6 @@ def setup_model(dirname):
return restored_img
- global have_codeformer
- have_codeformer = True
-
global codeformer
codeformer = FaceRestorerCodeFormer(dirname)
shared.face_restorers.append(codeformer)
diff --git a/modules/devices.py b/modules/devices.py
index 1ed6ffdc..57e51da3 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -15,13 +15,6 @@ def has_mps() -> bool:
else:
return mac_specific.has_mps
-def extract_device_id(args, name):
- for x in range(len(args)):
- if name in args[x]:
- return args[x + 1]
-
- return None
-
def get_cuda_device_string():
from modules import shared
@@ -56,11 +49,15 @@ def get_device_for(task):
def torch_gc():
+
if torch.cuda.is_available():
with torch.cuda.device(get_cuda_device_string()):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
+ if has_mps():
+ mac_specific.torch_mps_gc()
+
def enable_tf32():
if torch.cuda.is_available():
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index 2fced999..02a1727d 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -1,15 +1,13 @@
-import os
+import sys
import numpy as np
import torch
from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
import modules.esrgan_model_arch as arch
from modules import modelloader, images, devices
-from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts
-
+from modules.upscaler import Upscaler, UpscalerData
def mod2normal(state_dict):
@@ -134,7 +132,7 @@ class UpscalerESRGAN(Upscaler):
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
scalers.append(scaler_data)
for file in model_paths:
- if "http" in file:
+ if file.startswith("http"):
name = self.model_name
else:
name = modelloader.friendly_name(file)
@@ -143,26 +141,25 @@ class UpscalerESRGAN(Upscaler):
self.scalers.append(scaler_data)
def do_upscale(self, img, selected_model):
- model = self.load_model(selected_model)
- if model is None:
+ try:
+ model = self.load_model(selected_model)
+ except Exception as e:
+ print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
return img
model.to(devices.device_esrgan)
img = esrgan_upscale(model, img)
return img
def load_model(self, path: str):
- if "http" in path:
- filename = load_file_from_url(
+ if path.startswith("http"):
+ # TODO: this doesn't use `path` at all?
+ filename = modelloader.load_file_from_url(
url=self.model_url,
model_dir=self.model_download_path,
file_name=f"{self.model_name}.pth",
- progress=True,
)
else:
filename = path
- if not os.path.exists(filename) or filename is None:
- print(f"Unable to load {self.model_path} from {filename}")
- return None
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
diff --git a/modules/extensions.py b/modules/extensions.py
index abc6e2b1..c561159a 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -1,7 +1,7 @@
import os
import threading
-from modules import shared, errors
+from modules import shared, errors, cache
from modules.gitpython_hack import Repo
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
@@ -21,6 +21,7 @@ def active():
class Extension:
lock = threading.Lock()
+ cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']
def __init__(self, name, path, enabled=True, is_builtin=False):
self.name = name
@@ -36,15 +37,29 @@ class Extension:
self.remote = None
self.have_info_from_repo = False
+ def to_dict(self):
+ return {x: getattr(self, x) for x in self.cached_fields}
+
+ def from_dict(self, d):
+ for field in self.cached_fields:
+ setattr(self, field, d[field])
+
def read_info_from_repo(self):
if self.is_builtin or self.have_info_from_repo:
return
- with self.lock:
- if self.have_info_from_repo:
- return
+ def read_from_repo():
+ with self.lock:
+ if self.have_info_from_repo:
+ return
+
+ self.do_read_info_from_repo()
+
+ return self.to_dict()
- self.do_read_info_from_repo()
+ d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
+ self.from_dict(d)
+ self.status = 'unknown'
def do_read_info_from_repo(self):
repo = None
@@ -58,7 +73,6 @@ class Extension:
self.remote = None
else:
try:
- self.status = 'unknown'
self.remote = next(repo.remote().urls, None)
commit = repo.head.commit
self.commit_date = commit.committed_date
diff --git a/modules/extra_networks.py b/modules/extra_networks.py
index 1f093df2..6ae07e91 100644
--- a/modules/extra_networks.py
+++ b/modules/extra_networks.py
@@ -4,16 +4,22 @@ from collections import defaultdict
from modules import errors
extra_network_registry = {}
+extra_network_aliases = {}
def initialize():
extra_network_registry.clear()
+ extra_network_aliases.clear()
def register_extra_network(extra_network):
extra_network_registry[extra_network.name] = extra_network
+def register_extra_network_alias(extra_network, alias):
+ extra_network_aliases[alias] = extra_network
+
+
def register_default_extra_networks():
from modules.extra_networks_hypernet import ExtraNetworkHypernet
register_extra_network(ExtraNetworkHypernet())
@@ -82,20 +88,26 @@ def activate(p, extra_network_data):
"""call activate for extra networks in extra_network_data in specified order, then call
activate for all remaining registered networks with an empty argument list"""
+ activated = []
+
for extra_network_name, extra_network_args in extra_network_data.items():
extra_network = extra_network_registry.get(extra_network_name, None)
+
+ if extra_network is None:
+ extra_network = extra_network_aliases.get(extra_network_name, None)
+
if extra_network is None:
print(f"Skipping unknown extra network: {extra_network_name}")
continue
try:
extra_network.activate(p, extra_network_args)
+ activated.append(extra_network)
except Exception as e:
errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
for extra_network_name, extra_network in extra_network_registry.items():
- args = extra_network_data.get(extra_network_name, None)
- if args is not None:
+ if extra_network in activated:
continue
try:
@@ -103,6 +115,9 @@ def activate(p, extra_network_data):
except Exception as e:
errors.display(e, f"activating extra network {extra_network_name}")
+ if p.scripts is not None:
+ p.scripts.after_extra_networks_activate(p, batch_number=p.iteration, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds, extra_network_data=extra_network_data)
+
def deactivate(p, extra_network_data):
"""call deactivate for extra networks in extra_network_data in specified order, then call
diff --git a/modules/extras.py b/modules/extras.py
index 830b53aa..e9c0263e 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -73,8 +73,7 @@ def to_half(tensor, enable):
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
- shared.state.begin()
- shared.state.job = 'model-merge'
+ shared.state.begin(job="model-merge")
def fail(message):
shared.state.textinfo = message
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index dd30a1b5..a3448be9 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -174,31 +174,6 @@ def send_image_and_dimensions(x):
return img, w, h
-
-def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
- """Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
-
- Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
- parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
-
- If the infotext has no hash, then a hypernet with the same name will be selected instead.
- """
- hypernet_name = hypernet_name.lower()
- if hypernet_hash is not None:
- # Try to match the hash in the name
- for hypernet_key in shared.hypernetworks.keys():
- result = re_hypernet_hash.search(hypernet_key)
- if result is not None and result[1] == hypernet_hash:
- return hypernet_key
- else:
- # Fall back to a hypernet with the same name
- for hypernet_key in shared.hypernetworks.keys():
- if hypernet_key.lower().startswith(hypernet_name):
- return hypernet_key
-
- return None
-
-
def restore_old_hires_fix_params(res):
"""for infotexts that specify old First pass size parameter, convert it into
width, height, and hr scale"""
@@ -332,10 +307,6 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
return res
-settings_map = {}
-
-
-
infotext_to_setting_name_mapping = [
('Clip skip', 'CLIP_stop_at_last_layers', ),
('Conditional mask weight', 'inpainting_mask_weight'),
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index 6ecd295c..8e0f13bd 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -25,7 +25,7 @@ def gfpgann():
return None
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
- if len(models) == 1 and "http" in models[0]:
+ if len(models) == 1 and models[0].startswith("http"):
model_file = models[0]
elif len(models) != 0:
latest_file = max(models, key=os.path.getctime)
diff --git a/modules/hashes.py b/modules/hashes.py
index 8b7ea0ac..b7a33b42 100644
--- a/modules/hashes.py
+++ b/modules/hashes.py
@@ -1,38 +1,11 @@
import hashlib
-import json
import os.path
-import filelock
-
from modules import shared
-from modules.paths import data_path
-
-
-cache_filename = os.path.join(data_path, "cache.json")
-cache_data = None
-
-
-def dump_cache():
- with filelock.FileLock(f"{cache_filename}.lock"):
- with open(cache_filename, "w", encoding="utf8") as file:
- json.dump(cache_data, file, indent=4)
-
-
-def cache(subsection):
- global cache_data
-
- if cache_data is None:
- with filelock.FileLock(f"{cache_filename}.lock"):
- if not os.path.isfile(cache_filename):
- cache_data = {}
- else:
- with open(cache_filename, "r", encoding="utf8") as file:
- cache_data = json.load(file)
-
- s = cache_data.get(subsection, {})
- cache_data[subsection] = s
+import modules.cache
- return s
+dump_cache = modules.cache.dump_cache
+cache = modules.cache.cache
def calculate_sha256(filename):
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 5d12b449..c4821d21 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -3,6 +3,7 @@ import glob
import html
import os
import inspect
+from contextlib import closing
import modules.textual_inversion.dataset
import torch
@@ -353,17 +354,6 @@ def load_hypernetworks(names, multipliers=None):
shared.loaded_hypernetworks.append(hypernetwork)
-def find_closest_hypernetwork_name(search: str):
- if not search:
- return None
- search = search.lower()
- applicable = [name for name in shared.hypernetworks if search in name.lower()]
- if not applicable:
- return None
- applicable = sorted(applicable, key=lambda name: len(name))
- return applicable[0]
-
-
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
@@ -388,7 +378,7 @@ def apply_hypernetworks(hypernetworks, context, layer=None):
return context_k, context_v
-def attention_CrossAttention_forward(self, x, context=None, mask=None):
+def attention_CrossAttention_forward(self, x, context=None, mask=None, **kwargs):
h = self.heads
q = self.to_q(x)
@@ -446,18 +436,6 @@ def statistics(data):
return total_information, recent_information
-def report_statistics(loss_info:dict):
- keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
- for key in keys:
- try:
- print("Loss statistics for file " + key)
- info, recent = statistics(list(loss_info[key]))
- print(info)
- print(recent)
- except Exception as e:
- print(e)
-
-
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
# Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
@@ -734,8 +712,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
preview_text = p.prompt
- processed = processing.process_images(p)
- image = processed.images[0] if len(processed.images) > 0 else None
+ with closing(p):
+ processed = processing.process_images(p)
+ image = processed.images[0] if len(processed.images) > 0 else None
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
@@ -770,7 +749,6 @@ Last saved image: {html.escape(last_saved_image)}<br/>
pbar.leave = False
pbar.close()
hypernetwork.eval()
- #report_statistics(loss_dict)
sd_hijack_checkpoint.remove()
diff --git a/modules/images.py b/modules/images.py
index 7bbfc3e0..38aa933d 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import datetime
import pytz
@@ -10,7 +12,7 @@ import re
import numpy as np
import piexif
import piexif.helper
-from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
+from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin
import string
import json
import hashlib
@@ -139,6 +141,11 @@ class GridAnnotation:
def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
+
+ color_active = ImageColor.getcolor(opts.grid_text_active_color, 'RGB')
+ color_inactive = ImageColor.getcolor(opts.grid_text_inactive_color, 'RGB')
+ color_background = ImageColor.getcolor(opts.grid_background_color, 'RGB')
+
def wrap(drawing, text, font, line_length):
lines = ['']
for word in text.split():
@@ -168,9 +175,6 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
fnt = get_font(fontsize)
- color_active = (0, 0, 0)
- color_inactive = (153, 153, 153)
-
pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
cols = im.width // width
@@ -179,7 +183,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
- calc_img = Image.new("RGB", (1, 1), "white")
+ calc_img = Image.new("RGB", (1, 1), color_background)
calc_d = ImageDraw.Draw(calc_img)
for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
@@ -200,7 +204,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
- result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), "white")
+ result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), color_background)
for row in range(rows):
for col in range(cols):
@@ -302,12 +306,14 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None):
if ratio < src_ratio:
fill_height = height // 2 - src_h // 2
- res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
- res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
+ if fill_height > 0:
+ res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
+ res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
elif ratio > src_ratio:
fill_width = width // 2 - src_w // 2
- res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
- res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
+ if fill_width > 0:
+ res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
+ res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
return res
@@ -357,7 +363,7 @@ class FilenameGenerator:
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
- 'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.model_name, replace_spaces=False),
+ 'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.name_for_extra, replace_spaces=False),
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
@@ -372,8 +378,9 @@ class FilenameGenerator:
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
+ 'user': lambda self: self.p.user,
'vae_filename': lambda self: self.get_vae_filename(),
-
+ 'none': lambda self: '', # Overrides the default so you can get just the sequence number
}
default_time_format = '%Y%m%d%H%M%S'
@@ -497,13 +504,23 @@ def get_next_sequence_number(path, basename):
return result + 1
-def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None):
+def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None, pnginfo_section_name='parameters'):
+ """
+ Saves image to filename, including geninfo as text information for generation info.
+ For PNG images, geninfo is added to existing pnginfo dictionary using the pnginfo_section_name argument as key.
+ For JPG images, there's no dictionary and geninfo just replaces the EXIF description.
+ """
+
if extension is None:
extension = os.path.splitext(filename)[1]
image_format = Image.registered_extensions()[extension]
if extension.lower() == '.png':
+ existing_pnginfo = existing_pnginfo or {}
+ if opts.enable_pnginfo:
+ existing_pnginfo[pnginfo_section_name] = geninfo
+
if opts.enable_pnginfo:
pnginfo_data = PngImagePlugin.PngInfo()
for k, v in (existing_pnginfo or {}).items():
@@ -585,13 +602,13 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
else:
file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
+ file_decoration = namegen.apply(file_decoration) + suffix
+
add_number = opts.save_images_add_number or file_decoration == ''
if file_decoration != "" and add_number:
file_decoration = f"-{file_decoration}"
- file_decoration = namegen.apply(file_decoration) + suffix
-
if add_number:
basecount = get_next_sequence_number(path, basename)
fullfn = None
@@ -622,7 +639,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
"""
temp_file_path = f"{filename_without_extension}.tmp"
- save_image_with_geninfo(image_to_save, info, temp_file_path, extension, params.pnginfo)
+ save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name)
os.replace(temp_file_path, filename_without_extension + extension)
@@ -639,12 +656,18 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
ratio = image.width / image.height
-
+ resize_to = None
if oversize and ratio > 1:
- image = image.resize((round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)), LANCZOS)
+ resize_to = round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)
elif oversize:
- image = image.resize((round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)), LANCZOS)
+ resize_to = round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)
+ if resize_to is not None:
+ try:
+ # Resizing image with LANCZOS could throw an exception if e.g. image mode is I;16
+ image = image.resize(resize_to, LANCZOS)
+ except Exception:
+ image = image.resize(resize_to)
try:
_atomically_save_image(image, fullfn_without_extension, ".jpg")
except Exception as e:
@@ -662,8 +685,15 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
return fullfn, txt_fullfn
-def read_info_from_image(image):
- items = image.info or {}
+IGNORED_INFO_KEYS = {
+ 'jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
+ 'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
+ 'icc_profile', 'chromaticity', 'photoshop',
+}
+
+
+def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
+ items = (image.info or {}).copy()
geninfo = items.pop('parameters', None)
@@ -679,9 +709,7 @@ def read_info_from_image(image):
items['exif comment'] = exif_comment
geninfo = exif_comment
- for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
- 'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
- 'icc_profile', 'chromaticity']:
+ for field in IGNORED_INFO_KEYS:
items.pop(field, None)
if items.get("Software", None) == "NovelAI":
diff --git a/modules/img2img.py b/modules/img2img.py
index 2c497020..a811e7a4 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -1,23 +1,26 @@
import os
+from contextlib import closing
from pathlib import Path
import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
+import gradio as gr
-from modules import sd_samplers
-from modules.generation_parameters_copypaste import create_override_settings_dict
+from modules import sd_samplers, images as imgutil
+from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
+from modules.images import save_image
import modules.shared as shared
import modules.processing as processing
from modules.ui import plaintext_to_html
import modules.scripts
-def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0):
+def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
processing.fix_seed(p)
- images = shared.listfiles(input_dir)
+ images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp")))
is_inpaint_batch = False
if inpaint_mask_dir:
@@ -36,6 +39,14 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
state.job_count = len(images) * p.n_iter
+ # extract "default" params to use in case getting png info fails
+ prompt = p.prompt
+ negative_prompt = p.negative_prompt
+ seed = p.seed
+ cfg_scale = p.cfg_scale
+ sampler_name = p.sampler_name
+ steps = p.steps
+
for i, image in enumerate(images):
state.job = f"{i+1} out of {len(images)}"
if state.skipped:
@@ -79,25 +90,45 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
mask_image = Image.open(mask_image_path)
p.image_mask = mask_image
+ if use_png_info:
+ try:
+ info_img = img
+ if png_info_dir:
+ info_img_path = os.path.join(png_info_dir, os.path.basename(image))
+ info_img = Image.open(info_img_path)
+ geninfo, _ = imgutil.read_info_from_image(info_img)
+ parsed_parameters = parse_generation_parameters(geninfo)
+ parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})}
+ except Exception:
+ parsed_parameters = {}
+
+ p.prompt = prompt + (" " + parsed_parameters["Prompt"] if "Prompt" in parsed_parameters else "")
+ p.negative_prompt = negative_prompt + (" " + parsed_parameters["Negative prompt"] if "Negative prompt" in parsed_parameters else "")
+ p.seed = int(parsed_parameters.get("Seed", seed))
+ p.cfg_scale = float(parsed_parameters.get("CFG scale", cfg_scale))
+ p.sampler_name = parsed_parameters.get("Sampler", sampler_name)
+ p.steps = int(parsed_parameters.get("Steps", steps))
+
proc = modules.scripts.scripts_img2img.run(p, *args)
if proc is None:
proc = process_images(p)
for n, processed_image in enumerate(proc.images):
- filename = image_path.name
+ filename = image_path.stem
+ infotext = proc.infotext(p, n)
+ relpath = os.path.dirname(os.path.relpath(image, input_dir))
if n > 0:
- left, right = os.path.splitext(filename)
- filename = f"{left}-{n}{right}"
+ filename += f"-{n}"
if not save_normally:
- os.makedirs(output_dir, exist_ok=True)
+ os.makedirs(os.path.join(output_dir, relpath), exist_ok=True)
if processed_image.mode == 'RGBA':
processed_image = processed_image.convert("RGB")
- processed_image.save(os.path.join(output_dir, filename))
+ save_image(processed_image, os.path.join(output_dir, relpath), None, extension=opts.samples_format, info=infotext, forced_filename=filename, save_to_dirs=False)
-def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
+def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts)
is_batch = mode == 5
@@ -180,24 +211,25 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
p.scripts = modules.scripts.scripts_img2img
p.script_args = args
+ p.user = request.username
+
if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
if mask:
p.extra_generation_params["Mask blur"] = mask_blur
- if is_batch:
- assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
+ with closing(p):
+ if is_batch:
+ assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
- process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by)
-
- processed = Processed(p, [], p.seed, "")
- else:
- processed = modules.scripts.scripts_img2img.run(p, *args)
- if processed is None:
- processed = process_images(p)
+ process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
- p.close()
+ processed = Processed(p, [], p.seed, "")
+ else:
+ processed = modules.scripts.scripts_img2img.run(p, *args)
+ if processed is None:
+ processed = process_images(p)
shared.total_tqdm.clear()
@@ -208,4 +240,4 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
if opts.do_not_show_images:
processed.images = []
- return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
+ return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 9b2c5b60..a3ae1dd5 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -184,8 +184,7 @@ class InterrogateModels:
def interrogate(self, pil_image):
res = ""
- shared.state.begin()
- shared.state.job = 'interrogate'
+ shared.state.begin(job="interrogate")
try:
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index 231ebc5f..93838745 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -1,4 +1,5 @@
# this scripts installs necessary requirements and launches main program in webui.py
+import re
import subprocess
import os
import sys
@@ -9,6 +10,9 @@ from functools import lru_cache
from modules import cmd_args, errors
from modules.paths_internal import script_path, extensions_dir
+from modules import timer
+
+timer.startup_timer.record("start")
args, _ = cmd_args.parser.parse_known_args()
@@ -69,10 +73,12 @@ def git_tag():
return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
except Exception:
try:
- from pathlib import Path
- changelog_md = Path(__file__).parent.parent / "CHANGELOG.md"
- with changelog_md.open(encoding="utf-8") as file:
- return next((line.strip() for line in file if line.strip()), "<none>")
+
+ changelog_md = os.path.join(os.path.dirname(os.path.dirname(__file__)), "CHANGELOG.md")
+ with open(changelog_md, "r", encoding="utf-8") as file:
+ line = next((line.strip() for line in file if line.strip()), "<none>")
+ line = line.replace("## ", "")
+ return line
except Exception:
return "<none>"
@@ -142,15 +148,15 @@ def git_clone(url, dir, name, commithash=None):
if commithash is None:
return
- current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
+ current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
if current_hash == commithash:
return
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
- run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
+ run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
return
- run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
+ run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
if commithash is not None:
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
@@ -224,6 +230,44 @@ def run_extensions_installers(settings_file):
run_extension_installer(os.path.join(extensions_dir, dirname_extension))
+re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
+
+
+def requrements_met(requirements_file):
+ """
+ Does a simple parse of a requirements.txt file to determine if all rerqirements in it
+ are already installed. Returns True if so, False if not installed or parsing fails.
+ """
+
+ import importlib.metadata
+ import packaging.version
+
+ with open(requirements_file, "r", encoding="utf8") as file:
+ for line in file:
+ if line.strip() == "":
+ continue
+
+ m = re.match(re_requirement, line)
+ if m is None:
+ return False
+
+ package = m.group(1).strip()
+ version_required = (m.group(2) or "").strip()
+
+ if version_required == "":
+ continue
+
+ try:
+ version_installed = importlib.metadata.version(package)
+ except Exception:
+ return False
+
+ if packaging.version.parse(version_required) != packaging.version.parse(version_installed):
+ return False
+
+ return True
+
+
def prepare_environment():
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
@@ -235,11 +279,13 @@ def prepare_environment():
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
+ stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
+ stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
@@ -297,6 +343,7 @@ def prepare_environment():
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
+ git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
@@ -306,7 +353,9 @@ def prepare_environment():
if not os.path.isfile(requirements_file):
requirements_file = os.path.join(script_path, requirements_file)
- run_pip(f"install -r \"{requirements_file}\"", "requirements")
+
+ if not requrements_met(requirements_file):
+ run_pip(f"install -r \"{requirements_file}\"", "requirements")
run_extensions_installers(settings_file=args.ui_settings_file)
@@ -321,6 +370,7 @@ def prepare_environment():
exit(0)
+
def configure_for_tests():
if "--api" not in sys.argv:
sys.argv.append("--api")
diff --git a/modules/lowvram.py b/modules/lowvram.py
index d95bcfbf..6bbc11eb 100644
--- a/modules/lowvram.py
+++ b/modules/lowvram.py
@@ -53,19 +53,46 @@ def setup_for_low_vram(sd_model, use_medvram):
send_me_to_gpu(first_stage_model, None)
return first_stage_model_decode(z)
- # for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field
- if hasattr(sd_model.cond_stage_model, 'model'):
- sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
-
- # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then
- # send the model to GPU. Then put modules back. the modules will be in CPU.
- stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), getattr(sd_model, 'embedder', None), sd_model.model
- sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None
+ to_remain_in_cpu = [
+ (sd_model, 'first_stage_model'),
+ (sd_model, 'depth_model'),
+ (sd_model, 'embedder'),
+ (sd_model, 'model'),
+ (sd_model, 'embedder'),
+ ]
+
+ is_sdxl = hasattr(sd_model, 'conditioner')
+ is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
+
+ if is_sdxl:
+ to_remain_in_cpu.append((sd_model, 'conditioner'))
+ elif is_sd2:
+ to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
+ else:
+ to_remain_in_cpu.append((sd_model.cond_stage_model, 'transformer'))
+
+ # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model
+ stored = []
+ for obj, field in to_remain_in_cpu:
+ module = getattr(obj, field, None)
+ stored.append(module)
+ setattr(obj, field, None)
+
+ # send the model to GPU.
sd_model.to(devices.device)
- sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored
+
+ # put modules back. the modules will be in CPU.
+ for (obj, field), module in zip(to_remain_in_cpu, stored):
+ setattr(obj, field, module)
# register hooks for those the first three models
- sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
+ if is_sdxl:
+ sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
+ elif is_sd2:
+ sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
+ else:
+ sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
+
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
@@ -73,11 +100,9 @@ def setup_for_low_vram(sd_model, use_medvram):
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
if sd_model.embedder:
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
- parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
- if hasattr(sd_model.cond_stage_model, 'model'):
- sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer
- del sd_model.cond_stage_model.transformer
+ if hasattr(sd_model, 'cond_stage_model'):
+ parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
if use_medvram:
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
diff --git a/modules/mac_specific.py b/modules/mac_specific.py
index d74c6b95..9ceb43ba 100644
--- a/modules/mac_specific.py
+++ b/modules/mac_specific.py
@@ -1,20 +1,43 @@
+import logging
+
import torch
import platform
from modules.sd_hijack_utils import CondFunc
from packaging import version
+log = logging.getLogger(__name__)
+
-# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
-# check `getattr` and try it for compatibility
+# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
+# use check `getattr` and try it for compatibility.
+# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availabilty,
+# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
def check_for_mps() -> bool:
- if not getattr(torch, 'has_mps', False):
- return False
+ if version.parse(torch.__version__) <= version.parse("2.0.1"):
+ if not getattr(torch, 'has_mps', False):
+ return False
+ try:
+ torch.zeros(1).to(torch.device("mps"))
+ return True
+ except Exception:
+ return False
+ else:
+ return torch.backends.mps.is_available() and torch.backends.mps.is_built()
+
+
+has_mps = check_for_mps()
+
+
+def torch_mps_gc() -> None:
try:
- torch.zeros(1).to(torch.device("mps"))
- return True
+ from modules.shared import state
+ if state.current_latent is not None:
+ log.debug("`current_latent` is set, skipping MPS garbage collection")
+ return
+ from torch.mps import empty_cache
+ empty_cache()
except Exception:
- return False
-has_mps = check_for_mps()
+ log.warning("MPS garbage collection failed", exc_info=True)
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
diff --git a/modules/modelloader.py b/modules/modelloader.py
index 75f01247..098bcb79 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import os
import shutil
import importlib
@@ -8,6 +10,29 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale
from modules.paths import script_path, models_path
+def load_file_from_url(
+ url: str,
+ *,
+ model_dir: str,
+ progress: bool = True,
+ file_name: str | None = None,
+) -> str:
+ """Download a file from `url` into `model_dir`, using the file present if possible.
+
+ Returns the path to the downloaded file.
+ """
+ os.makedirs(model_dir, exist_ok=True)
+ if not file_name:
+ parts = urlparse(url)
+ file_name = os.path.basename(parts.path)
+ cached_file = os.path.abspath(os.path.join(model_dir, file_name))
+ if not os.path.exists(cached_file):
+ print(f'Downloading: "{url}" to {cached_file}\n')
+ from torch.hub import download_url_to_file
+ download_url_to_file(url, cached_file, progress=progress)
+ return cached_file
+
+
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
"""
A one-and done loader to try finding the desired models in specified directories.
@@ -46,9 +71,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
if model_url is not None and len(output) == 0:
if download_name is not None:
- from basicsr.utils.download_util import load_file_from_url
- dl = load_file_from_url(model_url, places[0], True, download_name)
- output.append(dl)
+ output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name))
else:
output.append(model_url)
@@ -59,7 +82,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
def friendly_name(file: str):
- if "http" in file:
+ if file.startswith("http"):
file = urlparse(file).path
file = os.path.basename(file)
diff --git a/modules/paths.py b/modules/paths.py
index 5171df4f..25052339 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -5,6 +5,21 @@ from modules.paths_internal import models_path, script_path, data_path, extensio
import modules.safe # noqa: F401
+def mute_sdxl_imports():
+ """create fake modules that SDXL wants to import but doesn't actually use for our purposes"""
+
+ class Dummy:
+ pass
+
+ module = Dummy()
+ module.LPIPS = None
+ sys.modules['taming.modules.losses.lpips'] = module
+
+ module = Dummy()
+ module.StableDataModuleFromConfig = None
+ sys.modules['sgm.data'] = module
+
+
# data_path = cmd_opts_pre.data
sys.path.insert(0, script_path)
@@ -18,8 +33,11 @@ for possible_sd_path in possible_sd_paths:
assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"
+mute_sdxl_imports()
+
path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion', []),
+ (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
@@ -35,20 +53,13 @@ for d, must_exist, what, options in path_dirs:
d = os.path.abspath(d)
if "atstart" in options:
sys.path.insert(0, d)
+ elif "sgm" in options:
+ # Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we
+ # import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir.
+
+ sys.path.insert(0, d)
+ import sgm # noqa: F401
+ sys.path.pop(0)
else:
sys.path.append(d)
paths[what] = d
-
-
-class Prioritize:
- def __init__(self, name):
- self.name = name
- self.path = None
-
- def __enter__(self):
- self.path = sys.path.copy()
- sys.path = [paths[self.name]] + sys.path
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- sys.path = self.path
- self.path = None
diff --git a/modules/postprocessing.py b/modules/postprocessing.py
index 736315e2..136e9c88 100644
--- a/modules/postprocessing.py
+++ b/modules/postprocessing.py
@@ -9,8 +9,7 @@ from modules.shared import opts
def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
devices.torch_gc()
- shared.state.begin()
- shared.state.job = 'extras'
+ shared.state.begin(job="extras")
image_data = []
image_names = []
@@ -54,7 +53,9 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
for image, name in zip(image_data, image_names):
shared.state.textinfo = name
- existing_pnginfo = image.info or {}
+ parameters, existing_pnginfo = images.read_info_from_image(image)
+ if parameters:
+ existing_pnginfo["parameters"] = parameters
pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
diff --git a/modules/processing.py b/modules/processing.py
index 8da73884..a74a5302 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -14,7 +14,7 @@ from skimage import exposure
from typing import Any, Dict, List
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -184,6 +184,8 @@ class StableDiffusionProcessing:
self.uc = None
self.c = None
+ self.user = None
+
@property
def sd_model(self):
return shared.sd_model
@@ -328,8 +330,21 @@ class StableDiffusionProcessing:
caches is a list with items described above.
"""
+
+ cached_params = (
+ required_prompts,
+ steps,
+ opts.CLIP_stop_at_last_layers,
+ shared.sd_model.sd_checkpoint_info,
+ extra_network_data,
+ opts.sdxl_crop_left,
+ opts.sdxl_crop_top,
+ self.width,
+ self.height,
+ )
+
for cache in caches:
- if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]:
+ if cache[0] is not None and cached_params == cache[0]:
return cache[1]
cache = caches[0]
@@ -337,14 +352,17 @@ class StableDiffusionProcessing:
with devices.autocast():
cache[1] = function(shared.sd_model, required_prompts, steps)
- cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data)
+ cache[0] = cached_params
return cache[1]
def setup_conds(self):
+ prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
+ negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
+
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
- self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
- self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
+ self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
+ self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
def parse_extra_network_prompts(self):
self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
@@ -520,9 +538,42 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
return x
+def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
+ samples = []
+
+ for i in range(batch.shape[0]):
+ sample = decode_first_stage(model, batch[i:i + 1])[0]
+
+ if check_for_nans:
+ try:
+ devices.test_for_nans(sample, "vae")
+ except devices.NansException as e:
+ if devices.dtype_vae == torch.float32 or not shared.opts.auto_vae_precision:
+ raise e
+
+ errors.print_error_explanation(
+ "A tensor with all NaNs was produced in VAE.\n"
+ "Web UI will now convert VAE into 32-bit float and retry.\n"
+ "To disable this behavior, disable the 'Automaticlly revert VAE to 32-bit floats' setting.\n"
+ "To always start with 32-bit VAE, use --no-half-vae commandline flag."
+ )
+
+ devices.dtype_vae = torch.float32
+ model.first_stage_model.to(devices.dtype_vae)
+ batch = batch.to(devices.dtype_vae)
+
+ sample = decode_first_stage(model, batch[i:i + 1])[0]
+
+ if target_device is not None:
+ sample = sample.to(target_device)
+
+ samples.append(sample)
+
+ return samples
+
+
def decode_first_stage(model, x):
- with devices.autocast(disable=x.dtype == devices.dtype_vae):
- x = model.decode_first_stage(x)
+ x = model.decode_first_stage(x.to(devices.dtype_vae))
return x
@@ -549,7 +600,7 @@ def program_version():
return res
-def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
+def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False):
index = position_in_batch + iteration * p.batch_size
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
@@ -570,10 +621,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
- "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
+ "Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
- "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
+ "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None),
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
"Clip skip": None if clip_skip <= 1 else clip_skip,
@@ -585,13 +636,15 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
**p.extra_generation_params,
"Version": program_version() if opts.add_version_to_infotext else None,
+ "User": p.user if opts.add_user_name_to_info else None,
}
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
+ prompt_text = p.prompt if use_main_prompt else all_prompts[index]
negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
- return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
+ return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
def process_images(p: StableDiffusionProcessing) -> Processed:
@@ -602,7 +655,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
try:
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
- if sd_models.checkpoint_alisases.get(p.override_settings.get('sd_model_checkpoint')) is None:
+ if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
p.override_settings.pop('sd_model_checkpoint', None)
sd_models.reload_model_weights()
@@ -663,8 +716,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
else:
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
- def infotext(iteration=0, position_in_batch=0):
- return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
+ def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):
+ return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()
@@ -728,9 +781,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.setup_conds()
- if len(model_hijack.comments) > 0:
- for comment in model_hijack.comments:
- comments[comment] = 1
+ for comment in model_hijack.comments:
+ comments[comment] = 1
+
+ p.extra_generation_params.update(model_hijack.extra_generation_params)
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
@@ -738,10 +792,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
- x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
- for x in x_samples_ddim:
- devices.test_for_nans(x, "vae")
-
+ x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@@ -824,7 +875,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
grid = images.image_grid(output_images, p.batch_size)
if opts.return_grid:
- text = infotext()
+ text = infotext(use_main_prompt=True)
infotexts.insert(0, text)
if opts.enable_pnginfo:
grid.info["parameters"] = text
@@ -832,7 +883,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
index_of_first_image = 1
if opts.grid_save:
- images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
+ images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
if not p.disable_extra_networks and p.extra_network_data:
extra_networks.deactivate(p, p.extra_network_data)
@@ -1009,7 +1060,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
image = sd_samplers.sample_to_image(image, index, approximation=0)
info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
- images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, suffix="-before-highres-fix")
+ images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
if latent_scale_mode is not None:
for i in range(samples.shape[0]):
@@ -1074,6 +1125,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
+ if self.scripts is not None:
+ self.scripts.before_hr(self)
+
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
@@ -1280,7 +1334,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image = torch.from_numpy(batch_images)
image = 2. * image - 1.
- image = image.to(shared.device)
+ image = image.to(shared.device, dtype=devices.dtype_vae)
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index 0069d8b0..b29d079d 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import re
from collections import namedtuple
from typing import List
@@ -109,7 +111,25 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
-def get_learned_conditioning(model, prompts, steps):
+class SdConditioning(list):
+ """
+ A list with prompts for stable diffusion's conditioner model.
+ Can also specify width and height of created image - SDXL needs it.
+ """
+ def __init__(self, prompts, is_negative_prompt=False, width=None, height=None, copy_from=None):
+ super().__init__()
+ self.extend(prompts)
+
+ if copy_from is None:
+ copy_from = prompts
+
+ self.is_negative_prompt = is_negative_prompt or getattr(copy_from, 'is_negative_prompt', False)
+ self.width = width or getattr(copy_from, 'width', None)
+ self.height = height or getattr(copy_from, 'height', None)
+
+
+
+def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
and the sampling step at which this condition is to be replaced by the next one.
@@ -139,12 +159,17 @@ def get_learned_conditioning(model, prompts, steps):
res.append(cached)
continue
- texts = [x[1] for x in prompt_schedule]
+ texts = SdConditioning([x[1] for x in prompt_schedule], copy_from=prompts)
conds = model.get_learned_conditioning(texts)
cond_schedule = []
for i, (end_at_step, _) in enumerate(prompt_schedule):
- cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
+ if isinstance(conds, dict):
+ cond = {k: v[i] for k, v in conds.items()}
+ else:
+ cond = conds[i]
+
+ cond_schedule.append(ScheduledPromptConditioning(end_at_step, cond))
cache[prompt] = cond_schedule
res.append(cond_schedule)
@@ -155,11 +180,13 @@ def get_learned_conditioning(model, prompts, steps):
re_AND = re.compile(r"\bAND\b")
re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
-def get_multicond_prompt_list(prompts):
+
+def get_multicond_prompt_list(prompts: SdConditioning | list[str]):
res_indexes = []
- prompt_flat_list = []
prompt_indexes = {}
+ prompt_flat_list = SdConditioning(prompts)
+ prompt_flat_list.clear()
for prompt in prompts:
subprompts = re_AND.split(prompt)
@@ -196,6 +223,7 @@ class MulticondLearnedConditioning:
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
+
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
For each prompt, the list is obtained by splitting the prompt using the AND separator.
@@ -214,20 +242,57 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne
return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
+class DictWithShape(dict):
+ def __init__(self, x, shape):
+ super().__init__()
+ self.update(x)
+
+ @property
+ def shape(self):
+ return self["crossattn"].shape
+
+
def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
param = c[0][0].cond
- res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
+ is_dict = isinstance(param, dict)
+
+ if is_dict:
+ dict_cond = param
+ res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()}
+ res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape)
+ else:
+ res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
+
for i, cond_schedule in enumerate(c):
target_index = 0
for current, entry in enumerate(cond_schedule):
if current_step <= entry.end_at_step:
target_index = current
break
- res[i] = cond_schedule[target_index].cond
+
+ if is_dict:
+ for k, param in cond_schedule[target_index].cond.items():
+ res[k][i] = param
+ else:
+ res[i] = cond_schedule[target_index].cond
return res
+def stack_conds(tensors):
+ # if prompts have wildly different lengths above the limit we'll get tensors of different shapes
+ # and won't be able to torch.stack them. So this fixes that.
+ token_count = max([x.shape[0] for x in tensors])
+ for i in range(len(tensors)):
+ if tensors[i].shape[0] != token_count:
+ last_vector = tensors[i][-1:]
+ last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
+ tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
+
+ return torch.stack(tensors)
+
+
+
def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
param = c.batch[0][0].schedules[0].cond
@@ -249,16 +314,14 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
conds_list.append(conds_for_batch)
- # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
- # and won't be able to torch.stack them. So this fixes that.
- token_count = max([x.shape[0] for x in tensors])
- for i in range(len(tensors)):
- if tensors[i].shape[0] != token_count:
- last_vector = tensors[i][-1:]
- last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
- tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
+ if isinstance(tensors[0], dict):
+ keys = list(tensors[0].keys())
+ stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys}
+ stacked = DictWithShape(stacked, stacked['crossattn'].shape)
+ else:
+ stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype)
- return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
+ return conds_list, stacked
re_attention = re.compile(r"""
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index 2d27b321..0700b853 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -2,7 +2,6 @@ import os
import numpy as np
from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from modules.upscaler import Upscaler, UpscalerData
@@ -43,9 +42,10 @@ class UpscalerRealESRGAN(Upscaler):
if not self.enable:
return img
- info = self.load_model(path)
- if not os.path.exists(info.local_data_path):
- print(f"Unable to load RealESRGAN model: {info.name}")
+ try:
+ info = self.load_model(path)
+ except Exception:
+ errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
return img
upsampler = RealESRGANer(
@@ -63,20 +63,17 @@ class UpscalerRealESRGAN(Upscaler):
return image
def load_model(self, path):
- try:
- info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)
-
- if info is None:
- print(f"Unable to find model info: {path}")
- return None
-
- if info.local_data_path.startswith("http"):
- info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True)
-
- return info
- except Exception:
- errors.report("Error making Real-ESRGAN models list", exc_info=True)
- return None
+ for scaler in self.scalers:
+ if scaler.data_path == path:
+ if scaler.local_data_path.startswith("http"):
+ scaler.local_data_path = modelloader.load_file_from_url(
+ scaler.data_path,
+ model_dir=self.model_download_path,
+ )
+ if not os.path.exists(scaler.local_data_path):
+ raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}")
+ return scaler
+ raise ValueError(f"Unable to find model info: {path}")
def load_models(self, _):
return get_realesrgan_models(self)
diff --git a/modules/script_loading.py b/modules/script_loading.py
index 306a1f35..0d55f193 100644
--- a/modules/script_loading.py
+++ b/modules/script_loading.py
@@ -12,11 +12,12 @@ def load_module(path):
return module
-def preload_extensions(extensions_dir, parser):
+def preload_extensions(extensions_dir, parser, extension_list=None):
if not os.path.isdir(extensions_dir):
return
- for dirname in sorted(os.listdir(extensions_dir)):
+ extensions = extension_list if extension_list is not None else os.listdir(extensions_dir)
+ for dirname in sorted(extensions):
preload_script = os.path.join(extensions_dir, dirname, "preload.py")
if not os.path.isfile(preload_script):
continue
diff --git a/modules/scripts.py b/modules/scripts.py
index 99bf836a..7d9dd59f 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -1,6 +1,7 @@
import os
import re
import sys
+import inspect
from collections import namedtuple
import gradio as gr
@@ -116,6 +117,21 @@ class Script:
pass
+ def after_extra_networks_activate(self, p, *args, **kwargs):
+ """
+ Calledafter extra networks activation, before conds calculation
+ allow modification of the network after extra networks activation been applied
+ won't be call if p.disable_extra_networks
+
+ **kwargs will have those items:
+ - batch_number - index of current batch, from 0 to number of batches-1
+ - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
+ - seeds - list of seeds for current batch
+ - subseeds - list of subseeds for current batch
+ - extra_network_data - list of ExtraNetworkParams for current stage
+ """
+ pass
+
def process_batch(self, p, *args, **kwargs):
"""
Same as process(), but called for every batch.
@@ -186,6 +202,11 @@ class Script:
return f'script_{tabname}{title}_{item_id}'
+ def before_hr(self, p, *args):
+ """
+ This function is called before hires fix start.
+ """
+ pass
current_basedir = paths.script_path
@@ -249,7 +270,7 @@ def load_scripts():
def register_scripts_from_module(module):
for script_class in module.__dict__.values():
- if type(script_class) != type:
+ if not inspect.isclass(script_class):
continue
if issubclass(script_class, Script):
@@ -483,6 +504,14 @@ class ScriptRunner:
except Exception:
errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
+ def after_extra_networks_activate(self, p, **kwargs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.after_extra_networks_activate(p, *script_args, **kwargs)
+ except Exception:
+ errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True)
+
def process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
try:
@@ -548,6 +577,15 @@ class ScriptRunner:
self.scripts[si].args_to = args_to
+ def before_hr(self, p):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.before_hr(p, *script_args)
+ except Exception:
+ errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
+
+
scripts_txt2img: ScriptRunner = None
scripts_img2img: ScriptRunner = None
scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 3b6f95ce..c8fdd4f1 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -15,6 +15,11 @@ import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
import ldm.modules.encoders.modules
+import sgm.modules.attention
+import sgm.modules.diffusionmodules.model
+import sgm.modules.diffusionmodules.openaimodel
+import sgm.modules.encoders.modules
+
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
@@ -56,6 +61,9 @@ def apply_optimizations(option=None):
ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
+ sgm.modules.diffusionmodules.model.nonlinearity = silu
+ sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
+
if current_optimizer is not None:
current_optimizer.undo()
current_optimizer = None
@@ -89,6 +97,10 @@ def undo_optimizations():
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
+ sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
+ sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
+
def fix_checkpoint():
"""checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
@@ -147,7 +159,6 @@ def undo_weighted_forward(sd_model):
class StableDiffusionModelHijack:
fixes = None
- comments = []
layers = None
circular_enabled = False
clip = None
@@ -156,6 +167,9 @@ class StableDiffusionModelHijack:
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
def __init__(self):
+ self.extra_generation_params = {}
+ self.comments = []
+
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
def apply_optimizations(self, option=None):
@@ -166,6 +180,32 @@ class StableDiffusionModelHijack:
undo_optimizations()
def hijack(self, m):
+ conditioner = getattr(m, 'conditioner', None)
+ if conditioner:
+ text_cond_models = []
+
+ for i in range(len(conditioner.embedders)):
+ embedder = conditioner.embedders[i]
+ typename = type(embedder).__name__
+ if typename == 'FrozenOpenCLIPEmbedder':
+ embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
+ conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
+ text_cond_models.append(conditioner.embedders[i])
+ if typename == 'FrozenCLIPEmbedder':
+ model_embeddings = embedder.transformer.text_model.embeddings
+ model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
+ conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
+ text_cond_models.append(conditioner.embedders[i])
+ if typename == 'FrozenOpenCLIPEmbedder2':
+ embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
+ conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
+ text_cond_models.append(conditioner.embedders[i])
+
+ if len(text_cond_models) == 1:
+ m.cond_stage_model = text_cond_models[0]
+ else:
+ m.cond_stage_model = conditioner
+
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
@@ -203,7 +243,7 @@ class StableDiffusionModelHijack:
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
def undo_hijack(self, m):
- if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
+ if type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
@@ -236,6 +276,7 @@ class StableDiffusionModelHijack:
def clear_comments(self):
self.comments = []
+ self.extra_generation_params = {}
def get_prompt_lengths(self, text):
if self.clip is None:
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index 3b5a7666..5443e609 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -42,6 +42,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
self.chunk_length = 75
+ self.is_trainable = getattr(wrapped, 'is_trainable', False)
+ self.input_key = getattr(wrapped, 'input_key', 'txt')
+ self.legacy_ucg_val = None
+
def empty_chunk(self):
"""creates an empty PromptChunk and returns it"""
@@ -199,8 +203,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
"""
Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
- be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
+ be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280.
An example shape returned by this function can be: (2, 77, 768).
+ For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values.
Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
"""
@@ -229,11 +234,23 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
z = self.process_tokens(tokens, multipliers)
zs.append(z)
- if len(used_embeddings) > 0:
- embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()])
- self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
+ if opts.textual_inversion_add_hashes_to_infotext and used_embeddings:
+ hashes = []
+ for name, embedding in used_embeddings.items():
+ shorthash = embedding.shorthash
+ if not shorthash:
+ continue
+
+ name = name.replace(":", "").replace(",", "")
+ hashes.append(f"{name}: {shorthash}")
+
+ if hashes:
+ self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
- return torch.hstack(zs)
+ if getattr(self.wrapped, 'return_pooled', False):
+ return torch.hstack(zs), zs[0].pooled
+ else:
+ return torch.hstack(zs)
def process_tokens(self, remade_batch_tokens, batch_multipliers):
"""
@@ -256,9 +273,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
original_mean = z.mean()
- z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
+ z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
- z = z * (original_mean / new_mean)
+ z *= (original_mean / new_mean)
return z
@@ -315,3 +332,18 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
return embedded
+
+
+class FrozenCLIPEmbedderForSDXLWithCustomWords(FrozenCLIPEmbedderWithCustomWords):
+ def __init__(self, wrapped, hijack):
+ super().__init__(wrapped, hijack)
+
+ def encode_with_transformers(self, tokens):
+ outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=self.wrapped.layer == "hidden")
+
+ if self.wrapped.layer == "last":
+ z = outputs.last_hidden_state
+ else:
+ z = outputs.hidden_states[self.wrapped.layer_idx]
+
+ return z
diff --git a/modules/sd_hijack_open_clip.py b/modules/sd_hijack_open_clip.py
index f733e852..bb0b96c7 100644
--- a/modules/sd_hijack_open_clip.py
+++ b/modules/sd_hijack_open_clip.py
@@ -32,6 +32,40 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit
def encode_embedding_init_text(self, init_text, nvpt):
ids = tokenizer.encode(init_text)
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
- embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
+ embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0)
+
+ return embedded
+
+
+class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
+ def __init__(self, wrapped, hijack):
+ super().__init__(wrapped, hijack)
+
+ self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
+ self.id_start = tokenizer.encoder["<start_of_text>"]
+ self.id_end = tokenizer.encoder["<end_of_text>"]
+ self.id_pad = 0
+
+ def tokenize(self, texts):
+ assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
+
+ tokenized = [tokenizer.encode(text) for text in texts]
+
+ return tokenized
+
+ def encode_with_transformers(self, tokens):
+ d = self.wrapped.encode_with_transformer(tokens)
+ z = d[self.wrapped.layer]
+
+ pooled = d.get("pooled")
+ if pooled is not None:
+ z.pooled = pooled
+
+ return z
+
+ def encode_embedding_init_text(self, init_text, nvpt):
+ ids = tokenizer.encode(init_text)
+ ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
+ embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0)
return embedded
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 53e27ade..b5f85ba5 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -14,7 +14,11 @@ from modules.hypernetworks import hypernetwork
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
+import sgm.modules.attention
+import sgm.modules.diffusionmodules.model
+
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
+sgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward
class SdOptimization:
@@ -39,6 +43,9 @@ class SdOptimization:
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
+ sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward
+
class SdOptimizationXformers(SdOptimization):
name = "xformers"
@@ -51,6 +58,8 @@ class SdOptimizationXformers(SdOptimization):
def apply(self):
ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
+ sgm.modules.attention.CrossAttention.forward = xformers_attention_forward
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
class SdOptimizationSdpNoMem(SdOptimization):
@@ -65,6 +74,8 @@ class SdOptimizationSdpNoMem(SdOptimization):
def apply(self):
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
+ sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
class SdOptimizationSdp(SdOptimizationSdpNoMem):
@@ -76,6 +87,8 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem):
def apply(self):
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
+ sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
class SdOptimizationSubQuad(SdOptimization):
@@ -86,6 +99,8 @@ class SdOptimizationSubQuad(SdOptimization):
def apply(self):
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
+ sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
class SdOptimizationV1(SdOptimization):
@@ -94,9 +109,9 @@ class SdOptimizationV1(SdOptimization):
cmd_opt = "opt_split_attention_v1"
priority = 10
-
def apply(self):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
+ sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
class SdOptimizationInvokeAI(SdOptimization):
@@ -109,6 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization):
def apply(self):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
+ sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
class SdOptimizationDoggettx(SdOptimization):
@@ -119,6 +135,8 @@ class SdOptimizationDoggettx(SdOptimization):
def apply(self):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
+ sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
def list_optimizers(res):
@@ -155,7 +173,7 @@ def get_available_vram():
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
-def split_cross_attention_forward_v1(self, x, context=None, mask=None):
+def split_cross_attention_forward_v1(self, x, context=None, mask=None, **kwargs):
h = self.heads
q_in = self.to_q(x)
@@ -196,7 +214,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
# taken from https://github.com/Doggettx/stable-diffusion and modified
-def split_cross_attention_forward(self, x, context=None, mask=None):
+def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs):
h = self.heads
q_in = self.to_q(x)
@@ -262,11 +280,13 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
+
def einsum_op_compvis(q, k, v):
s = einsum('b i d, b j d -> b i j', q, k)
s = s.softmax(dim=-1, dtype=s.dtype)
return einsum('b i j, b j d -> b i d', s, v)
+
def einsum_op_slice_0(q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[0], slice_size):
@@ -274,6 +294,7 @@ def einsum_op_slice_0(q, k, v, slice_size):
r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
return r
+
def einsum_op_slice_1(q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[1], slice_size):
@@ -281,6 +302,7 @@ def einsum_op_slice_1(q, k, v, slice_size):
r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
return r
+
def einsum_op_mps_v1(q, k, v):
if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
return einsum_op_compvis(q, k, v)
@@ -290,12 +312,14 @@ def einsum_op_mps_v1(q, k, v):
slice_size -= 1
return einsum_op_slice_1(q, k, v, slice_size)
+
def einsum_op_mps_v2(q, k, v):
if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
return einsum_op_compvis(q, k, v)
else:
return einsum_op_slice_0(q, k, v, 1)
+
def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
if size_mb <= max_tensor_mb:
@@ -305,6 +329,7 @@ def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
return einsum_op_slice_0(q, k, v, q.shape[0] // div)
return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
+
def einsum_op_cuda(q, k, v):
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
@@ -315,6 +340,7 @@ def einsum_op_cuda(q, k, v):
# Divide factor of safety as there's copying and fragmentation
return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
+
def einsum_op(q, k, v):
if q.device.type == 'cuda':
return einsum_op_cuda(q, k, v)
@@ -328,7 +354,8 @@ def einsum_op(q, k, v):
# Tested on i7 with 8MB L3 cache.
return einsum_op_tensor_mem(q, k, v, 32)
-def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
+
+def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, **kwargs):
h = self.heads
q = self.to_q(x)
@@ -356,7 +383,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
-def sub_quad_attention_forward(self, x, context=None, mask=None):
+def sub_quad_attention_forward(self, x, context=None, mask=None, **kwargs):
assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
h = self.heads
@@ -392,6 +419,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
return x
+
def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
bytes_per_token = torch.finfo(q.dtype).bits//8
batch_x_heads, q_tokens, _ = q.shape
@@ -442,7 +470,7 @@ def get_xformers_flash_attention_op(q, k, v):
return None
-def xformers_attention_forward(self, x, context=None, mask=None):
+def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
@@ -465,9 +493,10 @@ def xformers_attention_forward(self, x, context=None, mask=None):
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
return self.to_out(out)
+
# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
-def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
+def scaled_dot_product_attention_forward(self, x, context=None, mask=None, **kwargs):
batch_size, sequence_length, inner_dim = x.shape
if mask is not None:
@@ -507,10 +536,12 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
hidden_states = self.to_out[1](hidden_states)
return hidden_states
-def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None):
+
+def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, **kwargs):
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
return scaled_dot_product_attention_forward(self, x, context, mask)
+
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
@@ -569,6 +600,7 @@ def cross_attention_attnblock_forward(self, x):
return h3
+
def xformers_attnblock_forward(self, x):
try:
h_ = x
@@ -592,6 +624,7 @@ def xformers_attnblock_forward(self, x):
except NotImplementedError:
return cross_attention_attnblock_forward(self, x)
+
def sdp_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
@@ -612,10 +645,12 @@ def sdp_attnblock_forward(self, x):
out = self.proj_out(out)
return x + out
+
def sdp_no_mem_attnblock_forward(self, x):
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
return sdp_attnblock_forward(self, x)
+
def sub_quad_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py
index ca1daf45..2101f1a0 100644
--- a/modules/sd_hijack_unet.py
+++ b/modules/sd_hijack_unet.py
@@ -39,7 +39,10 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
if isinstance(cond, dict):
for y in cond.keys():
- cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
+ if isinstance(cond[y], list):
+ cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
+ else:
+ cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]
with devices.autocast():
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
@@ -77,3 +80,6 @@ first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devi
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
+
+CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model, unet_needs_upcast)
+CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 6ff5d17d..fb31a793 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -14,7 +14,7 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
-from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet
+from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
import tomesd
@@ -23,7 +23,8 @@ model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
checkpoints_list = {}
-checkpoint_alisases = {}
+checkpoint_aliases = {}
+checkpoint_alisases = checkpoint_aliases # for compatibility with old name
checkpoints_loaded = collections.OrderedDict()
@@ -66,7 +67,7 @@ class CheckpointInfo:
def register(self):
checkpoints_list[self.title] = self
for id in self.ids:
- checkpoint_alisases[id] = self
+ checkpoint_aliases[id] = self
def calculate_shorthash(self):
self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
@@ -112,7 +113,7 @@ def checkpoint_tiles():
def list_models():
checkpoints_list.clear()
- checkpoint_alisases.clear()
+ checkpoint_aliases.clear()
cmd_ckpt = shared.cmd_opts.ckpt
if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):
@@ -136,7 +137,7 @@ def list_models():
def get_closet_checkpoint_match(search_string):
- checkpoint_info = checkpoint_alisases.get(search_string, None)
+ checkpoint_info = checkpoint_aliases.get(search_string, None)
if checkpoint_info is not None:
return checkpoint_info
@@ -166,7 +167,7 @@ def select_checkpoint():
"""Raises `FileNotFoundError` if no checkpoints are found."""
model_checkpoint = shared.opts.sd_model_checkpoint
- checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
+ checkpoint_info = checkpoint_aliases.get(model_checkpoint, None)
if checkpoint_info is not None:
return checkpoint_info
@@ -247,7 +248,12 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors":
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
- pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
+
+ if not shared.opts.disable_mmap_load_safetensors:
+ pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
+ else:
+ pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
+ pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
else:
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
@@ -283,6 +289,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if state_dict is None:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
+ model.is_sdxl = hasattr(model, 'conditioner')
+ model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
+ model.is_sd1 = not model.is_sdxl and not model.is_sd2
+
+ if model.is_sdxl:
+ sd_models_xl.extend_sdxl(model)
+
model.load_state_dict(state_dict, strict=False)
del state_dict
timer.record("apply weights to model")
@@ -313,7 +326,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
timer.record("apply half()")
- devices.dtype_unet = model.model.diffusion_model.dtype
+ devices.dtype_unet = torch.float16 if model.is_sdxl and not shared.cmd_opts.no_half else model.model.diffusion_model.dtype
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae)
@@ -328,7 +341,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
model.sd_checkpoint_info = checkpoint_info
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
- model.logvar = model.logvar.to(devices.device) # fix for training
+ if hasattr(model, 'logvar'):
+ model.logvar = model.logvar.to(devices.device) # fix for training
sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae()
@@ -385,10 +399,11 @@ def repair_config(sd_config):
if not hasattr(sd_config.model.params, "use_ema"):
sd_config.model.params.use_ema = False
- if shared.cmd_opts.no_half:
- sd_config.model.params.unet_config.params.use_fp16 = False
- elif shared.cmd_opts.upcast_sampling:
- sd_config.model.params.unet_config.params.use_fp16 = True
+ if hasattr(sd_config.model.params, 'unet_config'):
+ if shared.cmd_opts.no_half:
+ sd_config.model.params.unet_config.params.use_fp16 = False
+ elif shared.cmd_opts.upcast_sampling:
+ sd_config.model.params.unet_config.params.use_fp16 = True
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
@@ -401,6 +416,8 @@ def repair_config(sd_config):
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
+sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
+sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
class SdModelData:
@@ -435,6 +452,15 @@ class SdModelData:
model_data = SdModelData()
+def get_empty_cond(sd_model):
+ if hasattr(sd_model, 'conditioner'):
+ d = sd_model.get_learned_conditioning([""])
+ return d['crossattn']
+ else:
+ return sd_model.cond_stage_model([""])
+
+
+
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
@@ -455,7 +481,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
- clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict
+ clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)
timer.record("find config")
@@ -468,7 +494,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
sd_model = None
try:
- with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
+ with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
sd_model = instantiate_from_config(sd_config.model)
except Exception:
pass
@@ -507,7 +533,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("scripts callbacks")
with devices.autocast(), torch.no_grad():
- sd_model.cond_stage_model_empty_prompt = sd_model.cond_stage_model([""])
+ sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)
timer.record("calculate empty prompt")
@@ -585,7 +611,6 @@ def unload_model_weights(sd_model=None, info=None):
sd_model = None
gc.collect()
devices.torch_gc()
- torch.cuda.empty_cache()
print(f"Unloaded weights {timer.summary()}.")
diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py
index 9bfe1237..8266fa39 100644
--- a/modules/sd_models_config.py
+++ b/modules/sd_models_config.py
@@ -6,12 +6,15 @@ from modules import shared, paths, sd_disable_initialization
sd_configs_path = shared.sd_configs_path
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
+sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference")
config_default = shared.sd_default_config
config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
+config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
+config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
@@ -68,7 +71,11 @@ def guess_model_config_from_state_dict(sd, filename):
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
- if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
+ if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
+ return config_sdxl
+ if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
+ return config_sdxl_refiner
+ elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
return config_depth_model
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
return config_unclip
diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py
new file mode 100644
index 00000000..40559208
--- /dev/null
+++ b/modules/sd_models_xl.py
@@ -0,0 +1,99 @@
+from __future__ import annotations
+
+import torch
+
+import sgm.models.diffusion
+import sgm.modules.diffusionmodules.denoiser_scaling
+import sgm.modules.diffusionmodules.discretizer
+from modules import devices, shared, prompt_parser
+
+
+def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
+ for embedder in self.conditioner.embedders:
+ embedder.ucg_rate = 0.0
+
+ width = getattr(batch, 'width', 1024)
+ height = getattr(batch, 'height', 1024)
+ is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
+ aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
+
+ devices_args = dict(device=devices.device, dtype=devices.dtype)
+
+ sdxl_conds = {
+ "txt": batch,
+ "original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
+ "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1),
+ "target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
+ "aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1),
+ }
+
+ force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch)
+ c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else [])
+
+ return c
+
+
+def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
+ return self.model(x, t, cond)
+
+
+def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility
+ return x
+
+
+sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
+sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
+sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding
+
+
+def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt):
+ res = []
+
+ for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]:
+ encoded = embedder.encode_embedding_init_text(init_text, nvpt)
+ res.append(encoded)
+
+ return torch.cat(res, dim=1)
+
+
+def process_texts(self, texts):
+ for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
+ return embedder.process_texts(texts)
+
+
+def get_target_prompt_token_count(self, token_count):
+ for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]:
+ return embedder.get_target_prompt_token_count(token_count)
+
+
+# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
+sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
+sgm.modules.GeneralConditioner.process_texts = process_texts
+sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
+
+
+def extend_sdxl(model):
+ """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
+
+ dtype = next(model.model.diffusion_model.parameters()).dtype
+ model.model.diffusion_model.dtype = dtype
+ model.model.conditioning_key = 'crossattn'
+ model.cond_stage_key = 'txt'
+ # model.cond_stage_model will be set in sd_hijack
+
+ model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
+
+ discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
+ model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
+
+ model.conditioner.wrapped = torch.nn.Module()
+
+
+sgm.modules.attention.print = lambda *args: None
+sgm.modules.diffusionmodules.model.print = lambda *args: None
+sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None
+sgm.modules.encoders.modules.print = lambda *args: None
+
+# this gets the code to load the vanilla attention that we override
+sgm.modules.attention.SDP_IS_AVAILABLE = True
+sgm.modules.attention.XFORMERS_IS_AVAILABLE = False
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index f22aad8f..bea2684c 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -28,6 +28,9 @@ def create_sampler(name, model):
assert config is not None, f'bad sampler name: {name}'
+ if model.is_sdxl and config.options.get("no_sdxl", False):
+ raise Exception(f"Sampler {config.name} is not supported for SDXL")
+
sampler = config.constructor(model)
sampler.config = config
diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py
index bdae8b40..4a8396f9 100644
--- a/modules/sd_samplers_compvis.py
+++ b/modules/sd_samplers_compvis.py
@@ -11,9 +11,9 @@ import modules.models.diffusion.uni_pc
samplers_data_compvis = [
- sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True}),
- sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
- sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {}),
+ sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True, "no_sdxl": True}),
+ sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {"no_sdxl": True}),
+ sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {"no_sdxl": True}),
]
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 71581b76..5552a8dc 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -53,6 +53,28 @@ k_diffusion_scheduler = {
}
+def catenate_conds(conds):
+ if not isinstance(conds[0], dict):
+ return torch.cat(conds)
+
+ return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
+
+
+def subscript_cond(cond, a, b):
+ if not isinstance(cond, dict):
+ return cond[a:b]
+
+ return {key: vec[a:b] for key, vec in cond.items()}
+
+
+def pad_cond(tensor, repeats, empty):
+ if not isinstance(tensor, dict):
+ return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
+
+ tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
+ return tensor
+
+
class CFGDenoiser(torch.nn.Module):
"""
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
@@ -105,10 +127,13 @@ class CFGDenoiser(torch.nn.Module):
if shared.sd_model.model.conditioning_key == "crossattn-adm":
image_uncond = torch.zeros_like(image_cond)
- make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
+ make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
else:
image_uncond = image_cond
- make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
+ if isinstance(uncond, dict):
+ make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
+ else:
+ make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
if not is_edit_model:
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
@@ -140,28 +165,28 @@ class CFGDenoiser(torch.nn.Module):
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
if num_repeats < 0:
- tensor = torch.cat([tensor, empty.repeat((tensor.shape[0], -num_repeats, 1))], axis=1)
+ tensor = pad_cond(tensor, -num_repeats, empty)
self.padded_cond_uncond = True
elif num_repeats > 0:
- uncond = torch.cat([uncond, empty.repeat((uncond.shape[0], num_repeats, 1))], axis=1)
+ uncond = pad_cond(uncond, num_repeats, empty)
self.padded_cond_uncond = True
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
if is_edit_model:
- cond_in = torch.cat([tensor, uncond, uncond])
+ cond_in = catenate_conds([tensor, uncond, uncond])
elif skip_uncond:
cond_in = tensor
else:
- cond_in = torch.cat([tensor, uncond])
+ cond_in = catenate_conds([tensor, uncond])
if shared.batch_cond_uncond:
- x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in))
+ x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
else:
x_out = torch.zeros_like(x_in)
for batch_offset in range(0, x_out.shape[0], batch_size):
a = batch_offset
b = a + batch_size
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict([cond_in[a:b]], image_cond_in[a:b]))
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
else:
x_out = torch.zeros_like(x_in)
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
@@ -170,14 +195,14 @@ class CFGDenoiser(torch.nn.Module):
b = min(a + batch_size, tensor.shape[0])
if not is_edit_model:
- c_crossattn = [tensor[a:b]]
+ c_crossattn = subscript_cond(tensor, a, b)
else:
c_crossattn = torch.cat([tensor[a:b]], uncond)
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
if not skip_uncond:
- x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
+ x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
denoised_image_indexes = [x[0][0] for x in conds_list]
if skip_uncond:
diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py
index e2f00468..86bd658a 100644
--- a/modules/sd_vae_approx.py
+++ b/modules/sd_vae_approx.py
@@ -2,9 +2,9 @@ import os
import torch
from torch import nn
-from modules import devices, paths
+from modules import devices, paths, shared
-sd_vae_approx_model = None
+sd_vae_approx_models = {}
class VAEApprox(nn.Module):
@@ -31,30 +31,55 @@ class VAEApprox(nn.Module):
return x
+def download_model(model_path, model_url):
+ if not os.path.exists(model_path):
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
+
+ print(f'Downloading VAEApprox model to: {model_path}')
+ torch.hub.download_url_to_file(model_url, model_path)
+
+
def model():
- global sd_vae_approx_model
+ model_name = "vaeapprox-sdxl.pt" if getattr(shared.sd_model, 'is_sdxl', False) else "model.pt"
+ loaded_model = sd_vae_approx_models.get(model_name)
- if sd_vae_approx_model is None:
- model_path = os.path.join(paths.models_path, "VAE-approx", "model.pt")
- sd_vae_approx_model = VAEApprox()
+ if loaded_model is None:
+ model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
if not os.path.exists(model_path):
- model_path = os.path.join(paths.script_path, "models", "VAE-approx", "model.pt")
- sd_vae_approx_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
- sd_vae_approx_model.eval()
- sd_vae_approx_model.to(devices.device, devices.dtype)
+ model_path = os.path.join(paths.script_path, "models", "VAE-approx", model_name)
+
+ if not os.path.exists(model_path):
+ model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
+ download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name)
+
+ loaded_model = VAEApprox()
+ loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
+ loaded_model.eval()
+ loaded_model.to(devices.device, devices.dtype)
+ sd_vae_approx_models[model_name] = loaded_model
- return sd_vae_approx_model
+ return loaded_model
def cheap_approximation(sample):
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
- coefs = torch.tensor([
- [0.298, 0.207, 0.208],
- [0.187, 0.286, 0.173],
- [-0.158, 0.189, 0.264],
- [-0.184, -0.271, -0.473],
- ]).to(sample.device)
+ if shared.sd_model.is_sdxl:
+ coeffs = [
+ [ 0.3448, 0.4168, 0.4395],
+ [-0.1953, -0.0290, 0.0250],
+ [ 0.1074, 0.0886, -0.0163],
+ [-0.3730, -0.2499, -0.2088],
+ ]
+ else:
+ coeffs = [
+ [ 0.298, 0.207, 0.208],
+ [ 0.187, 0.286, 0.173],
+ [-0.158, 0.189, 0.264],
+ [-0.184, -0.271, -0.473],
+ ]
+
+ coefs = torch.tensor(coeffs).to(sample.device)
x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py
index 5e8496e8..5bf7c76e 100644
--- a/modules/sd_vae_taesd.py
+++ b/modules/sd_vae_taesd.py
@@ -8,9 +8,9 @@ import os
import torch
import torch.nn as nn
-from modules import devices, paths_internal
+from modules import devices, paths_internal, shared
-sd_vae_taesd = None
+sd_vae_taesd_models = {}
def conv(n_in, n_out, **kwargs):
@@ -61,9 +61,7 @@ class TAESD(nn.Module):
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
-def download_model(model_path):
- model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth'
-
+def download_model(model_path, model_url):
if not os.path.exists(model_path):
os.makedirs(os.path.dirname(model_path), exist_ok=True)
@@ -72,17 +70,19 @@ def download_model(model_path):
def model():
- global sd_vae_taesd
+ model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
+ loaded_model = sd_vae_taesd_models.get(model_name)
- if sd_vae_taesd is None:
- model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth")
- download_model(model_path)
+ if loaded_model is None:
+ model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name)
+ download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
if os.path.exists(model_path):
- sd_vae_taesd = TAESD(model_path)
- sd_vae_taesd.eval()
- sd_vae_taesd.to(devices.device, devices.dtype)
+ loaded_model = TAESD(model_path)
+ loaded_model.eval()
+ loaded_model.to(devices.device, devices.dtype)
+ sd_vae_taesd_models[model_name] = loaded_model
else:
raise FileNotFoundError('TAESD model not found')
- return sd_vae_taesd.decoder
+ return loaded_model.decoder
diff --git a/modules/shared.py b/modules/shared.py
index a0862055..aa72c9c8 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -1,14 +1,17 @@
import datetime
import json
import os
+import re
import sys
import threading
import time
+import logging
import gradio as gr
import torch
import tqdm
+import launch
import modules.interrogate
import modules.memmon
import modules.styles
@@ -18,11 +21,13 @@ from modules.paths_internal import models_path, script_path, data_path, sd_confi
from ldm.models.diffusion.ddpm import LatentDiffusion
from typing import Optional
+log = logging.getLogger(__name__)
+
demo = None
parser = cmd_args.parser
-script_loading.preload_extensions(extensions_dir, parser)
+script_loading.preload_extensions(extensions_dir, parser, extension_list=launch.list_extensions(launch.args.ui_settings_file))
script_loading.preload_extensions(extensions_builtin_dir, parser)
if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
@@ -144,12 +149,15 @@ class State:
def request_restart(self) -> None:
self.interrupt()
self.server_command = "restart"
+ log.info("Received restart request")
def skip(self):
self.skipped = True
+ log.info("Received skip request")
def interrupt(self):
self.interrupted = True
+ log.info("Received interrupt request")
def nextjob(self):
if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
@@ -173,7 +181,7 @@ class State:
return obj
- def begin(self):
+ def begin(self, job: str = "(unknown)"):
self.sampling_step = 0
self.job_count = -1
self.processing_has_refined_job_count = False
@@ -187,10 +195,13 @@ class State:
self.interrupted = False
self.textinfo = None
self.time_start = time.time()
-
+ self.job = job
devices.torch_gc()
+ log.info("Starting job %s", job)
def end(self):
+ duration = time.time() - self.time_start
+ log.info("Ending job %s (%.2f seconds)", self.job, duration)
self.job = ""
self.job_count = 0
@@ -311,6 +322,10 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
"grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
+ "font": OptionInfo("", "Font for image grids that have text"),
+ "grid_text_active_color": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}),
+ "grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}),
+ "grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}),
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
@@ -376,6 +391,7 @@ options_templates.update(options_section(('system', "System"), {
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
+ "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
}))
options_templates.update(options_section(('training', "Training"), {
@@ -411,12 +427,20 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
+ "auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
}))
+options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
+ "sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
+ "sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
+ "sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
+ "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
+}))
+
options_templates.update(options_section(('optimizations', "Optimizations"), {
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
- "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
+ "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
@@ -451,12 +475,15 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
"extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."),
"extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'),
- "extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}),
- "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ "extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}),
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"),
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
+ "extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
+ "extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_restart(),
+ "textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
+ "textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"),
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks),
}))
@@ -470,7 +497,6 @@ options_templates.update(options_section(('ui', "User interface"), {
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
- "font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
@@ -481,6 +507,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
+ "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(),
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
@@ -493,6 +520,7 @@ options_templates.update(options_section(('ui', "User interface"), {
options_templates.update(options_section(('infotext', "Infotext"), {
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
+ "add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
"disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
"infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
@@ -817,8 +845,12 @@ mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
mem_mon.start()
+def natural_sort_key(s, regex=re.compile('([0-9]+)')):
+ return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)]
+
+
def listfiles(dirname):
- filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=str.lower) if not x.startswith(".")]
+ filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")]
return [file for file in filenames if os.path.isfile(file)]
@@ -843,8 +875,11 @@ def walk_files(path, allowed_extensions=None):
if allowed_extensions is not None:
allowed_extensions = set(allowed_extensions)
- for root, _, files in os.walk(path, followlinks=True):
- for filename in files:
+ items = list(os.walk(path, followlinks=True))
+ items = sorted(items, key=lambda x: natural_sort_key(x[0]))
+
+ for root, _, files in items:
+ for filename in sorted(files, key=natural_sort_key):
if allowed_extensions is not None:
_, ext = os.path.splitext(filename)
if ext not in allowed_extensions:
diff --git a/modules/textual_inversion/logging.py b/modules/textual_inversion/logging.py
index 734a4b6f..45823eb1 100644
--- a/modules/textual_inversion/logging.py
+++ b/modules/textual_inversion/logging.py
@@ -2,11 +2,51 @@ import datetime
import json
import os
-saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "gradient_step", "latent_sampling_method"}
-saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
-saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
+saved_params_shared = {
+ "batch_size",
+ "clip_grad_mode",
+ "clip_grad_value",
+ "create_image_every",
+ "data_root",
+ "gradient_step",
+ "initial_step",
+ "latent_sampling_method",
+ "learn_rate",
+ "log_directory",
+ "model_hash",
+ "model_name",
+ "num_of_dataset_images",
+ "steps",
+ "template_file",
+ "training_height",
+ "training_width",
+}
+saved_params_ti = {
+ "embedding_name",
+ "num_vectors_per_token",
+ "save_embedding_every",
+ "save_image_with_stored_embedding",
+}
+saved_params_hypernet = {
+ "activation_func",
+ "add_layer_norm",
+ "hypernetwork_name",
+ "layer_structure",
+ "save_hypernetwork_every",
+ "use_dropout",
+ "weight_init",
+}
saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
-saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"}
+saved_params_previews = {
+ "preview_cfg_scale",
+ "preview_height",
+ "preview_negative_prompt",
+ "preview_prompt",
+ "preview_sampler_index",
+ "preview_seed",
+ "preview_steps",
+ "preview_width",
+}
def save_settings_to_file(log_directory, all_params):
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 0d4c3f84..dbd856bd 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -7,7 +7,7 @@ from modules import paths, shared, images, deepbooru
from modules.textual_inversion import autocrop
-def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
+def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.15, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
try:
if process_caption:
shared.interrogator.load()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index bb6f211c..6166c76f 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -1,5 +1,6 @@
import os
from collections import namedtuple
+from contextlib import closing
import torch
import tqdm
@@ -12,7 +13,7 @@ import numpy as np
from PIL import Image, PngImagePlugin
from torch.utils.tensorboard import SummaryWriter
-from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors
+from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
@@ -48,6 +49,8 @@ class Embedding:
self.sd_checkpoint_name = None
self.optimizer_state_dict = None
self.filename = None
+ self.hash = None
+ self.shorthash = None
def save(self, filename):
embedding_data = {
@@ -81,6 +84,10 @@ class Embedding:
self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
return self.cached_checksum
+ def set_hash(self, v):
+ self.hash = v
+ self.shorthash = self.hash[0:12]
+
class DirWithTextualInversionEmbeddings:
def __init__(self, path):
@@ -198,6 +205,7 @@ class EmbeddingDatabase:
embedding.vectors = vec.shape[0]
embedding.shape = vec.shape[-1]
embedding.filename = path
+ embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model)
@@ -248,7 +256,7 @@ class EmbeddingDatabase:
self.word_embeddings.update(sorted_word_embeddings)
displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
- if self.previously_displayed_embeddings != displayed_embeddings:
+ if shared.opts.textual_inversion_print_at_load and self.previously_displayed_embeddings != displayed_embeddings:
self.previously_displayed_embeddings = displayed_embeddings
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
if self.skipped_embeddings:
@@ -584,8 +592,9 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
preview_text = p.prompt
- processed = processing.process_images(p)
- image = processed.images[0] if len(processed.images) > 0 else None
+ with closing(p):
+ processed = processing.process_images(p)
+ image = processed.images[0] if len(processed.images) > 0 else None
if unload:
shared.sd_model.first_stage_model.to(devices.cpu)
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 2e7d202d..29d94e8c 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -1,13 +1,15 @@
+from contextlib import closing
+
import modules.scripts
from modules import sd_samplers, processing
from modules.generation_parameters_copypaste import create_override_settings_dict
from modules.shared import opts, cmd_opts
import modules.shared as shared
from modules.ui import plaintext_to_html
+import gradio as gr
-
-def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args):
+def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts)
p = processing.StableDiffusionProcessingTxt2Img(
@@ -48,15 +50,16 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
p.scripts = modules.scripts.scripts_txt2img
p.script_args = args
+ p.user = request.username
+
if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
- processed = modules.scripts.scripts_txt2img.run(p, *args)
-
- if processed is None:
- processed = processing.process_images(p)
+ with closing(p):
+ processed = modules.scripts.scripts_txt2img.run(p, *args)
- p.close()
+ if processed is None:
+ processed = processing.process_images(p)
shared.total_tqdm.clear()
@@ -67,4 +70,4 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
if opts.do_not_show_images:
processed.images = []
- return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
+ return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
diff --git a/modules/ui.py b/modules/ui.py
index e2e3b6da..07ecee7b 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -83,8 +83,7 @@ detect_image_size_symbol = '\U0001F4D0' # 📐
up_down_symbol = '\u2195\ufe0f' # ↕️
-def plaintext_to_html(text):
- return ui_common.plaintext_to_html(text)
+plaintext_to_html = ui_common.plaintext_to_html
def send_gradio_gallery_to_image(x):
@@ -155,7 +154,7 @@ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_di
img = Image.open(image)
filename = os.path.basename(image)
left, _ = os.path.splitext(filename)
- print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a'))
+ print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a', encoding='utf-8'))
return [gr.update(), None]
@@ -733,6 +732,10 @@ def create_ui():
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
+ with gr.Accordion("PNG info", open=False):
+ img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info")
+ img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir")
+ img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.")
img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
@@ -773,7 +776,7 @@ def create_ui():
selected_scale_tab = gr.State(value=0)
with gr.Tabs():
- with gr.Tab(label="Resize to") as tab_scale_to:
+ with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to:
with FormRow():
with gr.Column(elem_id="img2img_column_size", scale=4):
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
@@ -782,7 +785,7 @@ def create_ui():
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn")
- with gr.Tab(label="Resize by") as tab_scale_by:
+ with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by:
scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
with FormRow():
@@ -934,6 +937,9 @@ def create_ui():
img2img_batch_output_dir,
img2img_batch_inpaint_mask_dir,
override_settings,
+ img2img_batch_use_png_info,
+ img2img_batch_png_info_props,
+ img2img_batch_png_info_dir,
] + custom_inputs,
outputs=[
img2img_gallery,
diff --git a/modules/ui_common.py b/modules/ui_common.py
index 57c2d0ad..11eb2a4b 100644
--- a/modules/ui_common.py
+++ b/modules/ui_common.py
@@ -29,9 +29,10 @@ def update_generation_info(generation_info, html_info, img_index):
return html_info, gr.update()
-def plaintext_to_html(text):
- text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
- return text
+def plaintext_to_html(text, classname=None):
+ content = "<br>\n".join(html.escape(x) for x in text.split('\n'))
+
+ return f"<p class='{classname}'>{content}</p>" if classname else f"<p>{content}</p>"
def save_files(js_data, images, do_make_zip, index):
@@ -157,7 +158,7 @@ Requested path was: {f}
with gr.Group():
html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
- html_log = gr.HTML(elem_id=f'html_log_{tabname}')
+ html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log")
generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
if tabname == 'txt2img' or tabname == 'img2img':
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index c7e0a866..f3e4fba7 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -1,5 +1,5 @@
import json
-import os.path
+import os
import threading
import time
from datetime import datetime
@@ -138,7 +138,10 @@ def extension_table():
<table id="extensions">
<thead>
<tr>
- <th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th>
+ <th>
+ <input class="gr-check-radio gr-checkbox all_extensions_toggle" type="checkbox" {'checked="checked"' if all(ext.enabled for ext in extensions.extensions) else ''} onchange="toggle_all_extensions(event)" />
+ <abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr>
+ </th>
<th>URL</th>
<th>Branch</th>
<th>Version</th>
@@ -170,7 +173,7 @@ def extension_table():
code += f"""
<tr>
- <td><label{style}><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
+ <td><label{style}><input class="gr-check-radio gr-checkbox extension_toggle" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''} onchange="toggle_extension(event)" />{html.escape(ext.name)}</label></td>
<td>{remote}</td>
<td>{ext.branch}</td>
<td>{version_link}</td>
@@ -421,9 +424,19 @@ sort_ordering = [
(False, lambda x: x.get('name', 'z')),
(True, lambda x: x.get('name', 'z')),
(False, lambda x: 'z'),
+ (True, lambda x: x.get('commit_time', '')),
+ (True, lambda x: x.get('created_at', '')),
+ (True, lambda x: x.get('stars', 0)),
]
+def get_date(info: dict, key):
+ try:
+ return datetime.strptime(info.get(key), "%Y-%m-%dT%H:%M:%SZ").strftime("%Y-%m-%d")
+ except (ValueError, TypeError):
+ return ''
+
+
def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""):
extlist = available_extensions["extensions"]
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
@@ -448,7 +461,10 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
for ext in sorted(extlist, key=sort_function, reverse=sort_reverse):
name = ext.get("name", "noname")
+ stars = int(ext.get("stars", 0))
added = ext.get('added', 'unknown')
+ update_time = get_date(ext, 'commit_time')
+ create_time = get_date(ext, 'created_at')
url = ext.get("url", None)
description = ext.get("description", "")
extension_tags = ext.get("tags", [])
@@ -475,7 +491,8 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
code += f"""
<tr>
<td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td>
- <td>{html.escape(description)}<p class="info"><span class="date_added">Added: {html.escape(added)}</span></p></td>
+ <td>{html.escape(description)}<p class="info">
+ <span class="date_added">Update: {html.escape(update_time)} Added: {html.escape(added)} Created: {html.escape(create_time)}</span><span class="star_count">stars: <b>{stars}</b></a></p></td>
<td>{install_code}</td>
</tr>
@@ -496,14 +513,8 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
def preload_extensions_git_metadata():
- t0 = time.time()
for extension in extensions.extensions:
extension.read_info_from_repo()
- print(
- f"preload_extensions_git_metadata for "
- f"{len(extensions.extensions)} extensions took "
- f"{time.time() - t0:.2f}s"
- )
def create_ui():
@@ -553,13 +564,14 @@ def create_ui():
with gr.TabItem("Available", id="available"):
with gr.Row():
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
- available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json", label="Extension index URL").style(container=False)
+ extensions_index_url = os.environ.get('WEBUI_EXTENSIONS_INDEX', "https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json")
+ available_extensions_index = gr.Text(value=extensions_index_url, label="Extension index URL").style(container=False)
extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
with gr.Row():
hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
- sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index")
+ sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index")
with gr.Row():
search_extensions_text = gr.Text(label="Search").style(container=False)
@@ -568,9 +580,9 @@ def create_ui():
available_extensions_table = gr.HTML()
refresh_available_extensions_button.click(
- fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
+ fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update(), gr.update()]),
inputs=[available_extensions_index, hide_tags, sort_column],
- outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result, search_extensions_text],
+ outputs=[available_extensions_index, available_extensions_table, hide_tags, search_extensions_text, install_result],
)
install_extension_button.click(
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index a7d3bc79..49612298 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -2,14 +2,16 @@ import os.path
import urllib.parse
from pathlib import Path
-from modules import shared
+from modules import shared, ui_extra_networks_user_metadata, errors
from modules.images import read_info_from_image, save_image_with_geninfo
from modules.ui import up_down_symbol
import gradio as gr
import json
import html
+from fastapi.exceptions import HTTPException
from modules.generation_parameters_copypaste import image_from_url_text
+from modules.ui_components import ToolButton
extra_pages = []
allowed_dirs = set()
@@ -26,12 +28,15 @@ def register_page(page):
def fetch_file(filename: str = ""):
from starlette.responses import FileResponse
+ if not os.path.isfile(filename):
+ raise HTTPException(status_code=404, detail="File not found")
+
if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs):
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
ext = os.path.splitext(filename)[1].lower()
- if ext not in (".png", ".jpg", ".jpeg", ".webp"):
- raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
+ if ext not in (".png", ".jpg", ".jpeg", ".webp", ".gif"):
+ raise ValueError(f"File cannot be fetched: {filename}. Only png, jpg, webp, and gif.")
# would profit from returning 304
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
@@ -48,25 +53,71 @@ def get_metadata(page: str = "", item: str = ""):
if metadata is None:
return JSONResponse({})
- return JSONResponse({"metadata": metadata})
+ return JSONResponse({"metadata": json.dumps(metadata, indent=4, ensure_ascii=False)})
+
+
+def get_single_card(page: str = "", tabname: str = "", name: str = ""):
+ from starlette.responses import JSONResponse
+
+ page = next(iter([x for x in extra_pages if x.name == page]), None)
+
+ try:
+ item = page.create_item(name, enable_filter=False)
+ page.items[name] = item
+ except Exception as e:
+ errors.display(e, "creating item for extra network")
+ item = page.items.get(name)
+
+ page.read_user_metadata(item)
+ item_html = page.create_html_for_item(item, tabname)
+
+ return JSONResponse({"html": item_html})
def add_pages_to_demo(app):
app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"])
+ app.add_api_route("/sd_extra_networks/get-single-card", get_single_card, methods=["GET"])
+
+
+def quote_js(s):
+ s = s.replace('\\', '\\\\')
+ s = s.replace('"', '\\"')
+ return f'"{s}"'
class ExtraNetworksPage:
def __init__(self, title):
self.title = title
self.name = title.lower()
+ self.id_page = self.name.replace(" ", "_")
self.card_page = shared.html("extra-networks-card.html")
self.allow_negative_prompt = False
self.metadata = {}
+ self.items = {}
def refresh(self):
pass
+ def read_user_metadata(self, item):
+ filename = item.get("filename", None)
+ basename, ext = os.path.splitext(filename)
+ metadata_filename = basename + '.json'
+
+ metadata = {}
+ try:
+ if os.path.isfile(metadata_filename):
+ with open(metadata_filename, "r", encoding="utf8") as file:
+ metadata = json.load(file)
+ except Exception as e:
+ errors.display(e, f"reading extra network user metadata from {metadata_filename}")
+
+ desc = metadata.get("description", None)
+ if desc is not None:
+ item["description"] = desc
+
+ item["user_metadata"] = metadata
+
def link_preview(self, filename):
quoted_filename = urllib.parse.quote(filename.replace('\\', '/'))
mtime = os.path.getmtime(filename)
@@ -83,15 +134,14 @@ class ExtraNetworksPage:
return ""
def create_html(self, tabname):
- view = shared.opts.extra_networks_default_view
items_html = ''
self.metadata = {}
subdirs = {}
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
- for root, dirs, _ in os.walk(parentdir, followlinks=True):
- for dirname in dirs:
+ for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])):
+ for dirname in sorted(dirs, key=shared.natural_sort_key):
x = os.path.join(root, dirname)
if not os.path.isdir(x):
@@ -119,11 +169,15 @@ class ExtraNetworksPage:
</button>
""" for subdir in subdirs])
- for item in self.list_items():
+ self.items = {x["name"]: x for x in self.list_items()}
+ for item in self.items.values():
metadata = item.get("metadata")
if metadata:
self.metadata[item["name"]] = metadata
+ if "user_metadata" not in item:
+ self.read_user_metadata(item)
+
items_html += self.create_html_for_item(item, tabname)
if items_html == '':
@@ -133,16 +187,19 @@ class ExtraNetworksPage:
self_name_id = self.name.replace(" ", "_")
res = f"""
-<div id='{tabname}_{self_name_id}_subdirs' class='extra-network-subdirs extra-network-subdirs-{view}'>
+<div id='{tabname}_{self_name_id}_subdirs' class='extra-network-subdirs extra-network-subdirs-cards'>
{subdirs_html}
</div>
-<div id='{tabname}_{self_name_id}_cards' class='extra-network-{view}'>
+<div id='{tabname}_{self_name_id}_cards' class='extra-network-cards'>
{items_html}
</div>
"""
return res
+ def create_item(self, name, index=None):
+ raise NotImplementedError()
+
def list_items(self):
raise NotImplementedError()
@@ -158,7 +215,7 @@ class ExtraNetworksPage:
onclick = item.get("onclick", None)
if onclick is None:
- onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
+ onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
@@ -166,7 +223,9 @@ class ExtraNetworksPage:
metadata_button = ""
metadata = item.get("metadata")
if metadata:
- metadata_button = f"<div class='metadata-button' title='Show metadata' onclick='extraNetworksRequestMetadata(event, {json.dumps(self.name)}, {json.dumps(item['name'])})'></div>"
+ metadata_button = f"<div class='metadata-button card-button' title='Show internal metadata' onclick='extraNetworksRequestMetadata(event, {quote_js(self.name)}, {quote_js(item['name'])})'></div>"
+
+ edit_button = f"<div class='edit-button card-button' title='Edit metadata' onclick='extraNetworksEditUserMetadata(event, {quote_js(tabname)}, {quote_js(self.id_page)}, {quote_js(item['name'])})'></div>"
local_path = ""
filename = item.get("filename", "")
@@ -190,16 +249,17 @@ class ExtraNetworksPage:
args = {
"background_image": background_image,
- "style": f"'display: none; {height}{width}'",
+ "style": f"'display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%'",
"prompt": item.get("prompt", None),
- "tabname": json.dumps(tabname),
- "local_preview": json.dumps(item["local_preview"]),
+ "tabname": quote_js(tabname),
+ "local_preview": quote_js(item["local_preview"]),
"name": item["name"],
- "description": (item.get("description") or ""),
+ "description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""),
"card_clicked": onclick,
- "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
+ "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"',
"search_term": item.get("search_term", ""),
"metadata_button": metadata_button,
+ "edit_button": edit_button,
"search_only": " search_only" if search_only else "",
"sort_keys": sort_keys,
}
@@ -247,6 +307,9 @@ class ExtraNetworksPage:
pass
return None
+ def create_user_metadata_editor(self, ui, tabname):
+ return ui_extra_networks_user_metadata.UserMetadataEditor(ui, tabname, self)
+
def initialize():
extra_pages.clear()
@@ -297,23 +360,26 @@ def create_ui(container, button, tabname):
ui = ExtraNetworksUi()
ui.pages = []
ui.pages_contents = []
+ ui.user_metadata_editors = []
ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
ui.tabname = tabname
with gr.Tabs(elem_id=tabname+"_extra_tabs"):
for page in ui.stored_extra_pages:
- page_id = page.title.lower().replace(" ", "_")
-
- with gr.Tab(page.title, id=page_id):
- elem_id = f"{tabname}_{page_id}_cards_html"
+ with gr.Tab(page.title, id=page.id_page):
+ elem_id = f"{tabname}_{page.id_page}_cards_html"
page_elem = gr.HTML('Loading...', elem_id=elem_id)
ui.pages.append(page_elem)
- page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + json.dumps(tabname) + '); return []}', inputs=[], outputs=[])
+ page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[])
+
+ editor = page.create_user_metadata_editor(ui, tabname)
+ editor.create_ui()
+ ui.user_metadata_editors.append(editor)
gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", multiselect=False, visible=False, show_label=False, interactive=True)
- gr.Button(up_down_symbol, elem_id=tabname+"_extra_sortorder")
+ ToolButton(up_down_symbol, elem_id=tabname+"_extra_sortorder")
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
@@ -363,6 +429,8 @@ def path_is_parent(parent_path, child_path):
def setup_ui(ui, gallery):
def save_preview(index, images, filename):
+ # this function is here for backwards compatibility and likely will be removed soon
+
if len(images) == 0:
print("There is no image in gallery to save as a preview.")
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
@@ -394,3 +462,7 @@ def setup_ui(ui, gallery):
outputs=[*ui.pages]
)
+ for editor in ui.user_metadata_editors:
+ editor.setup_ui(gallery)
+
+
diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py
index 8b9ab71b..76780cfd 100644
--- a/modules/ui_extra_networks_checkpoints.py
+++ b/modules/ui_extra_networks_checkpoints.py
@@ -1,8 +1,8 @@
import html
-import json
import os
from modules import shared, ui_extra_networks, sd_models
+from modules.ui_extra_networks import quote_js
class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
@@ -12,21 +12,23 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
def refresh(self):
shared.refresh_checkpoints()
+ def create_item(self, name, index=None):
+ checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
+ path, ext = os.path.splitext(checkpoint.filename)
+ return {
+ "name": checkpoint.name_for_extra,
+ "filename": checkpoint.filename,
+ "preview": self.find_preview(path),
+ "description": self.find_description(path),
+ "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
+ "onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"',
+ "local_preview": f"{path}.{shared.opts.samples_format}",
+ "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)},
+ }
+
def list_items(self):
- checkpoint: sd_models.CheckpointInfo
- for index, (name, checkpoint) in enumerate(sd_models.checkpoints_list.items()):
- path, ext = os.path.splitext(checkpoint.filename)
- yield {
- "name": checkpoint.name_for_extra,
- "filename": path,
- "preview": self.find_preview(path),
- "description": self.find_description(path),
- "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
- "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
- "local_preview": f"{path}.{shared.opts.samples_format}",
- "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)},
-
- }
+ for index, name in enumerate(sd_models.checkpoints_list):
+ yield self.create_item(name, index)
def allowed_directories_for_previews(self):
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py
index 7c19b532..e53ccb42 100644
--- a/modules/ui_extra_networks_hypernets.py
+++ b/modules/ui_extra_networks_hypernets.py
@@ -1,7 +1,7 @@
-import json
import os
from modules import shared, ui_extra_networks
+from modules.ui_extra_networks import quote_js
class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
@@ -11,21 +11,24 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
def refresh(self):
shared.reload_hypernetworks()
+ def create_item(self, name, index=None):
+ full_path = shared.hypernetworks[name]
+ path, ext = os.path.splitext(full_path)
+
+ return {
+ "name": name,
+ "filename": full_path,
+ "preview": self.find_preview(path),
+ "description": self.find_description(path),
+ "search_term": self.search_terms_from_path(path),
+ "prompt": quote_js(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + quote_js(">"),
+ "local_preview": f"{path}.preview.{shared.opts.samples_format}",
+ "sort_keys": {'default': index, **self.get_sort_keys(path + ext)},
+ }
+
def list_items(self):
- for index, (name, path) in enumerate(shared.hypernetworks.items()):
- path, ext = os.path.splitext(path)
-
- yield {
- "name": name,
- "filename": path,
- "preview": self.find_preview(path),
- "description": self.find_description(path),
- "search_term": self.search_terms_from_path(path),
- "prompt": json.dumps(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
- "local_preview": f"{path}.preview.{shared.opts.samples_format}",
- "sort_keys": {'default': index, **self.get_sort_keys(path + ext)},
-
- }
+ for index, name in enumerate(shared.hypernetworks):
+ yield self.create_item(name, index)
def allowed_directories_for_previews(self):
return [shared.cmd_opts.hypernetwork_dir]
diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py
index 58a61c55..d1794e50 100644
--- a/modules/ui_extra_networks_textual_inversion.py
+++ b/modules/ui_extra_networks_textual_inversion.py
@@ -1,7 +1,7 @@
-import json
import os
from modules import ui_extra_networks, sd_hijack, shared
+from modules.ui_extra_networks import quote_js
class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
@@ -12,20 +12,24 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
def refresh(self):
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
+ def create_item(self, name, index=None):
+ embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
+
+ path, ext = os.path.splitext(embedding.filename)
+ return {
+ "name": name,
+ "filename": embedding.filename,
+ "preview": self.find_preview(path),
+ "description": self.find_description(path),
+ "search_term": self.search_terms_from_path(embedding.filename),
+ "prompt": quote_js(embedding.name),
+ "local_preview": f"{path}.preview.{shared.opts.samples_format}",
+ "sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)},
+ }
+
def list_items(self):
- for index, embedding in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings.values()):
- path, ext = os.path.splitext(embedding.filename)
- yield {
- "name": embedding.name,
- "filename": embedding.filename,
- "preview": self.find_preview(path),
- "description": self.find_description(path),
- "search_term": self.search_terms_from_path(embedding.filename),
- "prompt": json.dumps(embedding.name),
- "local_preview": f"{path}.preview.{shared.opts.samples_format}",
- "sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)},
-
- }
+ for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings):
+ yield self.create_item(name, index)
def allowed_directories_for_previews(self):
return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py
new file mode 100644
index 00000000..63d4b503
--- /dev/null
+++ b/modules/ui_extra_networks_user_metadata.py
@@ -0,0 +1,195 @@
+import datetime
+import html
+import json
+import os.path
+
+import gradio as gr
+
+from modules import generation_parameters_copypaste, images, sysinfo, errors
+
+
+class UserMetadataEditor:
+
+ def __init__(self, ui, tabname, page):
+ self.ui = ui
+ self.tabname = tabname
+ self.page = page
+ self.id_part = f"{self.tabname}_{self.page.id_page}_edit_user_metadata"
+
+ self.box = None
+
+ self.edit_name_input = None
+ self.button_edit = None
+
+ self.edit_name = None
+ self.edit_description = None
+ self.edit_notes = None
+ self.html_filedata = None
+ self.html_preview = None
+ self.html_status = None
+
+ self.button_cancel = None
+ self.button_replace_preview = None
+ self.button_save = None
+
+ def get_user_metadata(self, name):
+ item = self.page.items.get(name, {})
+
+ user_metadata = item.get('user_metadata', None)
+ if user_metadata is None:
+ user_metadata = {}
+ item['user_metadata'] = user_metadata
+
+ return user_metadata
+
+ def create_extra_default_items_in_left_column(self):
+ pass
+
+ def create_default_editor_elems(self):
+ with gr.Row():
+ with gr.Column(scale=2):
+ self.edit_name = gr.HTML(elem_classes="extra-network-name")
+ self.edit_description = gr.Textbox(label="Description", lines=4)
+ self.html_filedata = gr.HTML()
+
+ self.create_extra_default_items_in_left_column()
+
+ with gr.Column(scale=1, min_width=0):
+ self.html_preview = gr.HTML()
+
+ def create_default_buttons(self):
+
+ with gr.Row(elem_classes="edit-user-metadata-buttons"):
+ self.button_cancel = gr.Button('Cancel')
+ self.button_replace_preview = gr.Button('Replace preview', variant='primary')
+ self.button_save = gr.Button('Save', variant='primary')
+
+ self.html_status = gr.HTML(elem_classes="edit-user-metadata-status")
+
+ self.button_cancel.click(fn=None, _js="closePopup")
+
+ def get_card_html(self, name):
+ item = self.page.items.get(name, {})
+
+ preview_url = item.get("preview", None)
+
+ if not preview_url:
+ filename, _ = os.path.splitext(item["filename"])
+ preview_url = self.page.find_preview(filename)
+ item["preview"] = preview_url
+
+ if preview_url:
+ preview = f'''
+ <div class='card standalone-card-preview'>
+ <img src="{html.escape(preview_url)}" class="preview">
+ </div>
+ '''
+ else:
+ preview = "<div class='card standalone-card-preview'></div>"
+
+ return preview
+
+ def get_metadata_table(self, name):
+ item = self.page.items.get(name, {})
+ try:
+ filename = item["filename"]
+
+ stats = os.stat(filename)
+ params = [
+ ('File size: ', sysinfo.pretty_bytes(stats.st_size)),
+ ('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')),
+ ]
+
+ return params
+ except Exception as e:
+ errors.display(e, f"reading info for {name}")
+ return []
+
+ def put_values_into_components(self, name):
+ user_metadata = self.get_user_metadata(name)
+
+ try:
+ params = self.get_metadata_table(name)
+ except Exception as e:
+ errors.display(e, f"reading metadata info for {name}")
+ params = []
+
+ table = '<table class="file-metadata">' + "".join(f"<tr><th>{name}</th><td>{value}</td></tr>" for name, value in params) + '</table>'
+
+ return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', '')
+
+ def write_user_metadata(self, name, metadata):
+ item = self.page.items.get(name, {})
+ filename = item.get("filename", None)
+ basename, ext = os.path.splitext(filename)
+
+ with open(basename + '.json', "w", encoding="utf8") as file:
+ json.dump(metadata, file)
+
+ def save_user_metadata(self, name, desc, notes):
+ user_metadata = self.get_user_metadata(name)
+ user_metadata["description"] = desc
+ user_metadata["notes"] = notes
+
+ self.write_user_metadata(name, user_metadata)
+
+ def setup_save_handler(self, button, func, components):
+ button\
+ .click(fn=func, inputs=[self.edit_name_input, *components], outputs=[])\
+ .then(fn=None, _js="function(name){closePopup(); extraNetworksRefreshSingleCard(" + json.dumps(self.page.name) + "," + json.dumps(self.tabname) + ", name);}", inputs=[self.edit_name_input], outputs=[])
+
+ def create_editor(self):
+ self.create_default_editor_elems()
+
+ self.edit_notes = gr.TextArea(label='Notes', lines=4)
+
+ self.create_default_buttons()
+
+ self.button_edit\
+ .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=[self.edit_name, self.edit_description, self.html_filedata, self.html_preview, self.edit_notes])\
+ .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])
+
+ self.setup_save_handler(self.button_save, self.save_user_metadata, [self.edit_description, self.edit_notes])
+
+ def create_ui(self):
+ with gr.Box(visible=False, elem_id=self.id_part, elem_classes="edit-user-metadata") as box:
+ self.box = box
+
+ self.edit_name_input = gr.Textbox("Edit user metadata card id", visible=False, elem_id=f"{self.id_part}_name")
+ self.button_edit = gr.Button("Edit user metadata", visible=False, elem_id=f"{self.id_part}_button")
+
+ self.create_editor()
+
+ def save_preview(self, index, gallery, name):
+ if len(gallery) == 0:
+ return self.get_card_html(name), "There is no image in gallery to save as a preview."
+
+ item = self.page.items.get(name, {})
+
+ index = int(index)
+ index = 0 if index < 0 else index
+ index = len(gallery) - 1 if index >= len(gallery) else index
+
+ img_info = gallery[index if index >= 0 else 0]
+ image = generation_parameters_copypaste.image_from_url_text(img_info)
+ geninfo, items = images.read_info_from_image(image)
+
+ images.save_image_with_geninfo(image, geninfo, item["local_preview"])
+
+ return self.get_card_html(name), ''
+
+ def setup_ui(self, gallery):
+ self.button_replace_preview.click(
+ fn=self.save_preview,
+ _js="function(x, y, z){return [selected_gallery_index(), y, z]}",
+ inputs=[self.edit_name_input, gallery, self.edit_name_input],
+ outputs=[self.html_preview, self.html_status]
+ ).then(
+ fn=None,
+ _js="function(name){extraNetworksRefreshSingleCard(" + json.dumps(self.page.name) + "," + json.dumps(self.tabname) + ", name);}",
+ inputs=[self.edit_name_input],
+ outputs=[]
+ )
+
+
+
diff --git a/modules/ui_settings.py b/modules/ui_settings.py
index 0c560b30..a6076bf3 100644
--- a/modules/ui_settings.py
+++ b/modules/ui_settings.py
@@ -260,13 +260,20 @@ class UiSettings:
component = self.component_dict[k]
info = opts.data_labels[k]
- change_handler = component.release if hasattr(component, 'release') else component.change
- change_handler(
- fn=lambda value, k=k: self.run_settings_single(value, key=k),
- inputs=[component],
- outputs=[component, self.text_settings],
- show_progress=info.refresh is not None,
- )
+ if isinstance(component, gr.Textbox):
+ methods = [component.submit, component.blur]
+ elif hasattr(component, 'release'):
+ methods = [component.release]
+ else:
+ methods = [component.change]
+
+ for method in methods:
+ method(
+ fn=lambda value, k=k: self.run_settings_single(value, key=k),
+ inputs=[component],
+ outputs=[component, self.text_settings],
+ show_progress=info.refresh is not None,
+ )
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
button_set_checkpoint.click(