aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin
diff options
context:
space:
mode:
Diffstat (limited to 'extensions-builtin')
-rw-r--r--extensions-builtin/LDSR/ldsr_model_arch.py3
-rw-r--r--extensions-builtin/Lora/extra_networks_lora.py20
-rw-r--r--extensions-builtin/Lora/lora.py207
-rw-r--r--extensions-builtin/Lora/preload.py6
-rw-r--r--extensions-builtin/Lora/scripts/lora_script.py35
-rw-r--r--extensions-builtin/Lora/ui_extra_networks_lora.py36
-rw-r--r--extensions-builtin/SwinIR/scripts/swinir_model.py8
-rw-r--r--extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js45
8 files changed, 335 insertions, 25 deletions
diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py
index 0ad49f4e..bc11cc6e 100644
--- a/extensions-builtin/LDSR/ldsr_model_arch.py
+++ b/extensions-builtin/LDSR/ldsr_model_arch.py
@@ -1,7 +1,6 @@
import os
import gc
import time
-import warnings
import numpy as np
import torch
@@ -15,8 +14,6 @@ from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, ismap
from modules import shared, sd_hijack
-warnings.filterwarnings("ignore", category=UserWarning)
-
cached_ldsr_model: torch.nn.Module = None
diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py
new file mode 100644
index 00000000..8f2e753e
--- /dev/null
+++ b/extensions-builtin/Lora/extra_networks_lora.py
@@ -0,0 +1,20 @@
+from modules import extra_networks
+import lora
+
+class ExtraNetworkLora(extra_networks.ExtraNetwork):
+ def __init__(self):
+ super().__init__('lora')
+
+ def activate(self, p, params_list):
+ names = []
+ multipliers = []
+ for params in params_list:
+ assert len(params.items) > 0
+
+ names.append(params.items[0])
+ multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
+
+ lora.load_loras(names, multipliers)
+
+ def deactivate(self, p):
+ pass
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py
new file mode 100644
index 00000000..cb8f1d36
--- /dev/null
+++ b/extensions-builtin/Lora/lora.py
@@ -0,0 +1,207 @@
+import glob
+import os
+import re
+import torch
+
+from modules import shared, devices, sd_models
+
+re_digits = re.compile(r"\d+")
+re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
+re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)")
+re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)")
+re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
+
+
+def convert_diffusers_name_to_compvis(key):
+ def match(match_list, regex):
+ r = re.match(regex, key)
+ if not r:
+ return False
+
+ match_list.clear()
+ match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
+ return True
+
+ m = []
+
+ if match(m, re_unet_down_blocks):
+ return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}"
+
+ if match(m, re_unet_mid_blocks):
+ return f"diffusion_model_middle_block_1_{m[1]}"
+
+ if match(m, re_unet_up_blocks):
+ return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}"
+
+ if match(m, re_text_block):
+ return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
+
+ return key
+
+
+class LoraOnDisk:
+ def __init__(self, name, filename):
+ self.name = name
+ self.filename = filename
+
+
+class LoraModule:
+ def __init__(self, name):
+ self.name = name
+ self.multiplier = 1.0
+ self.modules = {}
+ self.mtime = None
+
+
+class LoraUpDownModule:
+ def __init__(self):
+ self.up = None
+ self.down = None
+ self.alpha = None
+
+
+def assign_lora_names_to_compvis_modules(sd_model):
+ lora_layer_mapping = {}
+
+ for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
+ lora_name = name.replace(".", "_")
+ lora_layer_mapping[lora_name] = module
+ module.lora_layer_name = lora_name
+
+ for name, module in shared.sd_model.model.named_modules():
+ lora_name = name.replace(".", "_")
+ lora_layer_mapping[lora_name] = module
+ module.lora_layer_name = lora_name
+
+ sd_model.lora_layer_mapping = lora_layer_mapping
+
+
+def load_lora(name, filename):
+ lora = LoraModule(name)
+ lora.mtime = os.path.getmtime(filename)
+
+ sd = sd_models.read_state_dict(filename)
+
+ keys_failed_to_match = []
+
+ for key_diffusers, weight in sd.items():
+ fullkey = convert_diffusers_name_to_compvis(key_diffusers)
+ key, lora_key = fullkey.split(".", 1)
+
+ sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
+ if sd_module is None:
+ keys_failed_to_match.append(key_diffusers)
+ continue
+
+ lora_module = lora.modules.get(key, None)
+ if lora_module is None:
+ lora_module = LoraUpDownModule()
+ lora.modules[key] = lora_module
+
+ if lora_key == "alpha":
+ lora_module.alpha = weight.item()
+ continue
+
+ if type(sd_module) == torch.nn.Linear:
+ module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
+ elif type(sd_module) == torch.nn.Conv2d:
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
+ else:
+ assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
+
+ with torch.no_grad():
+ module.weight.copy_(weight)
+
+ module.to(device=devices.device, dtype=devices.dtype)
+
+ if lora_key == "lora_up.weight":
+ lora_module.up = module
+ elif lora_key == "lora_down.weight":
+ lora_module.down = module
+ else:
+ assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
+
+ if len(keys_failed_to_match) > 0:
+ print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
+
+ return lora
+
+
+def load_loras(names, multipliers=None):
+ already_loaded = {}
+
+ for lora in loaded_loras:
+ if lora.name in names:
+ already_loaded[lora.name] = lora
+
+ loaded_loras.clear()
+
+ loras_on_disk = [available_loras.get(name, None) for name in names]
+ if any([x is None for x in loras_on_disk]):
+ list_available_loras()
+
+ loras_on_disk = [available_loras.get(name, None) for name in names]
+
+ for i, name in enumerate(names):
+ lora = already_loaded.get(name, None)
+
+ lora_on_disk = loras_on_disk[i]
+ if lora_on_disk is not None:
+ if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
+ lora = load_lora(name, lora_on_disk.filename)
+
+ if lora is None:
+ print(f"Couldn't find Lora with name {name}")
+ continue
+
+ lora.multiplier = multipliers[i] if multipliers else 1.0
+ loaded_loras.append(lora)
+
+
+def lora_forward(module, input, res):
+ if len(loaded_loras) == 0:
+ return res
+
+ lora_layer_name = getattr(module, 'lora_layer_name', None)
+ for lora in loaded_loras:
+ module = lora.modules.get(lora_layer_name, None)
+ if module is not None:
+ if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
+ res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
+ else:
+ res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
+
+ return res
+
+
+def lora_Linear_forward(self, input):
+ return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input))
+
+
+def lora_Conv2d_forward(self, input):
+ return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input))
+
+
+def list_available_loras():
+ available_loras.clear()
+
+ os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
+
+ candidates = \
+ glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \
+ glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
+ glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)
+
+ for filename in sorted(candidates):
+ if os.path.isdir(filename):
+ continue
+
+ name = os.path.splitext(os.path.basename(filename))[0]
+
+ available_loras[name] = LoraOnDisk(name, filename)
+
+
+available_loras = {}
+loaded_loras = []
+
+list_available_loras()
diff --git a/extensions-builtin/Lora/preload.py b/extensions-builtin/Lora/preload.py
new file mode 100644
index 00000000..863dc5c0
--- /dev/null
+++ b/extensions-builtin/Lora/preload.py
@@ -0,0 +1,6 @@
+import os
+from modules import paths
+
+
+def preload(parser):
+ parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py
new file mode 100644
index 00000000..544b228d
--- /dev/null
+++ b/extensions-builtin/Lora/scripts/lora_script.py
@@ -0,0 +1,35 @@
+import torch
+
+import lora
+import extra_networks_lora
+import ui_extra_networks_lora
+from modules import script_callbacks, ui_extra_networks, extra_networks, shared
+
+
+def unload():
+ torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
+ torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
+
+
+def before_ui():
+ ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
+ extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora())
+
+
+if not hasattr(torch.nn, 'Linear_forward_before_lora'):
+ torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
+
+if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
+ torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
+
+torch.nn.Linear.forward = lora.lora_Linear_forward
+torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
+
+script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
+script_callbacks.on_script_unloaded(unload)
+script_callbacks.on_before_ui(before_ui)
+
+
+shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
+ "lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"),
+}))
diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py
new file mode 100644
index 00000000..54a80d36
--- /dev/null
+++ b/extensions-builtin/Lora/ui_extra_networks_lora.py
@@ -0,0 +1,36 @@
+import json
+import os
+import lora
+
+from modules import shared, ui_extra_networks
+
+
+class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
+ def __init__(self):
+ super().__init__('Lora')
+
+ def refresh(self):
+ lora.list_available_loras()
+
+ def list_items(self):
+ for name, lora_on_disk in lora.available_loras.items():
+ path, ext = os.path.splitext(lora_on_disk.filename)
+ previews = [path + ".png", path + ".preview.png"]
+
+ preview = None
+ for file in previews:
+ if os.path.isfile(file):
+ preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file))
+ break
+
+ yield {
+ "name": name,
+ "filename": path,
+ "preview": preview,
+ "prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
+ "local_preview": path + ".png",
+ }
+
+ def allowed_directories_for_previews(self):
+ return [shared.cmd_opts.lora_dir]
+
diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py
index 9a74b253..e8783bca 100644
--- a/extensions-builtin/SwinIR/scripts/swinir_model.py
+++ b/extensions-builtin/SwinIR/scripts/swinir_model.py
@@ -8,7 +8,7 @@ from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared
-from modules.shared import cmd_opts, opts
+from modules.shared import cmd_opts, opts, state
from swinir_model_arch import SwinIR as net
from swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData
@@ -145,7 +145,13 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list:
+ if state.interrupted or state.skipped:
+ break
+
for w_idx in w_idx_list:
+ if state.interrupted or state.skipped:
+ break
+
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
diff --git a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js
index eccfb0f9..4a85c8eb 100644
--- a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js
+++ b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js
@@ -4,16 +4,10 @@
// Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
-function checkBrackets(evt) {
- textArea = evt.target;
- tabName = evt.target.parentElement.parentElement.id.split("_")[0];
- counterElt = document.querySelector('gradio-app').shadowRoot.querySelector('#' + tabName + '_token_counter');
-
- promptName = evt.target.parentElement.parentElement.id.includes('neg') ? ' negative' : '';
-
- errorStringParen = '(' + tabName + promptName + ' prompt) - Different number of opening and closing parentheses detected.\n';
- errorStringSquare = '[' + tabName + promptName + ' prompt] - Different number of opening and closing square brackets detected.\n';
- errorStringCurly = '{' + tabName + promptName + ' prompt} - Different number of opening and closing curly brackets detected.\n';
+function checkBrackets(evt, textArea, counterElt) {
+ errorStringParen = '(...) - Different number of opening and closing parentheses detected.\n';
+ errorStringSquare = '[...] - Different number of opening and closing square brackets detected.\n';
+ errorStringCurly = '{...} - Different number of opening and closing curly brackets detected.\n';
openBracketRegExp = /\(/g;
closeBracketRegExp = /\)/g;
@@ -86,22 +80,31 @@ function checkBrackets(evt) {
}
if(counterElt.title != '') {
- counterElt.style = 'color: #FF5555;';
+ counterElt.classList.add('error');
} else {
- counterElt.style = '';
+ counterElt.classList.remove('error');
}
}
+function setupBracketChecking(id_prompt, id_counter){
+ var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
+ var counter = gradioApp().getElementById(id_counter)
+ textarea.addEventListener("input", function(evt){
+ checkBrackets(evt, textarea, counter)
+ });
+}
+
var shadowRootLoaded = setInterval(function() {
- var shadowTextArea = document.querySelector('gradio-app').shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea');
- if(shadowTextArea.length < 1) {
- return false;
- }
+ var shadowRoot = document.querySelector('gradio-app').shadowRoot;
+ if(! shadowRoot) return false;
+
+ var shadowTextArea = shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea');
+ if(shadowTextArea.length < 1) return false;
- clearInterval(shadowRootLoaded);
+ clearInterval(shadowRootLoaded);
- document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_prompt').onkeyup = checkBrackets;
- document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_neg_prompt').onkeyup = checkBrackets;
- document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_prompt').onkeyup = checkBrackets;
- document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_neg_prompt').onkeyup = checkBrackets;
+ setupBracketChecking('txt2img_prompt', 'txt2img_token_counter')
+ setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter')
+ setupBracketChecking('img2img_prompt', 'imgimg_token_counter')
+ setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter')
}, 1000);