aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py88
-rw-r--r--modules/api/models.py4
-rw-r--r--modules/call_queue.py3
-rw-r--r--modules/cmd_args.py2
-rw-r--r--modules/codeformer_model.py4
-rw-r--r--modules/devices.py7
-rw-r--r--modules/esrgan_model.py23
-rw-r--r--modules/generation_parameters_copypaste.py29
-rw-r--r--modules/gfpgan_model.py2
-rw-r--r--modules/hypernetworks/hypernetwork.py24
-rw-r--r--modules/images.py28
-rw-r--r--modules/img2img.py5
-rw-r--r--modules/launch_utils.py4
-rw-r--r--modules/mac_specific.py23
-rw-r--r--modules/modelloader.py31
-rw-r--r--modules/paths.py14
-rw-r--r--modules/processing.py16
-rw-r--r--modules/realesrgan_model.py33
-rw-r--r--modules/scripts.py3
-rw-r--r--modules/sd_models.py7
-rw-r--r--modules/shared.py2
-rw-r--r--modules/textual_inversion/logging.py48
-rw-r--r--modules/txt2img.py6
-rw-r--r--modules/ui.py6
-rw-r--r--modules/ui_extensions.py11
25 files changed, 230 insertions, 193 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 2e49526e..41adaef7 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -14,7 +14,7 @@ from fastapi.encoders import jsonable_encoder
from secrets import compare_digest
import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart
from modules.api import models
from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
@@ -22,7 +22,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image
-from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
+from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_alisases
from modules.sd_vae import vae_dict
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models
@@ -30,13 +30,7 @@ from modules import devices
from typing import Dict, List, Any
import piexif
import piexif.helper
-
-
-def upscaler_to_index(name: str):
- try:
- return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
- except Exception as e:
- raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e
+from contextlib import closing
def script_name_to_index(name, scripts):
@@ -84,6 +78,8 @@ def encode_pil_to_base64(image):
image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
+ if image.mode == "RGBA":
+ image = image.convert("RGB")
parameters = image.info.get('parameters', None)
exif_bytes = piexif.dump({
"Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
@@ -209,6 +205,11 @@ class Api:
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
+ if shared.cmd_opts.add_stop_route:
+ self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
+ self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"])
+ self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"])
+
self.default_script_arg_txt2img = []
self.default_script_arg_img2img = []
@@ -324,19 +325,19 @@ class Api:
args.pop('save_images', None)
with self.queue_lock:
- p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
- p.scripts = script_runner
- p.outpath_grids = opts.outdir_txt2img_grids
- p.outpath_samples = opts.outdir_txt2img_samples
-
- shared.state.begin()
- if selectable_scripts is not None:
- p.script_args = script_args
- processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
- else:
- p.script_args = tuple(script_args) # Need to pass args as tuple here
- processed = process_images(p)
- shared.state.end()
+ with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
+ p.scripts = script_runner
+ p.outpath_grids = opts.outdir_txt2img_grids
+ p.outpath_samples = opts.outdir_txt2img_samples
+
+ shared.state.begin()
+ if selectable_scripts is not None:
+ p.script_args = script_args
+ processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
+ else:
+ p.script_args = tuple(script_args) # Need to pass args as tuple here
+ processed = process_images(p)
+ shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
@@ -380,20 +381,21 @@ class Api:
args.pop('save_images', None)
with self.queue_lock:
- p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
- p.init_images = [decode_base64_to_image(x) for x in init_images]
- p.scripts = script_runner
- p.outpath_grids = opts.outdir_img2img_grids
- p.outpath_samples = opts.outdir_img2img_samples
+ with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
+ p.init_images = [decode_base64_to_image(x) for x in init_images]
+ p.scripts = script_runner
+ p.outpath_grids = opts.outdir_img2img_grids
+ p.outpath_samples = opts.outdir_img2img_samples
+
+ shared.state.begin()
+ if selectable_scripts is not None:
+ p.script_args = script_args
+ processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
+ else:
+ p.script_args = tuple(script_args) # Need to pass args as tuple here
+ processed = process_images(p)
+ shared.state.end()
- shared.state.begin()
- if selectable_scripts is not None:
- p.script_args = script_args
- processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
- else:
- p.script_args = tuple(script_args) # Need to pass args as tuple here
- processed = process_images(p)
- shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
@@ -517,6 +519,10 @@ class Api:
return options
def set_config(self, req: Dict[str, Any]):
+ checkpoint_name = req.get("sd_model_checkpoint", None)
+ if checkpoint_name is not None and checkpoint_name not in checkpoint_alisases:
+ raise RuntimeError(f"model {checkpoint_name!r} not found")
+
for k, v in req.items():
shared.opts.set(k, v)
@@ -715,3 +721,15 @@ class Api:
def launch(self, server_name, port):
self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0)
+
+ def kill_webui(self):
+ restart.stop_program()
+
+ def restart_webui(self):
+ if restart.is_restartable():
+ restart.restart_program()
+ return Response(status_code=501)
+
+ def stop_webui(request):
+ shared.state.server_command = "stop"
+ return Response("Stopping.")
diff --git a/modules/api/models.py b/modules/api/models.py
index b3a745f0..b5683071 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -274,10 +274,6 @@ class PromptStyleItem(BaseModel):
prompt: Optional[str] = Field(title="Prompt")
negative_prompt: Optional[str] = Field(title="Negative Prompt")
-class ArtistItem(BaseModel):
- name: str = Field(title="Name")
- score: float = Field(title="Score")
- category: str = Field(title="Category")
class EmbeddingItem(BaseModel):
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
diff --git a/modules/call_queue.py b/modules/call_queue.py
index 1b5e5273..69bf63d2 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -1,3 +1,4 @@
+from functools import wraps
import html
import threading
import time
@@ -18,6 +19,7 @@ def wrap_queued_call(func):
def wrap_gradio_gpu_call(func, extra_outputs=None):
+ @wraps(func)
def f(*args, **kwargs):
# if the first argument is a string that says "task(...)", it is treated as a job id
@@ -45,6 +47,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
+ @wraps(func)
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
if run_memmon:
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index de905caa..624dcb4f 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -106,4 +106,4 @@ parser.add_argument("--skip-version-check", action='store_true', help="Do not ch
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
-parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
+parser.add_argument('--add-stop-route', action='store_true', help='enable server stop/restart/kill via api')
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index d974e4b8..f293acf5 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -15,7 +15,6 @@ model_dir = "Codeformer"
model_path = os.path.join(models_path, model_dir)
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
-have_codeformer = False
codeformer = None
@@ -123,9 +122,6 @@ def setup_model(dirname):
return restored_img
- global have_codeformer
- have_codeformer = True
-
global codeformer
codeformer = FaceRestorerCodeFormer(dirname)
shared.face_restorers.append(codeformer)
diff --git a/modules/devices.py b/modules/devices.py
index 1ed6ffdc..620ed1a6 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -15,13 +15,6 @@ def has_mps() -> bool:
else:
return mac_specific.has_mps
-def extract_device_id(args, name):
- for x in range(len(args)):
- if name in args[x]:
- return args[x + 1]
-
- return None
-
def get_cuda_device_string():
from modules import shared
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index 2fced999..02a1727d 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -1,15 +1,13 @@
-import os
+import sys
import numpy as np
import torch
from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
import modules.esrgan_model_arch as arch
from modules import modelloader, images, devices
-from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts
-
+from modules.upscaler import Upscaler, UpscalerData
def mod2normal(state_dict):
@@ -134,7 +132,7 @@ class UpscalerESRGAN(Upscaler):
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
scalers.append(scaler_data)
for file in model_paths:
- if "http" in file:
+ if file.startswith("http"):
name = self.model_name
else:
name = modelloader.friendly_name(file)
@@ -143,26 +141,25 @@ class UpscalerESRGAN(Upscaler):
self.scalers.append(scaler_data)
def do_upscale(self, img, selected_model):
- model = self.load_model(selected_model)
- if model is None:
+ try:
+ model = self.load_model(selected_model)
+ except Exception as e:
+ print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
return img
model.to(devices.device_esrgan)
img = esrgan_upscale(model, img)
return img
def load_model(self, path: str):
- if "http" in path:
- filename = load_file_from_url(
+ if path.startswith("http"):
+ # TODO: this doesn't use `path` at all?
+ filename = modelloader.load_file_from_url(
url=self.model_url,
model_dir=self.model_download_path,
file_name=f"{self.model_name}.pth",
- progress=True,
)
else:
filename = path
- if not os.path.exists(filename) or filename is None:
- print(f"Unable to load {self.model_path} from {filename}")
- return None
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index dd30a1b5..a3448be9 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -174,31 +174,6 @@ def send_image_and_dimensions(x):
return img, w, h
-
-def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
- """Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
-
- Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
- parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
-
- If the infotext has no hash, then a hypernet with the same name will be selected instead.
- """
- hypernet_name = hypernet_name.lower()
- if hypernet_hash is not None:
- # Try to match the hash in the name
- for hypernet_key in shared.hypernetworks.keys():
- result = re_hypernet_hash.search(hypernet_key)
- if result is not None and result[1] == hypernet_hash:
- return hypernet_key
- else:
- # Fall back to a hypernet with the same name
- for hypernet_key in shared.hypernetworks.keys():
- if hypernet_key.lower().startswith(hypernet_name):
- return hypernet_key
-
- return None
-
-
def restore_old_hires_fix_params(res):
"""for infotexts that specify old First pass size parameter, convert it into
width, height, and hr scale"""
@@ -332,10 +307,6 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
return res
-settings_map = {}
-
-
-
infotext_to_setting_name_mapping = [
('Clip skip', 'CLIP_stop_at_last_layers', ),
('Conditional mask weight', 'inpainting_mask_weight'),
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index 6ecd295c..8e0f13bd 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -25,7 +25,7 @@ def gfpgann():
return None
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
- if len(models) == 1 and "http" in models[0]:
+ if len(models) == 1 and models[0].startswith("http"):
model_file = models[0]
elif len(models) != 0:
latest_file = max(models, key=os.path.getctime)
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 5d12b449..51941c11 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -353,17 +353,6 @@ def load_hypernetworks(names, multipliers=None):
shared.loaded_hypernetworks.append(hypernetwork)
-def find_closest_hypernetwork_name(search: str):
- if not search:
- return None
- search = search.lower()
- applicable = [name for name in shared.hypernetworks if search in name.lower()]
- if not applicable:
- return None
- applicable = sorted(applicable, key=lambda name: len(name))
- return applicable[0]
-
-
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
@@ -446,18 +435,6 @@ def statistics(data):
return total_information, recent_information
-def report_statistics(loss_info:dict):
- keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
- for key in keys:
- try:
- print("Loss statistics for file " + key)
- info, recent = statistics(list(loss_info[key]))
- print(info)
- print(recent)
- except Exception as e:
- print(e)
-
-
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
# Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
@@ -770,7 +747,6 @@ Last saved image: {html.escape(last_saved_image)}<br/>
pbar.leave = False
pbar.close()
hypernetwork.eval()
- #report_statistics(loss_dict)
sd_hijack_checkpoint.remove()
diff --git a/modules/images.py b/modules/images.py
index 7bbfc3e0..cdfa6719 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -372,8 +372,8 @@ class FilenameGenerator:
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
+ 'user': lambda self: self.p.user,
'vae_filename': lambda self: self.get_vae_filename(),
-
}
default_time_format = '%Y%m%d%H%M%S'
@@ -497,13 +497,23 @@ def get_next_sequence_number(path, basename):
return result + 1
-def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None):
+def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None, pnginfo_section_name='parameters'):
+ """
+ Saves image to filename, including geninfo as text information for generation info.
+ For PNG images, geninfo is added to existing pnginfo dictionary using the pnginfo_section_name argument as key.
+ For JPG images, there's no dictionary and geninfo just replaces the EXIF description.
+ """
+
if extension is None:
extension = os.path.splitext(filename)[1]
image_format = Image.registered_extensions()[extension]
if extension.lower() == '.png':
+ existing_pnginfo = existing_pnginfo or {}
+ if opts.enable_pnginfo:
+ existing_pnginfo[pnginfo_section_name] = geninfo
+
if opts.enable_pnginfo:
pnginfo_data = PngImagePlugin.PngInfo()
for k, v in (existing_pnginfo or {}).items():
@@ -622,7 +632,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
"""
temp_file_path = f"{filename_without_extension}.tmp"
- save_image_with_geninfo(image_to_save, info, temp_file_path, extension, params.pnginfo)
+ save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name)
os.replace(temp_file_path, filename_without_extension + extension)
@@ -639,12 +649,18 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
ratio = image.width / image.height
-
+ resize_to = None
if oversize and ratio > 1:
- image = image.resize((round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)), LANCZOS)
+ resize_to = round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)
elif oversize:
- image = image.resize((round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)), LANCZOS)
+ resize_to = round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)
+ if resize_to is not None:
+ try:
+ # Resizing image with LANCZOS could throw an exception if e.g. image mode is I;16
+ image = image.resize(resize_to, LANCZOS)
+ except:
+ image = image.resize(resize_to)
try:
_atomically_save_image(image, fullfn_without_extension, ".jpg")
except Exception as e:
diff --git a/modules/img2img.py b/modules/img2img.py
index 2c497020..b07d7f2f 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -3,6 +3,7 @@ from pathlib import Path
import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
+import gradio as gr
from modules import sd_samplers
from modules.generation_parameters_copypaste import create_override_settings_dict
@@ -97,7 +98,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
processed_image.save(os.path.join(output_dir, filename))
-def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
+def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts)
is_batch = mode == 5
@@ -180,6 +181,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
p.scripts = modules.scripts.scripts_img2img
p.script_args = args
+ p.user = request.username
+
if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index 609a181e..97539e68 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -147,10 +147,10 @@ def git_clone(url, dir, name, commithash=None):
return
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
- run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
+ run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
return
- run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
+ run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
if commithash is not None:
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
diff --git a/modules/mac_specific.py b/modules/mac_specific.py
index d74c6b95..735847f5 100644
--- a/modules/mac_specific.py
+++ b/modules/mac_specific.py
@@ -4,16 +4,21 @@ from modules.sd_hijack_utils import CondFunc
from packaging import version
-# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
-# check `getattr` and try it for compatibility
+# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
+# use check `getattr` and try it for compatibility.
+# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availabilty,
+# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
def check_for_mps() -> bool:
- if not getattr(torch, 'has_mps', False):
- return False
- try:
- torch.zeros(1).to(torch.device("mps"))
- return True
- except Exception:
- return False
+ if version.parse(torch.__version__) <= version.parse("2.0.1"):
+ if not getattr(torch, 'has_mps', False):
+ return False
+ try:
+ torch.zeros(1).to(torch.device("mps"))
+ return True
+ except Exception:
+ return False
+ else:
+ return torch.backends.mps.is_available() and torch.backends.mps.is_built()
has_mps = check_for_mps()
diff --git a/modules/modelloader.py b/modules/modelloader.py
index 75f01247..098bcb79 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import os
import shutil
import importlib
@@ -8,6 +10,29 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale
from modules.paths import script_path, models_path
+def load_file_from_url(
+ url: str,
+ *,
+ model_dir: str,
+ progress: bool = True,
+ file_name: str | None = None,
+) -> str:
+ """Download a file from `url` into `model_dir`, using the file present if possible.
+
+ Returns the path to the downloaded file.
+ """
+ os.makedirs(model_dir, exist_ok=True)
+ if not file_name:
+ parts = urlparse(url)
+ file_name = os.path.basename(parts.path)
+ cached_file = os.path.abspath(os.path.join(model_dir, file_name))
+ if not os.path.exists(cached_file):
+ print(f'Downloading: "{url}" to {cached_file}\n')
+ from torch.hub import download_url_to_file
+ download_url_to_file(url, cached_file, progress=progress)
+ return cached_file
+
+
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
"""
A one-and done loader to try finding the desired models in specified directories.
@@ -46,9 +71,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
if model_url is not None and len(output) == 0:
if download_name is not None:
- from basicsr.utils.download_util import load_file_from_url
- dl = load_file_from_url(model_url, places[0], True, download_name)
- output.append(dl)
+ output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name))
else:
output.append(model_url)
@@ -59,7 +82,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
def friendly_name(file: str):
- if "http" in file:
+ if file.startswith("http"):
file = urlparse(file).path
file = os.path.basename(file)
diff --git a/modules/paths.py b/modules/paths.py
index 5171df4f..bada804e 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -38,17 +38,3 @@ for d, must_exist, what, options in path_dirs:
else:
sys.path.append(d)
paths[what] = d
-
-
-class Prioritize:
- def __init__(self, name):
- self.name = name
- self.path = None
-
- def __enter__(self):
- self.path = sys.path.copy()
- sys.path = [paths[self.name]] + sys.path
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- sys.path = self.path
- self.path = None
diff --git a/modules/processing.py b/modules/processing.py
index 8da73884..85e104aa 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -184,6 +184,8 @@ class StableDiffusionProcessing:
self.uc = None
self.c = None
+ self.user = None
+
@property
def sd_model(self):
return shared.sd_model
@@ -549,7 +551,7 @@ def program_version():
return res
-def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
+def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False):
index = position_in_batch + iteration * p.batch_size
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
@@ -585,13 +587,15 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
**p.extra_generation_params,
"Version": program_version() if opts.add_version_to_infotext else None,
+ "User": p.user if opts.add_user_name_to_info else None,
}
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
+ prompt_text = p.prompt if use_main_prompt else all_prompts[index]
negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
- return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
+ return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
def process_images(p: StableDiffusionProcessing) -> Processed:
@@ -663,8 +667,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
else:
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
- def infotext(iteration=0, position_in_batch=0):
- return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
+ def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):
+ return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()
@@ -824,7 +828,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
grid = images.image_grid(output_images, p.batch_size)
if opts.return_grid:
- text = infotext()
+ text = infotext(use_main_prompt=True)
infotexts.insert(0, text)
if opts.enable_pnginfo:
grid.info["parameters"] = text
@@ -832,7 +836,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
index_of_first_image = 1
if opts.grid_save:
- images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
+ images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
if not p.disable_extra_networks and p.extra_network_data:
extra_networks.deactivate(p, p.extra_network_data)
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index 2d27b321..0700b853 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -2,7 +2,6 @@ import os
import numpy as np
from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from modules.upscaler import Upscaler, UpscalerData
@@ -43,9 +42,10 @@ class UpscalerRealESRGAN(Upscaler):
if not self.enable:
return img
- info = self.load_model(path)
- if not os.path.exists(info.local_data_path):
- print(f"Unable to load RealESRGAN model: {info.name}")
+ try:
+ info = self.load_model(path)
+ except Exception:
+ errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
return img
upsampler = RealESRGANer(
@@ -63,20 +63,17 @@ class UpscalerRealESRGAN(Upscaler):
return image
def load_model(self, path):
- try:
- info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)
-
- if info is None:
- print(f"Unable to find model info: {path}")
- return None
-
- if info.local_data_path.startswith("http"):
- info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True)
-
- return info
- except Exception:
- errors.report("Error making Real-ESRGAN models list", exc_info=True)
- return None
+ for scaler in self.scalers:
+ if scaler.data_path == path:
+ if scaler.local_data_path.startswith("http"):
+ scaler.local_data_path = modelloader.load_file_from_url(
+ scaler.data_path,
+ model_dir=self.model_download_path,
+ )
+ if not os.path.exists(scaler.local_data_path):
+ raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}")
+ return scaler
+ raise ValueError(f"Unable to find model info: {path}")
def load_models(self, _):
return get_realesrgan_models(self)
diff --git a/modules/scripts.py b/modules/scripts.py
index 99bf836a..49e4a611 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -1,6 +1,7 @@
import os
import re
import sys
+import inspect
from collections import namedtuple
import gradio as gr
@@ -249,7 +250,7 @@ def load_scripts():
def register_scripts_from_module(module):
for script_class in module.__dict__.values():
- if type(script_class) != type:
+ if not inspect.isclass(script_class):
continue
if issubclass(script_class, Script):
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 6ff5d17d..f65f4e36 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -247,7 +247,12 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors":
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
- pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
+
+ if not shared.opts.disable_mmap_load_safetensors:
+ pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
+ else:
+ pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
+ pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
else:
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
diff --git a/modules/shared.py b/modules/shared.py
index a0862055..203ee1b9 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -376,6 +376,7 @@ options_templates.update(options_section(('system', "System"), {
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
+ "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
}))
options_templates.update(options_section(('training', "Training"), {
@@ -493,6 +494,7 @@ options_templates.update(options_section(('ui', "User interface"), {
options_templates.update(options_section(('infotext', "Infotext"), {
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
+ "add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
"disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
"infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
diff --git a/modules/textual_inversion/logging.py b/modules/textual_inversion/logging.py
index 734a4b6f..45823eb1 100644
--- a/modules/textual_inversion/logging.py
+++ b/modules/textual_inversion/logging.py
@@ -2,11 +2,51 @@ import datetime
import json
import os
-saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "gradient_step", "latent_sampling_method"}
-saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
-saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
+saved_params_shared = {
+ "batch_size",
+ "clip_grad_mode",
+ "clip_grad_value",
+ "create_image_every",
+ "data_root",
+ "gradient_step",
+ "initial_step",
+ "latent_sampling_method",
+ "learn_rate",
+ "log_directory",
+ "model_hash",
+ "model_name",
+ "num_of_dataset_images",
+ "steps",
+ "template_file",
+ "training_height",
+ "training_width",
+}
+saved_params_ti = {
+ "embedding_name",
+ "num_vectors_per_token",
+ "save_embedding_every",
+ "save_image_with_stored_embedding",
+}
+saved_params_hypernet = {
+ "activation_func",
+ "add_layer_norm",
+ "hypernetwork_name",
+ "layer_structure",
+ "save_hypernetwork_every",
+ "use_dropout",
+ "weight_init",
+}
saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
-saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"}
+saved_params_previews = {
+ "preview_cfg_scale",
+ "preview_height",
+ "preview_negative_prompt",
+ "preview_prompt",
+ "preview_sampler_index",
+ "preview_seed",
+ "preview_steps",
+ "preview_width",
+}
def save_settings_to_file(log_directory, all_params):
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 2e7d202d..6aa79f23 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -4,10 +4,10 @@ from modules.generation_parameters_copypaste import create_override_settings_dic
from modules.shared import opts, cmd_opts
import modules.shared as shared
from modules.ui import plaintext_to_html
+import gradio as gr
-
-def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args):
+def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts)
p = processing.StableDiffusionProcessingTxt2Img(
@@ -48,6 +48,8 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
p.scripts = modules.scripts.scripts_txt2img
p.script_args = args
+ p.user = request.username
+
if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
diff --git a/modules/ui.py b/modules/ui.py
index e2e3b6da..d4dc7910 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -155,7 +155,7 @@ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_di
img = Image.open(image)
filename = os.path.basename(image)
left, _ = os.path.splitext(filename)
- print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a'))
+ print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a', encoding='utf-8'))
return [gr.update(), None]
@@ -773,7 +773,7 @@ def create_ui():
selected_scale_tab = gr.State(value=0)
with gr.Tabs():
- with gr.Tab(label="Resize to") as tab_scale_to:
+ with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to:
with FormRow():
with gr.Column(elem_id="img2img_column_size", scale=4):
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
@@ -782,7 +782,7 @@ def create_ui():
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn")
- with gr.Tab(label="Resize by") as tab_scale_by:
+ with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by:
scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
with FormRow():
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index c7e0a866..278bf5e4 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -138,7 +138,10 @@ def extension_table():
<table id="extensions">
<thead>
<tr>
- <th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th>
+ <th>
+ <input class="gr-check-radio gr-checkbox all_extensions_toggle" type="checkbox" {'checked="checked"' if all(ext.enabled for ext in extensions.extensions) else ''} onchange="toggle_all_extensions(event)" />
+ <abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr>
+ </th>
<th>URL</th>
<th>Branch</th>
<th>Version</th>
@@ -170,7 +173,7 @@ def extension_table():
code += f"""
<tr>
- <td><label{style}><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
+ <td><label{style}><input class="gr-check-radio gr-checkbox extension_toggle" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''} onchange="toggle_extension(event)" />{html.escape(ext.name)}</label></td>
<td>{remote}</td>
<td>{ext.branch}</td>
<td>{version_link}</td>
@@ -568,9 +571,9 @@ def create_ui():
available_extensions_table = gr.HTML()
refresh_available_extensions_button.click(
- fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
+ fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update(), gr.update()]),
inputs=[available_extensions_index, hide_tags, sort_column],
- outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result, search_extensions_text],
+ outputs=[available_extensions_index, available_extensions_table, hide_tags, search_extensions_text, install_result],
)
install_extension_button.click(