aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGELOG.md25
-rw-r--r--README.md3
-rw-r--r--extensions-builtin/Lora/network.py1
-rw-r--r--extensions-builtin/Lora/networks.py5
-rw-r--r--javascript/badScaleChecker.js108
-rw-r--r--modules/api/api.py6
-rw-r--r--modules/api/models.py6
-rw-r--r--modules/cmd_args.py1
-rw-r--r--modules/extensions.py10
-rw-r--r--modules/img2img.py2
-rw-r--r--modules/processing.py41
-rw-r--r--modules/scripts.py32
-rw-r--r--modules/sd_disable_initialization.py106
-rw-r--r--modules/sd_hijack_clip.py9
-rw-r--r--modules/sd_models.py16
-rw-r--r--modules/sd_samplers_extra.py74
-rw-r--r--modules/sd_samplers_kdiffusion.py8
-rw-r--r--modules/ui_extra_networks.py2
-rw-r--r--webui.py6
19 files changed, 306 insertions, 155 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 63a2c7d3..b18c6867 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,28 @@
+## 1.5.1
+
+### Minor:
+ * support parsing text encoder blocks in some new LoRAs
+ * delete scale checker script due to user demand
+
+### Extensions and API:
+ * add postprocess_batch_list script callback
+
+### Bug Fixes:
+ * fix TI training for SD1
+ * fix reload altclip model error
+ * prepend the pythonpath instead of overriding it
+ * fix typo in SD_WEBUI_RESTARTING
+ * if txt2img/img2img raises an exception, finally call state.end()
+ * fix composable diffusion weight parsing
+ * restyle Startup profile for black users
+ * fix webui not launching with --nowebui
+ * catch exception for non git extensions
+ * fix some options missing from /sdapi/v1/options
+ * fix for extension update status always saying "unknown"
+ * fix display of extra network cards that have `<>` in the name
+ * update lora extension to work with python 3.8
+
+
## 1.5.0
### Features:
diff --git a/README.md b/README.md
index b796d150..2fd6e425 100644
--- a/README.md
+++ b/README.md
@@ -88,7 +88,7 @@ A browser interface based on Gradio library for Stable Diffusion.
- [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions
- Now without any bad letters!
- Load checkpoints in safetensors format
-- Eased resolution restriction: generated image's domension must be a multiple of 8 rather than 64
+- Eased resolution restriction: generated image's dimension must be a multiple of 8 rather than 64
- Now with a license!
- Reorder elements in the UI from settings screen
@@ -169,5 +169,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
- LyCORIS - KohakuBlueleaf
+- Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You)
diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py
index 8ecfa29a..0a18d69e 100644
--- a/extensions-builtin/Lora/network.py
+++ b/extensions-builtin/Lora/network.py
@@ -1,3 +1,4 @@
+from __future__ import annotations
import os
from collections import namedtuple
import enum
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index af8188e3..17cbe1bb 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -163,6 +163,11 @@ def load_network(name, network_on_disk):
key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
+ # some SD1 Loras also have correct compvis keys
+ if sd_module is None:
+ key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
+ sd_module = shared.sd_model.network_layer_mapping.get(key, None)
+
if sd_module is None:
keys_failed_to_match[key_network] = key
continue
diff --git a/javascript/badScaleChecker.js b/javascript/badScaleChecker.js
deleted file mode 100644
index 625ad309..00000000
--- a/javascript/badScaleChecker.js
+++ /dev/null
@@ -1,108 +0,0 @@
-(function() {
- var ignore = localStorage.getItem("bad-scale-ignore-it") == "ignore-it";
-
- function getScale() {
- var ratio = 0,
- screen = window.screen,
- ua = navigator.userAgent.toLowerCase();
-
- if (window.devicePixelRatio !== undefined) {
- ratio = window.devicePixelRatio;
- } else if (~ua.indexOf('msie')) {
- if (screen.deviceXDPI && screen.logicalXDPI) {
- ratio = screen.deviceXDPI / screen.logicalXDPI;
- }
- } else if (window.outerWidth !== undefined && window.innerWidth !== undefined) {
- ratio = window.outerWidth / window.innerWidth;
- }
-
- return ratio == 0 ? 0 : Math.round(ratio * 100);
- }
-
- var showing = false;
-
- var div = document.createElement("div");
- div.style.position = "fixed";
- div.style.top = "0px";
- div.style.left = "0px";
- div.style.width = "100vw";
- div.style.backgroundColor = "firebrick";
- div.style.textAlign = "center";
- div.style.zIndex = 99;
-
- var b = document.createElement("b");
- b.innerHTML = 'Bad Scale: ??% ';
-
- div.appendChild(b);
-
- var note1 = document.createElement("p");
- note1.innerHTML = "Change your browser or your computer settings!";
- note1.title = 'Just make sure "computer-scale" * "browser-scale" = 100% ,\n' +
- "you can keep your computer-scale and only change this page's scale,\n" +
- "for example: your computer-scale is 125%, just use [\"CTRL\"+\"-\"] to make your browser-scale of this page to 80%.";
- div.appendChild(note1);
-
- var note2 = document.createElement("p");
- note2.innerHTML = " Otherwise, it will cause this page to not function properly!";
- note2.title = "When you click \"Copy image to: [inpaint sketch]\" in some img2img's tab,\n" +
- "if scale<100% the canvas will be invisible,\n" +
- "else if scale>100% this page will take large amount of memory and CPU performance.";
- div.appendChild(note2);
-
- var btn = document.createElement("button");
- btn.innerHTML = "Click here to ignore";
-
- div.appendChild(btn);
-
- function tryShowTopBar(scale) {
- if (showing) return;
-
- b.innerHTML = 'Bad Scale: ' + scale + '% ';
-
- var updateScaleTimer = setInterval(function() {
- var newScale = getScale();
- b.innerHTML = 'Bad Scale: ' + newScale + '% ';
- if (newScale == 100) {
- var p = div.parentNode;
- if (p != null) p.removeChild(div);
- showing = false;
- clearInterval(updateScaleTimer);
- check();
- }
- }, 999);
-
- btn.onclick = function() {
- clearInterval(updateScaleTimer);
- var p = div.parentNode;
- if (p != null) p.removeChild(div);
- ignore = true;
- showing = false;
- localStorage.setItem("bad-scale-ignore-it", "ignore-it");
- };
-
- document.body.appendChild(div);
- }
-
- function check() {
- if (!ignore) {
- var timer = setInterval(function() {
- var scale = getScale();
- if (scale != 100 && !ignore) {
- tryShowTopBar(scale);
- clearInterval(timer);
- }
- if (ignore) {
- clearInterval(timer);
- }
- }, 999);
- }
- }
-
- if (document.readyState != "complete") {
- document.onreadystatechange = function() {
- if (document.readyState != "complete") check();
- };
- } else {
- check();
- }
-})();
diff --git a/modules/api/api.py b/modules/api/api.py
index 09166df2..908c4514 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -344,6 +344,7 @@ class Api:
processed = process_images(p)
finally:
shared.state.end()
+ shared.total_tqdm.clear()
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
@@ -403,6 +404,7 @@ class Api:
processed = process_images(p)
finally:
shared.state.end()
+ shared.total_tqdm.clear()
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
@@ -729,9 +731,9 @@ class Api:
cuda = {'error': f'{err}'}
return models.MemoryResponse(ram=ram, cuda=cuda)
- def launch(self, server_name, port):
+ def launch(self, server_name, port, root_path):
self.app.include_router(self.router)
- uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive)
+ uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)
def kill_webui(self):
restart.stop_program()
diff --git a/modules/api/models.py b/modules/api/models.py
index bf97b1a3..800c9b93 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -208,11 +208,9 @@ class PreprocessResponse(BaseModel):
fields = {}
for key, metadata in opts.data_labels.items():
value = opts.data.get(key)
- optType = opts.typemap.get(type(metadata.default), type(metadata.default))
+ optType = opts.typemap.get(type(metadata.default), type(metadata.default)) if metadata.default else Any
- if metadata.default is None:
- pass
- elif metadata is not None:
+ if metadata is not None:
fields.update({key: (Optional[optType], Field(default=metadata.default, description=metadata.label))})
else:
fields.update({key: (Optional[optType], Field())})
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index dd5fadc4..cb4ec5f7 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -67,6 +67,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="pre
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
+parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model")
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
diff --git a/modules/extensions.py b/modules/extensions.py
index c561159a..3ad5ed53 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -56,10 +56,12 @@ class Extension:
self.do_read_info_from_repo()
return self.to_dict()
-
- d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
- self.from_dict(d)
- self.status = 'unknown'
+ try:
+ d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
+ self.from_dict(d)
+ except FileNotFoundError:
+ pass
+ self.status = 'unknown' if self.status == '' else self.status
def do_read_info_from_repo(self):
repo = None
diff --git a/modules/img2img.py b/modules/img2img.py
index a811e7a4..132cd100 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -20,7 +20,7 @@ import modules.scripts
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
processing.fix_seed(p)
- images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp")))
+ images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff")))
is_inpaint_batch = False
if inpaint_mask_dir:
diff --git a/modules/processing.py b/modules/processing.py
index a74a5302..b0992ee1 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -600,8 +600,12 @@ def program_version():
return res
-def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False):
- index = position_in_batch + iteration * p.batch_size
+def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False, index=None, all_negative_prompts=None):
+ if index is None:
+ index = position_in_batch + iteration * p.batch_size
+
+ if all_negative_prompts is None:
+ all_negative_prompts = p.all_negative_prompts
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
enable_hr = getattr(p, 'enable_hr', False)
@@ -617,12 +621,12 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Sampler": p.sampler_name,
"CFG scale": p.cfg_scale,
"Image CFG scale": getattr(p, 'image_cfg_scale', None),
- "Seed": all_seeds[index],
+ "Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra),
- "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
+ "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None),
@@ -642,7 +646,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
prompt_text = p.prompt if use_main_prompt else all_prompts[index]
- negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
+ negative_prompt_text = f"\nNegative prompt: {all_negative_prompts[index]}" if all_negative_prompts[index] else ""
return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
@@ -716,9 +720,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
else:
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
- def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):
- return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
-
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()
@@ -806,6 +807,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None:
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
+ p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+ p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+
+ batch_params = scripts.PostprocessBatchListArgs(list(x_samples_ddim))
+ p.scripts.postprocess_batch_list(p, batch_params, batch_number=n)
+ x_samples_ddim = batch_params.images
+
+ def infotext(index=0, use_main_prompt=False):
+ return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts)
+
for i, x_sample in enumerate(x_samples_ddim):
p.batch_index = i
@@ -814,7 +825,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.restore_faces:
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
- images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
+ images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-face-restoration")
devices.torch_gc()
@@ -831,15 +842,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.color_corrections is not None and i < len(p.color_corrections):
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
- images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
+ images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
if opts.samples_save and not p.do_not_save_samples:
- images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p)
+ images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
- text = infotext(n, i)
+ text = infotext(i)
infotexts.append(text)
if opts.enable_pnginfo:
image.info["parameters"] = text
@@ -850,10 +861,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
if opts.save_mask:
- images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
+ images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
if opts.save_mask_composite:
- images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask-composite")
+ images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
if opts.return_mask:
output_images.append(image_mask)
@@ -894,7 +905,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p,
images_list=output_images,
seed=p.all_seeds[0],
- info=infotext(),
+ info=infotexts[0],
comments="".join(f"{comment}\n" for comment in comments),
subseed=p.all_subseeds[0],
index_of_first_image=index_of_first_image,
diff --git a/modules/scripts.py b/modules/scripts.py
index f34240a0..5b4edcac 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -16,6 +16,11 @@ class PostprocessImageArgs:
self.image = image
+class PostprocessBatchListArgs:
+ def __init__(self, images):
+ self.images = images
+
+
class Script:
name = None
"""script's internal name derived from title"""
@@ -156,6 +161,25 @@ class Script:
pass
+ def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs):
+ """
+ Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor.
+ This is useful when you want to update the entire batch instead of individual images.
+
+ You can modify the postprocessing object (pp) to update the images in the batch, remove images, add images, etc.
+ If the number of images is different from the batch size when returning,
+ then the script has the responsibility to also update the following attributes in the processing object (p):
+ - p.prompts
+ - p.negative_prompts
+ - p.seeds
+ - p.subseeds
+
+ **kwargs will have same items as process_batch, and also:
+ - batch_number - index of current batch, from 0 to number of batches-1
+ """
+
+ pass
+
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
"""
Called for every image after it has been generated.
@@ -536,6 +560,14 @@ class ScriptRunner:
except Exception:
errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)
+ def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.postprocess_batch_list(p, pp, *script_args, **kwargs)
+ except Exception:
+ errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
+
def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts:
try:
diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py
index 9fc89dc6..695c5736 100644
--- a/modules/sd_disable_initialization.py
+++ b/modules/sd_disable_initialization.py
@@ -3,8 +3,31 @@ import open_clip
import torch
import transformers.utils.hub
+from modules import shared
-class DisableInitialization:
+
+class ReplaceHelper:
+ def __init__(self):
+ self.replaced = []
+
+ def replace(self, obj, field, func):
+ original = getattr(obj, field, None)
+ if original is None:
+ return None
+
+ self.replaced.append((obj, field, original))
+ setattr(obj, field, func)
+
+ return original
+
+ def restore(self):
+ for obj, field, original in self.replaced:
+ setattr(obj, field, original)
+
+ self.replaced.clear()
+
+
+class DisableInitialization(ReplaceHelper):
"""
When an object of this class enters a `with` block, it starts:
- preventing torch's layer initialization functions from working
@@ -21,7 +44,7 @@ class DisableInitialization:
"""
def __init__(self, disable_clip=True):
- self.replaced = []
+ super().__init__()
self.disable_clip = disable_clip
def replace(self, obj, field, func):
@@ -86,8 +109,81 @@ class DisableInitialization:
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
def __exit__(self, exc_type, exc_val, exc_tb):
- for obj, field, original in self.replaced:
- setattr(obj, field, original)
+ self.restore()
- self.replaced.clear()
+class InitializeOnMeta(ReplaceHelper):
+ """
+ Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device,
+ which results in those parameters having no values and taking no memory. model.to() will be broken and
+ will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict.
+
+ Usage:
+ ```
+ with sd_disable_initialization.InitializeOnMeta():
+ sd_model = instantiate_from_config(sd_config.model)
+ ```
+ """
+
+ def __enter__(self):
+ if shared.cmd_opts.disable_model_loading_ram_optimization:
+ return
+
+ def set_device(x):
+ x["device"] = "meta"
+ return x
+
+ linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs)))
+ conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs)))
+ mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs)))
+ self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None)
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.restore()
+
+
+class LoadStateDictOnMeta(ReplaceHelper):
+ """
+ Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device.
+ As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory.
+ Meant to be used together with InitializeOnMeta above.
+
+ Usage:
+ ```
+ with sd_disable_initialization.LoadStateDictOnMeta(state_dict):
+ model.load_state_dict(state_dict, strict=False)
+ ```
+ """
+
+ def __init__(self, state_dict, device):
+ super().__init__()
+ self.state_dict = state_dict
+ self.device = device
+
+ def __enter__(self):
+ if shared.cmd_opts.disable_model_loading_ram_optimization:
+ return
+
+ sd = self.state_dict
+ device = self.device
+
+ def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs):
+ params = [(name, param) for name, param in self._parameters.items() if param is not None and param.is_meta]
+
+ for name, param in params:
+ if param.is_meta:
+ self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device), requires_grad=param.requires_grad)
+
+ original(self, state_dict, prefix, *args, **kwargs)
+
+ for name, _ in params:
+ key = prefix + name
+ if key in sd:
+ del sd[key]
+
+ linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
+ conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
+ mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.restore()
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index 5443e609..16a5500e 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -270,12 +270,17 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
z = self.encode_with_transformers(tokens)
+ pooled = getattr(z, 'pooled', None)
+
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
original_mean = z.mean()
- z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
+ z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
- z *= (original_mean / new_mean)
+ z = z * (original_mean / new_mean)
+
+ if pooled is not None:
+ z.pooled = pooled
return z
diff --git a/modules/sd_models.py b/modules/sd_models.py
index fb31a793..acb1e817 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -460,7 +460,6 @@ def get_empty_cond(sd_model):
return sd_model.cond_stage_model([""])
-
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
@@ -495,19 +494,24 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
sd_model = None
try:
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
- sd_model = instantiate_from_config(sd_config.model)
- except Exception:
- pass
+ with sd_disable_initialization.InitializeOnMeta():
+ sd_model = instantiate_from_config(sd_config.model)
+
+ except Exception as e:
+ errors.display(e, "creating model quickly", full_traceback=True)
if sd_model is None:
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
- sd_model = instantiate_from_config(sd_config.model)
+
+ with sd_disable_initialization.InitializeOnMeta():
+ sd_model = instantiate_from_config(sd_config.model)
sd_model.used_config = checkpoint_config
timer.record("create model")
- load_model_weights(sd_model, checkpoint_info, state_dict, timer)
+ with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
+ load_model_weights(sd_model, checkpoint_info, state_dict, timer)
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
diff --git a/modules/sd_samplers_extra.py b/modules/sd_samplers_extra.py
new file mode 100644
index 00000000..1b981ca8
--- /dev/null
+++ b/modules/sd_samplers_extra.py
@@ -0,0 +1,74 @@
+import torch
+import tqdm
+import k_diffusion.sampling
+
+
+@torch.no_grad()
+def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., restart_list=None):
+ """Implements restart sampling in Restart Sampling for Improving Generative Processes (2023)
+ Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]}
+ If restart_list is None: will choose restart_list automatically, otherwise will use the given restart_list
+ """
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ step_id = 0
+ from k_diffusion.sampling import to_d, get_sigmas_karras
+
+ def heun_step(x, old_sigma, new_sigma, second_order=True):
+ nonlocal step_id
+ denoised = model(x, old_sigma * s_in, **extra_args)
+ d = to_d(x, old_sigma, denoised)
+ if callback is not None:
+ callback({'x': x, 'i': step_id, 'sigma': new_sigma, 'sigma_hat': old_sigma, 'denoised': denoised})
+ dt = new_sigma - old_sigma
+ if new_sigma == 0 or not second_order:
+ # Euler method
+ x = x + d * dt
+ else:
+ # Heun's method
+ x_2 = x + d * dt
+ denoised_2 = model(x_2, new_sigma * s_in, **extra_args)
+ d_2 = to_d(x_2, new_sigma, denoised_2)
+ d_prime = (d + d_2) / 2
+ x = x + d_prime * dt
+ step_id += 1
+ return x
+
+ steps = sigmas.shape[0] - 1
+ if restart_list is None:
+ if steps >= 20:
+ restart_steps = 9
+ restart_times = 1
+ if steps >= 36:
+ restart_steps = steps // 4
+ restart_times = 2
+ sigmas = get_sigmas_karras(steps - restart_steps * restart_times, sigmas[-2].item(), sigmas[0].item(), device=sigmas.device)
+ restart_list = {0.1: [restart_steps + 1, restart_times, 2]}
+ else:
+ restart_list = {}
+
+ restart_list = {int(torch.argmin(abs(sigmas - key), dim=0)): value for key, value in restart_list.items()}
+
+ step_list = []
+ for i in range(len(sigmas) - 1):
+ step_list.append((sigmas[i], sigmas[i + 1]))
+ if i + 1 in restart_list:
+ restart_steps, restart_times, restart_max = restart_list[i + 1]
+ min_idx = i + 1
+ max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0))
+ if max_idx < min_idx:
+ sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1]
+ while restart_times > 0:
+ restart_times -= 1
+ step_list.extend([(old_sigma, new_sigma) for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:])])
+
+ last_sigma = None
+ for old_sigma, new_sigma in tqdm.tqdm(step_list, disable=disable):
+ if last_sigma is None:
+ last_sigma = old_sigma
+ elif last_sigma < old_sigma:
+ x = x + k_diffusion.sampling.torch.randn_like(x) * s_noise * (old_sigma ** 2 - last_sigma ** 2) ** 0.5
+ x = heun_step(x, old_sigma, new_sigma)
+ last_sigma = new_sigma
+
+ return x
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 5552a8dc..e0da3425 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -2,7 +2,7 @@ from collections import deque
import torch
import inspect
import k_diffusion.sampling
-from modules import prompt_parser, devices, sd_samplers_common
+from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra
from modules.shared import opts, state
import modules.shared as shared
@@ -30,12 +30,14 @@ samplers_k_diffusion = [
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
+ ('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras'}),
]
+
samplers_data_k_diffusion = [
sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
for label, funcname, aliases, options in samplers_k_diffusion
- if hasattr(k_diffusion.sampling, funcname)
+ if callable(funcname) or hasattr(k_diffusion.sampling, funcname)
]
sampler_extra_params = {
@@ -270,7 +272,7 @@ class KDiffusionSampler:
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
self.funcname = funcname
- self.func = getattr(k_diffusion.sampling, self.funcname)
+ self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
self.extra_params = sampler_extra_params.get(funcname, [])
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
self.sampler_noises = None
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 49612298..f2752f10 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -253,7 +253,7 @@ class ExtraNetworksPage:
"prompt": item.get("prompt", None),
"tabname": quote_js(tabname),
"local_preview": quote_js(item["local_preview"]),
- "name": item["name"],
+ "name": html.escape(item["name"]),
"description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""),
"card_clicked": onclick,
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"',
diff --git a/webui.py b/webui.py
index 2314735f..2dc4f1aa 100644
--- a/webui.py
+++ b/webui.py
@@ -320,9 +320,9 @@ def initialize_rest(*, reload_script_modules=False):
if modules.sd_hijack.current_optimizer is None:
modules.sd_hijack.apply_optimizations()
- Thread(target=load_model).start()
+ devices.first_time_calculation()
- Thread(target=devices.first_time_calculation).start()
+ Thread(target=load_model).start()
shared.reload_hypernetworks()
startup_timer.record("reload hypernetworks")
@@ -374,7 +374,7 @@ def api_only():
api.launch(
server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1",
port=cmd_opts.port if cmd_opts.port else 7861,
- root_path = f"/{cmd_opts.subpath}"
+ root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else ""
)