aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin
diff options
context:
space:
mode:
authorKohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>2024-01-05 16:32:19 +0800
committerKohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>2024-01-05 16:32:19 +0800
commit18ca987c92f52690daec43a6c67363c341bb6008 (patch)
tree09967883224072cf4a6fe866d19b385541d24d88 /extensions-builtin
parenta06dab8d7aaeca1900acd565df7667087e8f064c (diff)
Add general forward method for all modules.
Diffstat (limited to 'extensions-builtin')
-rw-r--r--extensions-builtin/Lora/network.py34
-rw-r--r--extensions-builtin/Lora/networks.py12
2 files changed, 39 insertions, 7 deletions
diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py
index a62e5eff..f9b571b5 100644
--- a/extensions-builtin/Lora/network.py
+++ b/extensions-builtin/Lora/network.py
@@ -3,6 +3,10 @@ import os
from collections import namedtuple
import enum
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
from modules import sd_models, cache, errors, hashes, shared
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
@@ -115,6 +119,29 @@ class NetworkModule:
if hasattr(self.sd_module, 'weight'):
self.shape = self.sd_module.weight.shape
+ self.ops = None
+ self.extra_kwargs = {}
+ if isinstance(self.sd_module, nn.Conv2d):
+ self.ops = F.conv2d
+ self.extra_kwargs = {
+ 'stride': self.sd_module.stride,
+ 'padding': self.sd_module.padding
+ }
+ elif isinstance(self.sd_module, nn.Linear):
+ self.ops = F.linear
+ elif isinstance(self.sd_module, nn.LayerNorm):
+ self.ops = F.layer_norm
+ self.extra_kwargs = {
+ 'normalized_shape': self.sd_module.normalized_shape,
+ 'eps': self.sd_module.eps
+ }
+ elif isinstance(self.sd_module, nn.GroupNorm):
+ self.ops = F.group_norm
+ self.extra_kwargs = {
+ 'num_groups': self.sd_module.num_groups,
+ 'eps': self.sd_module.eps
+ }
+
self.dim = None
self.bias = weights.w.get("bias")
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
@@ -155,5 +182,10 @@ class NetworkModule:
raise NotImplementedError()
def forward(self, x, y):
- raise NotImplementedError()
+ """A general forward implementation for all modules"""
+ if self.ops is None:
+ raise NotImplementedError()
+ else:
+ updown, ex_bias = self.calc_updown(self.sd_module.weight)
+ return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs)
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index 72ebd624..32e10b62 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -458,23 +458,23 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.network_current_names = wanted_names
-def network_forward(module, input, original_forward):
+def network_forward(org_module, input, original_forward):
"""
Old way of applying Lora by executing operations during layer's forward.
Stacking many loras this way results in big performance degradation.
"""
if len(loaded_networks) == 0:
- return original_forward(module, input)
+ return original_forward(org_module, input)
input = devices.cond_cast_unet(input)
- network_restore_weights_from_backup(module)
- network_reset_cached_weight(module)
+ network_restore_weights_from_backup(org_module)
+ network_reset_cached_weight(org_module)
- y = original_forward(module, input)
+ y = original_forward(org_module, input)
- network_layer_name = getattr(module, 'network_layer_name', None)
+ network_layer_name = getattr(org_module, 'network_layer_name', None)
for lora in loaded_networks:
module = lora.modules.get(network_layer_name, None)
if module is None: