aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/Lora/scripts
diff options
context:
space:
mode:
Diffstat (limited to 'extensions-builtin/Lora/scripts')
-rw-r--r--extensions-builtin/Lora/scripts/lora_script.py41
1 files changed, 7 insertions, 34 deletions
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py
index 6ab8b6e7..ef23968c 100644
--- a/extensions-builtin/Lora/scripts/lora_script.py
+++ b/extensions-builtin/Lora/scripts/lora_script.py
@@ -1,57 +1,30 @@
import re
-import torch
import gradio as gr
from fastapi import FastAPI
import network
import networks
import lora # noqa:F401
+import lora_patches
import extra_networks_lora
import ui_extra_networks_lora
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
+
def unload():
- torch.nn.Linear.forward = torch.nn.Linear_forward_before_network
- torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_network
- torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_network
- torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_network
- torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_network
- torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_network
+ networks.originals.undo()
def before_ui():
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
- extra_network = extra_networks_lora.ExtraNetworkLora()
- extra_networks.register_extra_network(extra_network)
- extra_networks.register_extra_network_alias(extra_network, "lyco")
-
-
-if not hasattr(torch.nn, 'Linear_forward_before_network'):
- torch.nn.Linear_forward_before_network = torch.nn.Linear.forward
-
-if not hasattr(torch.nn, 'Linear_load_state_dict_before_network'):
- torch.nn.Linear_load_state_dict_before_network = torch.nn.Linear._load_from_state_dict
-
-if not hasattr(torch.nn, 'Conv2d_forward_before_network'):
- torch.nn.Conv2d_forward_before_network = torch.nn.Conv2d.forward
-
-if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
- torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
-
-if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
- torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
+ networks.extra_network_lora = extra_networks_lora.ExtraNetworkLora()
+ extra_networks.register_extra_network(networks.extra_network_lora)
+ extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco")
-if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_network'):
- torch.nn.MultiheadAttention_load_state_dict_before_network = torch.nn.MultiheadAttention._load_from_state_dict
-torch.nn.Linear.forward = networks.network_Linear_forward
-torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
-torch.nn.Conv2d.forward = networks.network_Conv2d_forward
-torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
-torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
-torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict
+networks.originals = lora_patches.LoraPatches()
script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload)