aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin
diff options
context:
space:
mode:
authorv0xie <28695009+v0xie@users.noreply.github.com>2023-11-04 14:56:47 -0700
committerv0xie <28695009+v0xie@users.noreply.github.com>2023-11-04 14:56:47 -0700
commitbbf00a96afb2215f13cc72a7908225ae300c423d (patch)
tree3a03875d3c55d10528e7891d6db515671b7a8c2f /extensions-builtin
parent329c8bacce706811776e1c1c6a0d39b46886a268 (diff)
refactor: remove unused function
Diffstat (limited to 'extensions-builtin')
-rw-r--r--extensions-builtin/Lora/network_oft.py47
1 files changed, 0 insertions, 47 deletions
diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py
index e4aa082b..93402bb2 100644
--- a/extensions-builtin/Lora/network_oft.py
+++ b/extensions-builtin/Lora/network_oft.py
@@ -2,7 +2,6 @@ import torch
import network
from lyco_helpers import factorization
from einops import rearrange
-from modules import devices
class ModuleTypeOFT(network.ModuleType):
@@ -54,58 +53,12 @@ class NetworkModuleOFT(network.NetworkModule):
raise ValueError("sd_module must be Linear or Conv")
if self.is_kohya:
- #self.num_blocks = self.dim
- #self.block_size = self.out_dim // self.num_blocks
- #self.block_size = self.dim
- #self.num_blocks = self.out_dim // self.block_size
self.constraint = self.alpha * self.out_dim
self.num_blocks, self.block_size = factorization(self.out_dim, self.dim)
else:
self.constraint = None
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
- if is_other_linear:
- self.lin_module = self.create_module(weights.w, "oft_diag", none_ok=True)
-
-
- def create_module(self, weights, key, none_ok=False):
- weight = weights.get(key)
-
- if weight is None and none_ok:
- return None
-
- is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention]
- is_conv = type(self.sd_module) in [torch.nn.Conv2d]
-
- if is_linear:
- weight = weight.reshape(weight.shape[0], -1)
- module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
- elif is_conv and key == "lora_down.weight" or key == "dyn_up":
- if len(weight.shape) == 2:
- weight = weight.reshape(weight.shape[0], -1, 1, 1)
-
- if weight.shape[2] != 1 or weight.shape[3] != 1:
- module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
- else:
- module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
- elif is_conv and key == "lora_mid.weight":
- module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
- elif is_conv and key == "lora_up.weight" or key == "dyn_down":
- module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
- else:
- raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}')
-
- with torch.no_grad():
- if weight.shape != module.weight.shape:
- weight = weight.reshape(module.weight.shape)
- module.weight.copy_(weight)
-
- module.to(device=devices.cpu, dtype=devices.dtype)
- module.weight.requires_grad_(False)
-
- return module
-
-
def merge_weight(self, R_weight, org_weight):
R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype)
if org_weight.dim() == 4: