aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/Lora/network.py
diff options
context:
space:
mode:
Diffstat (limited to 'extensions-builtin/Lora/network.py')
-rw-r--r--extensions-builtin/Lora/network.py35
1 files changed, 33 insertions, 2 deletions
diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py
index 6021fd8d..b8fd9194 100644
--- a/extensions-builtin/Lora/network.py
+++ b/extensions-builtin/Lora/network.py
@@ -3,6 +3,9 @@ import os
from collections import namedtuple
import enum
+import torch.nn as nn
+import torch.nn.functional as F
+
from modules import sd_models, cache, errors, hashes, shared
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
@@ -115,6 +118,29 @@ class NetworkModule:
if hasattr(self.sd_module, 'weight'):
self.shape = self.sd_module.weight.shape
+ self.ops = None
+ self.extra_kwargs = {}
+ if isinstance(self.sd_module, nn.Conv2d):
+ self.ops = F.conv2d
+ self.extra_kwargs = {
+ 'stride': self.sd_module.stride,
+ 'padding': self.sd_module.padding
+ }
+ elif isinstance(self.sd_module, nn.Linear):
+ self.ops = F.linear
+ elif isinstance(self.sd_module, nn.LayerNorm):
+ self.ops = F.layer_norm
+ self.extra_kwargs = {
+ 'normalized_shape': self.sd_module.normalized_shape,
+ 'eps': self.sd_module.eps
+ }
+ elif isinstance(self.sd_module, nn.GroupNorm):
+ self.ops = F.group_norm
+ self.extra_kwargs = {
+ 'num_groups': self.sd_module.num_groups,
+ 'eps': self.sd_module.eps
+ }
+
self.dim = None
self.bias = weights.w.get("bias")
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
@@ -137,7 +163,7 @@ class NetworkModule:
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None:
updown = updown.reshape(self.bias.shape)
- updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
+ updown += self.bias.to(orig_weight.device, dtype=updown.dtype)
updown = updown.reshape(output_shape)
if len(output_shape) == 4:
@@ -155,5 +181,10 @@ class NetworkModule:
raise NotImplementedError()
def forward(self, x, y):
- raise NotImplementedError()
+ """A general forward implementation for all modules"""
+ if self.ops is None:
+ raise NotImplementedError()
+ else:
+ updown, ex_bias = self.calc_updown(self.sd_module.weight)
+ return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs)