aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin
diff options
context:
space:
mode:
Diffstat (limited to 'extensions-builtin')
-rw-r--r--extensions-builtin/LDSR/ldsr_model_arch.py13
-rw-r--r--extensions-builtin/LDSR/scripts/ldsr_model.py7
-rw-r--r--extensions-builtin/LDSR/sd_hijack_autoencoder.py28
-rw-r--r--extensions-builtin/LDSR/sd_hijack_ddpm_v1.py66
-rw-r--r--extensions-builtin/Lora/extra_networks_lora.py18
-rw-r--r--extensions-builtin/Lora/lora.py80
-rw-r--r--extensions-builtin/Lora/scripts/lora_script.py38
-rw-r--r--extensions-builtin/Lora/ui_extra_networks_lora.py5
-rw-r--r--extensions-builtin/ScuNET/scripts/scunet_model.py19
-rw-r--r--extensions-builtin/ScuNET/scunet_model_arch.py11
-rw-r--r--extensions-builtin/SwinIR/scripts/swinir_model.py9
-rw-r--r--extensions-builtin/SwinIR/swinir_model_arch.py6
-rw-r--r--extensions-builtin/SwinIR/swinir_model_arch_v2.py58
-rw-r--r--extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js52
14 files changed, 256 insertions, 154 deletions
diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py
index bc11cc6e..7f450086 100644
--- a/extensions-builtin/LDSR/ldsr_model_arch.py
+++ b/extensions-builtin/LDSR/ldsr_model_arch.py
@@ -88,7 +88,7 @@ class LDSR:
x_t = None
logs = None
- for n in range(n_runs):
+ for _ in range(n_runs):
if custom_shape is not None:
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
@@ -110,7 +110,6 @@ class LDSR:
diffusion_steps = int(steps)
eta = 1.0
- down_sample_method = 'Lanczos'
gc.collect()
if torch.cuda.is_available:
@@ -131,11 +130,11 @@ class LDSR:
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
else:
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
-
+
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
-
+
logs = self.run(model["model"], im_padded, diffusion_steps, eta)
sample = logs["sample"]
@@ -158,7 +157,7 @@ class LDSR:
def get_cond(selected_path):
- example = dict()
+ example = {}
up_f = 4
c = selected_path.convert('RGB')
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
@@ -196,7 +195,7 @@ def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_s
@torch.no_grad()
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
- log = dict()
+ log = {}
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
return_first_stage_outputs=True,
@@ -244,7 +243,7 @@ def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
log["sample_noquant"] = x_sample_noquant
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
- except:
+ except Exception:
pass
log["sample"] = x_sample
diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py
index da19cff1..c4da79f3 100644
--- a/extensions-builtin/LDSR/scripts/ldsr_model.py
+++ b/extensions-builtin/LDSR/scripts/ldsr_model.py
@@ -7,7 +7,8 @@ from basicsr.utils.download_util import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData
from ldsr_model_arch import LDSR
from modules import shared, script_callbacks
-import sd_hijack_autoencoder, sd_hijack_ddpm_v1
+import sd_hijack_autoencoder # noqa: F401
+import sd_hijack_ddpm_v1 # noqa: F401
class UpscalerLDSR(Upscaler):
@@ -44,9 +45,9 @@ class UpscalerLDSR(Upscaler):
if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
model = local_safetensors_path
else:
- model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="model.ckpt", progress=True)
+ model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="model.ckpt", progress=True)
- yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_path, file_name="project.yaml", progress=True)
+ yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml", progress=True)
try:
return LDSR(model, yaml)
diff --git a/extensions-builtin/LDSR/sd_hijack_autoencoder.py b/extensions-builtin/LDSR/sd_hijack_autoencoder.py
index 8e03c7f8..81c5101b 100644
--- a/extensions-builtin/LDSR/sd_hijack_autoencoder.py
+++ b/extensions-builtin/LDSR/sd_hijack_autoencoder.py
@@ -1,16 +1,21 @@
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
-
+import numpy as np
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager
+
+from torch.optim.lr_scheduler import LambdaLR
+
+from ldm.modules.ema import LitEma
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.util import instantiate_from_config
import ldm.models.autoencoder
+from packaging import version
class VQModel(pl.LightningModule):
def __init__(self,
@@ -19,7 +24,7 @@ class VQModel(pl.LightningModule):
n_embed,
embed_dim,
ckpt_path=None,
- ignore_keys=[],
+ ignore_keys=None,
image_key="image",
colorize_nlabels=None,
monitor=None,
@@ -57,7 +62,7 @@ class VQModel(pl.LightningModule):
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [])
self.scheduler_config = scheduler_config
self.lr_g_factor = lr_g_factor
@@ -76,11 +81,11 @@ class VQModel(pl.LightningModule):
if context is not None:
print(f"{context}: Restored training weights")
- def init_from_ckpt(self, path, ignore_keys=list()):
+ def init_from_ckpt(self, path, ignore_keys=None):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
- for ik in ignore_keys:
+ for ik in ignore_keys or []:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
@@ -165,7 +170,7 @@ class VQModel(pl.LightningModule):
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
- log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
+ self._validation_step(batch, batch_idx, suffix="_ema")
return log_dict
def _validation_step(self, batch, batch_idx, suffix=""):
@@ -232,7 +237,7 @@ class VQModel(pl.LightningModule):
return self.decoder.conv_out.weight
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
- log = dict()
+ log = {}
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if only_inputs:
@@ -249,7 +254,8 @@ class VQModel(pl.LightningModule):
if plot_ema:
with self.ema_scope():
xrec_ema, _ = self(x)
- if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
+ if x.shape[1] > 3:
+ xrec_ema = self.to_rgb(xrec_ema)
log["reconstructions_ema"] = xrec_ema
return log
@@ -264,7 +270,7 @@ class VQModel(pl.LightningModule):
class VQModelInterface(VQModel):
def __init__(self, embed_dim, *args, **kwargs):
- super().__init__(embed_dim=embed_dim, *args, **kwargs)
+ super().__init__(*args, embed_dim=embed_dim, **kwargs)
self.embed_dim = embed_dim
def encode(self, x):
@@ -282,5 +288,5 @@ class VQModelInterface(VQModel):
dec = self.decoder(quant)
return dec
-setattr(ldm.models.autoencoder, "VQModel", VQModel)
-setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface)
+ldm.models.autoencoder.VQModel = VQModel
+ldm.models.autoencoder.VQModelInterface = VQModelInterface
diff --git a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py
index 5c0488e5..631a08ef 100644
--- a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py
+++ b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py
@@ -48,7 +48,7 @@ class DDPMV1(pl.LightningModule):
beta_schedule="linear",
loss_type="l2",
ckpt_path=None,
- ignore_keys=[],
+ ignore_keys=None,
load_only_unet=False,
monitor="val/loss",
use_ema=True,
@@ -100,7 +100,7 @@ class DDPMV1(pl.LightningModule):
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
@@ -182,13 +182,13 @@ class DDPMV1(pl.LightningModule):
if context is not None:
print(f"{context}: Restored training weights")
- def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
keys = list(sd.keys())
for k in keys:
- for ik in ignore_keys:
+ for ik in ignore_keys or []:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
@@ -375,7 +375,7 @@ class DDPMV1(pl.LightningModule):
@torch.no_grad()
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
- log = dict()
+ log = {}
x = self.get_input(batch, self.first_stage_key)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
@@ -383,7 +383,7 @@ class DDPMV1(pl.LightningModule):
log["inputs"] = x
# get diffusion row
- diffusion_row = list()
+ diffusion_row = []
x_start = x[:n_row]
for t in range(self.num_timesteps):
@@ -444,13 +444,13 @@ class LatentDiffusionV1(DDPMV1):
conditioning_key = None
ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", [])
- super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
+ super().__init__(*args, conditioning_key=conditioning_key, **kwargs)
self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable
self.cond_stage_key = cond_stage_key
try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
- except:
+ except Exception:
self.num_downs = 0
if not scale_by_std:
self.scale_factor = scale_factor
@@ -460,7 +460,7 @@ class LatentDiffusionV1(DDPMV1):
self.instantiate_cond_stage(cond_stage_config)
self.cond_stage_forward = cond_stage_forward
self.clip_denoised = False
- self.bbox_tokenizer = None
+ self.bbox_tokenizer = None
self.restarted_from_ckpt = False
if ckpt_path is not None:
@@ -792,7 +792,7 @@ class LatentDiffusionV1(DDPMV1):
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
# 2. apply model loop over last dim
- if isinstance(self.first_stage_model, VQModelInterface):
+ if isinstance(self.first_stage_model, VQModelInterface):
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
force_not_quantize=predict_cids or force_not_quantize)
for i in range(z.shape[-1])]
@@ -877,16 +877,6 @@ class LatentDiffusionV1(DDPMV1):
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
return self.p_losses(x, c, t, *args, **kwargs)
- def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
- def rescale_bbox(bbox):
- x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
- y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
- w = min(bbox[2] / crop_coordinates[2], 1 - x0)
- h = min(bbox[3] / crop_coordinates[3], 1 - y0)
- return x0, y0, w, h
-
- return [rescale_bbox(b) for b in bboxes]
-
def apply_model(self, x_noisy, t, cond, return_ids=False):
if isinstance(cond, dict):
@@ -900,7 +890,7 @@ class LatentDiffusionV1(DDPMV1):
if hasattr(self, "split_input_params"):
assert len(cond) == 1 # todo can only deal with one conditioning atm
- assert not return_ids
+ assert not return_ids
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
@@ -1126,7 +1116,7 @@ class LatentDiffusionV1(DDPMV1):
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ [x[:batch_size] for x in cond[key]] for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
@@ -1157,8 +1147,10 @@ class LatentDiffusionV1(DDPMV1):
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(x0_partial)
- if callback: callback(i)
- if img_callback: img_callback(img, i)
+ if callback:
+ callback(i)
+ if img_callback:
+ img_callback(img, i)
return img, intermediates
@torch.no_grad()
@@ -1205,8 +1197,10 @@ class LatentDiffusionV1(DDPMV1):
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(img)
- if callback: callback(i)
- if img_callback: img_callback(img, i)
+ if callback:
+ callback(i)
+ if img_callback:
+ img_callback(img, i)
if return_intermediates:
return img, intermediates
@@ -1221,7 +1215,7 @@ class LatentDiffusionV1(DDPMV1):
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ [x[:batch_size] for x in cond[key]] for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
return self.p_sample_loop(cond,
@@ -1253,7 +1247,7 @@ class LatentDiffusionV1(DDPMV1):
use_ddim = ddim_steps is not None
- log = dict()
+ log = {}
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=True,
@@ -1280,7 +1274,7 @@ class LatentDiffusionV1(DDPMV1):
if plot_diffusion_rows:
# get diffusion row
- diffusion_row = list()
+ diffusion_row = []
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
@@ -1322,7 +1316,7 @@ class LatentDiffusionV1(DDPMV1):
if inpaint:
# make a simple center square
- b, h, w = z.shape[0], z.shape[2], z.shape[3]
+ h, w = z.shape[2], z.shape[3]
mask = torch.ones(N, h, w).to(self.device)
# zeros will be filled in
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
@@ -1424,10 +1418,10 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
# TODO: move all layout-specific hacks to this class
def __init__(self, cond_stage_key, *args, **kwargs):
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
- super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
+ super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
def log_images(self, batch, N=8, *args, **kwargs):
- logs = super().log_images(batch=batch, N=N, *args, **kwargs)
+ logs = super().log_images(*args, batch=batch, N=N, **kwargs)
key = 'train' if self.training else 'validation'
dset = self.trainer.datamodule.datasets[key]
@@ -1443,7 +1437,7 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
logs['bbox_image'] = cond_img
return logs
-setattr(ldm.models.diffusion.ddpm, "DDPMV1", DDPMV1)
-setattr(ldm.models.diffusion.ddpm, "LatentDiffusionV1", LatentDiffusionV1)
-setattr(ldm.models.diffusion.ddpm, "DiffusionWrapperV1", DiffusionWrapperV1)
-setattr(ldm.models.diffusion.ddpm, "Layout2ImgDiffusionV1", Layout2ImgDiffusionV1)
+ldm.models.diffusion.ddpm.DDPMV1 = DDPMV1
+ldm.models.diffusion.ddpm.LatentDiffusionV1 = LatentDiffusionV1
+ldm.models.diffusion.ddpm.DiffusionWrapperV1 = DiffusionWrapperV1
+ldm.models.diffusion.ddpm.Layout2ImgDiffusionV1 = Layout2ImgDiffusionV1
diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py
index ccb249ac..b5fea4d2 100644
--- a/extensions-builtin/Lora/extra_networks_lora.py
+++ b/extensions-builtin/Lora/extra_networks_lora.py
@@ -23,5 +23,23 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
lora.load_loras(names, multipliers)
+ if shared.opts.lora_add_hashes_to_infotext:
+ lora_hashes = []
+ for item in lora.loaded_loras:
+ shorthash = item.lora_on_disk.shorthash
+ if not shorthash:
+ continue
+
+ alias = item.mentioned_name
+ if not alias:
+ continue
+
+ alias = alias.replace(":", "").replace(",", "")
+
+ lora_hashes.append(f"{alias}: {shorthash}")
+
+ if lora_hashes:
+ p.extra_generation_params["Lora hashes"] = ", ".join(lora_hashes)
+
def deactivate(self, p):
pass
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py
index b5d0c98f..eec14712 100644
--- a/extensions-builtin/Lora/lora.py
+++ b/extensions-builtin/Lora/lora.py
@@ -1,10 +1,9 @@
-import glob
import os
import re
import torch
from typing import Union
-from modules import shared, devices, sd_models, errors, scripts
+from modules import shared, devices, sd_models, errors, scripts, sd_hijack, hashes
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
@@ -77,9 +76,9 @@ class LoraOnDisk:
self.name = name
self.filename = filename
self.metadata = {}
+ self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
- _, ext = os.path.splitext(filename)
- if ext.lower() == ".safetensors":
+ if self.is_safetensors:
try:
self.metadata = sd_models.read_metadata_from_safetensors(filename)
except Exception as e:
@@ -95,14 +94,43 @@ class LoraOnDisk:
self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
self.alias = self.metadata.get('ss_output_name', self.name)
+ self.hash = None
+ self.shorthash = None
+ self.set_hash(
+ self.metadata.get('sshs_model_hash') or
+ hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
+ ''
+ )
+
+ def set_hash(self, v):
+ self.hash = v
+ self.shorthash = self.hash[0:12]
+
+ if self.shorthash:
+ available_lora_hash_lookup[self.shorthash] = self
+
+ def read_hash(self):
+ if not self.hash:
+ self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
+
+ def get_alias(self):
+ if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in forbidden_lora_aliases:
+ return self.name
+ else:
+ return self.alias
+
class LoraModule:
- def __init__(self, name):
+ def __init__(self, name, lora_on_disk: LoraOnDisk):
self.name = name
+ self.lora_on_disk = lora_on_disk
self.multiplier = 1.0
self.modules = {}
self.mtime = None
+ self.mentioned_name = None
+ """the text that was used to add lora to prompt - can be either name or an alias"""
+
class LoraUpDownModule:
def __init__(self):
@@ -127,11 +155,11 @@ def assign_lora_names_to_compvis_modules(sd_model):
sd_model.lora_layer_mapping = lora_layer_mapping
-def load_lora(name, filename):
- lora = LoraModule(name)
- lora.mtime = os.path.getmtime(filename)
+def load_lora(name, lora_on_disk):
+ lora = LoraModule(name, lora_on_disk)
+ lora.mtime = os.path.getmtime(lora_on_disk.filename)
- sd = sd_models.read_state_dict(filename)
+ sd = sd_models.read_state_dict(lora_on_disk.filename)
# this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
if not hasattr(shared.sd_model, 'lora_layer_mapping'):
@@ -177,7 +205,7 @@ def load_lora(name, filename):
else:
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
continue
- assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
+ raise AssertionError(f"Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}")
with torch.no_grad():
module.weight.copy_(weight)
@@ -189,10 +217,10 @@ def load_lora(name, filename):
elif lora_key == "lora_down.weight":
lora_module.down = module
else:
- assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
+ raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")
if len(keys_failed_to_match) > 0:
- print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
+ print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}")
return lora
@@ -207,30 +235,41 @@ def load_loras(names, multipliers=None):
loaded_loras.clear()
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
- if any([x is None for x in loras_on_disk]):
+ if any(x is None for x in loras_on_disk):
list_available_loras()
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
+ failed_to_load_loras = []
+
for i, name in enumerate(names):
lora = already_loaded.get(name, None)
lora_on_disk = loras_on_disk[i]
+
if lora_on_disk is not None:
if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
try:
- lora = load_lora(name, lora_on_disk.filename)
+ lora = load_lora(name, lora_on_disk)
except Exception as e:
errors.display(e, f"loading Lora {lora_on_disk.filename}")
continue
+ lora.mentioned_name = name
+
+ lora_on_disk.read_hash()
+
if lora is None:
+ failed_to_load_loras.append(name)
print(f"Couldn't find Lora with name {name}")
continue
lora.multiplier = multipliers[i] if multipliers else 1.0
loaded_loras.append(lora)
+ if len(failed_to_load_loras) > 0:
+ sd_hijack.model_hijack.comments.append("Failed to find Loras: " + ", ".join(failed_to_load_loras))
+
def lora_calc_updown(lora, module, target):
with torch.no_grad():
@@ -314,7 +353,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
print(f'failed to calculate lora weights for layer {lora_layer_name}')
- setattr(self, "lora_current_names", wanted_names)
+ self.lora_current_names = wanted_names
def lora_forward(module, input, original_forward):
@@ -348,8 +387,8 @@ def lora_forward(module, input, original_forward):
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
- setattr(self, "lora_current_names", ())
- setattr(self, "lora_weights_backup", None)
+ self.lora_current_names = ()
+ self.lora_weights_backup = None
def lora_Linear_forward(self, input):
@@ -398,7 +437,8 @@ def list_available_loras():
available_loras.clear()
available_lora_aliases.clear()
forbidden_lora_aliases.clear()
- forbidden_lora_aliases.update({"none": 1})
+ available_lora_hash_lookup.clear()
+ forbidden_lora_aliases.update({"none": 1, "Addams": 1})
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
@@ -428,7 +468,7 @@ def infotext_pasted(infotext, params):
added = []
- for k, v in params.items():
+ for k in params:
if not k.startswith("AddNet Model "):
continue
@@ -452,8 +492,10 @@ def infotext_pasted(infotext, params):
if added:
params["Prompt"] += "\n" + "".join(added)
+
available_loras = {}
available_lora_aliases = {}
+available_lora_hash_lookup = {}
forbidden_lora_aliases = {}
loaded_loras = []
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py
index 060bda05..e650f469 100644
--- a/extensions-builtin/Lora/scripts/lora_script.py
+++ b/extensions-builtin/Lora/scripts/lora_script.py
@@ -1,3 +1,5 @@
+import re
+
import torch
import gradio as gr
from fastapi import FastAPI
@@ -53,8 +55,9 @@ script_callbacks.on_infotext_pasted(lora.infotext_pasted)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
- "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
- "lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
+ "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras),
+ "lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
+ "lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
}))
@@ -77,6 +80,37 @@ def api_loras(_: gr.Blocks, app: FastAPI):
async def get_loras():
return [create_lora_json(obj) for obj in lora.available_loras.values()]
+ @app.post("/sdapi/v1/refresh-loras")
+ async def refresh_loras():
+ return lora.list_available_loras()
+
script_callbacks.on_app_started(api_loras)
+re_lora = re.compile("<lora:([^:]+):")
+
+
+def infotext_pasted(infotext, d):
+ hashes = d.get("Lora hashes")
+ if not hashes:
+ return
+
+ hashes = [x.strip().split(':', 1) for x in hashes.split(",")]
+ hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes}
+
+ def lora_replacement(m):
+ alias = m.group(1)
+ shorthash = hashes.get(alias)
+ if shorthash is None:
+ return m.group(0)
+
+ lora_on_disk = lora.available_lora_hash_lookup.get(shorthash)
+ if lora_on_disk is None:
+ return m.group(0)
+
+ return f'<lora:{lora_on_disk.get_alias()}:'
+
+ d["Prompt"] = re.sub(re_lora, lora_replacement, d["Prompt"])
+
+
+script_callbacks.on_infotext_pasted(infotext_pasted)
diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py
index 2050e3fa..259e99ac 100644
--- a/extensions-builtin/Lora/ui_extra_networks_lora.py
+++ b/extensions-builtin/Lora/ui_extra_networks_lora.py
@@ -16,10 +16,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
for name, lora_on_disk in lora.available_loras.items():
path, ext = os.path.splitext(lora_on_disk.filename)
- if shared.opts.lora_preferred_name == "Filename" or lora_on_disk.alias.lower() in lora.forbidden_lora_aliases:
- alias = name
- else:
- alias = lora_on_disk.alias
+ alias = lora_on_disk.get_alias()
yield {
"name": name,
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py
index c7fd5739..45d9297b 100644
--- a/extensions-builtin/ScuNET/scripts/scunet_model.py
+++ b/extensions-builtin/ScuNET/scripts/scunet_model.py
@@ -10,10 +10,9 @@ from tqdm import tqdm
from basicsr.utils.download_util import load_file_from_url
import modules.upscaler
-from modules import devices, modelloader
+from modules import devices, modelloader, script_callbacks
from scunet_model_arch import SCUNet as net
from modules.shared import opts
-from modules import images
class UpscalerScuNET(modules.upscaler.Upscaler):
@@ -122,8 +121,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
def load_model(self, path: str):
device = devices.get_device_for('scunet')
if "http" in path:
- filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
- progress=True)
+ filename = load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="%s.pth" % self.name, progress=True)
else:
filename = path
if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
@@ -133,8 +131,19 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
model.load_state_dict(torch.load(filename), strict=True)
model.eval()
- for k, v in model.named_parameters():
+ for _, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
return model
+
+
+def on_ui_settings():
+ import gradio as gr
+ from modules import shared
+
+ shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling"))
+ shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam"))
+
+
+script_callbacks.on_ui_settings(on_ui_settings)
diff --git a/extensions-builtin/ScuNET/scunet_model_arch.py b/extensions-builtin/ScuNET/scunet_model_arch.py
index 43ca8d36..b51a8806 100644
--- a/extensions-builtin/ScuNET/scunet_model_arch.py
+++ b/extensions-builtin/ScuNET/scunet_model_arch.py
@@ -61,7 +61,9 @@ class WMSA(nn.Module):
Returns:
output: tensor shape [b h w c]
"""
- if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
+ if self.type != 'W':
+ x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
+
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
h_windows = x.size(1)
w_windows = x.size(2)
@@ -85,8 +87,9 @@ class WMSA(nn.Module):
output = self.linear(output)
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
- if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2),
- dims=(1, 2))
+ if self.type != 'W':
+ output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2))
+
return output
def relative_embedding(self):
@@ -262,4 +265,4 @@ class SCUNet(nn.Module):
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0) \ No newline at end of file
+ nn.init.constant_(m.weight, 1.0)
diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py
index e8783bca..1c7bf325 100644
--- a/extensions-builtin/SwinIR/scripts/swinir_model.py
+++ b/extensions-builtin/SwinIR/scripts/swinir_model.py
@@ -1,4 +1,3 @@
-import contextlib
import os
import numpy as np
@@ -8,7 +7,7 @@ from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared
-from modules.shared import cmd_opts, opts, state
+from modules.shared import opts, state
from swinir_model_arch import SwinIR as net
from swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData
@@ -45,14 +44,14 @@ class UpscalerSwinIR(Upscaler):
img = upscale(img, model)
try:
torch.cuda.empty_cache()
- except:
+ except Exception:
pass
return img
def load_model(self, path, scale=4):
if "http" in path:
dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth")
- filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True)
+ filename = load_file_from_url(url=path, model_dir=self.model_download_path, file_name=dl_name, progress=True)
else:
filename = path
if filename is None or not os.path.exists(filename):
@@ -151,7 +150,7 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
for w_idx in w_idx_list:
if state.interrupted or state.skipped:
break
-
+
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
diff --git a/extensions-builtin/SwinIR/swinir_model_arch.py b/extensions-builtin/SwinIR/swinir_model_arch.py
index 863f42db..93b93274 100644
--- a/extensions-builtin/SwinIR/swinir_model_arch.py
+++ b/extensions-builtin/SwinIR/swinir_model_arch.py
@@ -644,7 +644,7 @@ class SwinIR(nn.Module):
"""
def __init__(self, img_size=64, patch_size=1, in_chans=3,
- embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
+ embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
@@ -805,7 +805,7 @@ class SwinIR(nn.Module):
def forward(self, x):
H, W = x.shape[2:]
x = self.check_image_size(x)
-
+
self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range
@@ -844,7 +844,7 @@ class SwinIR(nn.Module):
H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops()
- for i, layer in enumerate(self.layers):
+ for layer in self.layers:
flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops()
diff --git a/extensions-builtin/SwinIR/swinir_model_arch_v2.py b/extensions-builtin/SwinIR/swinir_model_arch_v2.py
index 0e28ae6e..dad22cca 100644
--- a/extensions-builtin/SwinIR/swinir_model_arch_v2.py
+++ b/extensions-builtin/SwinIR/swinir_model_arch_v2.py
@@ -74,7 +74,7 @@ class WindowAttention(nn.Module):
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
- pretrained_window_size=[0, 0]):
+ pretrained_window_size=(0, 0)):
super().__init__()
self.dim = dim
@@ -241,7 +241,7 @@ class SwinTransformerBlock(nn.Module):
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
-
+
def calculate_mask(self, x_size):
# calculate attention mask for SW-MSA
H, W = x_size
@@ -263,7 +263,7 @@ class SwinTransformerBlock(nn.Module):
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
- return attn_mask
+ return attn_mask
def forward(self, x, x_size):
H, W = x_size
@@ -288,7 +288,7 @@ class SwinTransformerBlock(nn.Module):
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
else:
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
-
+
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
@@ -369,7 +369,7 @@ class PatchMerging(nn.Module):
H, W = self.input_resolution
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
flops += H * W * self.dim // 2
- return flops
+ return flops
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
@@ -447,7 +447,7 @@ class BasicLayer(nn.Module):
nn.init.constant_(blk.norm1.weight, 0)
nn.init.constant_(blk.norm2.bias, 0)
nn.init.constant_(blk.norm2.weight, 0)
-
+
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
@@ -492,7 +492,7 @@ class PatchEmbed(nn.Module):
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
- return flops
+ return flops
class RSTB(nn.Module):
"""Residual Swin Transformer Block (RSTB).
@@ -531,7 +531,7 @@ class RSTB(nn.Module):
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
+ qkv_bias=qkv_bias,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path,
norm_layer=norm_layer,
@@ -622,7 +622,7 @@ class Upsample(nn.Sequential):
else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m)
-
+
class Upsample_hf(nn.Sequential):
"""Upsample module.
@@ -642,7 +642,7 @@ class Upsample_hf(nn.Sequential):
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
- super(Upsample_hf, self).__init__(*m)
+ super(Upsample_hf, self).__init__(*m)
class UpsampleOneStep(nn.Sequential):
@@ -667,8 +667,8 @@ class UpsampleOneStep(nn.Sequential):
H, W = self.input_resolution
flops = H * W * self.num_feat * 3 * 9
return flops
-
-
+
+
class Swin2SR(nn.Module):
r""" Swin2SR
@@ -698,8 +698,8 @@ class Swin2SR(nn.Module):
"""
def __init__(self, img_size=64, patch_size=1, in_chans=3,
- embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
- window_size=7, mlp_ratio=4., qkv_bias=True,
+ embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
+ window_size=7, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
@@ -764,7 +764,7 @@ class Swin2SR(nn.Module):
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias,
+ qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer,
@@ -776,7 +776,7 @@ class Swin2SR(nn.Module):
)
self.layers.append(layer)
-
+
if self.upsampler == 'pixelshuffle_hf':
self.layers_hf = nn.ModuleList()
for i_layer in range(self.num_layers):
@@ -787,7 +787,7 @@ class Swin2SR(nn.Module):
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias,
+ qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer,
@@ -799,7 +799,7 @@ class Swin2SR(nn.Module):
)
self.layers_hf.append(layer)
-
+
self.norm = norm_layer(self.num_features)
# build the last conv layer in deep feature extraction
@@ -829,10 +829,10 @@ class Swin2SR(nn.Module):
self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.conv_after_aux = nn.Sequential(
nn.Conv2d(3, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
+ nn.LeakyReLU(inplace=True))
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
-
+
elif self.upsampler == 'pixelshuffle_hf':
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
@@ -846,7 +846,7 @@ class Swin2SR(nn.Module):
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
-
+
elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR (to save parameters)
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
@@ -905,7 +905,7 @@ class Swin2SR(nn.Module):
x = self.patch_unembed(x, x_size)
return x
-
+
def forward_features_hf(self, x):
x_size = (x.shape[2], x.shape[3])
x = self.patch_embed(x)
@@ -919,7 +919,7 @@ class Swin2SR(nn.Module):
x = self.norm(x) # B L C
x = self.patch_unembed(x, x_size)
- return x
+ return x
def forward(self, x):
H, W = x.shape[2:]
@@ -951,7 +951,7 @@ class Swin2SR(nn.Module):
x = self.conv_after_body(self.forward_features(x)) + x
x_before = self.conv_before_upsample(x)
x_out = self.conv_last(self.upsample(x_before))
-
+
x_hf = self.conv_first_hf(x_before)
x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
x_hf = self.conv_before_upsample_hf(x_hf)
@@ -977,15 +977,15 @@ class Swin2SR(nn.Module):
x_first = self.conv_first(x)
res = self.conv_after_body(self.forward_features(x_first)) + x_first
x = x + self.conv_last(res)
-
+
x = x / self.img_range + self.mean
if self.upsampler == "pixelshuffle_aux":
return x[:, :, :H*self.upscale, :W*self.upscale], aux
-
+
elif self.upsampler == "pixelshuffle_hf":
x_out = x_out / self.img_range + self.mean
return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
-
+
else:
return x[:, :, :H*self.upscale, :W*self.upscale]
@@ -994,7 +994,7 @@ class Swin2SR(nn.Module):
H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops()
- for i, layer in enumerate(self.layers):
+ for layer in self.layers:
flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops()
@@ -1014,4 +1014,4 @@ if __name__ == '__main__':
x = torch.randn((1, 3, height, width))
x = model(x)
- print(x.shape) \ No newline at end of file
+ print(x.shape)
diff --git a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js
index 5c7a836a..114cf94c 100644
--- a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js
+++ b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js
@@ -4,39 +4,39 @@
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
function checkBrackets(textArea, counterElt) {
- var counts = {};
- (textArea.value.match(/[(){}\[\]]/g) || []).forEach(bracket => {
- counts[bracket] = (counts[bracket] || 0) + 1;
- });
- var errors = [];
+ var counts = {};
+ (textArea.value.match(/[(){}[\]]/g) || []).forEach(bracket => {
+ counts[bracket] = (counts[bracket] || 0) + 1;
+ });
+ var errors = [];
- function checkPair(open, close, kind) {
- if (counts[open] !== counts[close]) {
- errors.push(
- `${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.`
- );
+ function checkPair(open, close, kind) {
+ if (counts[open] !== counts[close]) {
+ errors.push(
+ `${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.`
+ );
+ }
}
- }
- checkPair('(', ')', 'round brackets');
- checkPair('[', ']', 'square brackets');
- checkPair('{', '}', 'curly brackets');
- counterElt.title = errors.join('\n');
- counterElt.classList.toggle('error', errors.length !== 0);
+ checkPair('(', ')', 'round brackets');
+ checkPair('[', ']', 'square brackets');
+ checkPair('{', '}', 'curly brackets');
+ counterElt.title = errors.join('\n');
+ counterElt.classList.toggle('error', errors.length !== 0);
}
function setupBracketChecking(id_prompt, id_counter) {
- var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
- var counter = gradioApp().getElementById(id_counter)
+ var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
+ var counter = gradioApp().getElementById(id_counter);
- if (textarea && counter) {
- textarea.addEventListener("input", () => checkBrackets(textarea, counter));
- }
+ if (textarea && counter) {
+ textarea.addEventListener("input", () => checkBrackets(textarea, counter));
+ }
}
-onUiLoaded(function () {
- setupBracketChecking('txt2img_prompt', 'txt2img_token_counter');
- setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter');
- setupBracketChecking('img2img_prompt', 'img2img_token_counter');
- setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter');
+onUiLoaded(function() {
+ setupBracketChecking('txt2img_prompt', 'txt2img_token_counter');
+ setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter');
+ setupBracketChecking('img2img_prompt', 'img2img_token_counter');
+ setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter');
});