From bd4da4474bef5c9c1f690c62b971704ee73d2860 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 13 Aug 2023 02:27:39 +0800 Subject: Add extra norm module into built-in lora ext refer to LyCORIS 1.9.0.dev6 add new option and module for training norm layer (Which is reported to be good for style) --- extensions-builtin/Lora/network_norm.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 extensions-builtin/Lora/network_norm.py (limited to 'extensions-builtin/Lora/network_norm.py') diff --git a/extensions-builtin/Lora/network_norm.py b/extensions-builtin/Lora/network_norm.py new file mode 100644 index 00000000..dab8b684 --- /dev/null +++ b/extensions-builtin/Lora/network_norm.py @@ -0,0 +1,29 @@ +import network + + +class ModuleTypeNorm(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["w_norm", "b_norm"]): + return NetworkModuleNorm(net, weights) + + return None + + +class NetworkModuleNorm(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + print("NetworkModuleNorm") + + self.w_norm = weights.w.get("w_norm") + self.b_norm = weights.w.get("b_norm") + + def calc_updown(self, orig_weight): + output_shape = self.w_norm.shape + updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype) + + if self.b_norm is not None: + ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype) + else: + ex_bias = None + + return self.finalize_updown(updown, orig_weight, output_shape, ex_bias) -- cgit v1.2.1 From 5881dcb8873b3f87b9c6545e9cb8d1d77023f4fe Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 13 Aug 2023 02:36:02 +0800 Subject: remove debug print --- extensions-builtin/Lora/network_norm.py | 1 - 1 file changed, 1 deletion(-) (limited to 'extensions-builtin/Lora/network_norm.py') diff --git a/extensions-builtin/Lora/network_norm.py b/extensions-builtin/Lora/network_norm.py index dab8b684..ce450158 100644 --- a/extensions-builtin/Lora/network_norm.py +++ b/extensions-builtin/Lora/network_norm.py @@ -12,7 +12,6 @@ class ModuleTypeNorm(network.ModuleType): class NetworkModuleNorm(network.NetworkModule): def __init__(self, net: network.Network, weights: network.NetworkWeights): super().__init__(net, weights) - print("NetworkModuleNorm") self.w_norm = weights.w.get("w_norm") self.b_norm = weights.w.get("b_norm") -- cgit v1.2.1