aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/cmd_args.py1
-rw-r--r--modules/config_states.py3
-rw-r--r--modules/images.py4
-rw-r--r--modules/initialize.py4
-rw-r--r--modules/launch_utils.py4
-rw-r--r--modules/options.py3
-rw-r--r--modules/paths.py2
-rw-r--r--modules/paths_internal.py1
-rw-r--r--modules/restart.py4
-rw-r--r--modules/sd_hijack.py17
-rw-r--r--modules/sd_models.py24
-rw-r--r--modules/sd_unet.py4
-rw-r--r--modules/shared_state.py2
-rw-r--r--modules/ui_extensions.py2
-rw-r--r--modules/ui_extra_networks.py6
-rw-r--r--modules/ui_gradio_extensions.py6
16 files changed, 61 insertions, 26 deletions
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index 5be879dd..4e602a84 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -117,3 +117,4 @@ parser.add_argument('--api-server-stop', action='store_true', help='enable serve
parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False)
+parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", )
diff --git a/modules/config_states.py b/modules/config_states.py
index b766aef1..651793c7 100644
--- a/modules/config_states.py
+++ b/modules/config_states.py
@@ -4,7 +4,6 @@ Supports saving and restoring webui and extensions from a known working set of c
import os
import json
-import time
import tqdm
from datetime import datetime
@@ -38,7 +37,7 @@ def list_config_states():
config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
for cs in config_states:
- timestamp = time.asctime(time.gmtime(cs["created_at"]))
+ timestamp = datetime.fromtimestamp(cs["created_at"]).strftime('%Y-%m-%d %H:%M:%S')
name = cs.get("name", "Config")
full_name = f"{name}: {timestamp}"
all_config_states[full_name] = cs
diff --git a/modules/images.py b/modules/images.py
index 592dfacf..daf4eebe 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -561,6 +561,8 @@ def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_p
})
piexif.insert(exif_bytes, filename)
+ elif extension.lower() == ".gif":
+ image.save(filename, format=image_format, comment=geninfo)
else:
image.save(filename, format=image_format, quality=opts.jpeg_quality)
@@ -739,6 +741,8 @@ def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
if exif_comment:
items['exif comment'] = exif_comment
geninfo = exif_comment
+ elif "comment" in items: # for gif
+ geninfo = items["comment"].decode('utf8', errors="ignore")
for field in IGNORED_INFO_KEYS:
items.pop(field, None)
diff --git a/modules/initialize.py b/modules/initialize.py
index f24f7637..ac95fc6f 100644
--- a/modules/initialize.py
+++ b/modules/initialize.py
@@ -151,8 +151,8 @@ def initialize_rest(*, reload_script_modules=False):
from modules import devices
devices.first_time_calculation()
-
- Thread(target=load_model).start()
+ if not shared.cmd_opts.skip_load_model_at_start:
+ Thread(target=load_model).start()
from modules import shared_items
shared_items.reload_hypernetworks()
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index 6e54d063..8cdbafa5 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -64,7 +64,7 @@ Use --skip-python-version-check to suppress this warning.
@lru_cache()
def commit_hash():
try:
- return subprocess.check_output([git, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
+ return subprocess.check_output([git, "-C", script_path, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
except Exception:
return "<none>"
@@ -72,7 +72,7 @@ def commit_hash():
@lru_cache()
def git_tag():
try:
- return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
+ return subprocess.check_output([git, "-C", script_path, "describe", "--tags"], shell=False, encoding='utf8').strip()
except Exception:
try:
diff --git a/modules/options.py b/modules/options.py
index e75916d2..ab40aff7 100644
--- a/modules/options.py
+++ b/modules/options.py
@@ -210,7 +210,8 @@ class Options:
def add_option(self, key, info):
self.data_labels[key] = info
- self.data[key] = info.default
+ if key not in self.data:
+ self.data[key] = info.default
def reorder(self):
"""reorder settings so that all items related to section always go together"""
diff --git a/modules/paths.py b/modules/paths.py
index 25052339..187b9496 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -1,6 +1,6 @@
import os
import sys
-from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir # noqa: F401
+from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, cwd # noqa: F401
import modules.safe # noqa: F401
diff --git a/modules/paths_internal.py b/modules/paths_internal.py
index 005a9b0a..89131a54 100644
--- a/modules/paths_internal.py
+++ b/modules/paths_internal.py
@@ -8,6 +8,7 @@ import shlex
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
sys.argv += shlex.split(commandline_args)
+cwd = os.getcwd()
modules_path = os.path.dirname(os.path.realpath(__file__))
script_path = os.path.dirname(modules_path)
diff --git a/modules/restart.py b/modules/restart.py
index 18eacaf3..2dd6493b 100644
--- a/modules/restart.py
+++ b/modules/restart.py
@@ -14,7 +14,9 @@ def is_restartable() -> bool:
def restart_program() -> None:
"""creates file tmp/restart and immediately stops the process, which webui.bat/webui.sh interpret as a command to start webui again"""
- (Path(script_path) / "tmp" / "restart").touch()
+ tmpdir = Path(script_path) / "tmp"
+ tmpdir.mkdir(parents=True, exist_ok=True)
+ (tmpdir / "restart").touch()
stop_program()
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 592f0055..22a1eb5c 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -2,7 +2,7 @@ import torch
from torch.nn.functional import silu
from types import MethodType
-from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
+from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
@@ -10,6 +10,7 @@ from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hija
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
import ldm.modules.diffusionmodules.openaimodel
+import ldm.models.diffusion.ddpm
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
import ldm.modules.encoders.modules
@@ -37,6 +38,8 @@ ldm.models.diffusion.ddpm.print = shared.ldm_print
optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None
+ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
+sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
def list_optimizers():
new_optimizers = script_callbacks.list_optimizers_callback()
@@ -239,10 +242,13 @@ class StableDiffusionModelHijack:
self.layers = flatten(m)
- if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
- ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
+ if isinstance(m, ldm.models.diffusion.ddpm.LatentDiffusion):
+ sd_unet.original_forward = ldm_original_forward
+ elif isinstance(m, sgm.models.diffusion.DiffusionEngine):
+ sd_unet.original_forward = sgm_original_forward
+ else:
+ sd_unet.original_forward = None
- ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
def undo_hijack(self, m):
conditioner = getattr(m, 'conditioner', None)
@@ -279,7 +285,8 @@ class StableDiffusionModelHijack:
self.layers = None
self.clip = None
- ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
+ sd_unet.original_forward = None
+
def apply_circular(self, enable):
if self.circular_enabled == enable:
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 930d0bee..5ef7aa13 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -7,16 +7,17 @@ import threading
import torch
import re
import safetensors.torch
-from omegaconf import OmegaConf
+from omegaconf import OmegaConf, ListConfig
from os import mkdir
from urllib import request
import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
-from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack
+from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
from modules.timer import Timer
import tomesd
+import numpy as np
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
@@ -129,9 +130,12 @@ except Exception:
def setup_model():
+ """called once at startup to do various one-time tasks related to SD models"""
+
os.makedirs(model_path, exist_ok=True)
enable_midas_autodownload()
+ patch_given_betas()
def checkpoint_tiles(use_short=False):
@@ -309,6 +313,8 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
if checkpoint_info in checkpoints_loaded:
# use checkpoint cache
print(f"Loading weights [{sd_model_hash}] from cache")
+ # move to end as latest
+ checkpoints_loaded.move_to_end(checkpoint_info)
return checkpoints_loaded[checkpoint_info]
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
@@ -453,6 +459,20 @@ def enable_midas_autodownload():
midas.api.load_model = load_model_wrapper
+def patch_given_betas():
+ import ldm.models.diffusion.ddpm
+
+ def patched_register_schedule(*args, **kwargs):
+ """a modified version of register_schedule function that converts plain list from Omegaconf into numpy"""
+
+ if isinstance(args[1], ListConfig):
+ args = (args[0], np.array(args[1]), *args[2:])
+
+ original_register_schedule(*args, **kwargs)
+
+ original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
+
+
def repair_config(sd_config):
if not hasattr(sd_config.model.params, "use_ema"):
diff --git a/modules/sd_unet.py b/modules/sd_unet.py
index 5525cfbc..6a7bc9e2 100644
--- a/modules/sd_unet.py
+++ b/modules/sd_unet.py
@@ -1,11 +1,11 @@
import torch.nn
-import ldm.modules.diffusionmodules.openaimodel
from modules import script_callbacks, shared, devices
unet_options = []
current_unet_option = None
current_unet = None
+original_forward = None
def list_unets():
@@ -88,5 +88,5 @@ def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
if current_unet is not None:
return current_unet.forward(x, timesteps, context, *args, **kwargs)
- return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
+ return original_forward(self, x, timesteps, context, *args, **kwargs)
diff --git a/modules/shared_state.py b/modules/shared_state.py
index d272ee5b..a68789cc 100644
--- a/modules/shared_state.py
+++ b/modules/shared_state.py
@@ -103,6 +103,7 @@ class State:
def begin(self, job: str = "(unknown)"):
self.sampling_step = 0
+ self.time_start = time.time()
self.job_count = -1
self.processing_has_refined_job_count = False
self.job_no = 0
@@ -114,7 +115,6 @@ class State:
self.skipped = False
self.interrupted = False
self.textinfo = None
- self.time_start = time.time()
self.job = job
devices.torch_gc()
log.info("Starting job %s", job)
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index 2e8c1d6d..c0a73b57 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -197,7 +197,7 @@ def update_config_states_table(state_name):
config_state = config_states.all_config_states[state_name]
config_name = config_state.get("name", "Config")
- created_date = time.asctime(time.gmtime(config_state["created_at"]))
+ created_date = datetime.fromtimestamp(config_state["created_at"]).strftime('%Y-%m-%d %H:%M:%S')
filepath = config_state.get("filepath", "<unknown>")
try:
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 21eed6a1..3eee371b 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -213,9 +213,9 @@ class ExtraNetworksPage:
metadata_button = ""
metadata = item.get("metadata")
if metadata:
- metadata_button = f"<div class='metadata-button card-button' title='Show internal metadata' onclick='extraNetworksRequestMetadata(event, {quote_js(self.name)}, {quote_js(item['name'])})'></div>"
+ metadata_button = f"<div class='metadata-button card-button' title='Show internal metadata' onclick='extraNetworksRequestMetadata(event, {quote_js(self.name)}, {quote_js(html.escape(item['name']))})'></div>"
- edit_button = f"<div class='edit-button card-button' title='Edit metadata' onclick='extraNetworksEditUserMetadata(event, {quote_js(tabname)}, {quote_js(self.id_page)}, {quote_js(item['name'])})'></div>"
+ edit_button = f"<div class='edit-button card-button' title='Edit metadata' onclick='extraNetworksEditUserMetadata(event, {quote_js(tabname)}, {quote_js(self.id_page)}, {quote_js(html.escape(item['name']))})'></div>"
local_path = ""
filename = item.get("filename", "")
@@ -235,7 +235,7 @@ class ExtraNetworksPage:
if search_only and shared.opts.extra_networks_hidden_models == "Never":
return ""
- sort_keys = " ".join([html.escape(f'data-sort-{k}={v}') for k, v in item.get("sort_keys", {}).items()]).strip()
+ sort_keys = " ".join([f'data-sort-{k}="{html.escape(str(v))}"' for k, v in item.get("sort_keys", {}).items()]).strip()
args = {
"background_image": background_image,
diff --git a/modules/ui_gradio_extensions.py b/modules/ui_gradio_extensions.py
index b824b113..0d368f8b 100644
--- a/modules/ui_gradio_extensions.py
+++ b/modules/ui_gradio_extensions.py
@@ -2,12 +2,12 @@ import os
import gradio as gr
from modules import localization, shared, scripts
-from modules.paths import script_path, data_path
+from modules.paths import script_path, data_path, cwd
def webpath(fn):
- if fn.startswith(script_path):
- web_path = os.path.relpath(fn, script_path).replace('\\', '/')
+ if fn.startswith(cwd):
+ web_path = os.path.relpath(fn, cwd)
else:
web_path = os.path.abspath(fn)