aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/ui.py8
-rw-r--r--modules/ui_extra_networks.py37
-rw-r--r--modules/ui_extra_networks_checkpoints.py38
-rw-r--r--modules/ui_extra_networks_hypernets.py2
-rw-r--r--modules/ui_extra_networks_textual_inversion.py2
5 files changed, 81 insertions, 6 deletions
diff --git a/modules/ui.py b/modules/ui.py
index 4e082408..f1195692 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1560,6 +1560,14 @@ def create_ui():
outputs=[component, text_settings],
)
+ button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
+ button_set_checkpoint.click(
+ fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'),
+ _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
+ inputs=[component_dict['sd_model_checkpoint'], dummy_component],
+ outputs=[component_dict['sd_model_checkpoint'], text_settings],
+ )
+
component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
def get_settings_values():
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index c6ff889a..5730c879 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -1,4 +1,6 @@
import os.path
+import urllib.parse
+from pathlib import Path
from modules import shared
import gradio as gr
@@ -8,12 +10,31 @@ import html
from modules.generation_parameters_copypaste import image_from_url_text
extra_pages = []
+allowed_dirs = set()
def register_page(page):
"""registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
extra_pages.append(page)
+ allowed_dirs.clear()
+ allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))
+
+
+def add_pages_to_demo(app):
+ def fetch_file(filename: str = ""):
+ from starlette.responses import FileResponse
+
+ if not any([Path(x).resolve() in Path(filename).resolve().parents for x in allowed_dirs]):
+ raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
+
+ if os.path.splitext(filename)[1].lower() != ".png":
+ raise ValueError(f"File cannot be fetched: {filename}. Only png.")
+
+ # would profit from returning 304
+ return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
+
+ app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
class ExtraNetworksPage:
@@ -26,6 +47,9 @@ class ExtraNetworksPage:
def refresh(self):
pass
+ def link_preview(self, filename):
+ return "./sd_extra_networks/thumb?filename=" + urllib.parse.quote(filename.replace('\\', '/')) + "&mtime=" + str(os.path.getmtime(filename))
+
def create_html(self, tabname):
view = shared.opts.extra_networks_default_view
items_html = ''
@@ -54,13 +78,17 @@ class ExtraNetworksPage:
def create_html_for_item(self, item, tabname):
preview = item.get("preview", None)
+ onclick = item.get("onclick", None)
+ if onclick is None:
+ onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
+
args = {
"preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
- "prompt": item["prompt"],
+ "prompt": item.get("prompt", None),
"tabname": json.dumps(tabname),
"local_preview": json.dumps(item["local_preview"]),
"name": item["name"],
- "card_clicked": '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"',
+ "card_clicked": onclick,
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
}
@@ -143,7 +171,7 @@ def path_is_parent(parent_path, child_path):
parent_path = os.path.abspath(parent_path)
child_path = os.path.abspath(child_path)
- return os.path.commonpath([parent_path]) == os.path.commonpath([parent_path, child_path])
+ return child_path.startswith(parent_path)
def setup_ui(ui, gallery):
@@ -173,7 +201,8 @@ def setup_ui(ui, gallery):
ui.button_save_preview.click(
fn=save_preview,
- _js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}",
+ _js="function(x, y, z){return [selected_gallery_index(), y, z]}",
inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
outputs=[*ui.pages]
)
+
diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py
new file mode 100644
index 00000000..c66cb830
--- /dev/null
+++ b/modules/ui_extra_networks_checkpoints.py
@@ -0,0 +1,38 @@
+import html
+import json
+import os
+import urllib.parse
+
+from modules import shared, ui_extra_networks, sd_models
+
+
+class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
+ def __init__(self):
+ super().__init__('Checkpoints')
+
+ def refresh(self):
+ shared.refresh_checkpoints()
+
+ def list_items(self):
+ for name, checkpoint1 in sd_models.checkpoints_list.items():
+ checkpoint: sd_models.CheckpointInfo = checkpoint1
+ path, ext = os.path.splitext(checkpoint.filename)
+ previews = [path + ".png", path + ".preview.png"]
+
+ preview = None
+ for file in previews:
+ if os.path.isfile(file):
+ preview = self.link_preview(file)
+ break
+
+ yield {
+ "name": checkpoint.model_name,
+ "filename": path,
+ "preview": preview,
+ "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
+ "local_preview": path + ".png",
+ }
+
+ def allowed_directories_for_previews(self):
+ return [shared.cmd_opts.ckpt_dir, sd_models.model_path]
+
diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py
index 65d000cf..8c15f8eb 100644
--- a/modules/ui_extra_networks_hypernets.py
+++ b/modules/ui_extra_networks_hypernets.py
@@ -19,7 +19,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
preview = None
for file in previews:
if os.path.isfile(file):
- preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file))
+ preview = self.link_preview(file)
break
yield {
diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py
index dbd23d2d..a9d3064b 100644
--- a/modules/ui_extra_networks_textual_inversion.py
+++ b/modules/ui_extra_networks_textual_inversion.py
@@ -19,7 +19,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
preview = None
if os.path.isfile(preview_file):
- preview = "./file=" + preview_file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(preview_file))
+ preview = self.link_preview(preview_file)
yield {
"name": embedding.name,