aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin
diff options
context:
space:
mode:
Diffstat (limited to 'extensions-builtin')
-rw-r--r--extensions-builtin/Lora/networks.py16
1 files changed, 12 insertions, 4 deletions
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index 0170dbfb..d22ed843 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -388,18 +388,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
if module is not None and hasattr(self, 'weight'):
try:
with torch.no_grad():
- updown, ex_bias = module.calc_updown(self.weight)
+ if getattr(self, 'fp16_weight', None) is None:
+ weight = self.weight
+ bias = self.bias
+ else:
+ weight = self.fp16_weight.clone().to(self.weight.device)
+ bias = getattr(self, 'fp16_bias', None)
+ if bias is not None:
+ bias = bias.clone().to(self.bias.device)
+ updown, ex_bias = module.calc_updown(weight)
- if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
+ if len(weight.shape) == 4 and weight.shape[1] == 9:
# inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
- self.weight.copy_((self.weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype))
+ self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype))
if ex_bias is not None and hasattr(self, 'bias'):
if self.bias is None:
self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
else:
- self.bias.copy_((self.bias.to(dtype=ex_bias.dtype) + ex_bias).to(dtype=self.bias.dtype))
+ self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype))
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1