aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--extensions-builtin/LDSR/ldsr_model_arch.py (renamed from modules/ldsr_model_arch.py)0
-rw-r--r--extensions-builtin/LDSR/preload.py6
-rw-r--r--extensions-builtin/LDSR/scripts/ldsr_model.py (renamed from modules/ldsr_model.py)14
-rw-r--r--extensions-builtin/LDSR/sd_hijack_autoencoder.py286
-rw-r--r--extensions-builtin/ScuNET/preload.py6
-rw-r--r--extensions-builtin/ScuNET/scripts/scunet_model.py (renamed from modules/scunet_model.py)6
-rw-r--r--extensions-builtin/ScuNET/scunet_model_arch.py (renamed from modules/scunet_model_arch.py)0
-rw-r--r--extensions-builtin/SwinIR/preload.py6
-rw-r--r--extensions-builtin/SwinIR/scripts/swinir_model.py (renamed from modules/swinir_model.py)33
-rw-r--r--extensions-builtin/SwinIR/swinir_model_arch.py (renamed from modules/swinir_model_arch.py)0
-rw-r--r--extensions-builtin/SwinIR/swinir_model_arch_v2.py (renamed from modules/swinir_model_arch_v2.py)0
-rw-r--r--javascript/hints.js2
-rw-r--r--javascript/progressbar.js16
-rw-r--r--modules/api/api.py7
-rw-r--r--modules/deepbooru.py2
-rw-r--r--modules/devices.py33
-rw-r--r--modules/extensions.py22
-rw-r--r--modules/generation_parameters_copypaste.py4
-rw-r--r--modules/hypernetworks/hypernetwork.py7
-rw-r--r--modules/img2img.py30
-rw-r--r--modules/interrogate.py16
-rw-r--r--modules/modelloader.py20
-rw-r--r--modules/processing.py6
-rw-r--r--modules/safe.py18
-rw-r--r--modules/sd_hijack.py7
-rw-r--r--modules/sd_samplers.py22
-rw-r--r--modules/shared.py16
-rw-r--r--modules/textual_inversion/autocrop.py6
-rw-r--r--modules/textual_inversion/dataset.py4
-rw-r--r--modules/textual_inversion/textual_inversion.py5
-rw-r--r--modules/ui.py16
-rw-r--r--modules/ui_extensions.py10
-rw-r--r--scripts/prompt_matrix.py2
-rw-r--r--scripts/xy_grid.py2
-rw-r--r--webui.py5
35 files changed, 521 insertions, 114 deletions
diff --git a/modules/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py
index 90e0a2f0..90e0a2f0 100644
--- a/modules/ldsr_model_arch.py
+++ b/extensions-builtin/LDSR/ldsr_model_arch.py
diff --git a/extensions-builtin/LDSR/preload.py b/extensions-builtin/LDSR/preload.py
new file mode 100644
index 00000000..d746007c
--- /dev/null
+++ b/extensions-builtin/LDSR/preload.py
@@ -0,0 +1,6 @@
+import os
+from modules import paths
+
+
+def preload(parser):
+ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR'))
diff --git a/modules/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py
index 8c4db44a..1cef29a4 100644
--- a/modules/ldsr_model.py
+++ b/extensions-builtin/LDSR/scripts/ldsr_model.py
@@ -5,8 +5,9 @@ import traceback
from basicsr.utils.download_util import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData
-from modules.ldsr_model_arch import LDSR
-from modules import shared
+from ldsr_model_arch import LDSR
+from modules import shared, script_callbacks
+import sd_hijack_autoencoder
class UpscalerLDSR(Upscaler):
@@ -52,3 +53,12 @@ class UpscalerLDSR(Upscaler):
return img
ddim_steps = shared.opts.ldsr_steps
return ldsr.super_resolution(img, ddim_steps, self.scale)
+
+
+def on_ui_settings():
+ import gradio as gr
+
+ shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling")))
+
+
+script_callbacks.on_ui_settings(on_ui_settings)
diff --git a/extensions-builtin/LDSR/sd_hijack_autoencoder.py b/extensions-builtin/LDSR/sd_hijack_autoencoder.py
new file mode 100644
index 00000000..8e03c7f8
--- /dev/null
+++ b/extensions-builtin/LDSR/sd_hijack_autoencoder.py
@@ -0,0 +1,286 @@
+# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
+# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
+# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
+
+import torch
+import pytorch_lightning as pl
+import torch.nn.functional as F
+from contextlib import contextmanager
+from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
+from ldm.modules.diffusionmodules.model import Encoder, Decoder
+from ldm.util import instantiate_from_config
+
+import ldm.models.autoencoder
+
+class VQModel(pl.LightningModule):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ n_embed,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ batch_resize_range=None,
+ scheduler_config=None,
+ lr_g_factor=1.0,
+ remap=None,
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
+ use_ema=False
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.n_embed = n_embed
+ self.image_key = image_key
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ self.loss = instantiate_from_config(lossconfig)
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
+ remap=remap,
+ sane_index_shape=sane_index_shape)
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ if colorize_nlabels is not None:
+ assert type(colorize_nlabels)==int
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+ if monitor is not None:
+ self.monitor = monitor
+ self.batch_resize_range = batch_resize_range
+ if self.batch_resize_range is not None:
+ print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
+
+ self.use_ema = use_ema
+ if self.use_ema:
+ self.model_ema = LitEma(self)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+ self.scheduler_config = scheduler_config
+ self.lr_g_factor = lr_g_factor
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.parameters())
+ self.model_ema.copy_to(self)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ missing, unexpected = self.load_state_dict(sd, strict=False)
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ print(f"Unexpected Keys: {unexpected}")
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self)
+
+ def encode(self, x):
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ quant, emb_loss, info = self.quantize(h)
+ return quant, emb_loss, info
+
+ def encode_to_prequant(self, x):
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ return h
+
+ def decode(self, quant):
+ quant = self.post_quant_conv(quant)
+ dec = self.decoder(quant)
+ return dec
+
+ def decode_code(self, code_b):
+ quant_b = self.quantize.embed_code(code_b)
+ dec = self.decode(quant_b)
+ return dec
+
+ def forward(self, input, return_pred_indices=False):
+ quant, diff, (_,_,ind) = self.encode(input)
+ dec = self.decode(quant)
+ if return_pred_indices:
+ return dec, diff, ind
+ return dec, diff
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
+ if self.batch_resize_range is not None:
+ lower_size = self.batch_resize_range[0]
+ upper_size = self.batch_resize_range[1]
+ if self.global_step <= 4:
+ # do the first few batches with max size to avoid later oom
+ new_resize = upper_size
+ else:
+ new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
+ if new_resize != x.shape[2]:
+ x = F.interpolate(x, size=new_resize, mode="bicubic")
+ x = x.detach()
+ return x
+
+ def training_step(self, batch, batch_idx, optimizer_idx):
+ # https://github.com/pytorch/pytorch/issues/37142
+ # try not to fool the heuristics
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss, ind = self(x, return_pred_indices=True)
+
+ if optimizer_idx == 0:
+ # autoencode
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train",
+ predicted_indices=ind)
+
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return aeloss
+
+ if optimizer_idx == 1:
+ # discriminator
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return discloss
+
+ def validation_step(self, batch, batch_idx):
+ log_dict = self._validation_step(batch, batch_idx)
+ with self.ema_scope():
+ log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
+ return log_dict
+
+ def _validation_step(self, batch, batch_idx, suffix=""):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss, ind = self(x, return_pred_indices=True)
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="val"+suffix,
+ predicted_indices=ind
+ )
+
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="val"+suffix,
+ predicted_indices=ind
+ )
+ rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
+ self.log(f"val{suffix}/rec_loss", rec_loss,
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
+ self.log(f"val{suffix}/aeloss", aeloss,
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
+ if version.parse(pl.__version__) >= version.parse('1.4.0'):
+ del log_dict_ae[f"val{suffix}/rec_loss"]
+ self.log_dict(log_dict_ae)
+ self.log_dict(log_dict_disc)
+ return self.log_dict
+
+ def configure_optimizers(self):
+ lr_d = self.learning_rate
+ lr_g = self.lr_g_factor*self.learning_rate
+ print("lr_d", lr_d)
+ print("lr_g", lr_g)
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
+ list(self.decoder.parameters())+
+ list(self.quantize.parameters())+
+ list(self.quant_conv.parameters())+
+ list(self.post_quant_conv.parameters()),
+ lr=lr_g, betas=(0.5, 0.9))
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+ lr=lr_d, betas=(0.5, 0.9))
+
+ if self.scheduler_config is not None:
+ scheduler = instantiate_from_config(self.scheduler_config)
+
+ print("Setting up LambdaLR scheduler...")
+ scheduler = [
+ {
+ 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
+ 'interval': 'step',
+ 'frequency': 1
+ },
+ {
+ 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
+ 'interval': 'step',
+ 'frequency': 1
+ },
+ ]
+ return [opt_ae, opt_disc], scheduler
+ return [opt_ae, opt_disc], []
+
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
+
+ def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ if only_inputs:
+ log["inputs"] = x
+ return log
+ xrec, _ = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+ log["inputs"] = x
+ log["reconstructions"] = xrec
+ if plot_ema:
+ with self.ema_scope():
+ xrec_ema, _ = self(x)
+ if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
+ log["reconstructions_ema"] = xrec_ema
+ return log
+
+ def to_rgb(self, x):
+ assert self.image_key == "segmentation"
+ if not hasattr(self, "colorize"):
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+ x = F.conv2d(x, weight=self.colorize)
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
+ return x
+
+
+class VQModelInterface(VQModel):
+ def __init__(self, embed_dim, *args, **kwargs):
+ super().__init__(embed_dim=embed_dim, *args, **kwargs)
+ self.embed_dim = embed_dim
+
+ def encode(self, x):
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ return h
+
+ def decode(self, h, force_not_quantize=False):
+ # also go through quantization layer
+ if not force_not_quantize:
+ quant, emb_loss, info = self.quantize(h)
+ else:
+ quant = h
+ quant = self.post_quant_conv(quant)
+ dec = self.decoder(quant)
+ return dec
+
+setattr(ldm.models.autoencoder, "VQModel", VQModel)
+setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface)
diff --git a/extensions-builtin/ScuNET/preload.py b/extensions-builtin/ScuNET/preload.py
new file mode 100644
index 00000000..f12c5b90
--- /dev/null
+++ b/extensions-builtin/ScuNET/preload.py
@@ -0,0 +1,6 @@
+import os
+from modules import paths
+
+
+def preload(parser):
+ parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(paths.models_path, 'ScuNET'))
diff --git a/modules/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py
index 52360241..e0fbf3a3 100644
--- a/modules/scunet_model.py
+++ b/extensions-builtin/ScuNET/scripts/scunet_model.py
@@ -9,7 +9,7 @@ from basicsr.utils.download_util import load_file_from_url
import modules.upscaler
from modules import devices, modelloader
-from modules.scunet_model_arch import SCUNet as net
+from scunet_model_arch import SCUNet as net
class UpscalerScuNET(modules.upscaler.Upscaler):
@@ -49,7 +49,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
if model is None:
return img
- device = devices.device_scunet
+ device = devices.get_device_for('scunet')
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
@@ -66,7 +66,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
return PIL.Image.fromarray(output, 'RGB')
def load_model(self, path: str):
- device = devices.device_scunet
+ device = devices.get_device_for('scunet')
if "http" in path:
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
progress=True)
diff --git a/modules/scunet_model_arch.py b/extensions-builtin/ScuNET/scunet_model_arch.py
index 43ca8d36..43ca8d36 100644
--- a/modules/scunet_model_arch.py
+++ b/extensions-builtin/ScuNET/scunet_model_arch.py
diff --git a/extensions-builtin/SwinIR/preload.py b/extensions-builtin/SwinIR/preload.py
new file mode 100644
index 00000000..567e44bc
--- /dev/null
+++ b/extensions-builtin/SwinIR/preload.py
@@ -0,0 +1,6 @@
+import os
+from modules import paths
+
+
+def preload(parser):
+ parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(paths.models_path, 'SwinIR'))
diff --git a/modules/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py
index facd262d..782769e2 100644
--- a/modules/swinir_model.py
+++ b/extensions-builtin/SwinIR/scripts/swinir_model.py
@@ -7,15 +7,14 @@ from PIL import Image
from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm
-from modules import modelloader, devices
+from modules import modelloader, devices, script_callbacks, shared
from modules.shared import cmd_opts, opts
-from modules.swinir_model_arch import SwinIR as net
-from modules.swinir_model_arch_v2 import Swin2SR as net2
+from swinir_model_arch import SwinIR as net
+from swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData
-precision_scope = (
- torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
-)
+
+device_swinir = devices.get_device_for('swinir')
class UpscalerSwinIR(Upscaler):
@@ -42,7 +41,7 @@ class UpscalerSwinIR(Upscaler):
model = self.load_model(model_file)
if model is None:
return img
- model = model.to(devices.device_swinir)
+ model = model.to(device_swinir, dtype=devices.dtype)
img = upscale(img, model)
try:
torch.cuda.empty_cache()
@@ -94,8 +93,6 @@ class UpscalerSwinIR(Upscaler):
model.load_state_dict(pretrained_model[params], strict=True)
else:
model.load_state_dict(pretrained_model, strict=True)
- if not cmd_opts.no_half:
- model = model.half()
return model
@@ -111,8 +108,8 @@ def upscale(
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(devices.device_swinir)
- with torch.no_grad(), precision_scope("cuda"):
+ img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype)
+ with torch.no_grad(), devices.autocast():
_, _, h_old, w_old = img.size()
h_pad = (h_old // window_size + 1) * window_size - h_old
w_pad = (w_old // window_size + 1) * window_size - w_old
@@ -139,8 +136,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
- E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=devices.device_swinir).type_as(img)
- W = torch.zeros_like(E, dtype=torch.half, device=devices.device_swinir)
+ E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img)
+ W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list:
@@ -159,3 +156,13 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
output = E.div_(W)
return output
+
+
+def on_ui_settings():
+ import gradio as gr
+
+ shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
+ shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
+
+
+script_callbacks.on_ui_settings(on_ui_settings)
diff --git a/modules/swinir_model_arch.py b/extensions-builtin/SwinIR/swinir_model_arch.py
index 863f42db..863f42db 100644
--- a/modules/swinir_model_arch.py
+++ b/extensions-builtin/SwinIR/swinir_model_arch.py
diff --git a/modules/swinir_model_arch_v2.py b/extensions-builtin/SwinIR/swinir_model_arch_v2.py
index 0e28ae6e..0e28ae6e 100644
--- a/modules/swinir_model_arch_v2.py
+++ b/extensions-builtin/SwinIR/swinir_model_arch_v2.py
diff --git a/javascript/hints.js b/javascript/hints.js
index ac417ff6..57db35be 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -94,6 +94,8 @@ titles = {
"Add difference": "Result = A + (B - C) * M",
"Learning rate": "how fast should the training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
+
+ "Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc."
}
diff --git a/javascript/progressbar.js b/javascript/progressbar.js
index 43d1d1ce..d58737c4 100644
--- a/javascript/progressbar.js
+++ b/javascript/progressbar.js
@@ -92,14 +92,26 @@ function check_gallery(id_gallery){
if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) {
// automatically re-open previously selected index (if exists)
activeElement = gradioApp().activeElement;
+ let scrollX = window.scrollX;
+ let scrollY = window.scrollY;
galleryButtons[prevSelectedIndex].click();
showGalleryImage();
+ // When the gallery button is clicked, it gains focus and scrolls itself into view
+ // We need to scroll back to the previous position
+ setTimeout(function (){
+ window.scrollTo(scrollX, scrollY);
+ }, 50);
+
if(activeElement){
// i fought this for about an hour; i don't know why the focus is lost or why this helps recover it
- // if somenoe has a better solution please by all means
- setTimeout(function() { activeElement.focus() }, 1);
+ // if someone has a better solution please by all means
+ setTimeout(function (){
+ activeElement.focus({
+ preventScroll: true // Refocus the element that was focused before the gallery was opened without scrolling to it
+ })
+ }, 1);
}
}
})
diff --git a/modules/api/api.py b/modules/api/api.py
index 1de3f98f..54ee7cb0 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -152,7 +152,10 @@ class Api:
)
if populate.sampler_name:
populate.sampler_index = None # prevent a warning later on
- p = StableDiffusionProcessingImg2Img(**vars(populate))
+
+ args = vars(populate)
+ args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
+ p = StableDiffusionProcessingImg2Img(**args)
imgs = []
for img in init_images:
@@ -170,7 +173,7 @@ class Api:
b64images = list(map(encode_pil_to_base64, processed.images))
- if (not img2imgreq.include_init_images):
+ if not img2imgreq.include_init_images:
img2imgreq.init_images = None
img2imgreq.mask = None
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
index 31ec7e17..dfc83357 100644
--- a/modules/deepbooru.py
+++ b/modules/deepbooru.py
@@ -21,7 +21,7 @@ class DeepDanbooru:
files = modelloader.load_models(
model_path=os.path.join(paths.models_path, "torch_deepdanbooru"),
model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',
- ext_filter=".pt",
+ ext_filter=[".pt"],
download_name='model-resnet_custom_v3.pt',
)
diff --git a/modules/devices.py b/modules/devices.py
index f00079c6..f8cffae1 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -44,6 +44,15 @@ def get_optimal_device():
return cpu
+def get_device_for(task):
+ from modules import shared
+
+ if task in shared.cmd_opts.use_cpu:
+ return cpu
+
+ return get_optimal_device()
+
+
def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(get_cuda_device_string()):
@@ -53,37 +62,35 @@ def torch_gc():
def enable_tf32():
if torch.cuda.is_available():
+
+ # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
+ # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
+ if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]):
+ torch.backends.cudnn.benchmark = True
+
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
+
errors.run(enable_tf32, "Enabling TF32")
cpu = torch.device("cpu")
-device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None
+device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
dtype = torch.float16
dtype_vae = torch.float16
def randn(seed, shape):
- # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
- if device.type == 'mps':
- generator = torch.Generator(device=cpu)
- generator.manual_seed(seed)
- noise = torch.randn(shape, generator=generator, device=cpu).to(device)
- return noise
-
torch.manual_seed(seed)
+ if device.type == 'mps':
+ return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)
def randn_without_seed(shape):
- # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
if device.type == 'mps':
- generator = torch.Generator(device=cpu)
- noise = torch.randn(shape, generator=generator, device=cpu).to(device)
- return noise
-
+ return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)
diff --git a/modules/extensions.py b/modules/extensions.py
index db9c4200..b522125c 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -8,6 +8,7 @@ from modules import paths, shared
extensions = []
extensions_dir = os.path.join(paths.script_path, "extensions")
+extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
def active():
@@ -15,12 +16,13 @@ def active():
class Extension:
- def __init__(self, name, path, enabled=True):
+ def __init__(self, name, path, enabled=True, is_builtin=False):
self.name = name
self.path = path
self.enabled = enabled
self.status = ''
self.can_update = False
+ self.is_builtin = is_builtin
repo = None
try:
@@ -79,11 +81,19 @@ def list_extensions():
if not os.path.isdir(extensions_dir):
return
- for dirname in sorted(os.listdir(extensions_dir)):
- path = os.path.join(extensions_dir, dirname)
- if not os.path.isdir(path):
- continue
+ paths = []
+ for dirname in [extensions_dir, extensions_builtin_dir]:
+ if not os.path.isdir(dirname):
+ return
- extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
+ for extension_dirname in sorted(os.listdir(dirname)):
+ path = os.path.join(dirname, extension_dirname)
+ if not os.path.isdir(path):
+ continue
+
+ paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
+
+ for dirname, path, is_builtin in paths:
+ extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
extensions.append(extension)
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 01980dca..44fe1a6c 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -184,6 +184,10 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
else:
res[k] = v
+ # Missing CLIP skip means it was set to 1 (the default)
+ if "Clip skip" not in res:
+ res["Clip skip"] = "1"
+
return res
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 8466887f..c406ffb3 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -433,7 +433,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
+ old_parallel_processing_allowed = shared.parallel_processing_allowed
+
if unload:
+ shared.parallel_processing_allowed = False
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
@@ -495,7 +498,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
if shared.state.interrupted:
break
- with torch.autocast("cuda"):
+ with devices.autocast():
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
if tag_drop_out != 0 or shuffle_tags:
shared.sd_model.cond_stage_model.to(devices.device)
@@ -612,10 +615,12 @@ Last saved image: {html.escape(last_saved_image)}<br/>
if shared.opts.save_optimizer_state:
hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
+
del optimizer
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
+ shared.parallel_processing_allowed = old_parallel_processing_allowed
return hypernetwork, filename
diff --git a/modules/img2img.py b/modules/img2img.py
index 7e58994a..830cfa15 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -4,7 +4,7 @@ import sys
import traceback
import numpy as np
-from PIL import Image, ImageOps, ImageChops
+from PIL import Image, ImageOps, ImageFilter, ImageEnhance
from modules import devices, sd_samplers
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
@@ -40,7 +40,7 @@ def process_batch(p, input_dir, output_dir, args):
img = Image.open(image)
# Use the EXIF orientation of photos taken by smartphones.
- img = ImageOps.exif_transpose(img)
+ img = ImageOps.exif_transpose(img)
p.init_images = [img] * p.batch_size
proc = modules.scripts.scripts_img2img.run(p, *args)
@@ -59,18 +59,30 @@ def process_batch(p, input_dir, output_dir, args):
processed_image.save(os.path.join(output_dir, filename))
-def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
+def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_with_mask_orig, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
is_inpaint = mode == 1
is_batch = mode == 2
if is_inpaint:
# Drawn mask
if mask_mode == 0:
- image = init_img_with_mask['image']
- mask = init_img_with_mask['mask']
- alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
- mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
- image = image.convert('RGB')
+ image = init_img_with_mask
+ is_mask_sketch = isinstance(image, dict)
+ is_mask_paint = not is_mask_sketch
+ if is_mask_sketch:
+ # Sketch: mask iff. not transparent
+ image, mask = image["image"], image["mask"]
+ pred = np.array(mask)[..., -1] > 0
+ else:
+ # Color-sketch: mask iff. painted over
+ orig = init_img_with_mask_orig or image
+ pred = np.any(np.array(image) != np.array(orig), axis=-1)
+ mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
+ if is_mask_paint:
+ mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
+ blur = ImageFilter.GaussianBlur(mask_blur)
+ image = Image.composite(image.filter(blur), orig, mask.filter(blur))
+ image = image.convert("RGB")
# Uploaded mask
else:
image = init_img_inpaint
@@ -82,7 +94,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
# Use the EXIF orientation of photos taken by smartphones.
if image is not None:
- image = ImageOps.exif_transpose(image)
+ image = ImageOps.exif_transpose(image)
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 9769aa34..0068b81c 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -1,4 +1,3 @@
-import contextlib
import os
import sys
import traceback
@@ -11,10 +10,9 @@ from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import modules.shared as shared
-from modules import devices, paths, lowvram
+from modules import devices, paths, lowvram, modelloader
blip_image_eval_size = 384
-blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
clip_model_name = 'ViT-L/14'
Category = namedtuple("Category", ["name", "topn", "items"])
@@ -47,7 +45,14 @@ class InterrogateModels:
def load_blip_model(self):
import models.blip
- blip_model = models.blip.blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
+ files = modelloader.load_models(
+ model_path=os.path.join(paths.models_path, "BLIP"),
+ model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
+ ext_filter=[".pth"],
+ download_name='model_base_caption_capfilt_large.pth',
+ )
+
+ blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
blip_model.eval()
return blip_model
@@ -148,8 +153,7 @@ class InterrogateModels:
clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
- precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
- with torch.no_grad(), precision_scope("cuda"):
+ with torch.no_grad(), devices.autocast():
image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
image_features /= image_features.norm(dim=-1, keepdim=True)
diff --git a/modules/modelloader.py b/modules/modelloader.py
index 7d2f0ade..e647f6fa 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -124,10 +124,9 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
def load_upscalers():
- sd = shared.script_path
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
# so we'll try to import any _model.py files before looking in __subclasses__
- modules_dir = os.path.join(sd, "modules")
+ modules_dir = os.path.join(shared.script_path, "modules")
for file in os.listdir(modules_dir):
if "_model.py" in file:
model_name = file.replace("_model.py", "")
@@ -136,22 +135,13 @@ def load_upscalers():
importlib.import_module(full_model)
except:
pass
+
datas = []
- c_o = vars(shared.cmd_opts)
+ commandline_options = vars(shared.cmd_opts)
for cls in Upscaler.__subclasses__():
name = cls.__name__
- module_name = cls.__module__
- module = importlib.import_module(module_name)
- class_ = getattr(module, name)
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
- opt_string = None
- try:
- if cmd_name in c_o:
- opt_string = c_o[cmd_name]
- except:
- pass
- scaler = class_(opt_string)
- for child in scaler.scalers:
- datas.append(child)
+ scaler = cls(commandline_options.get(cmd_name, None))
+ datas += scaler.scalers
shared.sd_upscalers = datas
diff --git a/modules/processing.py b/modules/processing.py
index edceb532..3d2c4dc9 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -414,7 +414,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
- negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[0] if p.all_negative_prompts[0] else ""
+ negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else ""
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
@@ -530,8 +530,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
with devices.autocast():
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
- samples_ddim = samples_ddim.to(devices.dtype_vae)
- x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
+ x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
+ x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
del samples_ddim
diff --git a/modules/safe.py b/modules/safe.py
index a9209e38..10460ad0 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -62,14 +62,12 @@ class RestrictedUnpickler(pickle.Unpickler):
raise Exception(f"global '{module}/{name}' is forbidden")
-allowed_zip_names = ["archive/data.pkl", "archive/version"]
-allowed_zip_names_re = re.compile(r"^archive/data/\d+$")
-
+# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
+allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
+data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
def check_zip_filenames(filename, names):
for name in names:
- if name in allowed_zip_names:
- continue
if allowed_zip_names_re.match(name):
continue
@@ -82,8 +80,14 @@ def check_pt(filename, extra_handler):
# new pytorch format is a zip file
with zipfile.ZipFile(filename) as z:
check_zip_filenames(filename, z.namelist())
-
- with z.open('archive/data.pkl') as file:
+
+ # find filename of data.pkl in zip file: '<directory name>/data.pkl'
+ data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
+ if len(data_pkl_filenames) == 0:
+ raise Exception(f"data.pkl not found in {filename}")
+ if len(data_pkl_filenames) > 1:
+ raise Exception(f"Multiple data.pkl found in {filename}")
+ with z.open(data_pkl_filenames[0]) as file:
unpickler = RestrictedUnpickler(file)
unpickler.extra_handler = extra_handler
unpickler.load()
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index b824b5bf..95a17093 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -17,6 +17,7 @@ from modules.sd_hijack_optimizations import invokeAI_mps_available
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
+import ldm.modules.diffusionmodules.openaimodel
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
import ldm.modules.encoders.modules
@@ -182,11 +183,7 @@ def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != devices.device:
-
- if devices.has_mps():
- attr = attr.to(device="mps", dtype=torch.float32)
- else:
- attr = attr.to(devices.device)
+ attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
setattr(self, name, attr)
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 5fefb227..4c123d3b 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -6,6 +6,7 @@ import tqdm
from PIL import Image
import inspect
import k_diffusion.sampling
+import torchsde._brownian.brownian_interval
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
from modules import prompt_parser, devices, processing, images
@@ -364,7 +365,23 @@ class TorchHijack:
if noise.shape == x.shape:
return noise
- return torch.randn_like(x)
+ if x.device.type == 'mps':
+ return torch.randn_like(x, device=devices.cpu).to(x.device)
+ else:
+ return torch.randn_like(x)
+
+
+# MPS fix for randn in torchsde
+def torchsde_randn(size, dtype, device, seed):
+ if device.type == 'mps':
+ generator = torch.Generator(devices.cpu).manual_seed(int(seed))
+ return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
+ else:
+ generator = torch.Generator(device).manual_seed(int(seed))
+ return torch.randn(size, dtype=dtype, device=device, generator=generator)
+
+
+torchsde._brownian.brownian_interval._randn = torchsde_randn
class KDiffusionSampler:
@@ -415,8 +432,7 @@ class KDiffusionSampler:
self.model_wrap.step = 0
self.eta = p.eta or opts.eta_ancestral
- if self.sampler_noises is not None:
- k_diffusion.sampling.torch = TorchHijack(self.sampler_noises)
+ k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
extra_params_kwargs = {}
for param_name in self.extra_params:
diff --git a/modules/shared.py b/modules/shared.py
index c36ee211..dc45fcaa 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -50,9 +50,6 @@ parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory wi
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
-parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
-parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
-parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
@@ -61,7 +58,7 @@ parser.add_argument("--opt-split-attention", action='store_true', help="force-en
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
-parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower)
+parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
@@ -72,6 +69,7 @@ parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor")
+parser.add_argument("--gradio-inpaint-tool", type=str, choices=["sketch", "color-sketch"], default="sketch", help="gradio inpainting editor: can be either sketch to only blur/noise the input, or color-sketch to paint over it")
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
@@ -94,6 +92,7 @@ parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, req
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
script_loading.preload_extensions(extensions.extensions_dir, parser)
+script_loading.preload_extensions(extensions.extensions_builtin_dir, parser)
cmd_opts = parser.parse_args()
@@ -111,8 +110,8 @@ restricted_opts = {
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
-devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
-(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'])
+devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
+ (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
device = devices.device
weight_load_location = None if cmd_opts.lowram else "cpu"
@@ -325,9 +324,6 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
- "SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
- "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
- "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
"use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"),
}))
@@ -371,7 +367,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
- 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
+ 'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
}))
diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py
index 9859974a..68e1103c 100644
--- a/modules/textual_inversion/autocrop.py
+++ b/modules/textual_inversion/autocrop.py
@@ -276,8 +276,8 @@ def poi_average(pois, settings):
weight += poi.weight
x += poi.x * poi.weight
y += poi.y * poi.weight
- avg_x = round(x / weight)
- avg_y = round(y / weight)
+ avg_x = round(weight and x / weight)
+ avg_y = round(weight and y / weight)
return PointOfInterest(avg_x, avg_y)
@@ -338,4 +338,4 @@ class Settings:
self.face_points_weight = face_points_weight
self.annotate_image = annotate_image
self.destop_view_image = False
- self.dnn_model_path = dnn_model_path \ No newline at end of file
+ self.dnn_model_path = dnn_model_path
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index e5725f33..2dc64c3c 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -82,7 +82,7 @@ class PersonalizedBase(Dataset):
torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
latent_sample = None
- with torch.autocast("cuda"):
+ with devices.autocast():
latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
@@ -101,7 +101,7 @@ class PersonalizedBase(Dataset):
entry.cond_text = self.create_text(filename_text)
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
- with torch.autocast("cuda"):
+ with devices.autocast():
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
self.dataset.append(entry)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 4eb75cb5..e28c357a 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -269,6 +269,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
+ old_parallel_processing_allowed = shared.parallel_processing_allowed
pin_memory = shared.opts.pin_memory
@@ -279,6 +280,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
if unload:
+ shared.parallel_processing_allowed = False
shared.sd_model.first_stage_model.to(devices.cpu)
embedding.vec.requires_grad = True
@@ -316,7 +318,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
if shared.state.interrupted:
break
- with torch.autocast("cuda"):
+ with devices.autocast():
# c = stack_conds(batch.cond).to(devices.device)
# mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
# print(mask)
@@ -450,6 +452,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
pbar.leave = False
pbar.close()
shared.sd_model.first_stage_model.to(devices.device)
+ shared.parallel_processing_allowed = old_parallel_processing_allowed
return embedding, filename
diff --git a/modules/ui.py b/modules/ui.py
index 00809361..3acb9b48 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -28,7 +28,6 @@ import modules.codeformer_model
import modules.generation_parameters_copypaste as parameters_copypaste
import modules.gfpgan_model
import modules.hypernetworks.ui
-import modules.ldsr_model
import modules.scripts
import modules.shared as shared
import modules.styles
@@ -792,11 +791,22 @@ def create_ui():
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool).style(height=480)
with gr.TabItem('Inpaint', id='inpaint'):
- init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480)
+ init_img_with_mask_orig = gr.State(None)
+ init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480)
+ def update_orig(image, state):
+ if image is not None:
+ same_size = state is not None and state.size == image.size
+ has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
+ edited = same_size and has_exact_match
+ return image if not edited or state is None else state
+
+ init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig)
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
+ show_mask_alpha = cmd_opts.gradio_inpaint_tool == "color-sketch"
+ mask_alpha = gr.Slider(label="Mask transparency", interactive=show_mask_alpha, visible=show_mask_alpha)
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4)
with gr.Row():
@@ -884,12 +894,14 @@ def create_ui():
img2img_prompt_style2,
init_img,
init_img_with_mask,
+ init_img_with_mask_orig,
init_img_inpaint,
init_mask_inpaint,
mask_mode,
steps,
sampler_index,
mask_blur,
+ mask_alpha,
inpainting_fill,
restore_faces,
tiling,
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index 030f011e..b487ac25 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -17,7 +17,7 @@ available_extensions = {"extensions": []}
def check_access():
- assert not shared.cmd_opts.disable_extension_access, "extension access disabed because of commandline flags"
+ assert not shared.cmd_opts.disable_extension_access, "extension access disabled because of command line flags"
def apply_and_restart(disable_list, update_list):
@@ -78,6 +78,12 @@ def extension_table():
"""
for ext in extensions.extensions:
+ remote = ""
+ if ext.is_builtin:
+ remote = "built-in"
+ elif ext.remote:
+ remote = f"""<a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape("built-in" if ext.is_builtin else ext.remote or '')}</a>"""
+
if ext.can_update:
ext_status = f"""<label><input class="gr-check-radio gr-checkbox" name="update_{html.escape(ext.name)}" checked="checked" type="checkbox">{html.escape(ext.status)}</label>"""
else:
@@ -86,7 +92,7 @@ def extension_table():
code += f"""
<tr>
<td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
- <td><a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape(ext.remote or '')}</a></td>
+ <td>{remote}</td>
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
</tr>
"""
diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py
index 4d1e152d..5fd952e9 100644
--- a/scripts/prompt_matrix.py
+++ b/scripts/prompt_matrix.py
@@ -84,6 +84,6 @@ class Script(scripts.Script):
processed.infotexts.insert(0, processed.infotexts[0])
if opts.grid_save:
- images.save_image(processed.images[0], p.outpath_grids, "prompt_matrix", prompt=original_prompt, seed=processed.seed, grid=True, p=p)
+ images.save_image(processed.images[0], p.outpath_grids, "prompt_matrix", extension=opts.grid_format, prompt=original_prompt, seed=processed.seed, grid=True, p=p)
return processed
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
index 0f27deda..d402c281 100644
--- a/scripts/xy_grid.py
+++ b/scripts/xy_grid.py
@@ -383,6 +383,6 @@ class Script(scripts.Script):
)
if opts.grid_save:
- images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p)
+ images.save_image(processed.images[0], p.outpath_grids, "xy_grid", extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
return processed
diff --git a/webui.py b/webui.py
index 16e7ec1a..78204d11 100644
--- a/webui.py
+++ b/webui.py
@@ -53,10 +53,11 @@ def initialize():
codeformer.setup_model(cmd_opts.codeformer_models_path)
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
- modelloader.load_upscalers()
modules.scripts.load_scripts()
+ modelloader.load_upscalers()
+
modules.sd_vae.refresh_vae_list()
modules.sd_models.load_model()
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
@@ -177,6 +178,8 @@ def webui():
print('Reloading custom scripts')
modules.scripts.reload_scripts()
+ modelloader.load_upscalers()
+
print('Reloading modules: modules.ui')
importlib.reload(modules.ui)
print('Refreshing Model List')