aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-07-18 08:05:28 +0300
committerGitHub <noreply@github.com>2023-07-18 08:05:28 +0300
commit20c41364ccba1319e68e6b4a58f53f110c5d4828 (patch)
tree3c54c366b3641ea426764926df649ab0f326ebdf /extensions-builtin
parenta99d5708e6d603e8f7cfd1b8c6595f8026219ba0 (diff)
parent3d31caf4a53c4bb4469b72790b459eba7b251da9 (diff)
Merge pull request #11843 from KohakuBlueleaf/fix-lyco-support
Fix wrong key name in lokr module
Diffstat (limited to 'extensions-builtin')
-rw-r--r--extensions-builtin/Lora/network_lokr.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/extensions-builtin/Lora/network_lokr.py b/extensions-builtin/Lora/network_lokr.py
index 920062e2..340acdab 100644
--- a/extensions-builtin/Lora/network_lokr.py
+++ b/extensions-builtin/Lora/network_lokr.py
@@ -6,8 +6,8 @@ import network
class ModuleTypeLokr(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
- has_1 = "lokr_w1" in weights.w or ("lokr_w1a" in weights.w and "lokr_w1b" in weights.w)
- has_2 = "lokr_w2" in weights.w or ("lokr_w2a" in weights.w and "lokr_w2b" in weights.w)
+ has_1 = "lokr_w1" in weights.w or ("lokr_w1_a" in weights.w and "lokr_w1_b" in weights.w)
+ has_2 = "lokr_w2" in weights.w or ("lokr_w2_a" in weights.w and "lokr_w2_b" in weights.w)
if has_1 and has_2:
return NetworkModuleLokr(net, weights)
@@ -28,11 +28,11 @@ class NetworkModuleLokr(network.NetworkModule):
self.w1 = weights.w.get("lokr_w1")
self.w1a = weights.w.get("lokr_w1_a")
self.w1b = weights.w.get("lokr_w1_b")
- self.dim = self.w1b.shape[0] if self.w1b else self.dim
+ self.dim = self.w1b.shape[0] if self.w1b is not None else self.dim
self.w2 = weights.w.get("lokr_w2")
self.w2a = weights.w.get("lokr_w2_a")
self.w2b = weights.w.get("lokr_w2_b")
- self.dim = self.w2b.shape[0] if self.w2b else self.dim
+ self.dim = self.w2b.shape[0] if self.w2b is not None else self.dim
self.t2 = weights.w.get("lokr_t2")
def calc_updown(self, orig_weight):