aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-08-15 19:23:27 +0300
committerAUTOMATIC1111 <16777216c@gmail.com>2023-08-15 19:23:40 +0300
commitf01682ee01e81e8ef84fd6fffe8f7aa17233285d (patch)
tree80f62099e6af5f77c7df8c092c37c71ed24750d9
parent7327be97aa9beeae881bf4649a56792bd284efd5 (diff)
store patches for Lora in a specialized module
-rw-r--r--extensions-builtin/Lora/lora_patches.py31
-rw-r--r--extensions-builtin/Lora/networks.py32
-rw-r--r--extensions-builtin/Lora/scripts/lora_script.py52
-rw-r--r--modules/patches.py64
4 files changed, 118 insertions, 61 deletions
diff --git a/extensions-builtin/Lora/lora_patches.py b/extensions-builtin/Lora/lora_patches.py
new file mode 100644
index 00000000..b394d8e9
--- /dev/null
+++ b/extensions-builtin/Lora/lora_patches.py
@@ -0,0 +1,31 @@
+import torch
+
+import networks
+from modules import patches
+
+
+class LoraPatches:
+ def __init__(self):
+ self.Linear_forward = patches.patch(__name__, torch.nn.Linear, 'forward', networks.network_Linear_forward)
+ self.Linear_load_state_dict = patches.patch(__name__, torch.nn.Linear, '_load_from_state_dict', networks.network_Linear_load_state_dict)
+ self.Conv2d_forward = patches.patch(__name__, torch.nn.Conv2d, 'forward', networks.network_Conv2d_forward)
+ self.Conv2d_load_state_dict = patches.patch(__name__, torch.nn.Conv2d, '_load_from_state_dict', networks.network_Conv2d_load_state_dict)
+ self.GroupNorm_forward = patches.patch(__name__, torch.nn.GroupNorm, 'forward', networks.network_GroupNorm_forward)
+ self.GroupNorm_load_state_dict = patches.patch(__name__, torch.nn.GroupNorm, '_load_from_state_dict', networks.network_GroupNorm_load_state_dict)
+ self.LayerNorm_forward = patches.patch(__name__, torch.nn.LayerNorm, 'forward', networks.network_LayerNorm_forward)
+ self.LayerNorm_load_state_dict = patches.patch(__name__, torch.nn.LayerNorm, '_load_from_state_dict', networks.network_LayerNorm_load_state_dict)
+ self.MultiheadAttention_forward = patches.patch(__name__, torch.nn.MultiheadAttention, 'forward', networks.network_MultiheadAttention_forward)
+ self.MultiheadAttention_load_state_dict = patches.patch(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict', networks.network_MultiheadAttention_load_state_dict)
+
+ def undo(self):
+ self.Linear_forward = patches.undo(__name__, torch.nn.Linear, 'forward')
+ self.Linear_load_state_dict = patches.undo(__name__, torch.nn.Linear, '_load_from_state_dict')
+ self.Conv2d_forward = patches.undo(__name__, torch.nn.Conv2d, 'forward')
+ self.Conv2d_load_state_dict = patches.undo(__name__, torch.nn.Conv2d, '_load_from_state_dict')
+ self.GroupNorm_forward = patches.undo(__name__, torch.nn.GroupNorm, 'forward')
+ self.GroupNorm_load_state_dict = patches.undo(__name__, torch.nn.GroupNorm, '_load_from_state_dict')
+ self.LayerNorm_forward = patches.undo(__name__, torch.nn.LayerNorm, 'forward')
+ self.LayerNorm_load_state_dict = patches.undo(__name__, torch.nn.LayerNorm, '_load_from_state_dict')
+ self.MultiheadAttention_forward = patches.undo(__name__, torch.nn.MultiheadAttention, 'forward')
+ self.MultiheadAttention_load_state_dict = patches.undo(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict')
+
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index 22fdff4a..9fca36b6 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -2,6 +2,7 @@ import logging
import os
import re
+import lora_patches
import network
import network_lora
import network_hada
@@ -418,74 +419,74 @@ def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
def network_Linear_forward(self, input):
if shared.opts.lora_functional:
- return network_forward(self, input, torch.nn.Linear_forward_before_network)
+ return network_forward(self, input, originals.Linear_forward)
network_apply_weights(self)
- return torch.nn.Linear_forward_before_network(self, input)
+ return originals.Linear_forward(self, input)
def network_Linear_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
- return torch.nn.Linear_load_state_dict_before_network(self, *args, **kwargs)
+ return originals.Linear_load_state_dict(self, *args, **kwargs)
def network_Conv2d_forward(self, input):
if shared.opts.lora_functional:
- return network_forward(self, input, torch.nn.Conv2d_forward_before_network)
+ return network_forward(self, input, originals.Conv2d_forward)
network_apply_weights(self)
- return torch.nn.Conv2d_forward_before_network(self, input)
+ return originals.Conv2d_forward(self, input)
def network_Conv2d_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
- return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)
+ return originals.Conv2d_load_state_dict(self, *args, **kwargs)
def network_GroupNorm_forward(self, input):
if shared.opts.lora_functional:
- return network_forward(self, input, torch.nn.GroupNorm_forward_before_network)
+ return network_forward(self, input, originals.GroupNorm_forward)
network_apply_weights(self)
- return torch.nn.GroupNorm_forward_before_network(self, input)
+ return originals.GroupNorm_forward(self, input)
def network_GroupNorm_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
- return torch.nn.GroupNorm_load_state_dict_before_network(self, *args, **kwargs)
+ return originals.GroupNorm_load_state_dict(self, *args, **kwargs)
def network_LayerNorm_forward(self, input):
if shared.opts.lora_functional:
- return network_forward(self, input, torch.nn.LayerNorm_forward_before_network)
+ return network_forward(self, input, originals.LayerNorm_forward)
network_apply_weights(self)
- return torch.nn.LayerNorm_forward_before_network(self, input)
+ return originals.LayerNorm_forward(self, input)
def network_LayerNorm_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
- return torch.nn.LayerNorm_load_state_dict_before_network(self, *args, **kwargs)
+ return originals.LayerNorm_load_state_dict(self, *args, **kwargs)
def network_MultiheadAttention_forward(self, *args, **kwargs):
network_apply_weights(self)
- return torch.nn.MultiheadAttention_forward_before_network(self, *args, **kwargs)
+ return originals.MultiheadAttention_forward(self, *args, **kwargs)
def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
- return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs)
+ return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs)
def list_available_networks():
@@ -552,6 +553,9 @@ def infotext_pasted(infotext, params):
if added:
params["Prompt"] += "\n" + "".join(added)
+
+originals: lora_patches.LoraPatches = None
+
extra_network_lora = None
available_networks = {}
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py
index 4c6e774a..546fb55e 100644
--- a/extensions-builtin/Lora/scripts/lora_script.py
+++ b/extensions-builtin/Lora/scripts/lora_script.py
@@ -7,17 +7,14 @@ from fastapi import FastAPI
import network
import networks
import lora # noqa:F401
+import lora_patches
import extra_networks_lora
import ui_extra_networks_lora
-from modules import script_callbacks, ui_extra_networks, extra_networks, shared
+from modules import script_callbacks, ui_extra_networks, extra_networks, shared, patches
+
def unload():
- torch.nn.Linear.forward = torch.nn.Linear_forward_before_network
- torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_network
- torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_network
- torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_network
- torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_network
- torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_network
+ networks.originals.undo()
def before_ui():
@@ -28,46 +25,7 @@ def before_ui():
extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco")
-if not hasattr(torch.nn, 'Linear_forward_before_network'):
- torch.nn.Linear_forward_before_network = torch.nn.Linear.forward
-
-if not hasattr(torch.nn, 'Linear_load_state_dict_before_network'):
- torch.nn.Linear_load_state_dict_before_network = torch.nn.Linear._load_from_state_dict
-
-if not hasattr(torch.nn, 'Conv2d_forward_before_network'):
- torch.nn.Conv2d_forward_before_network = torch.nn.Conv2d.forward
-
-if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
- torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
-
-if not hasattr(torch.nn, 'GroupNorm_forward_before_network'):
- torch.nn.GroupNorm_forward_before_network = torch.nn.GroupNorm.forward
-
-if not hasattr(torch.nn, 'GroupNorm_load_state_dict_before_network'):
- torch.nn.GroupNorm_load_state_dict_before_network = torch.nn.GroupNorm._load_from_state_dict
-
-if not hasattr(torch.nn, 'LayerNorm_forward_before_network'):
- torch.nn.LayerNorm_forward_before_network = torch.nn.LayerNorm.forward
-
-if not hasattr(torch.nn, 'LayerNorm_load_state_dict_before_network'):
- torch.nn.LayerNorm_load_state_dict_before_network = torch.nn.LayerNorm._load_from_state_dict
-
-if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
- torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
-
-if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_network'):
- torch.nn.MultiheadAttention_load_state_dict_before_network = torch.nn.MultiheadAttention._load_from_state_dict
-
-torch.nn.Linear.forward = networks.network_Linear_forward
-torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
-torch.nn.Conv2d.forward = networks.network_Conv2d_forward
-torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
-torch.nn.GroupNorm.forward = networks.network_GroupNorm_forward
-torch.nn.GroupNorm._load_from_state_dict = networks.network_GroupNorm_load_state_dict
-torch.nn.LayerNorm.forward = networks.network_LayerNorm_forward
-torch.nn.LayerNorm._load_from_state_dict = networks.network_LayerNorm_load_state_dict
-torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
-torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict
+networks.originals = lora_patches.LoraPatches()
script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload)
diff --git a/modules/patches.py b/modules/patches.py
new file mode 100644
index 00000000..348235e7
--- /dev/null
+++ b/modules/patches.py
@@ -0,0 +1,64 @@
+from collections import defaultdict
+
+
+def patch(key, obj, field, replacement):
+ """Replaces a function in a module or a class.
+
+ Also stores the original function in this module, possible to be retrieved via original(key, obj, field).
+ If the function is already replaced by this caller (key), an exception is raised -- use undo() before that.
+
+ Arguments:
+ key: identifying information for who is doing the replacement. You can use __name__.
+ obj: the module or the class
+ field: name of the function as a string
+ replacement: the new function
+
+ Returns:
+ the original function
+ """
+
+ patch_key = (obj, field)
+ if patch_key in originals[key]:
+ raise RuntimeError(f"patch for {field} is already applied")
+
+ original_func = getattr(obj, field)
+ originals[key][patch_key] = original_func
+
+ setattr(obj, field, replacement)
+
+ return original_func
+
+
+def undo(key, obj, field):
+ """Undoes the peplacement by the patch().
+
+ If the function is not replaced, raises an exception.
+
+ Arguments:
+ key: identifying information for who is doing the replacement. You can use __name__.
+ obj: the module or the class
+ field: name of the function as a string
+
+ Returns:
+ Always None
+ """
+
+ patch_key = (obj, field)
+
+ if patch_key not in originals[key]:
+ raise RuntimeError(f"there is no patch for {field} to undo")
+
+ original_func = originals[key].pop(patch_key)
+ setattr(obj, field, original_func)
+
+ return None
+
+
+def original(key, obj, field):
+ """Returns the original function for the patch created by the patch() function"""
+ patch_key = (obj, field)
+
+ return originals[key].get(patch_key, None)
+
+
+originals = defaultdict(dict)