aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2023-05-01 14:27:53 +0300
committerAUTOMATIC <16777216c@gmail.com>2023-05-01 14:27:53 +0300
commitfe8a10d428bcc6be9cc8efb9772eca9e40f98dc8 (patch)
treed1e0ff50e327c3c59230b39907284c20ffbf0fe3 /extensions-builtin
parent22bcc7be428c94e9408f589966c2040187245d81 (diff)
parent6fbd85dd0c0dffc06560bff91f4c4b65e441ca5f (diff)
Merge branch 'release_candidate'
Diffstat (limited to 'extensions-builtin')
-rw-r--r--extensions-builtin/LDSR/scripts/ldsr_model.py20
-rw-r--r--extensions-builtin/Lora/extra_networks_lora.py2
-rw-r--r--extensions-builtin/Lora/lora.py6
-rw-r--r--extensions-builtin/Lora/scripts/lora_script.py2
-rw-r--r--extensions-builtin/ScuNET/scripts/scunet_model.py83
-rw-r--r--extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js121
6 files changed, 118 insertions, 116 deletions
diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py
index b8cff29b..da19cff1 100644
--- a/extensions-builtin/LDSR/scripts/ldsr_model.py
+++ b/extensions-builtin/LDSR/scripts/ldsr_model.py
@@ -25,22 +25,28 @@ class UpscalerLDSR(Upscaler):
yaml_path = os.path.join(self.model_path, "project.yaml")
old_model_path = os.path.join(self.model_path, "model.pth")
new_model_path = os.path.join(self.model_path, "model.ckpt")
- safetensors_model_path = os.path.join(self.model_path, "model.safetensors")
+
+ local_model_paths = self.find_models(ext_filter=[".ckpt", ".safetensors"])
+ local_ckpt_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("model.ckpt")]), None)
+ local_safetensors_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("model.safetensors")]), None)
+ local_yaml_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("project.yaml")]), None)
+
if os.path.exists(yaml_path):
statinfo = os.stat(yaml_path)
if statinfo.st_size >= 10485760:
print("Removing invalid LDSR YAML file.")
os.remove(yaml_path)
+
if os.path.exists(old_model_path):
print("Renaming model from model.pth to model.ckpt")
os.rename(old_model_path, new_model_path)
- if os.path.exists(safetensors_model_path):
- model = safetensors_model_path
+
+ if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
+ model = local_safetensors_path
else:
- model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
- file_name="model.ckpt", progress=True)
- yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
- file_name="project.yaml", progress=True)
+ model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="model.ckpt", progress=True)
+
+ yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_path, file_name="project.yaml", progress=True)
try:
return LDSR(model, yaml)
diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py
index 6be6ef73..45f899fc 100644
--- a/extensions-builtin/Lora/extra_networks_lora.py
+++ b/extensions-builtin/Lora/extra_networks_lora.py
@@ -8,7 +8,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
def activate(self, p, params_list):
additional = shared.opts.sd_lora
- if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
+ if additional != "None" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py
index d3eb0d3b..6f246921 100644
--- a/extensions-builtin/Lora/lora.py
+++ b/extensions-builtin/Lora/lora.py
@@ -211,7 +211,11 @@ def load_loras(names, multipliers=None):
lora_on_disk = loras_on_disk[i]
if lora_on_disk is not None:
if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
- lora = load_lora(name, lora_on_disk.filename)
+ try:
+ lora = load_lora(name, lora_on_disk.filename)
+ except Exception as e:
+ errors.display(e, f"loading Lora {lora_on_disk.filename}")
+ continue
if lora is None:
print(f"Couldn't find Lora with name {name}")
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py
index 0adab225..3fc38ab9 100644
--- a/extensions-builtin/Lora/scripts/lora_script.py
+++ b/extensions-builtin/Lora/scripts/lora_script.py
@@ -52,5 +52,5 @@ script_callbacks.on_before_ui(before_ui)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
- "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
+ "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
}))
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py
index e0fbf3a3..c7fd5739 100644
--- a/extensions-builtin/ScuNET/scripts/scunet_model.py
+++ b/extensions-builtin/ScuNET/scripts/scunet_model.py
@@ -5,11 +5,15 @@ import traceback
import PIL.Image
import numpy as np
import torch
+from tqdm import tqdm
+
from basicsr.utils.download_util import load_file_from_url
import modules.upscaler
from modules import devices, modelloader
from scunet_model_arch import SCUNet as net
+from modules.shared import opts
+from modules import images
class UpscalerScuNET(modules.upscaler.Upscaler):
@@ -42,28 +46,78 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scalers.append(scaler_data2)
self.scalers = scalers
- def do_upscale(self, img: PIL.Image, selected_file):
+ @staticmethod
+ @torch.no_grad()
+ def tiled_inference(img, model):
+ # test the image tile by tile
+ h, w = img.shape[2:]
+ tile = opts.SCUNET_tile
+ tile_overlap = opts.SCUNET_tile_overlap
+ if tile == 0:
+ return model(img)
+
+ device = devices.get_device_for('scunet')
+ assert tile % 8 == 0, "tile size should be a multiple of window_size"
+ sf = 1
+
+ stride = tile - tile_overlap
+ h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
+ w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
+ E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device)
+ W = torch.zeros_like(E, dtype=devices.dtype, device=device)
+
+ with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar:
+ for h_idx in h_idx_list:
+
+ for w_idx in w_idx_list:
+
+ in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
+
+ out_patch = model(in_patch)
+ out_patch_mask = torch.ones_like(out_patch)
+
+ E[
+ ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
+ ].add_(out_patch)
+ W[
+ ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
+ ].add_(out_patch_mask)
+ pbar.update(1)
+ output = E.div_(W)
+
+ return output
+
+ def do_upscale(self, img: PIL.Image.Image, selected_file):
+
torch.cuda.empty_cache()
model = self.load_model(selected_file)
if model is None:
+ print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr)
return img
device = devices.get_device_for('scunet')
- img = np.array(img)
- img = img[:, :, ::-1]
- img = np.moveaxis(img, 2, 0) / 255
- img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(device)
-
- with torch.no_grad():
- output = model(img)
- output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
- output = 255. * np.moveaxis(output, 0, 2)
- output = output.astype(np.uint8)
- output = output[:, :, ::-1]
+ tile = opts.SCUNET_tile
+ h, w = img.height, img.width
+ np_img = np.array(img)
+ np_img = np_img[:, :, ::-1] # RGB to BGR
+ np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW
+ torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore
+
+ if tile > h or tile > w:
+ _img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device)
+ _img[:, :, :h, :w] = torch_img # pad image
+ torch_img = _img
+
+ torch_output = self.tiled_inference(torch_img, model).squeeze(0)
+ torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
+ np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
+ del torch_img, torch_output
torch.cuda.empty_cache()
- return PIL.Image.fromarray(output, 'RGB')
+
+ output = np_output.transpose((1, 2, 0)) # CHW to HWC
+ output = output[:, :, ::-1] # BGR to RGB
+ return PIL.Image.fromarray((output * 255).astype(np.uint8))
def load_model(self, path: str):
device = devices.get_device_for('scunet')
@@ -84,4 +138,3 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
model = model.to(device)
return model
-
diff --git a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js
index f0918e26..5c7a836a 100644
--- a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js
+++ b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js
@@ -1,103 +1,42 @@
// Stable Diffusion WebUI - Bracket checker
-// Version 1.0
-// By Hingashi no Florin/Bwin4L
+// By Hingashi no Florin/Bwin4L & @akx
// Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
-function checkBrackets(evt, textArea, counterElt) {
- errorStringParen = '(...) - Different number of opening and closing parentheses detected.\n';
- errorStringSquare = '[...] - Different number of opening and closing square brackets detected.\n';
- errorStringCurly = '{...} - Different number of opening and closing curly brackets detected.\n';
-
- openBracketRegExp = /\(/g;
- closeBracketRegExp = /\)/g;
-
- openSquareBracketRegExp = /\[/g;
- closeSquareBracketRegExp = /\]/g;
-
- openCurlyBracketRegExp = /\{/g;
- closeCurlyBracketRegExp = /\}/g;
-
- totalOpenBracketMatches = 0;
- totalCloseBracketMatches = 0;
- totalOpenSquareBracketMatches = 0;
- totalCloseSquareBracketMatches = 0;
- totalOpenCurlyBracketMatches = 0;
- totalCloseCurlyBracketMatches = 0;
-
- openBracketMatches = textArea.value.match(openBracketRegExp);
- if(openBracketMatches) {
- totalOpenBracketMatches = openBracketMatches.length;
- }
-
- closeBracketMatches = textArea.value.match(closeBracketRegExp);
- if(closeBracketMatches) {
- totalCloseBracketMatches = closeBracketMatches.length;
- }
-
- openSquareBracketMatches = textArea.value.match(openSquareBracketRegExp);
- if(openSquareBracketMatches) {
- totalOpenSquareBracketMatches = openSquareBracketMatches.length;
- }
-
- closeSquareBracketMatches = textArea.value.match(closeSquareBracketRegExp);
- if(closeSquareBracketMatches) {
- totalCloseSquareBracketMatches = closeSquareBracketMatches.length;
- }
-
- openCurlyBracketMatches = textArea.value.match(openCurlyBracketRegExp);
- if(openCurlyBracketMatches) {
- totalOpenCurlyBracketMatches = openCurlyBracketMatches.length;
- }
-
- closeCurlyBracketMatches = textArea.value.match(closeCurlyBracketRegExp);
- if(closeCurlyBracketMatches) {
- totalCloseCurlyBracketMatches = closeCurlyBracketMatches.length;
- }
-
- if(totalOpenBracketMatches != totalCloseBracketMatches) {
- if(!counterElt.title.includes(errorStringParen)) {
- counterElt.title += errorStringParen;
- }
- } else {
- counterElt.title = counterElt.title.replace(errorStringParen, '');
- }
-
- if(totalOpenSquareBracketMatches != totalCloseSquareBracketMatches) {
- if(!counterElt.title.includes(errorStringSquare)) {
- counterElt.title += errorStringSquare;
- }
- } else {
- counterElt.title = counterElt.title.replace(errorStringSquare, '');
- }
-
- if(totalOpenCurlyBracketMatches != totalCloseCurlyBracketMatches) {
- if(!counterElt.title.includes(errorStringCurly)) {
- counterElt.title += errorStringCurly;
+function checkBrackets(textArea, counterElt) {
+ var counts = {};
+ (textArea.value.match(/[(){}\[\]]/g) || []).forEach(bracket => {
+ counts[bracket] = (counts[bracket] || 0) + 1;
+ });
+ var errors = [];
+
+ function checkPair(open, close, kind) {
+ if (counts[open] !== counts[close]) {
+ errors.push(
+ `${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.`
+ );
}
- } else {
- counterElt.title = counterElt.title.replace(errorStringCurly, '');
}
- if(counterElt.title != '') {
- counterElt.classList.add('error');
- } else {
- counterElt.classList.remove('error');
- }
+ checkPair('(', ')', 'round brackets');
+ checkPair('[', ']', 'square brackets');
+ checkPair('{', '}', 'curly brackets');
+ counterElt.title = errors.join('\n');
+ counterElt.classList.toggle('error', errors.length !== 0);
}
-function setupBracketChecking(id_prompt, id_counter){
- var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
- var counter = gradioApp().getElementById(id_counter)
+function setupBracketChecking(id_prompt, id_counter) {
+ var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
+ var counter = gradioApp().getElementById(id_counter)
- textarea.addEventListener("input", function(evt){
- checkBrackets(evt, textarea, counter)
- });
+ if (textarea && counter) {
+ textarea.addEventListener("input", () => checkBrackets(textarea, counter));
+ }
}
-onUiLoaded(function(){
- setupBracketChecking('txt2img_prompt', 'txt2img_token_counter')
- setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter')
- setupBracketChecking('img2img_prompt', 'img2img_token_counter')
- setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter')
-}) \ No newline at end of file
+onUiLoaded(function () {
+ setupBracketChecking('txt2img_prompt', 'txt2img_token_counter');
+ setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter');
+ setupBracketChecking('img2img_prompt', 'img2img_token_counter');
+ setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter');
+});