aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.github/workflows/warns_merge_master.yml19
-rw-r--r--extensions-builtin/Lora/lora.py12
-rw-r--r--extensions-builtin/Lora/ui_edit_user_metadata.py184
-rw-r--r--extensions-builtin/Lora/ui_extra_networks_lora.py50
-rw-r--r--html/extra-networks-card.html10
-rw-r--r--javascript/extraNetworks.js64
-rw-r--r--javascript/hints.js2
-rw-r--r--modules/api/api.py33
-rw-r--r--modules/cache.py97
-rw-r--r--modules/call_queue.py18
-rw-r--r--modules/extensions.py26
-rw-r--r--modules/hashes.py38
-rw-r--r--modules/img2img.py2
-rw-r--r--modules/launch_utils.py10
-rw-r--r--modules/processing.py7
-rw-r--r--modules/sd_hijack.py5
-rw-r--r--modules/sd_hijack_clip.py15
-rw-r--r--modules/shared.py2
-rw-r--r--modules/textual_inversion/textual_inversion.py11
-rw-r--r--modules/txt2img.py2
-rw-r--r--modules/ui.py3
-rw-r--r--modules/ui_common.py9
-rw-r--r--modules/ui_extensions.py11
-rw-r--r--modules/ui_extra_networks.py94
-rw-r--r--modules/ui_extra_networks_checkpoints.py33
-rw-r--r--modules/ui_extra_networks_hypernets.py33
-rw-r--r--modules/ui_extra_networks_textual_inversion.py32
-rw-r--r--modules/ui_extra_networks_user_metadata.py190
-rw-r--r--style.css96
29 files changed, 911 insertions, 197 deletions
diff --git a/.github/workflows/warns_merge_master.yml b/.github/workflows/warns_merge_master.yml
new file mode 100644
index 00000000..ae2aab6b
--- /dev/null
+++ b/.github/workflows/warns_merge_master.yml
@@ -0,0 +1,19 @@
+name: Pull requests can't target master branch
+
+"on":
+ pull_request:
+ types:
+ - opened
+ - synchronize
+ - reopened
+ branches:
+ - master
+
+jobs:
+ check:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Warning marge into master
+ run: |
+ echo -e "::warning::This pull request directly merge into \"master\" branch, normally development happens on \"dev\" branch."
+ exit 1
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py
index cd46e6c7..c8710922 100644
--- a/extensions-builtin/Lora/lora.py
+++ b/extensions-builtin/Lora/lora.py
@@ -3,7 +3,7 @@ import re
import torch
from typing import Union
-from modules import shared, devices, sd_models, errors, scripts, sd_hijack, hashes
+from modules import shared, devices, sd_models, errors, scripts, sd_hijack, hashes, cache
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
@@ -78,9 +78,16 @@ class LoraOnDisk:
self.metadata = {}
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
+ def read_metadata():
+ metadata = sd_models.read_metadata_from_safetensors(filename)
+ metadata.pop('ssmd_cover_images', None) # those are cover images, and they are too big to display in UI as text
+
+ return metadata
+
if self.is_safetensors:
try:
- self.metadata = sd_models.read_metadata_from_safetensors(filename)
+ #self.metadata = sd_models.read_metadata_from_safetensors(filename)
+ self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)
except Exception as e:
errors.display(e, f"reading lora {filename}")
@@ -91,7 +98,6 @@ class LoraOnDisk:
self.metadata = m
- self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
self.alias = self.metadata.get('ss_output_name', self.name)
self.hash = None
diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py
new file mode 100644
index 00000000..6db63b09
--- /dev/null
+++ b/extensions-builtin/Lora/ui_edit_user_metadata.py
@@ -0,0 +1,184 @@
+import html
+import random
+
+import gradio as gr
+import re
+
+from modules import ui_extra_networks_user_metadata
+
+
+def is_non_comma_tagset(tags):
+ average_tag_length = sum(len(x) for x in tags.keys()) / len(tags)
+
+ return average_tag_length >= 16
+
+
+re_word = re.compile(r"[-_\w']+")
+re_comma = re.compile(r" *, *")
+
+
+def build_tags(metadata):
+ tags = {}
+
+ for _, tags_dict in metadata.get("ss_tag_frequency", {}).items():
+ for tag, tag_count in tags_dict.items():
+ tag = tag.strip()
+ tags[tag] = tags.get(tag, 0) + int(tag_count)
+
+ if tags and is_non_comma_tagset(tags):
+ new_tags = {}
+
+ for text, text_count in tags.items():
+ for word in re.findall(re_word, text):
+ if len(word) < 3:
+ continue
+
+ new_tags[word] = new_tags.get(word, 0) + text_count
+
+ tags = new_tags
+
+ ordered_tags = sorted(tags.keys(), key=tags.get, reverse=True)
+
+ return [(tag, tags[tag]) for tag in ordered_tags]
+
+
+class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor):
+ def __init__(self, ui, tabname, page):
+ super().__init__(ui, tabname, page)
+
+ self.taginfo = None
+ self.edit_activation_text = None
+ self.slider_preferred_weight = None
+ self.edit_notes = None
+
+ def save_lora_user_metadata(self, name, desc, activation_text, preferred_weight, notes):
+ user_metadata = self.get_user_metadata(name)
+ user_metadata["description"] = desc
+ user_metadata["activation text"] = activation_text
+ user_metadata["preferred weight"] = preferred_weight
+ user_metadata["notes"] = notes
+
+ self.write_user_metadata(name, user_metadata)
+
+ def get_metadata_table(self, name):
+ table = super().get_metadata_table(name)
+ item = self.page.items.get(name, {})
+ metadata = item.get("metadata") or {}
+
+ keys = [
+ ('ss_sd_model_name', "Model:"),
+ ('ss_resolution', "Resolution:"),
+ ('ss_clip_skip', "Clip skip:"),
+ ]
+
+ for key, label in keys:
+ value = metadata.get(key, None)
+ if value is not None and str(value) != "None":
+ table.append((label, html.escape(value)))
+
+ image_count = 0
+ for _, params in metadata.get("ss_dataset_dirs", {}).items():
+ image_count += int(params.get("img_count", 0))
+
+ if image_count:
+ table.append(("Dataset size:", image_count))
+
+ return table
+
+ def put_values_into_components(self, name):
+ user_metadata = self.get_user_metadata(name)
+ values = super().put_values_into_components(name)
+
+ item = self.page.items.get(name, {})
+ metadata = item.get("metadata") or {}
+
+ tags = build_tags(metadata)
+ gradio_tags = [(tag, str(count)) for tag, count in tags[0:24]]
+
+ return [
+ *values[0:4],
+ gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False),
+ user_metadata.get('activation text', ''),
+ float(user_metadata.get('preferred weight', 0.0)),
+ user_metadata.get('notes', ''),
+ gr.update(visible=True if tags else False),
+ gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False),
+ ]
+
+ def generate_random_prompt(self, name):
+ item = self.page.items.get(name, {})
+ metadata = item.get("metadata") or {}
+ tags = build_tags(metadata)
+
+ return self.generate_random_prompt_from_tags(tags)
+
+ def generate_random_prompt_from_tags(self, tags):
+ max_count = None
+ res = []
+ for tag, count in tags:
+ if not max_count:
+ max_count = count
+
+ v = random.random() * max_count
+ if count > v:
+ res.append(tag)
+
+ return ", ".join(sorted(res))
+
+ def create_editor(self):
+ self.create_default_editor_elems()
+
+ self.taginfo = gr.HighlightedText(label="Tags")
+ self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora")
+ self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01)
+
+ with gr.Row() as row_random_prompt:
+ with gr.Column(scale=8):
+ random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
+
+ with gr.Column(scale=1, min_width=120):
+ generate_random_prompt = gr.Button('Generate').style(full_width=True, size="lg")
+
+ self.edit_notes = gr.TextArea(label='Notes', lines=4)
+
+ generate_random_prompt.click(fn=self.generate_random_prompt, inputs=[self.edit_name_input], outputs=[random_prompt], show_progress=False)
+
+ def select_tag(activation_text, evt: gr.SelectData):
+ tag = evt.value[0]
+
+ words = re.split(re_comma, activation_text)
+ if tag in words:
+ words = [x for x in words if x != tag and x.strip()]
+ return ", ".join(words)
+
+ return activation_text + ", " + tag if activation_text else tag
+
+ self.taginfo.select(fn=select_tag, inputs=[self.edit_activation_text], outputs=[self.edit_activation_text], show_progress=False)
+
+ self.create_default_buttons()
+
+ viewed_components = [
+ self.edit_name,
+ self.edit_description,
+ self.html_filedata,
+ self.html_preview,
+ self.taginfo,
+ self.edit_activation_text,
+ self.slider_preferred_weight,
+ self.edit_notes,
+ row_random_prompt,
+ random_prompt,
+ ]
+
+ self.button_edit\
+ .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\
+ .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])
+
+ edited_components = [
+ self.edit_description,
+ self.edit_activation_text,
+ self.slider_preferred_weight,
+ self.edit_notes,
+ ]
+
+ self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components)
diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py
index da49790b..b2bc1810 100644
--- a/extensions-builtin/Lora/ui_extra_networks_lora.py
+++ b/extensions-builtin/Lora/ui_extra_networks_lora.py
@@ -1,8 +1,9 @@
-import json
import os
import lora
from modules import shared, ui_extra_networks
+from modules.ui_extra_networks import quote_js
+from ui_edit_user_metadata import LoraUserMetadataEditor
class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
@@ -12,25 +13,42 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
def refresh(self):
lora.list_available_loras()
- def list_items(self):
- for index, (name, lora_on_disk) in enumerate(lora.available_loras.items()):
- path, ext = os.path.splitext(lora_on_disk.filename)
+ def create_item(self, name, index=None):
+ lora_on_disk = lora.available_loras.get(name)
+
+ path, ext = os.path.splitext(lora_on_disk.filename)
+
+ alias = lora_on_disk.get_alias()
- alias = lora_on_disk.get_alias()
+ # in 1.5 filename changes to be full filename instead of path without extension, and metadata is dict instead of json string
+ item = {
+ "name": name,
+ "filename": lora_on_disk.filename,
+ "preview": self.find_preview(path),
+ "description": self.find_description(path),
+ "search_term": self.search_terms_from_path(lora_on_disk.filename),
+ "local_preview": f"{path}.{shared.opts.samples_format}",
+ "metadata": lora_on_disk.metadata,
+ "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
+ }
- yield {
- "name": name,
- "filename": path,
- "preview": self.find_preview(path),
- "description": self.find_description(path),
- "search_term": self.search_terms_from_path(lora_on_disk.filename),
- "prompt": json.dumps(f"<lora:{alias}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
- "local_preview": f"{path}.{shared.opts.samples_format}",
- "metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
- "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
+ self.read_user_metadata(item)
+ activation_text = item["user_metadata"].get("activation text")
+ preferred_weight = item["user_metadata"].get("preferred weight", 0.0)
+ item["prompt"] = quote_js(f"<lora:{alias}:") + " + " + (str(preferred_weight) if preferred_weight else "opts.extra_networks_default_multiplier") + " + " + quote_js(">")
- }
+ if activation_text:
+ item["prompt"] += " + " + quote_js(" " + activation_text)
+
+ return item
+
+ def list_items(self):
+ for index, name in enumerate(lora.available_loras):
+ item = self.create_item(name, index)
+ yield item
def allowed_directories_for_previews(self):
return [shared.cmd_opts.lora_dir]
+ def create_user_metadata_editor(self, ui, tabname):
+ return LoraUserMetadataEditor(ui, tabname, self)
diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html
index 68a84c3a..eb8b1a67 100644
--- a/html/extra-networks-card.html
+++ b/html/extra-networks-card.html
@@ -1,11 +1,11 @@
-<div class='card' style={style} onclick={card_clicked} {sort_keys}>
+<div class='card' style={style} onclick={card_clicked} data-name="{name}" {sort_keys}>
{background_image}
- {metadata_button}
+ <div class="button-row">
+ {edit_button}
+ {metadata_button}
+ </div>
<div class='actions'>
<div class='additional'>
- <ul>
- <a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
- </ul>
<span style="display:none" class='search_term{search_only}'>{search_term}</span>
</div>
<span class='name'>{name}</span>
diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js
index b87bca3e..e453094a 100644
--- a/javascript/extraNetworks.js
+++ b/javascript/extraNetworks.js
@@ -113,7 +113,7 @@ function setupExtraNetworks() {
onUiLoaded(setupExtraNetworks);
-var re_extranet = /<([^:]+:[^:]+):[\d.]+>/;
+var re_extranet = /<([^:]+:[^:]+):[\d.]+>(.*)/;
var re_extranet_g = /\s+<([^:]+:[^:]+):[\d.]+>/g;
function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
@@ -121,15 +121,22 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
var replaced = false;
var newTextareaText;
if (m) {
+ var extraTextAfterNet = m[2];
var partToSearch = m[1];
- newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found) {
+ var foundAtPosition = -1;
+ newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, net, pos) {
m = found.match(re_extranet);
if (m[1] == partToSearch) {
replaced = true;
+ foundAtPosition = pos;
return "";
}
return found;
});
+
+ if (foundAtPosition >= 0 && newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) {
+ newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length);
+ }
} else {
newTextareaText = textarea.value.replaceAll(new RegExp(text, "g"), function(found) {
if (found == text) {
@@ -182,19 +189,20 @@ function extraNetworksSearchButton(tabs_id, event) {
var globalPopup = null;
var globalPopupInner = null;
+function closePopup() {
+ if (!globalPopup) return;
+
+ globalPopup.style.display = "none";
+}
function popup(contents) {
if (!globalPopup) {
globalPopup = document.createElement('div');
- globalPopup.onclick = function() {
- globalPopup.style.display = "none";
- };
+ globalPopup.onclick = closePopup;
globalPopup.classList.add('global-popup');
var close = document.createElement('div');
close.classList.add('global-popup-close');
- close.onclick = function() {
- globalPopup.style.display = "none";
- };
+ close.onclick = closePopup;
close.title = "Close";
globalPopup.appendChild(close);
@@ -263,3 +271,43 @@ function extraNetworksRequestMetadata(event, extraPage, cardName) {
event.stopPropagation();
}
+
+var extraPageUserMetadataEditors = {};
+
+function extraNetworksEditUserMetadata(event, tabname, extraPage, cardName) {
+ var id = tabname + '_' + extraPage + '_edit_user_metadata';
+
+ var editor = extraPageUserMetadataEditors[id];
+ if (!editor) {
+ editor = {};
+ editor.page = gradioApp().getElementById(id);
+ editor.nameTextarea = gradioApp().querySelector("#" + id + "_name" + ' textarea');
+ editor.button = gradioApp().querySelector("#" + id + "_button");
+ extraPageUserMetadataEditors[id] = editor;
+ }
+
+ editor.nameTextarea.value = cardName;
+ updateInput(editor.nameTextarea);
+
+ editor.button.click();
+
+ popup(editor.page);
+
+ event.stopPropagation();
+}
+
+function extraNetworksRefreshSingleCard(page, tabname, name) {
+ requestGet("./sd_extra_networks/get-single-card", {page: page, tabname: tabname, name: name}, function(data) {
+ if (data && data.html) {
+ var card = gradioApp().querySelector('.card[data-name=' + JSON.stringify(name) + ']'); // likely using the wrong stringify function
+
+ var newDiv = document.createElement('DIV');
+ newDiv.innerHTML = data.html;
+ var newCard = newDiv.firstElementChild;
+
+ newCard.style = '';
+ card.parentElement.insertBefore(newCard, card);
+ card.parentElement.removeChild(card);
+ }
+ });
+}
diff --git a/javascript/hints.js b/javascript/hints.js
index dc75ce31..41201b2f 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -84,8 +84,6 @@ var titles = {
"Checkpoint name": "Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.",
"Inpainting conditioning mask strength": "Only applies to inpainting models. Determines how strongly to mask off the original image for inpainting and img2img. 1.0 means fully masked, which is the default behaviour. 0.0 means a fully unmasked conditioning. Lower values will help preserve the overall composition of the image, but will struggle with large changes.",
- "vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).",
-
"Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
"Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.",
diff --git a/modules/api/api.py b/modules/api/api.py
index 11045292..2a4cd8a2 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,5 +1,6 @@
import base64
import io
+import os
import time
import datetime
import uvicorn
@@ -98,14 +99,16 @@ def encode_pil_to_base64(image):
def api_middleware(app: FastAPI):
- rich_available = True
+ rich_available = False
try:
- import anyio # importing just so it can be placed on silent list
- import starlette # importing just so it can be placed on silent list
- from rich.console import Console
- console = Console()
+ if os.environ.get('WEBUI_RICH_EXCEPTIONS', None) is not None:
+ import anyio # importing just so it can be placed on silent list
+ import starlette # importing just so it can be placed on silent list
+ from rich.console import Console
+ console = Console()
+ rich_available = True
except Exception:
- rich_available = False
+ pass
@app.middleware("http")
async def log_and_time(req: Request, call_next):
@@ -116,14 +119,14 @@ def api_middleware(app: FastAPI):
endpoint = req.scope.get('path', 'err')
if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
- t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
- code = res.status_code,
- ver = req.scope.get('http_version', '0.0'),
- cli = req.scope.get('client', ('0:0.0.0', 0))[0],
- prot = req.scope.get('scheme', 'err'),
- method = req.scope.get('method', 'err'),
- endpoint = endpoint,
- duration = duration,
+ t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
+ code=res.status_code,
+ ver=req.scope.get('http_version', '0.0'),
+ cli=req.scope.get('client', ('0:0.0.0', 0))[0],
+ prot=req.scope.get('scheme', 'err'),
+ method=req.scope.get('method', 'err'),
+ endpoint=endpoint,
+ duration=duration,
))
return res
@@ -134,7 +137,7 @@ def api_middleware(app: FastAPI):
"body": vars(e).get('body', ''),
"errors": str(e),
}
- if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
+ if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
message = f"API error: {request.method}: {request.url} {err}"
if rich_available:
print(message)
diff --git a/modules/cache.py b/modules/cache.py
new file mode 100644
index 00000000..07180602
--- /dev/null
+++ b/modules/cache.py
@@ -0,0 +1,97 @@
+import json
+import os.path
+import threading
+
+from modules.paths import data_path, script_path
+
+cache_filename = os.path.join(data_path, "cache.json")
+cache_data = None
+cache_lock = threading.Lock()
+
+
+def dump_cache():
+ """
+ Saves all cache data to a file.
+ """
+
+ with cache_lock:
+ with open(cache_filename, "w", encoding="utf8") as file:
+ json.dump(cache_data, file, indent=4)
+
+
+def cache(subsection):
+ """
+ Retrieves or initializes a cache for a specific subsection.
+
+ Parameters:
+ subsection (str): The subsection identifier for the cache.
+
+ Returns:
+ dict: The cache data for the specified subsection.
+ """
+
+ global cache_data
+
+ if cache_data is None:
+ with cache_lock:
+ if cache_data is None:
+ if not os.path.isfile(cache_filename):
+ cache_data = {}
+ else:
+ try:
+ with open(cache_filename, "r", encoding="utf8") as file:
+ cache_data = json.load(file)
+ except Exception:
+ os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json"))
+ print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache')
+ cache_data = {}
+
+ s = cache_data.get(subsection, {})
+ cache_data[subsection] = s
+
+ return s
+
+
+def cached_data_for_file(subsection, title, filename, func):
+ """
+ Retrieves or generates data for a specific file, using a caching mechanism.
+
+ Parameters:
+ subsection (str): The subsection of the cache to use.
+ title (str): The title of the data entry in the subsection of the cache.
+ filename (str): The path to the file to be checked for modifications.
+ func (callable): A function that generates the data if it is not available in the cache.
+
+ Returns:
+ dict or None: The cached or generated data, or None if data generation fails.
+
+ The `cached_data_for_file` function implements a caching mechanism for data stored in files.
+ It checks if the data associated with the given `title` is present in the cache and compares the
+ modification time of the file with the cached modification time. If the file has been modified,
+ the cache is considered invalid and the data is regenerated using the provided `func`.
+ Otherwise, the cached data is returned.
+
+ If the data generation fails, None is returned to indicate the failure. Otherwise, the generated
+ or cached data is returned as a dictionary.
+ """
+
+ existing_cache = cache(subsection)
+ ondisk_mtime = os.path.getmtime(filename)
+
+ entry = existing_cache.get(title)
+ if entry:
+ cached_mtime = existing_cache[title].get("mtime", 0)
+ if ondisk_mtime > cached_mtime:
+ entry = None
+
+ if not entry:
+ entry = func()
+ if entry is None:
+ return None
+
+ entry['mtime'] = ondisk_mtime
+ existing_cache[title] = entry
+
+ dump_cache()
+
+ return entry
diff --git a/modules/call_queue.py b/modules/call_queue.py
index 3b94f8a4..61aa240f 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -85,9 +85,9 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
elapsed = time.perf_counter() - t
elapsed_m = int(elapsed // 60)
elapsed_s = elapsed % 60
- elapsed_text = f"{elapsed_s:.2f}s"
+ elapsed_text = f"{elapsed_s:.1f} sec."
if elapsed_m > 0:
- elapsed_text = f"{elapsed_m}m "+elapsed_text
+ elapsed_text = f"{elapsed_m} min. "+elapsed_text
if run_memmon:
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
@@ -95,14 +95,22 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
reserved_peak = mem_stats['reserved_peak']
sys_peak = mem_stats['system_peak']
sys_total = mem_stats['total']
- sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
+ sys_pct = sys_peak/max(sys_total, 1) * 100
- vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
+ toltip_a = "Active: peak amount of video memory used during generation (excluding cached data)"
+ toltip_r = "Reserved: total amout of video memory allocated by the Torch library "
+ toltip_sys = "System: peak amout of video memory allocated by all running programs, out of total capacity"
+
+ text_a = f"<abbr title='{toltip_a}'>A</abbr>: <span class='measurement'>{active_peak/1024:.2f} GB</span>"
+ text_r = f"<abbr title='{toltip_r}'>R</abbr>: <span class='measurement'>{reserved_peak/1024:.2f} GB</span>"
+ text_sys = f"<abbr title='{toltip_sys}'>Sys</abbr>: <span class='measurement'>{sys_peak/1024:.1f}/{sys_total/1024:g} GB</span> ({sys_pct:.1f}%)"
+
+ vram_html = f"<p class='vram'>{text_a}, <wbr>{text_r}, <wbr>{text_sys}</p>"
else:
vram_html = ''
# last item is always HTML
- res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
+ res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr><span class='measurement'>{elapsed_text}</span></p>{vram_html}</div>"
return tuple(res)
diff --git a/modules/extensions.py b/modules/extensions.py
index abc6e2b1..c561159a 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -1,7 +1,7 @@
import os
import threading
-from modules import shared, errors
+from modules import shared, errors, cache
from modules.gitpython_hack import Repo
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
@@ -21,6 +21,7 @@ def active():
class Extension:
lock = threading.Lock()
+ cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']
def __init__(self, name, path, enabled=True, is_builtin=False):
self.name = name
@@ -36,15 +37,29 @@ class Extension:
self.remote = None
self.have_info_from_repo = False
+ def to_dict(self):
+ return {x: getattr(self, x) for x in self.cached_fields}
+
+ def from_dict(self, d):
+ for field in self.cached_fields:
+ setattr(self, field, d[field])
+
def read_info_from_repo(self):
if self.is_builtin or self.have_info_from_repo:
return
- with self.lock:
- if self.have_info_from_repo:
- return
+ def read_from_repo():
+ with self.lock:
+ if self.have_info_from_repo:
+ return
+
+ self.do_read_info_from_repo()
+
+ return self.to_dict()
- self.do_read_info_from_repo()
+ d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
+ self.from_dict(d)
+ self.status = 'unknown'
def do_read_info_from_repo(self):
repo = None
@@ -58,7 +73,6 @@ class Extension:
self.remote = None
else:
try:
- self.status = 'unknown'
self.remote = next(repo.remote().urls, None)
commit = repo.head.commit
self.commit_date = commit.committed_date
diff --git a/modules/hashes.py b/modules/hashes.py
index ec1187fe..b7a33b42 100644
--- a/modules/hashes.py
+++ b/modules/hashes.py
@@ -1,43 +1,11 @@
import hashlib
-import json
import os.path
-import filelock
-
from modules import shared
-from modules.paths import data_path, script_path
-
-
-cache_filename = os.path.join(data_path, "cache.json")
-cache_data = None
-
-
-def dump_cache():
- with filelock.FileLock(f"{cache_filename}.lock"):
- with open(cache_filename, "w", encoding="utf8") as file:
- json.dump(cache_data, file, indent=4)
-
-
-def cache(subsection):
- global cache_data
-
- if cache_data is None:
- with filelock.FileLock(f"{cache_filename}.lock"):
- if not os.path.isfile(cache_filename):
- cache_data = {}
- else:
- try:
- with open(cache_filename, "r", encoding="utf8") as file:
- cache_data = json.load(file)
- except Exception:
- os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json"))
- print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache')
- cache_data = {}
-
- s = cache_data.get(subsection, {})
- cache_data[subsection] = s
+import modules.cache
- return s
+dump_cache = modules.cache.dump_cache
+cache = modules.cache.cache
def calculate_sha256(filename):
diff --git a/modules/img2img.py b/modules/img2img.py
index 664e2688..a811e7a4 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -240,4 +240,4 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
if opts.do_not_show_images:
processed.images = []
- return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
+ return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index 0e0dbca4..ff77cbfd 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -69,10 +69,12 @@ def git_tag():
return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
except Exception:
try:
- from pathlib import Path
- changelog_md = Path(__file__).parent.parent / "CHANGELOG.md"
- with changelog_md.open(encoding="utf-8") as file:
- return next((line.strip() for line in file if line.strip()), "<none>")
+
+ changelog_md = os.path.join(os.path.dirname(os.path.dirname(__file__)), "CHANGELOG.md")
+ with open(changelog_md, "r", encoding="utf-8") as file:
+ line = next((line.strip() for line in file if line.strip()), "<none>")
+ line = line.replace("## ", "")
+ return line
except Exception:
return "<none>"
diff --git a/modules/processing.py b/modules/processing.py
index cd568a20..49441e77 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -732,9 +732,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.setup_conds()
- if len(model_hijack.comments) > 0:
- for comment in model_hijack.comments:
- comments[comment] = 1
+ for comment in model_hijack.comments:
+ comments[comment] = 1
+
+ p.extra_generation_params.update(model_hijack.extra_generation_params)
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 3b6f95ce..6b5aae4b 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -147,7 +147,6 @@ def undo_weighted_forward(sd_model):
class StableDiffusionModelHijack:
fixes = None
- comments = []
layers = None
circular_enabled = False
clip = None
@@ -156,6 +155,9 @@ class StableDiffusionModelHijack:
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
def __init__(self):
+ self.extra_generation_params = {}
+ self.comments = []
+
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
def apply_optimizations(self, option=None):
@@ -236,6 +238,7 @@ class StableDiffusionModelHijack:
def clear_comments(self):
self.comments = []
+ self.extra_generation_params = {}
def get_prompt_lengths(self, text):
if self.clip is None:
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index 3b5a7666..c1d780a3 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -229,9 +229,18 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
z = self.process_tokens(tokens, multipliers)
zs.append(z)
- if len(used_embeddings) > 0:
- embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()])
- self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
+ if opts.textual_inversion_add_hashes_to_infotext and used_embeddings:
+ hashes = []
+ for name, embedding in used_embeddings.items():
+ shorthash = embedding.shorthash
+ if not shorthash:
+ continue
+
+ name = name.replace(":", "").replace(",", "")
+ hashes.append(f"{name}: {shorthash}")
+
+ if hashes:
+ self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
return torch.hstack(zs)
diff --git a/modules/shared.py b/modules/shared.py
index 48478a68..427dcc50 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -472,6 +472,8 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_restart(),
+ "textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
+ "textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"),
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks),
}))
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index cbe975b7..6166c76f 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -13,7 +13,7 @@ import numpy as np
from PIL import Image, PngImagePlugin
from torch.utils.tensorboard import SummaryWriter
-from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors
+from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
@@ -49,6 +49,8 @@ class Embedding:
self.sd_checkpoint_name = None
self.optimizer_state_dict = None
self.filename = None
+ self.hash = None
+ self.shorthash = None
def save(self, filename):
embedding_data = {
@@ -82,6 +84,10 @@ class Embedding:
self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
return self.cached_checksum
+ def set_hash(self, v):
+ self.hash = v
+ self.shorthash = self.hash[0:12]
+
class DirWithTextualInversionEmbeddings:
def __init__(self, path):
@@ -199,6 +205,7 @@ class EmbeddingDatabase:
embedding.vectors = vec.shape[0]
embedding.shape = vec.shape[-1]
embedding.filename = path
+ embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model)
@@ -249,7 +256,7 @@ class EmbeddingDatabase:
self.word_embeddings.update(sorted_word_embeddings)
displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
- if self.previously_displayed_embeddings != displayed_embeddings:
+ if shared.opts.textual_inversion_print_at_load and self.previously_displayed_embeddings != displayed_embeddings:
self.previously_displayed_embeddings = displayed_embeddings
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
if self.skipped_embeddings:
diff --git a/modules/txt2img.py b/modules/txt2img.py
index d0be2e73..29d94e8c 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -70,4 +70,4 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
if opts.do_not_show_images:
processed.images = []
- return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
+ return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
diff --git a/modules/ui.py b/modules/ui.py
index 39d226ad..07ecee7b 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -83,8 +83,7 @@ detect_image_size_symbol = '\U0001F4D0' # 📐
up_down_symbol = '\u2195\ufe0f' # ↕️
-def plaintext_to_html(text):
- return ui_common.plaintext_to_html(text)
+plaintext_to_html = ui_common.plaintext_to_html
def send_gradio_gallery_to_image(x):
diff --git a/modules/ui_common.py b/modules/ui_common.py
index 57c2d0ad..11eb2a4b 100644
--- a/modules/ui_common.py
+++ b/modules/ui_common.py
@@ -29,9 +29,10 @@ def update_generation_info(generation_info, html_info, img_index):
return html_info, gr.update()
-def plaintext_to_html(text):
- text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
- return text
+def plaintext_to_html(text, classname=None):
+ content = "<br>\n".join(html.escape(x) for x in text.split('\n'))
+
+ return f"<p class='{classname}'>{content}</p>" if classname else f"<p>{content}</p>"
def save_files(js_data, images, do_make_zip, index):
@@ -157,7 +158,7 @@ Requested path was: {f}
with gr.Group():
html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
- html_log = gr.HTML(elem_id=f'html_log_{tabname}')
+ html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log")
generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
if tabname == 'txt2img' or tabname == 'img2img':
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index dff522ef..f3e4fba7 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -1,5 +1,5 @@
import json
-import os.path
+import os
import threading
import time
from datetime import datetime
@@ -513,14 +513,8 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
def preload_extensions_git_metadata():
- t0 = time.time()
for extension in extensions.extensions:
extension.read_info_from_repo()
- print(
- f"preload_extensions_git_metadata for "
- f"{len(extensions.extensions)} extensions took "
- f"{time.time() - t0:.2f}s"
- )
def create_ui():
@@ -570,7 +564,8 @@ def create_ui():
with gr.TabItem("Available", id="available"):
with gr.Row():
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
- available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json", label="Extension index URL").style(container=False)
+ extensions_index_url = os.environ.get('WEBUI_EXTENSIONS_INDEX', "https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json")
+ available_extensions_index = gr.Text(value=extensions_index_url, label="Extension index URL").style(container=False)
extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 693cafb6..760fba43 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -2,12 +2,13 @@ import os.path
import urllib.parse
from pathlib import Path
-from modules import shared
+from modules import shared, ui_extra_networks_user_metadata, errors
from modules.images import read_info_from_image, save_image_with_geninfo
from modules.ui import up_down_symbol
import gradio as gr
import json
import html
+from fastapi.exceptions import HTTPException
from modules.generation_parameters_copypaste import image_from_url_text
@@ -26,6 +27,9 @@ def register_page(page):
def fetch_file(filename: str = ""):
from starlette.responses import FileResponse
+ if not os.path.isfile(filename):
+ raise HTTPException(status_code=404, detail="File not found")
+
if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs):
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
@@ -51,22 +55,66 @@ def get_metadata(page: str = "", item: str = ""):
return JSONResponse({"metadata": metadata})
+def get_single_card(page: str = "", tabname: str = "", name: str = ""):
+ from starlette.responses import JSONResponse
+
+ page = next(iter([x for x in extra_pages if x.name == page]), None)
+
+ try:
+ item = page.create_item(name)
+ except Exception as e:
+ errors.display(e, "creating item for extra network")
+ item = page.items.get(name)
+
+ item_html = page.create_html_for_item(item, tabname)
+
+ return JSONResponse({"html": item_html})
+
+
def add_pages_to_demo(app):
app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"])
+ app.add_api_route("/sd_extra_networks/get-single-card", get_single_card, methods=["GET"])
+
+
+def quote_js(s):
+ s = s.replace('\\', '\\\\')
+ s = s.replace('"', '\\"')
+ return f'"{s}"'
class ExtraNetworksPage:
def __init__(self, title):
self.title = title
self.name = title.lower()
+ self.id_page = self.name.replace(" ", "_")
self.card_page = shared.html("extra-networks-card.html")
self.allow_negative_prompt = False
self.metadata = {}
+ self.items = {}
def refresh(self):
pass
+ def read_user_metadata(self, item):
+ filename = item.get("filename", None)
+ basename, ext = os.path.splitext(filename)
+ metadata_filename = basename + '.json'
+
+ metadata = {}
+ try:
+ if os.path.isfile(metadata_filename):
+ with open(metadata_filename, "r", encoding="utf8") as file:
+ metadata = json.load(file)
+ except Exception as e:
+ errors.display(e, f"reading extra network user metadata from {metadata_filename}")
+
+ desc = metadata.get("description", None)
+ if desc is not None:
+ item["description"] = desc
+
+ item["user_metadata"] = metadata
+
def link_preview(self, filename):
quoted_filename = urllib.parse.quote(filename.replace('\\', '/'))
mtime = os.path.getmtime(filename)
@@ -119,11 +167,15 @@ class ExtraNetworksPage:
</button>
""" for subdir in subdirs])
- for item in self.list_items():
+ self.items = {x["name"]: x for x in self.list_items()}
+ for item in self.items.values():
metadata = item.get("metadata")
if metadata:
self.metadata[item["name"]] = metadata
+ if "user_metadata" not in item:
+ self.read_user_metadata(item)
+
items_html += self.create_html_for_item(item, tabname)
if items_html == '':
@@ -143,6 +195,9 @@ class ExtraNetworksPage:
return res
+ def create_item(self, name, index=None):
+ raise NotImplementedError()
+
def list_items(self):
raise NotImplementedError()
@@ -158,7 +213,7 @@ class ExtraNetworksPage:
onclick = item.get("onclick", None)
if onclick is None:
- onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
+ onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
@@ -166,7 +221,9 @@ class ExtraNetworksPage:
metadata_button = ""
metadata = item.get("metadata")
if metadata:
- metadata_button = f"<div class='metadata-button' title='Show metadata' onclick='extraNetworksRequestMetadata(event, {json.dumps(self.name)}, {json.dumps(item['name'])})'></div>"
+ metadata_button = f"<div class='metadata-button card-button' title='Show internal metadata' onclick='extraNetworksRequestMetadata(event, {quote_js(self.name)}, {quote_js(item['name'])})'></div>"
+
+ edit_button = f"<div class='edit-button card-button' title='Edit metadata' onclick='extraNetworksEditUserMetadata(event, {quote_js(tabname)}, {quote_js(self.id_page)}, {quote_js(item['name'])})'></div>"
local_path = ""
filename = item.get("filename", "")
@@ -192,14 +249,15 @@ class ExtraNetworksPage:
"background_image": background_image,
"style": f"'display: none; {height}{width}'",
"prompt": item.get("prompt", None),
- "tabname": json.dumps(tabname),
- "local_preview": json.dumps(item["local_preview"]),
+ "tabname": quote_js(tabname),
+ "local_preview": quote_js(item["local_preview"]),
"name": item["name"],
"description": (item.get("description") or ""),
"card_clicked": onclick,
- "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
+ "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"',
"search_term": item.get("search_term", ""),
"metadata_button": metadata_button,
+ "edit_button": edit_button,
"search_only": " search_only" if search_only else "",
"sort_keys": sort_keys,
}
@@ -247,6 +305,9 @@ class ExtraNetworksPage:
pass
return None
+ def create_user_metadata_editor(self, ui, tabname):
+ return ui_extra_networks_user_metadata.UserMetadataEditor(ui, tabname, self)
+
def initialize():
extra_pages.clear()
@@ -297,19 +358,22 @@ def create_ui(container, button, tabname):
ui = ExtraNetworksUi()
ui.pages = []
ui.pages_contents = []
+ ui.user_metadata_editors = []
ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
ui.tabname = tabname
with gr.Tabs(elem_id=tabname+"_extra_tabs"):
for page in ui.stored_extra_pages:
- page_id = page.title.lower().replace(" ", "_")
-
- with gr.Tab(page.title, id=page_id):
- elem_id = f"{tabname}_{page_id}_cards_html"
+ with gr.Tab(page.title, id=page.id_page):
+ elem_id = f"{tabname}_{page.id_page}_cards_html"
page_elem = gr.HTML('Loading...', elem_id=elem_id)
ui.pages.append(page_elem)
- page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + json.dumps(tabname) + '); return []}', inputs=[], outputs=[])
+ page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[])
+
+ editor = page.create_user_metadata_editor(ui, tabname)
+ editor.create_ui()
+ ui.user_metadata_editors.append(editor)
gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", multiselect=False, visible=False, show_label=False, interactive=True)
@@ -363,6 +427,8 @@ def path_is_parent(parent_path, child_path):
def setup_ui(ui, gallery):
def save_preview(index, images, filename):
+ # this function is here for backwards compatibility and likely will be removed soon
+
if len(images) == 0:
print("There is no image in gallery to save as a preview.")
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
@@ -394,3 +460,7 @@ def setup_ui(ui, gallery):
outputs=[*ui.pages]
)
+ for editor in ui.user_metadata_editors:
+ editor.setup_ui(gallery)
+
+
diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py
index 8b9ab71b..e73b5b1f 100644
--- a/modules/ui_extra_networks_checkpoints.py
+++ b/modules/ui_extra_networks_checkpoints.py
@@ -1,8 +1,8 @@
import html
-import json
import os
from modules import shared, ui_extra_networks, sd_models
+from modules.ui_extra_networks import quote_js
class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
@@ -12,21 +12,24 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
def refresh(self):
shared.refresh_checkpoints()
+ def create_item(self, name, index=None):
+ checkpoint: sd_models.CheckpointInfo = sd_models.checkpoints_list.get(name)
+ path, ext = os.path.splitext(checkpoint.filename)
+ return {
+ "name": checkpoint.name_for_extra,
+ "filename": checkpoint.filename,
+ "preview": self.find_preview(path),
+ "description": self.find_description(path),
+ "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
+ "onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"',
+ "local_preview": f"{path}.{shared.opts.samples_format}",
+ "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)},
+
+ }
+
def list_items(self):
- checkpoint: sd_models.CheckpointInfo
- for index, (name, checkpoint) in enumerate(sd_models.checkpoints_list.items()):
- path, ext = os.path.splitext(checkpoint.filename)
- yield {
- "name": checkpoint.name_for_extra,
- "filename": path,
- "preview": self.find_preview(path),
- "description": self.find_description(path),
- "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
- "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
- "local_preview": f"{path}.{shared.opts.samples_format}",
- "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)},
-
- }
+ for index, name in enumerate(sd_models.checkpoints_list):
+ yield self.create_item(name, index)
def allowed_directories_for_previews(self):
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py
index 7c19b532..e53ccb42 100644
--- a/modules/ui_extra_networks_hypernets.py
+++ b/modules/ui_extra_networks_hypernets.py
@@ -1,7 +1,7 @@
-import json
import os
from modules import shared, ui_extra_networks
+from modules.ui_extra_networks import quote_js
class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
@@ -11,21 +11,24 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
def refresh(self):
shared.reload_hypernetworks()
+ def create_item(self, name, index=None):
+ full_path = shared.hypernetworks[name]
+ path, ext = os.path.splitext(full_path)
+
+ return {
+ "name": name,
+ "filename": full_path,
+ "preview": self.find_preview(path),
+ "description": self.find_description(path),
+ "search_term": self.search_terms_from_path(path),
+ "prompt": quote_js(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + quote_js(">"),
+ "local_preview": f"{path}.preview.{shared.opts.samples_format}",
+ "sort_keys": {'default': index, **self.get_sort_keys(path + ext)},
+ }
+
def list_items(self):
- for index, (name, path) in enumerate(shared.hypernetworks.items()):
- path, ext = os.path.splitext(path)
-
- yield {
- "name": name,
- "filename": path,
- "preview": self.find_preview(path),
- "description": self.find_description(path),
- "search_term": self.search_terms_from_path(path),
- "prompt": json.dumps(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
- "local_preview": f"{path}.preview.{shared.opts.samples_format}",
- "sort_keys": {'default': index, **self.get_sort_keys(path + ext)},
-
- }
+ for index, name in enumerate(shared.hypernetworks):
+ yield self.create_item(name, index)
def allowed_directories_for_previews(self):
return [shared.cmd_opts.hypernetwork_dir]
diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py
index 58a61c55..d1794e50 100644
--- a/modules/ui_extra_networks_textual_inversion.py
+++ b/modules/ui_extra_networks_textual_inversion.py
@@ -1,7 +1,7 @@
-import json
import os
from modules import ui_extra_networks, sd_hijack, shared
+from modules.ui_extra_networks import quote_js
class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
@@ -12,20 +12,24 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
def refresh(self):
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
+ def create_item(self, name, index=None):
+ embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
+
+ path, ext = os.path.splitext(embedding.filename)
+ return {
+ "name": name,
+ "filename": embedding.filename,
+ "preview": self.find_preview(path),
+ "description": self.find_description(path),
+ "search_term": self.search_terms_from_path(embedding.filename),
+ "prompt": quote_js(embedding.name),
+ "local_preview": f"{path}.preview.{shared.opts.samples_format}",
+ "sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)},
+ }
+
def list_items(self):
- for index, embedding in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings.values()):
- path, ext = os.path.splitext(embedding.filename)
- yield {
- "name": embedding.name,
- "filename": embedding.filename,
- "preview": self.find_preview(path),
- "description": self.find_description(path),
- "search_term": self.search_terms_from_path(embedding.filename),
- "prompt": json.dumps(embedding.name),
- "local_preview": f"{path}.preview.{shared.opts.samples_format}",
- "sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)},
-
- }
+ for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings):
+ yield self.create_item(name, index)
def allowed_directories_for_previews(self):
return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py
new file mode 100644
index 00000000..01ff4e4b
--- /dev/null
+++ b/modules/ui_extra_networks_user_metadata.py
@@ -0,0 +1,190 @@
+import datetime
+import html
+import json
+import os.path
+
+import gradio as gr
+
+from modules import generation_parameters_copypaste, images, sysinfo, errors
+
+
+class UserMetadataEditor:
+
+ def __init__(self, ui, tabname, page):
+ self.ui = ui
+ self.tabname = tabname
+ self.page = page
+ self.id_part = f"{self.tabname}_{self.page.id_page}_edit_user_metadata"
+
+ self.box = None
+
+ self.edit_name_input = None
+ self.button_edit = None
+
+ self.edit_name = None
+ self.edit_description = None
+ self.edit_notes = None
+ self.html_filedata = None
+ self.html_preview = None
+ self.html_status = None
+
+ self.button_cancel = None
+ self.button_replace_preview = None
+ self.button_save = None
+
+ def get_user_metadata(self, name):
+ item = self.page.items.get(name, {})
+
+ user_metadata = item.get('user_metadata', None)
+ if user_metadata is None:
+ user_metadata = {}
+ item['user_metadata'] = user_metadata
+
+ return user_metadata
+
+ def create_default_editor_elems(self):
+ with gr.Row():
+ with gr.Column(scale=2):
+ self.edit_name = gr.HTML(elem_classes="extra-network-name")
+ self.edit_description = gr.Textbox(label="Description", lines=4)
+ self.html_filedata = gr.HTML()
+
+ with gr.Column(scale=1, min_width=0):
+ self.html_preview = gr.HTML()
+
+ def create_default_buttons(self):
+
+ with gr.Row(elem_classes="edit-user-metadata-buttons"):
+ self.button_cancel = gr.Button('Cancel')
+ self.button_replace_preview = gr.Button('Replace preview', variant='primary')
+ self.button_save = gr.Button('Save', variant='primary')
+
+ self.html_status = gr.HTML(elem_classes="edit-user-metadata-status")
+
+ self.button_cancel.click(fn=None, _js="closePopup")
+
+ def get_card_html(self, name):
+ item = self.page.items.get(name, {})
+
+ preview_url = item.get("preview", None)
+
+ if not preview_url:
+ filename, _ = os.path.splitext(item["filename"])
+ preview_url = self.page.find_preview(filename)
+ item["preview"] = preview_url
+
+ if preview_url:
+ preview = f'''
+ <div class='card standalone-card-preview'>
+ <img src="{html.escape(preview_url)}" class="preview">
+ </div>
+ '''
+ else:
+ preview = "<div class='card standalone-card-preview'></div>"
+
+ return preview
+
+ def get_metadata_table(self, name):
+ item = self.page.items.get(name, {})
+ try:
+ filename = item["filename"]
+
+ stats = os.stat(filename)
+ params = [
+ ('File size: ', sysinfo.pretty_bytes(stats.st_size)),
+ ('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')),
+ ]
+
+ return params
+ except Exception as e:
+ errors.display(e, f"reading info for {name}")
+ return []
+
+ def put_values_into_components(self, name):
+ user_metadata = self.get_user_metadata(name)
+
+ try:
+ params = self.get_metadata_table(name)
+ except Exception as e:
+ errors.display(e, f"reading metadata info for {name}")
+ params = []
+
+ table = '<table class="file-metadata">' + "".join(f"<tr><th>{name}</th><td>{value}</td></tr>" for name, value in params) + '</table>'
+
+ return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', ''),
+
+ def write_user_metadata(self, name, metadata):
+ item = self.page.items.get(name, {})
+ filename = item.get("filename", None)
+ basename, ext = os.path.splitext(filename)
+
+ with open(basename + '.json', "w", encoding="utf8") as file:
+ json.dump(metadata, file)
+
+ def save_user_metadata(self, name, desc, notes):
+ user_metadata = self.get_user_metadata(name)
+ user_metadata["description"] = desc
+ user_metadata["notes"] = notes
+
+ self.write_user_metadata(name, user_metadata)
+
+ def setup_save_handler(self, button, func, components):
+ button\
+ .click(fn=func, inputs=[self.edit_name_input, *components], outputs=[])\
+ .then(fn=None, _js="function(name){closePopup(); extraNetworksRefreshSingleCard(" + json.dumps(self.page.name) + "," + json.dumps(self.tabname) + ", name);}", inputs=[self.edit_name_input], outputs=[])
+
+ def create_editor(self):
+ self.create_default_editor_elems()
+
+ self.edit_notes = gr.TextArea(label='Notes', lines=4)
+
+ self.create_default_buttons()
+
+ self.button_edit\
+ .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=[self.edit_name, self.edit_description, self.html_filedata, self.html_preview, self.edit_notes])\
+ .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])
+
+ self.setup_save_handler(self.button_save, self.save_user_metadata, [self.edit_description, self.edit_notes])
+
+ def create_ui(self):
+ with gr.Box(visible=False, elem_id=self.id_part, elem_classes="edit-user-metadata") as box:
+ self.box = box
+
+ self.edit_name_input = gr.Textbox("Edit user metadata card id", visible=False, elem_id=f"{self.id_part}_name")
+ self.button_edit = gr.Button("Edit user metadata", visible=False, elem_id=f"{self.id_part}_button")
+
+ self.create_editor()
+
+ def save_preview(self, index, gallery, name):
+ if len(gallery) == 0:
+ return self.get_card_html(name), "There is no image in gallery to save as a preview."
+
+ item = self.page.items.get(name, {})
+
+ index = int(index)
+ index = 0 if index < 0 else index
+ index = len(gallery) - 1 if index >= len(gallery) else index
+
+ img_info = gallery[index if index >= 0 else 0]
+ image = generation_parameters_copypaste.image_from_url_text(img_info)
+ geninfo, items = images.read_info_from_image(image)
+
+ images.save_image_with_geninfo(image, geninfo, item["local_preview"])
+
+ return self.get_card_html(name), ''
+
+ def setup_ui(self, gallery):
+ self.button_replace_preview.click(
+ fn=self.save_preview,
+ _js="function(x, y, z){return [selected_gallery_index(), y, z]}",
+ inputs=[self.edit_name_input, gallery, self.edit_name_input],
+ outputs=[self.html_preview, self.html_status]
+ ).then(
+ fn=None,
+ _js="function(name){extraNetworksRefreshSingleCard(" + json.dumps(self.page.name) + "," + json.dumps(self.tabname) + ", name);}",
+ inputs=[self.edit_name_input],
+ outputs=[]
+ )
+
+
+
diff --git a/style.css b/style.css
index 5073f0f0..4e22cfd6 100644
--- a/style.css
+++ b/style.css
@@ -227,20 +227,39 @@ button.custom-button{
align-self: end;
}
-.performance {
+.html-log .comments{
+ padding-top: 0.5em;
+}
+
+.html-log .comments:empty{
+ padding-top: 0;
+}
+
+.html-log .performance {
font-size: 0.85em;
color: #444;
+ display: flex;
}
-.performance p{
+.html-log .performance p{
display: inline-block;
}
-.performance .time {
- margin-right: 0;
+.html-log .performance p.time, .performance p.vram, .performance p.time abbr, .performance p.vram abbr {
+ margin-bottom: 0;
+ color: var(--block-title-text-color);
+}
+
+.html-log .performance p.time {
+}
+
+.html-log .performance p.vram {
+ margin-left: auto;
}
-.performance .vram {
+.html-log .performance .measurement{
+ color: var(--body-text-color);
+ font-weight: bold;
}
#txt2img_generate, #img2img_generate {
@@ -531,6 +550,9 @@ table.popup-table .link{
background-color: rgba(20, 20, 20, 0.95);
}
+.global-popup *{
+ box-sizing: border-box;
+}
.global-popup-close:before {
content: "×";
@@ -796,32 +818,42 @@ footer {
}
-.extra-network-cards .card .metadata-button:before, .extra-network-thumbs .card .metadata-button:before{
- content: "🛈";
-}
-.extra-network-cards .card .metadata-button, .extra-network-thumbs .card .metadata-button{
+.extra-network-cards .card .button-row, .extra-network-thumbs .card .button-row{
display: none;
position: absolute;
color: white;
right: 0;
}
-.extra-network-cards .card .metadata-button {
+.extra-network-cards .card:hover .button-row, .extra-network-thumbs .card:hover .button-row{
+ display: flex;
+}
+
+.extra-network-cards .card .card-button, .extra-network-thumbs .card .card-button{
+ color: white;
+}
+
+.extra-network-cards .card .metadata-button:before, .extra-network-thumbs .card .metadata-button:before{
+ content: "🛈";
+}
+
+.extra-network-cards .card .edit-button:before, .extra-network-thumbs .card .edit-button:before{
+ content: "🛠";
+}
+
+.extra-network-cards .card .card-button {
text-shadow: 2px 2px 3px black;
padding: 0.25em;
font-size: 22pt;
width: 1.5em;
}
-.extra-network-thumbs .card .metadata-button {
+.extra-network-thumbs .card .card-button {
text-shadow: 1px 1px 2px black;
padding: 0;
font-size: 16pt;
width: 1em;
top: -0.25em;
}
-.extra-network-cards .card:hover .metadata-button, .extra-network-thumbs .card:hover .metadata-button{
- display: inline-block;
-}
-.extra-network-cards .card .metadata-button:hover, .extra-network-thumbs .card .metadata-button:hover{
+.extra-network-cards .card .card-button:hover, .extra-network-thumbs .card .card-button:hover{
color: red;
}
@@ -842,7 +874,7 @@ footer {
position: relative;
}
-.extra-network-thumbs .card .preview{
+.extra-network-thumbs .card .preview, .standalone-card-preview.card .preview{
position: absolute;
object-fit: cover;
width: 100%;
@@ -886,7 +918,7 @@ footer {
word-break: break-all;
}
-.extra-network-cards .card{
+.extra-network-cards .card, .standalone-card-preview.card{
display: inline-block;
margin: 0.5em;
width: 16em;
@@ -970,3 +1002,33 @@ footer {
width: 100%;
height:100%;
}
+
+div.block.gradio-box.edit-user-metadata {
+ width: 56em;
+ background: var(--body-background-fill);
+ padding: 2em !important;
+}
+
+.edit-user-metadata .extra-network-name{
+ font-size: 18pt;
+ color: var(--body-text-color);
+}
+
+.edit-user-metadata .file-metadata{
+ color: var(--body-text-color);
+}
+
+.edit-user-metadata .file-metadata th{
+ text-align: left;
+}
+
+.edit-user-metadata .wrap.translucent{
+ background: var(--body-background-fill);
+}
+.edit-user-metadata .gradio-highlightedtext span{
+ word-break: break-word;
+}
+
+.edit-user-metadata-buttons{
+ margin-top: 1.5em;
+}