aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/processing.py3
-rw-r--r--modules/sd_models.py24
-rw-r--r--modules/ui_extra_networks.py11
3 files changed, 35 insertions, 3 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 06e7a440..59717b4c 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -583,6 +583,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if state.job_count == -1:
state.job_count = p.n_iter
+ extra_network_data = None
for n in range(p.n_iter):
p.iteration = n
@@ -712,7 +713,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
- if not p.disable_extra_networks:
+ if not p.disable_extra_networks and extra_network_data:
extra_networks.deactivate(p, extra_network_data)
devices.torch_gc()
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 93959f55..5f57ec0c 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -210,6 +210,30 @@ def get_state_dict_from_checkpoint(pl_sd):
return pl_sd
+def read_metadata_from_safetensors(filename):
+ import json
+
+ with open(filename, mode="rb") as file:
+ metadata_len = file.read(8)
+ metadata_len = int.from_bytes(metadata_len, "little")
+ json_start = file.read(2)
+
+ assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file"
+ json_data = json_start + file.read(metadata_len-2)
+ json_obj = json.loads(json_data)
+
+ res = {}
+ for k, v in json_obj.get("__metadata__", {}).items():
+ res[k] = v
+ if isinstance(v, str) and v[0] == '{':
+ try:
+ res[k] = json.loads(v)
+ except Exception as e:
+ pass
+
+ return res
+
+
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors":
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 01df5e90..cdfd6f2a 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -30,8 +30,8 @@ def add_pages_to_demo(app):
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
ext = os.path.splitext(filename)[1].lower()
- if ext not in (".png", ".jpg"):
- raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg.")
+ if ext not in (".png", ".jpg", ".webp"):
+ raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
# would profit from returning 304
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
@@ -124,6 +124,12 @@ class ExtraNetworksPage:
if onclick is None:
onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
+ metadata_button = ""
+ metadata = item.get("metadata")
+ if metadata:
+ metadata_onclick = '"' + html.escape(f"""extraNetworksShowMetadata({json.dumps(metadata)}); return false;""") + '"'
+ metadata_button = f"<div class='metadata-button' title='Show metadata' onclick={metadata_onclick}></div>"
+
args = {
"preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
"prompt": item.get("prompt", None),
@@ -134,6 +140,7 @@ class ExtraNetworksPage:
"card_clicked": onclick,
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
"search_term": item.get("search_term", ""),
+ "metadata_button": metadata_button,
}
return self.card_page.format(**args)