From 500d9a32c7b1f877c8f44159a9a10c594b545a80 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 23:11:37 +0300 Subject: add --lora-dir commandline option --- extensions-builtin/Lora/lora.py | 9 ++++----- extensions-builtin/Lora/preload.py | 6 ++++++ extensions-builtin/Lora/ui_extra_networks_lora.py | 2 +- 3 files changed, 11 insertions(+), 6 deletions(-) create mode 100644 extensions-builtin/Lora/preload.py (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 6d860224..da1797dc 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -177,12 +177,12 @@ def lora_Conv2d_forward(self, input): def list_available_loras(): available_loras.clear() - os.makedirs(lora_dir, exist_ok=True) + os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) candidates = \ - glob.glob(os.path.join(lora_dir, '**/*.pt'), recursive=True) + \ - glob.glob(os.path.join(lora_dir, '**/*.safetensors'), recursive=True) + \ - glob.glob(os.path.join(lora_dir, '**/*.ckpt'), recursive=True) + glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \ + glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \ + glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True) for filename in sorted(candidates): if os.path.isdir(filename): @@ -193,7 +193,6 @@ def list_available_loras(): available_loras[name] = LoraOnDisk(name, filename) -lora_dir = os.path.join(shared.models_path, "Lora") available_loras = {} loaded_loras = [] diff --git a/extensions-builtin/Lora/preload.py b/extensions-builtin/Lora/preload.py new file mode 100644 index 00000000..863dc5c0 --- /dev/null +++ b/extensions-builtin/Lora/preload.py @@ -0,0 +1,6 @@ +import os +from modules import paths + + +def preload(parser): + parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora')) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 65397890..4406f8a0 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -31,5 +31,5 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): } def allowed_directories_for_previews(self): - return [lora.lora_dir] + return [shared.cmd_opts.lora_dir] -- cgit v1.2.1 From fe7a623e6b7e04bab2cfc96e8fd6cf49b48daee1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 22 Jan 2023 00:02:41 +0300 Subject: add a slider for default value of added extra networks --- extensions-builtin/Lora/ui_extra_networks_lora.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 4406f8a0..54a80d36 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -1,3 +1,4 @@ +import json import os import lora @@ -26,7 +27,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): "name": name, "filename": path, "preview": preview, - "prompt": f"", + "prompt": json.dumps(f""), "local_preview": path + ".png", } -- cgit v1.2.1 From e407d1af897a7896d8c81e32dc86e7eb753ce207 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 18:12:51 +0300 Subject: add support for loras trained on kohya's scripts 0.4.0 (alphas) --- extensions-builtin/Lora/lora.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index da1797dc..220e64ff 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -92,6 +92,15 @@ def load_lora(name, filename): keys_failed_to_match.append(key_diffusers) continue + lora_module = lora.modules.get(key, None) + if lora_module is None: + lora_module = LoraUpDownModule() + lora.modules[key] = lora_module + + if lora_key == "alpha": + lora_module.alpha = weight.item() + continue + if type(sd_module) == torch.nn.Linear: module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) elif type(sd_module) == torch.nn.Conv2d: @@ -104,17 +113,12 @@ def load_lora(name, filename): module.to(device=devices.device, dtype=devices.dtype) - lora_module = lora.modules.get(key, None) - if lora_module is None: - lora_module = LoraUpDownModule() - lora.modules[key] = lora_module - if lora_key == "lora_up.weight": lora_module.up = module elif lora_key == "lora_down.weight": lora_module.down = module else: - assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight or lora_down.weight' + assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha' if len(keys_failed_to_match) > 0: print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}") @@ -161,7 +165,7 @@ def lora_forward(module, input, res): for lora in loaded_loras: module = lora.modules.get(lora_layer_name, None) if module is not None: - res = res + module.up(module.down(input)) * lora.multiplier + res = res + module.up(module.down(input)) * lora.multiplier * module.alpha / module.up.weight.shape[1] return res -- cgit v1.2.1 From c6f20f72629f3c417f10db2289d131441c6832f5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 18:52:55 +0300 Subject: make loras before 0.4.0 ALSO work --- extensions-builtin/Lora/lora.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 220e64ff..137e58f7 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -57,6 +57,7 @@ class LoraUpDownModule: def __init__(self): self.up = None self.down = None + self.alpha = None def assign_lora_names_to_compvis_modules(sd_model): @@ -165,7 +166,7 @@ def lora_forward(module, input, res): for lora in loaded_loras: module = lora.modules.get(lora_layer_name, None) if module is not None: - res = res + module.up(module.down(input)) * lora.multiplier * module.alpha / module.up.weight.shape[1] + res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) return res -- cgit v1.2.1 From f99352582084890b9167c1bf8699865bea0cef5f Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Mon, 23 Jan 2023 21:50:59 -0500 Subject: Make SwinIR interruptible --- extensions-builtin/SwinIR/scripts/swinir_model.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 9a74b253..3479760a 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -8,7 +8,7 @@ from basicsr.utils.download_util import load_file_from_url from tqdm import tqdm from modules import modelloader, devices, script_callbacks, shared -from modules.shared import cmd_opts, opts +from modules.shared import cmd_opts, opts, state from swinir_model_arch import SwinIR as net from swinir_model_arch_v2 import Swin2SR as net2 from modules.upscaler import Upscaler, UpscalerData @@ -145,7 +145,13 @@ def inference(img, model, tile, tile_overlap, window_size, scale): with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: for h_idx in h_idx_list: + if state.interrupted: + break + for w_idx in w_idx_list: + if state.interrupted: + break + in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] out_patch = model(in_patch) out_patch_mask = torch.ones_like(out_patch) -- cgit v1.2.1 From 3c47b050367ee220dcfed7be7883878417735614 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Mon, 23 Jan 2023 22:00:27 -0500 Subject: Also make SwinIR skippable --- extensions-builtin/SwinIR/scripts/swinir_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 3479760a..e8783bca 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -145,11 +145,11 @@ def inference(img, model, tile, tile_overlap, window_size, scale): with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: for h_idx in h_idx_list: - if state.interrupted: + if state.interrupted or state.skipped: break for w_idx in w_idx_list: - if state.interrupted: + if state.interrupted or state.skipped: break in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] -- cgit v1.2.1