aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/Lora/networks.py
diff options
context:
space:
mode:
Diffstat (limited to 'extensions-builtin/Lora/networks.py')
-rw-r--r--extensions-builtin/Lora/networks.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index 60d8dec4..8ea4ea60 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -381,12 +381,12 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
# inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
- self.weight += updown
+ self.weight.copy_((self.weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype))
if ex_bias is not None and hasattr(self, 'bias'):
if self.bias is None:
- self.bias = torch.nn.Parameter(ex_bias)
+ self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
else:
- self.bias += ex_bias
+ self.bias.copy_((self.bias.to(dtype=ex_bias.dtype) + ex_bias).to(dtype=self.bias.dtype))
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1