aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin
diff options
context:
space:
mode:
authorKohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>2023-12-14 01:43:24 +0800
committerKohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>2023-12-14 01:43:24 +0800
commit265bc26c21264d63956e8f30f1ce31dec917fc76 (patch)
treed0372003336bff507a9548e03873f1ddef3ae242 /extensions-builtin
parent735c9e8059384d4f640e5582413c30871f83eac5 (diff)
Use self.scale instead of custom finalize
Diffstat (limited to 'extensions-builtin')
-rw-r--r--extensions-builtin/Lora/network_oft.py20
1 files changed, 2 insertions, 18 deletions
diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py
index 44465f7a..e3ae61a2 100644
--- a/extensions-builtin/Lora/network_oft.py
+++ b/extensions-builtin/Lora/network_oft.py
@@ -21,6 +21,8 @@ class NetworkModuleOFT(network.NetworkModule):
self.lin_module = None
self.org_module: list[torch.Module] = [self.sd_module]
+ self.scale = 1.0
+
# kohya-ss
if "oft_blocks" in weights.w.keys():
self.is_kohya = True
@@ -78,21 +80,3 @@ class NetworkModuleOFT(network.NetworkModule):
print(torch.norm(updown))
output_shape = orig_weight.shape
return self.finalize_updown(updown, orig_weight, output_shape)
-
- def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
- if self.bias is not None:
- updown = updown.reshape(self.bias.shape)
- updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
- updown = updown.reshape(output_shape)
-
- if len(output_shape) == 4:
- updown = updown.reshape(output_shape)
-
- if orig_weight.size().numel() == updown.size().numel():
- updown = updown.reshape(orig_weight.shape)
-
- if ex_bias is not None:
- ex_bias = ex_bias * self.multiplier()
-
- # Ignore calc_scale, which is not used in OFT.
- return updown * self.multiplier(), ex_bias