aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/Lora/network_lyco.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-07-16 23:13:55 +0300
committerAUTOMATIC1111 <16777216c@gmail.com>2023-07-16 23:13:55 +0300
commitb75b004fe62826455f1aa77e849e7da13902cb17 (patch)
tree23aa9debf80fff6ef7fe9778e56df6a135065310 /extensions-builtin/Lora/network_lyco.py
parent7d26c479eebec03c2abb28f7b5226791688a7cea (diff)
lora extension rework to include other types of networks
Diffstat (limited to 'extensions-builtin/Lora/network_lyco.py')
-rw-r--r--extensions-builtin/Lora/network_lyco.py39
1 files changed, 39 insertions, 0 deletions
diff --git a/extensions-builtin/Lora/network_lyco.py b/extensions-builtin/Lora/network_lyco.py
new file mode 100644
index 00000000..18a822fa
--- /dev/null
+++ b/extensions-builtin/Lora/network_lyco.py
@@ -0,0 +1,39 @@
+import torch
+
+import lyco_helpers
+import network
+from modules import devices
+
+
+class NetworkModuleLyco(network.NetworkModule):
+ def __init__(self, net: network.Network, weights: network.NetworkWeights):
+ super().__init__(net, weights)
+
+ if hasattr(self.sd_module, 'weight'):
+ self.shape = self.sd_module.weight.shape
+
+ self.dim = None
+ self.bias = weights.w.get("bias")
+ self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
+ self.scale = weights.w["scale"].item() if "scale" in weights.w else None
+
+ def finalize_updown(self, updown, orig_weight, output_shape):
+ if self.bias is not None:
+ updown = updown.reshape(self.bias.shape)
+ updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
+ updown = updown.reshape(output_shape)
+
+ if len(output_shape) == 4:
+ updown = updown.reshape(output_shape)
+
+ if orig_weight.size().numel() == updown.size().numel():
+ updown = updown.reshape(orig_weight.shape)
+
+ scale = (
+ self.scale if self.scale is not None
+ else self.alpha / self.dim if self.dim is not None and self.alpha is not None
+ else 1.0
+ )
+
+ return updown * scale * self.network.multiplier
+