aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin
diff options
context:
space:
mode:
authorInvincibleDude <81354513+InvincibleDude@users.noreply.github.com>2023-01-24 15:44:09 +0300
committerGitHub <noreply@github.com>2023-01-24 15:44:09 +0300
commit44c0e6b993d00bb2f441f0fde409bcb79136f034 (patch)
treee27a45d1a3ceb8aab884631c7a806c5fe2c8386d /extensions-builtin
parent3bc8ee998db5f461b8011a72e6f167012ccb8bc1 (diff)
parent602a1864b05075ca4283986e6f5c7d5bce864e11 (diff)
Merge branch 'AUTOMATIC1111:master' into master
Diffstat (limited to 'extensions-builtin')
-rw-r--r--extensions-builtin/Lora/lora.py28
-rw-r--r--extensions-builtin/Lora/preload.py6
-rw-r--r--extensions-builtin/Lora/ui_extra_networks_lora.py5
-rw-r--r--extensions-builtin/SwinIR/scripts/swinir_model.py8
4 files changed, 32 insertions, 15 deletions
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py
index 6d860224..137e58f7 100644
--- a/extensions-builtin/Lora/lora.py
+++ b/extensions-builtin/Lora/lora.py
@@ -57,6 +57,7 @@ class LoraUpDownModule:
def __init__(self):
self.up = None
self.down = None
+ self.alpha = None
def assign_lora_names_to_compvis_modules(sd_model):
@@ -92,6 +93,15 @@ def load_lora(name, filename):
keys_failed_to_match.append(key_diffusers)
continue
+ lora_module = lora.modules.get(key, None)
+ if lora_module is None:
+ lora_module = LoraUpDownModule()
+ lora.modules[key] = lora_module
+
+ if lora_key == "alpha":
+ lora_module.alpha = weight.item()
+ continue
+
if type(sd_module) == torch.nn.Linear:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.Conv2d:
@@ -104,17 +114,12 @@ def load_lora(name, filename):
module.to(device=devices.device, dtype=devices.dtype)
- lora_module = lora.modules.get(key, None)
- if lora_module is None:
- lora_module = LoraUpDownModule()
- lora.modules[key] = lora_module
-
if lora_key == "lora_up.weight":
lora_module.up = module
elif lora_key == "lora_down.weight":
lora_module.down = module
else:
- assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight or lora_down.weight'
+ assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
if len(keys_failed_to_match) > 0:
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
@@ -161,7 +166,7 @@ def lora_forward(module, input, res):
for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None)
if module is not None:
- res = res + module.up(module.down(input)) * lora.multiplier
+ res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
return res
@@ -177,12 +182,12 @@ def lora_Conv2d_forward(self, input):
def list_available_loras():
available_loras.clear()
- os.makedirs(lora_dir, exist_ok=True)
+ os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
candidates = \
- glob.glob(os.path.join(lora_dir, '**/*.pt'), recursive=True) + \
- glob.glob(os.path.join(lora_dir, '**/*.safetensors'), recursive=True) + \
- glob.glob(os.path.join(lora_dir, '**/*.ckpt'), recursive=True)
+ glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \
+ glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
+ glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)
for filename in sorted(candidates):
if os.path.isdir(filename):
@@ -193,7 +198,6 @@ def list_available_loras():
available_loras[name] = LoraOnDisk(name, filename)
-lora_dir = os.path.join(shared.models_path, "Lora")
available_loras = {}
loaded_loras = []
diff --git a/extensions-builtin/Lora/preload.py b/extensions-builtin/Lora/preload.py
new file mode 100644
index 00000000..863dc5c0
--- /dev/null
+++ b/extensions-builtin/Lora/preload.py
@@ -0,0 +1,6 @@
+import os
+from modules import paths
+
+
+def preload(parser):
+ parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))
diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py
index 65397890..54a80d36 100644
--- a/extensions-builtin/Lora/ui_extra_networks_lora.py
+++ b/extensions-builtin/Lora/ui_extra_networks_lora.py
@@ -1,3 +1,4 @@
+import json
import os
import lora
@@ -26,10 +27,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
"name": name,
"filename": path,
"preview": preview,
- "prompt": f"<lora:{name}:1.0>",
+ "prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
"local_preview": path + ".png",
}
def allowed_directories_for_previews(self):
- return [lora.lora_dir]
+ return [shared.cmd_opts.lora_dir]
diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py
index 9a74b253..e8783bca 100644
--- a/extensions-builtin/SwinIR/scripts/swinir_model.py
+++ b/extensions-builtin/SwinIR/scripts/swinir_model.py
@@ -8,7 +8,7 @@ from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared
-from modules.shared import cmd_opts, opts
+from modules.shared import cmd_opts, opts, state
from swinir_model_arch import SwinIR as net
from swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData
@@ -145,7 +145,13 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list:
+ if state.interrupted or state.skipped:
+ break
+
for w_idx in w_idx_list:
+ if state.interrupted or state.skipped:
+ break
+
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)