aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorMuhammad Rizqi Nur <rizqinur2010@gmail.com>2022-11-02 00:25:08 +0700
committerGitHub <noreply@github.com>2022-11-02 00:25:08 +0700
commitf8c6468d42e1202f7aeaeb961ab003aa0a2daf99 (patch)
treea2542ce9bd8bba1e8aa93acd510a12ca8a0b344f /modules
parent7c8c3715f552378cf81ad28f26fad92b37bd153d (diff)
parent198a1ffcfc963a3d74674fad560e87dbebf7949f (diff)
Merge branch 'master' into vae-picker
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py16
-rw-r--r--modules/extensions.py83
-rw-r--r--modules/extras.py2
-rw-r--r--modules/generation_parameters_copypaste.py5
-rw-r--r--modules/images.py5
-rw-r--r--modules/img2img.py1
-rw-r--r--modules/interrogate.py4
-rw-r--r--modules/lowvram.py21
-rw-r--r--modules/processing.py3
-rw-r--r--modules/safe.py2
-rw-r--r--modules/script_callbacks.py19
-rw-r--r--modules/scripts.py21
-rw-r--r--modules/sd_hijack.py4
-rw-r--r--modules/sd_models.py14
-rw-r--r--modules/sd_samplers.py28
-rw-r--r--modules/shared.py18
-rw-r--r--modules/textual_inversion/textual_inversion.py10
-rw-r--r--modules/textual_inversion/ui.py7
-rw-r--r--modules/ui.py16
-rw-r--r--modules/ui_extensions.py268
20 files changed, 490 insertions, 57 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 6c06d449..bb87d795 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,6 +1,8 @@
+import base64
+import io
import time
import uvicorn
-from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
+from gradio.processing_utils import decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, Depends, HTTPException
import modules.shared as shared
from modules import devices
@@ -29,6 +31,12 @@ def setUpscalers(req: dict):
return reqDict
+def encode_pil_to_base64(image):
+ buffer = io.BytesIO()
+ image.save(buffer, format="png")
+ return base64.b64encode(buffer.getvalue())
+
+
class Api:
def __init__(self, app, queue_lock):
self.router = APIRouter()
@@ -40,6 +48,7 @@ class Api:
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
+ self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -176,6 +185,11 @@ class Api:
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
+ def interruptapi(self):
+ shared.state.interrupt()
+
+ return {}
+
def launch(self, server_name, port):
self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port)
diff --git a/modules/extensions.py b/modules/extensions.py
new file mode 100644
index 00000000..897af96e
--- /dev/null
+++ b/modules/extensions.py
@@ -0,0 +1,83 @@
+import os
+import sys
+import traceback
+
+import git
+
+from modules import paths, shared
+
+
+extensions = []
+extensions_dir = os.path.join(paths.script_path, "extensions")
+
+
+def active():
+ return [x for x in extensions if x.enabled]
+
+
+class Extension:
+ def __init__(self, name, path, enabled=True):
+ self.name = name
+ self.path = path
+ self.enabled = enabled
+ self.status = ''
+ self.can_update = False
+
+ repo = None
+ try:
+ if os.path.exists(os.path.join(path, ".git")):
+ repo = git.Repo(path)
+ except Exception:
+ print(f"Error reading github repository info from {path}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ if repo is None or repo.bare:
+ self.remote = None
+ else:
+ self.remote = next(repo.remote().urls, None)
+ self.status = 'unknown'
+
+ def list_files(self, subdir, extension):
+ from modules import scripts
+
+ dirpath = os.path.join(self.path, subdir)
+ if not os.path.isdir(dirpath):
+ return []
+
+ res = []
+ for filename in sorted(os.listdir(dirpath)):
+ res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
+
+ res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
+
+ return res
+
+ def check_updates(self):
+ repo = git.Repo(self.path)
+ for fetch in repo.remote().fetch("--dry-run"):
+ if fetch.flags != fetch.HEAD_UPTODATE:
+ self.can_update = True
+ self.status = "behind"
+ return
+
+ self.can_update = False
+ self.status = "latest"
+
+ def pull(self):
+ repo = git.Repo(self.path)
+ repo.remotes.origin.pull()
+
+
+def list_extensions():
+ extensions.clear()
+
+ if not os.path.isdir(extensions_dir):
+ return
+
+ for dirname in sorted(os.listdir(extensions_dir)):
+ path = os.path.join(extensions_dir, dirname)
+ if not os.path.isdir(path):
+ continue
+
+ extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
+ extensions.append(extension)
diff --git a/modules/extras.py b/modules/extras.py
index 4d51088b..8e2ab35c 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -141,7 +141,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
upscaling_resize_w, upscaling_resize_h, upscaling_crop)
cache_key = LruCache.Key(image_hash=hash(np.array(image.getdata()).tobytes()),
info_hash=hash(info),
- args_hash=hash(upscale_args))
+ args_hash=hash((upscale_args, upscale_first)))
cached_entry = cached_images.get(cache_key)
if cached_entry is None:
res = upscale(image, *upscale_args)
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index df70c728..985ec95e 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -17,6 +17,11 @@ paste_fields = {}
bind_list = []
+def reset():
+ paste_fields.clear()
+ bind_list.clear()
+
+
def quote(text):
if ',' not in str(text):
return text
diff --git a/modules/images.py b/modules/images.py
index a0728553..ae705cbd 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -510,8 +510,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
if extension.lower() == '.png':
pnginfo_data = PngImagePlugin.PngInfo()
- for k, v in params.pnginfo.items():
- pnginfo_data.add_text(k, str(v))
+ if opts.enable_pnginfo:
+ for k, v in params.pnginfo.items():
+ pnginfo_data.add_text(k, str(v))
image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
diff --git a/modules/img2img.py b/modules/img2img.py
index efda26e1..35c5df9b 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -55,6 +55,7 @@ def process_batch(p, input_dir, output_dir, args):
filename = f"{left}-{n}{right}"
if not save_normally:
+ os.makedirs(output_dir, exist_ok=True)
processed_image.save(os.path.join(output_dir, filename))
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 65b05d34..9769aa34 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -56,9 +56,9 @@ class InterrogateModels:
import clip
if self.running_on_cpu:
- model, preprocess = clip.load(clip_model_name, device="cpu")
+ model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
else:
- model, preprocess = clip.load(clip_model_name)
+ model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
model.eval()
model = model.to(devices.device_interrogate)
diff --git a/modules/lowvram.py b/modules/lowvram.py
index f327c3df..a4652cb1 100644
--- a/modules/lowvram.py
+++ b/modules/lowvram.py
@@ -38,13 +38,18 @@ def setup_for_low_vram(sd_model, use_medvram):
# see below for register_forward_pre_hook;
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
# useless here, and we just replace those methods
- def first_stage_model_encode_wrap(self, encoder, x):
- send_me_to_gpu(self, None)
- return encoder(x)
- def first_stage_model_decode_wrap(self, decoder, z):
- send_me_to_gpu(self, None)
- return decoder(z)
+ first_stage_model = sd_model.first_stage_model
+ first_stage_model_encode = sd_model.first_stage_model.encode
+ first_stage_model_decode = sd_model.first_stage_model.decode
+
+ def first_stage_model_encode_wrap(x):
+ send_me_to_gpu(first_stage_model, None)
+ return first_stage_model_encode(x)
+
+ def first_stage_model_decode_wrap(z):
+ send_me_to_gpu(first_stage_model, None)
+ return first_stage_model_decode(z)
# remove three big modules, cond, first_stage, and unet from the model and then
# send the model to GPU. Then put modules back. the modules will be in CPU.
@@ -56,8 +61,8 @@ def setup_for_low_vram(sd_model, use_medvram):
# register hooks for those the first two models
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
- sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
- sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
+ sd_model.first_stage_model.encode = first_stage_model_encode_wrap
+ sd_model.first_stage_model.decode = first_stage_model_decode_wrap
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
if use_medvram:
diff --git a/modules/processing.py b/modules/processing.py
index b1df4918..57d3a523 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -597,6 +597,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None:
p.scripts.postprocess(p, res)
+ p.sd_model = None
+ p.sampler = None
+
return res
diff --git a/modules/safe.py b/modules/safe.py
index 399165a1..348a24fc 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -32,7 +32,7 @@ class RestrictedUnpickler(pickle.Unpickler):
return getattr(collections, name)
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
return getattr(torch._utils, name)
- if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage']:
+ if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage']:
return getattr(torch, name)
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
return getattr(torch.nn.modules.container, name)
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 6ea58d61..ce264690 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -3,6 +3,8 @@ import traceback
from collections import namedtuple
import inspect
+from fastapi import FastAPI
+from gradio import Blocks
def report_exception(c, job):
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
@@ -25,6 +27,7 @@ class ImageSaveParams:
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
+callbacks_app_started = []
callbacks_model_loaded = []
callbacks_ui_tabs = []
callbacks_ui_settings = []
@@ -40,6 +43,14 @@ def clear_callbacks():
callbacks_image_saved.clear()
+def app_started_callback(demo: Blocks, app: FastAPI):
+ for c in callbacks_app_started:
+ try:
+ c.callback(demo, app)
+ except Exception:
+ report_exception(c, 'app_started_callback')
+
+
def model_loaded_callback(sd_model):
for c in callbacks_model_loaded:
try:
@@ -69,7 +80,7 @@ def ui_settings_callback():
def before_image_saved_callback(params: ImageSaveParams):
- for c in callbacks_image_saved:
+ for c in callbacks_before_image_saved:
try:
c.callback(params)
except Exception:
@@ -91,6 +102,12 @@ def add_callback(callbacks, fun):
callbacks.append(ScriptCallback(filename, fun))
+def on_app_started(callback):
+ """register a function to be called when the webui started, the gradio `Block` component and
+ fastapi `FastAPI` object are passed as the arguments"""
+ add_callback(callbacks_app_started, callback)
+
+
def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is
passed as an argument"""
diff --git a/modules/scripts.py b/modules/scripts.py
index 96e44bfd..533db45c 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -7,7 +7,7 @@ import modules.ui as ui
import gradio as gr
from modules.processing import StableDiffusionProcessing
-from modules import shared, paths, script_callbacks
+from modules import shared, paths, script_callbacks, extensions
AlwaysVisible = object()
@@ -107,17 +107,8 @@ def list_scripts(scriptdirname, extension):
for filename in sorted(os.listdir(basedir)):
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
- extdir = os.path.join(paths.script_path, "extensions")
- if os.path.exists(extdir):
- for dirname in sorted(os.listdir(extdir)):
- dirpath = os.path.join(extdir, dirname)
- scriptdirpath = os.path.join(dirpath, scriptdirname)
-
- if not os.path.isdir(scriptdirpath):
- continue
-
- for filename in sorted(os.listdir(scriptdirpath)):
- scripts_list.append(ScriptFile(dirpath, filename, os.path.join(scriptdirpath, filename)))
+ for ext in extensions.active():
+ scripts_list += ext.list_files(scriptdirname, extension)
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
@@ -127,11 +118,7 @@ def list_scripts(scriptdirname, extension):
def list_files_with_name(filename):
res = []
- dirs = [paths.script_path]
-
- extdir = os.path.join(paths.script_path, "extensions")
- if os.path.exists(extdir):
- dirs += [os.path.join(extdir, d) for d in sorted(os.listdir(extdir))]
+ dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
for dirpath in dirs:
if not os.path.isdir(dirpath):
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 0f10828e..bc49d235 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -94,6 +94,10 @@ class StableDiffusionModelHijack:
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
+ self.layers = None
+ self.circular_enabled = False
+ self.clip = None
+
def apply_circular(self, enable):
if self.circular_enabled == enable:
return
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 850f7b7b..6ab85b65 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -1,6 +1,7 @@
import collections
import os.path
import sys
+import gc
from collections import namedtuple
import torch
import re
@@ -214,6 +215,12 @@ def load_model(checkpoint_info=None):
if checkpoint_info.config != shared.cmd_opts.config:
print(f"Loading config from: {checkpoint_info.config}")
+ if shared.sd_model:
+ sd_hijack.model_hijack.undo_hijack(shared.sd_model)
+ shared.sd_model = None
+ gc.collect()
+ devices.torch_gc()
+
sd_config = OmegaConf.load(checkpoint_info.config)
if should_hijack_inpainting(checkpoint_info):
@@ -227,6 +234,7 @@ def load_model(checkpoint_info=None):
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
do_inpainting_hijack()
+
sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info)
@@ -246,14 +254,18 @@ def load_model(checkpoint_info=None):
return sd_model
-def reload_model_weights(sd_model, info=None, force=False):
+def reload_model_weights(sd_model=None, info=None, force=False):
from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint()
+
+ if not sd_model:
+ sd_model = shared.sd_model
if sd_model.sd_model_checkpoint == checkpoint_info.filename and not force:
return
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
+ del sd_model
checkpoints_loaded.clear()
load_model(checkpoint_info)
return shared.sd_model
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 3670b57d..8772db56 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -1,5 +1,6 @@
from collections import namedtuple
import numpy as np
+from math import floor
import torch
import tqdm
from PIL import Image
@@ -205,17 +206,22 @@ class VanillaStableDiffusionSampler:
self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None
+
+ def adjust_steps_if_invalid(self, p, num_steps):
+ if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
+ valid_step = 999 / (1000 // num_steps)
+ if valid_step == floor(valid_step):
+ return int(valid_step) + 1
+
+ return num_steps
+
+
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps)
-
+ steps = self.adjust_steps_if_invalid(p, steps)
self.initialize(p)
- # existing code fails with certain step counts, like 9
- try:
- self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
- except Exception:
- self.sampler.make_schedule(ddim_num_steps=steps+1, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
-
+ self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
self.init_latent = x
@@ -239,18 +245,14 @@ class VanillaStableDiffusionSampler:
self.last_latent = x
self.step = 0
- steps = steps or p.steps
+ steps = self.adjust_steps_if_invalid(p, steps or p.steps)
# Wrap the conditioning models with additional image conditioning for inpainting model
if image_conditioning is not None:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
- # existing code fails with certain step counts, like 9
- try:
- samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
- except Exception:
- samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
+ samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
return samples_ddim
diff --git a/modules/shared.py b/modules/shared.py
index 06440ac4..cbef5c43 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -40,7 +40,7 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
-parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
+parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
@@ -51,6 +51,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
+parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator")
@@ -97,6 +98,8 @@ restricted_opts = {
"outdir_save",
}
+cmd_opts.disable_extension_access = cmd_opts.share or cmd_opts.listen
+
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'])
@@ -132,6 +135,7 @@ class State:
current_image = None
current_image_sampling_step = 0
textinfo = None
+ need_restart = False
def skip(self):
self.skipped = True
@@ -285,11 +289,12 @@ options_templates.update(options_section(('system', "System"), {
}))
options_templates.update(options_section(('training', "Training"), {
- "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."),
+ "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
+ "training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
}))
options_templates.update(options_section(('sd', "Stable Diffusion"), {
@@ -355,6 +360,12 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
}))
+options_templates.update(options_section((None, "Hidden options"), {
+ "disabled_extensions": OptionInfo([], "Disable those extensions"),
+}))
+
+options_templates.update()
+
class Options:
data = None
@@ -366,8 +377,9 @@ class Options:
def __setattr__(self, key, value):
if self.data is not None:
- if key in self.data:
+ if key in self.data or key in self.data_labels:
self.data[key] = value
+ return
return super(Options, self).__setattr__(key, value)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index e0babb46..0aeb0459 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -235,6 +235,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
+ unload = shared.opts.unload_models_when_training
if save_embedding_every > 0:
embedding_dir = os.path.join(log_directory, "embeddings")
@@ -272,6 +273,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
+ if unload:
+ shared.sd_model.first_stage_model.to(devices.cpu)
embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
@@ -328,6 +331,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{embedding_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
+
+ shared.sd_model.first_stage_model.to(devices.device)
+
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
do_not_save_grid=True,
@@ -355,6 +361,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
processed = processing.process_images(p)
image = processed.images[0]
+ if unload:
+ shared.sd_model.first_stage_model.to(devices.cpu)
+
shared.state.current_image = image
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
@@ -400,6 +409,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
+ shared.sd_model.first_stage_model.to(devices.device)
return embedding, filename
diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
index e712284d..d679e6f4 100644
--- a/modules/textual_inversion/ui.py
+++ b/modules/textual_inversion/ui.py
@@ -25,8 +25,10 @@ def train_embedding(*args):
assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
+ apply_optimizations = shared.opts.training_xattention_optimizations
try:
- sd_hijack.undo_optimizations()
+ if not apply_optimizations:
+ sd_hijack.undo_optimizations()
embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
@@ -38,5 +40,6 @@ Embedding saved to {html.escape(filename)}
except Exception:
raise
finally:
- sd_hijack.apply_optimizations()
+ if not apply_optimizations:
+ sd_hijack.apply_optimizations()
diff --git a/modules/ui.py b/modules/ui.py
index 5055ca64..2c15abb7 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -19,7 +19,7 @@ import numpy as np
from PIL import Image, PngImagePlugin
-from modules import sd_hijack, sd_models, localization, script_callbacks
+from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions
from modules.paths import script_path
from modules.shared import opts, cmd_opts, restricted_opts
@@ -671,6 +671,7 @@ def create_ui(wrap_gradio_gpu_call):
import modules.img2img
import modules.txt2img
+ parameters_copypaste.reset()
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
@@ -1511,8 +1512,9 @@ def create_ui(wrap_gradio_gpu_call):
column = None
with gr.Row(elem_id="settings").style(equal_height=False):
for i, (k, item) in enumerate(opts.data_labels.items()):
+ section_must_be_skipped = item.section[0] is None
- if previous_section != item.section:
+ if previous_section != item.section and not section_must_be_skipped:
if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None):
if column is not None:
column.__exit__()
@@ -1531,6 +1533,8 @@ def create_ui(wrap_gradio_gpu_call):
if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
quicksettings_list.append((i, k, item))
components.append(dummy_component)
+ elif section_must_be_skipped:
+ components.append(dummy_component)
else:
component = create_setting_component(k)
component_dict[k] = component
@@ -1572,9 +1576,10 @@ def create_ui(wrap_gradio_gpu_call):
def request_restart():
shared.state.interrupt()
- settings_interface.gradio_ref.do_restart = True
+ shared.state.need_restart = True
restart_gradio.click(
+
fn=request_restart,
inputs=[],
outputs=[],
@@ -1612,14 +1617,15 @@ def create_ui(wrap_gradio_gpu_call):
interfaces += script_callbacks.ui_tabs_callback()
interfaces += [(settings_interface, "Settings", "settings")]
+ extensions_interface = ui_extensions.create_ui()
+ interfaces += [(extensions_interface, "Extensions", "extensions")]
+
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings"):
for i, k, item in quicksettings_list:
component = create_setting_component(k, is_quicksettings=True)
component_dict[k] = component
- settings_interface.gradio_ref = demo
-
parameters_copypaste.integrate_settings_paste_fields(component_dict)
parameters_copypaste.run_bind()
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
new file mode 100644
index 00000000..ab807722
--- /dev/null
+++ b/modules/ui_extensions.py
@@ -0,0 +1,268 @@
+import json
+import os.path
+import shutil
+import sys
+import time
+import traceback
+
+import git
+
+import gradio as gr
+import html
+
+from modules import extensions, shared, paths
+
+
+available_extensions = {"extensions": []}
+
+
+def check_access():
+ assert not shared.cmd_opts.disable_extension_access, "extension access disabed because of commandline flags"
+
+
+def apply_and_restart(disable_list, update_list):
+ check_access()
+
+ disabled = json.loads(disable_list)
+ assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
+
+ update = json.loads(update_list)
+ assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}"
+
+ update = set(update)
+
+ for ext in extensions.extensions:
+ if ext.name not in update:
+ continue
+
+ try:
+ ext.pull()
+ except Exception:
+ print(f"Error pulling updates for {ext.name}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ shared.opts.disabled_extensions = disabled
+ shared.opts.save(shared.config_filename)
+
+ shared.state.interrupt()
+ shared.state.need_restart = True
+
+
+def check_updates():
+ check_access()
+
+ for ext in extensions.extensions:
+ if ext.remote is None:
+ continue
+
+ try:
+ ext.check_updates()
+ except Exception:
+ print(f"Error checking updates for {ext.name}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ return extension_table()
+
+
+def extension_table():
+ code = f"""<!-- {time.time()} -->
+ <table id="extensions">
+ <thead>
+ <tr>
+ <th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th>
+ <th>URL</th>
+ <th><abbr title="Use checkbox to mark the extension for update; it will be updated when you click apply button">Update</abbr></th>
+ </tr>
+ </thead>
+ <tbody>
+ """
+
+ for ext in extensions.extensions:
+ if ext.can_update:
+ ext_status = f"""<label><input class="gr-check-radio gr-checkbox" name="update_{html.escape(ext.name)}" checked="checked" type="checkbox">{html.escape(ext.status)}</label>"""
+ else:
+ ext_status = ext.status
+
+ code += f"""
+ <tr>
+ <td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
+ <td><a href="{html.escape(ext.remote or '')}">{html.escape(ext.remote or '')}</a></td>
+ <td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
+ </tr>
+ """
+
+ code += """
+ </tbody>
+ </table>
+ """
+
+ return code
+
+
+def normalize_git_url(url):
+ if url is None:
+ return ""
+
+ url = url.replace(".git", "")
+ return url
+
+
+def install_extension_from_url(dirname, url):
+ check_access()
+
+ assert url, 'No URL specified'
+
+ if dirname is None or dirname == "":
+ *parts, last_part = url.split('/')
+ last_part = normalize_git_url(last_part)
+
+ dirname = last_part
+
+ target_dir = os.path.join(extensions.extensions_dir, dirname)
+ assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
+
+ normalized_url = normalize_git_url(url)
+ assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed'
+
+ tmpdir = os.path.join(paths.script_path, "tmp", dirname)
+
+ try:
+ shutil.rmtree(tmpdir, True)
+
+ repo = git.Repo.clone_from(url, tmpdir)
+ repo.remote().fetch()
+
+ os.rename(tmpdir, target_dir)
+
+ extensions.list_extensions()
+ return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")]
+ finally:
+ shutil.rmtree(tmpdir, True)
+
+
+def install_extension_from_index(url):
+ ext_table, message = install_extension_from_url(None, url)
+
+ return refresh_available_extensions_from_data(), ext_table, message
+
+
+def refresh_available_extensions(url):
+ global available_extensions
+
+ import urllib.request
+ with urllib.request.urlopen(url) as response:
+ text = response.read()
+
+ available_extensions = json.loads(text)
+
+ return url, refresh_available_extensions_from_data(), ''
+
+
+def refresh_available_extensions_from_data():
+ extlist = available_extensions["extensions"]
+ installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
+
+ code = f"""<!-- {time.time()} -->
+ <table id="available_extensions">
+ <thead>
+ <tr>
+ <th>Extension</th>
+ <th>Description</th>
+ <th>Action</th>
+ </tr>
+ </thead>
+ <tbody>
+ """
+
+ for ext in extlist:
+ name = ext.get("name", "noname")
+ url = ext.get("url", None)
+ description = ext.get("description", "")
+
+ if url is None:
+ continue
+
+ existing = installed_extension_urls.get(normalize_git_url(url), None)
+
+ install_code = f"""<input onclick="install_extension_from_index(this, '{html.escape(url)}')" type="button" value="{"Install" if not existing else "Installed"}" {"disabled=disabled" if existing else ""} class="gr-button gr-button-lg gr-button-secondary">"""
+
+ code += f"""
+ <tr>
+ <td><a href="{html.escape(url)}">{html.escape(name)}</a></td>
+ <td>{html.escape(description)}</td>
+ <td>{install_code}</td>
+ </tr>
+ """
+
+ code += """
+ </tbody>
+ </table>
+ """
+
+ return code
+
+
+def create_ui():
+ import modules.ui
+
+ with gr.Blocks(analytics_enabled=False) as ui:
+ with gr.Tabs(elem_id="tabs_extensions") as tabs:
+ with gr.TabItem("Installed"):
+
+ with gr.Row():
+ apply = gr.Button(value="Apply and restart UI", variant="primary")
+ check = gr.Button(value="Check for updates")
+ extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
+ extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False)
+
+ extensions_table = gr.HTML(lambda: extension_table())
+
+ apply.click(
+ fn=apply_and_restart,
+ _js="extensions_apply",
+ inputs=[extensions_disabled_list, extensions_update_list],
+ outputs=[],
+ )
+
+ check.click(
+ fn=check_updates,
+ _js="extensions_check",
+ inputs=[],
+ outputs=[extensions_table],
+ )
+
+ with gr.TabItem("Available"):
+ with gr.Row():
+ refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
+ available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/wiki/AUTOMATIC1111/stable-diffusion-webui/Extensions-index.md", label="Extension index URL").style(container=False)
+ extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
+ install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
+
+ install_result = gr.HTML()
+ available_extensions_table = gr.HTML()
+
+ refresh_available_extensions_button.click(
+ fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update()]),
+ inputs=[available_extensions_index],
+ outputs=[available_extensions_index, available_extensions_table, install_result],
+ )
+
+ install_extension_button.click(
+ fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
+ inputs=[extension_to_install],
+ outputs=[available_extensions_table, extensions_table, install_result],
+ )
+
+ with gr.TabItem("Install from URL"):
+ install_url = gr.Text(label="URL for extension's git repository")
+ install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
+ install_button = gr.Button(value="Install", variant="primary")
+ install_result = gr.HTML(elem_id="extension_install_result")
+
+ install_button.click(
+ fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
+ inputs=[install_dirname, install_url],
+ outputs=[extensions_table, install_result],
+ )
+
+ return ui