aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/Lora/extra_networks_lora.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-08-05 08:01:38 +0300
committerAUTOMATIC1111 <16777216c@gmail.com>2023-08-05 08:01:38 +0300
commitef1698fd6dbd6387341a1eeeded068ff1476ee50 (patch)
treeddaa0cf76e8cf95b93f63909a026ae3d5eab460a /extensions-builtin/Lora/extra_networks_lora.py
parent0fae47e97445df4e7de4d85538a80917fc2a2457 (diff)
parentc613416af375092f55b9bc8649c949e95d250c44 (diff)
Merge branch 'dev' into extra-networks-always-visible
Diffstat (limited to 'extensions-builtin/Lora/extra_networks_lora.py')
-rw-r--r--extensions-builtin/Lora/extra_networks_lora.py38
1 files changed, 26 insertions, 12 deletions
diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py
index 66ee9c85..ba2945c6 100644
--- a/extensions-builtin/Lora/extra_networks_lora.py
+++ b/extensions-builtin/Lora/extra_networks_lora.py
@@ -1,5 +1,5 @@
from modules import extra_networks, shared
-import lora
+import networks
class ExtraNetworkLora(extra_networks.ExtraNetwork):
@@ -9,24 +9,38 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
def activate(self, p, params_list):
additional = shared.opts.sd_lora
- if additional != "None" and additional in lora.available_loras and not any(x for x in params_list if x.items[0] == additional):
+ if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
names = []
- multipliers = []
+ te_multipliers = []
+ unet_multipliers = []
+ dyn_dims = []
for params in params_list:
assert params.items
- names.append(params.items[0])
- multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
+ names.append(params.positional[0])
- lora.load_loras(names, multipliers)
+ te_multiplier = float(params.positional[1]) if len(params.positional) > 1 else 1.0
+ te_multiplier = float(params.named.get("te", te_multiplier))
+
+ unet_multiplier = float(params.positional[2]) if len(params.positional) > 2 else te_multiplier
+ unet_multiplier = float(params.named.get("unet", unet_multiplier))
+
+ dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None
+ dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim
+
+ te_multipliers.append(te_multiplier)
+ unet_multipliers.append(unet_multiplier)
+ dyn_dims.append(dyn_dim)
+
+ networks.load_networks(names, te_multipliers, unet_multipliers, dyn_dims)
if shared.opts.lora_add_hashes_to_infotext:
- lora_hashes = []
- for item in lora.loaded_loras:
- shorthash = item.lora_on_disk.shorthash
+ network_hashes = []
+ for item in networks.loaded_networks:
+ shorthash = item.network_on_disk.shorthash
if not shorthash:
continue
@@ -36,10 +50,10 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
alias = alias.replace(":", "").replace(",", "")
- lora_hashes.append(f"{alias}: {shorthash}")
+ network_hashes.append(f"{alias}: {shorthash}")
- if lora_hashes:
- p.extra_generation_params["Lora hashes"] = ", ".join(lora_hashes)
+ if network_hashes:
+ p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)
def deactivate(self, p):
pass