aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/Lora/scripts/lora_script.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2023-01-21 16:15:53 +0300
committerAUTOMATIC <16777216c@gmail.com>2023-01-21 16:15:53 +0300
commit855b9e3d1c5a1bd8c2d815d38a38bc7c410be5a8 (patch)
tree6c978dd9650c65575ea26d3e86480fc383923076 /extensions-builtin/Lora/scripts/lora_script.py
parentcbfb4632585415dc914aff8c44869d792fd64c24 (diff)
Lora support!
update readme to reflect some recent changes
Diffstat (limited to 'extensions-builtin/Lora/scripts/lora_script.py')
-rw-r--r--extensions-builtin/Lora/scripts/lora_script.py30
1 files changed, 30 insertions, 0 deletions
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py
new file mode 100644
index 00000000..60b9eb64
--- /dev/null
+++ b/extensions-builtin/Lora/scripts/lora_script.py
@@ -0,0 +1,30 @@
+import torch
+
+import lora
+import extra_networks_lora
+import ui_extra_networks_lora
+from modules import script_callbacks, ui_extra_networks, extra_networks
+
+
+def unload():
+ torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
+ torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
+
+
+def before_ui():
+ ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
+ extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora())
+
+
+if not hasattr(torch.nn, 'Linear_forward_before_lora'):
+ torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
+
+if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
+ torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
+
+torch.nn.Linear.forward = lora.lora_Linear_forward
+torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
+
+script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
+script_callbacks.on_script_unloaded(unload)
+script_callbacks.on_before_ui(before_ui)