aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin
diff options
context:
space:
mode:
Diffstat (limited to 'extensions-builtin')
-rw-r--r--extensions-builtin/Lora/extra_networks_lora.py8
-rw-r--r--extensions-builtin/Lora/lora.py5
-rw-r--r--extensions-builtin/Lora/scripts/lora_script.py10
-rw-r--r--extensions-builtin/Lora/ui_extra_networks_lora.py3
4 files changed, 22 insertions, 4 deletions
diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py
index 8f2e753e..6be6ef73 100644
--- a/extensions-builtin/Lora/extra_networks_lora.py
+++ b/extensions-builtin/Lora/extra_networks_lora.py
@@ -1,4 +1,4 @@
-from modules import extra_networks
+from modules import extra_networks, shared
import lora
class ExtraNetworkLora(extra_networks.ExtraNetwork):
@@ -6,6 +6,12 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
super().__init__('lora')
def activate(self, p, params_list):
+ additional = shared.opts.sd_lora
+
+ if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
+ p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
+ params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
+
names = []
multipliers = []
for params in params_list:
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py
index 137e58f7..cb8f1d36 100644
--- a/extensions-builtin/Lora/lora.py
+++ b/extensions-builtin/Lora/lora.py
@@ -166,7 +166,10 @@ def lora_forward(module, input, res):
for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None)
if module is not None:
- res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
+ if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
+ res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
+ else:
+ res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
return res
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py
index 60b9eb64..2e860160 100644
--- a/extensions-builtin/Lora/scripts/lora_script.py
+++ b/extensions-builtin/Lora/scripts/lora_script.py
@@ -1,9 +1,10 @@
import torch
+import gradio as gr
import lora
import extra_networks_lora
import ui_extra_networks_lora
-from modules import script_callbacks, ui_extra_networks, extra_networks
+from modules import script_callbacks, ui_extra_networks, extra_networks, shared
def unload():
@@ -28,3 +29,10 @@ torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload)
script_callbacks.on_before_ui(before_ui)
+
+
+shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
+ "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
+ "lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"),
+
+}))
diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py
index 54a80d36..22cabcb0 100644
--- a/extensions-builtin/Lora/ui_extra_networks_lora.py
+++ b/extensions-builtin/Lora/ui_extra_networks_lora.py
@@ -20,13 +20,14 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
preview = None
for file in previews:
if os.path.isfile(file):
- preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file))
+ preview = self.link_preview(file)
break
yield {
"name": name,
"filename": path,
"preview": preview,
+ "search_term": self.search_terms_from_path(lora_on_disk.filename),
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
"local_preview": path + ".png",
}