From 80b26d2a69617b75d2d01c1e6b7d11445815ed4d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 25 Mar 2023 23:06:33 +0300 Subject: apply Lora by altering layer's weights instead of adding more calculations in forward() --- extensions-builtin/Lora/scripts/lora_script.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) (limited to 'extensions-builtin/Lora/scripts') diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 2e860160..dc329e81 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -9,7 +9,9 @@ from modules import script_callbacks, ui_extra_networks, extra_networks, shared def unload(): torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora + torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora + torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora def before_ui(): @@ -20,11 +22,19 @@ def before_ui(): if not hasattr(torch.nn, 'Linear_forward_before_lora'): torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward +if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'): + torch.nn.Linear_load_state_dict_before_lora = torch.nn.Linear._load_from_state_dict + if not hasattr(torch.nn, 'Conv2d_forward_before_lora'): torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward +if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'): + torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict + torch.nn.Linear.forward = lora.lora_Linear_forward +torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict torch.nn.Conv2d.forward = lora.lora_Conv2d_forward +torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) script_callbacks.on_script_unloaded(unload) @@ -33,6 +43,4 @@ script_callbacks.on_before_ui(before_ui) shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras), - "lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"), - })) -- cgit v1.2.1 From 650ddc9dd3c1d126221682be8270f7fba1b5b6ce Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 26 Mar 2023 10:44:20 +0300 Subject: Lora support for SD2 --- extensions-builtin/Lora/scripts/lora_script.py | 10 ++++++++++ 1 file changed, 10 insertions(+) (limited to 'extensions-builtin/Lora/scripts') diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index dc329e81..0adab225 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -12,6 +12,8 @@ def unload(): torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora + torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_lora + torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_lora def before_ui(): @@ -31,10 +33,18 @@ if not hasattr(torch.nn, 'Conv2d_forward_before_lora'): if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'): torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict +if not hasattr(torch.nn, 'MultiheadAttention_forward_before_lora'): + torch.nn.MultiheadAttention_forward_before_lora = torch.nn.MultiheadAttention.forward + +if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_lora'): + torch.nn.MultiheadAttention_load_state_dict_before_lora = torch.nn.MultiheadAttention._load_from_state_dict + torch.nn.Linear.forward = lora.lora_Linear_forward torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict torch.nn.Conv2d.forward = lora.lora_Conv2d_forward torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict +torch.nn.MultiheadAttention.forward = lora.lora_MultiheadAttention_forward +torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention_load_state_dict script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) script_callbacks.on_script_unloaded(unload) -- cgit v1.2.1