aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/Lora/networks.py
diff options
context:
space:
mode:
Diffstat (limited to 'extensions-builtin/Lora/networks.py')
-rw-r--r--extensions-builtin/Lora/networks.py443
1 files changed, 443 insertions, 0 deletions
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
new file mode 100644
index 00000000..5b0ddfb6
--- /dev/null
+++ b/extensions-builtin/Lora/networks.py
@@ -0,0 +1,443 @@
+import os
+import re
+
+import network
+import network_lora
+import network_hada
+
+import torch
+from typing import Union
+
+from modules import shared, devices, sd_models, errors, scripts, sd_hijack
+
+module_types = [
+ network_lora.ModuleTypeLora(),
+ network_hada.ModuleTypeHada(),
+]
+
+
+re_digits = re.compile(r"\d+")
+re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
+re_compiled = {}
+
+suffix_conversion = {
+ "attentions": {},
+ "resnets": {
+ "conv1": "in_layers_2",
+ "conv2": "out_layers_3",
+ "time_emb_proj": "emb_layers_1",
+ "conv_shortcut": "skip_connection",
+ }
+}
+
+
+def convert_diffusers_name_to_compvis(key, is_sd2):
+ def match(match_list, regex_text):
+ regex = re_compiled.get(regex_text)
+ if regex is None:
+ regex = re.compile(regex_text)
+ re_compiled[regex_text] = regex
+
+ r = re.match(regex, key)
+ if not r:
+ return False
+
+ match_list.clear()
+ match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
+ return True
+
+ m = []
+
+ if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
+ suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
+ return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
+
+ if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
+ suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
+ return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
+
+ if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
+ suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
+ return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
+
+ if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
+ return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
+
+ if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
+ return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
+
+ if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
+ if is_sd2:
+ if 'mlp_fc1' in m[1]:
+ return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
+ elif 'mlp_fc2' in m[1]:
+ return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
+ else:
+ return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
+
+ return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
+
+ if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"):
+ if 'mlp_fc1' in m[1]:
+ return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
+ elif 'mlp_fc2' in m[1]:
+ return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
+ else:
+ return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
+
+ return key
+
+
+def assign_network_names_to_compvis_modules(sd_model):
+ network_layer_mapping = {}
+
+ if shared.sd_model.is_sdxl:
+ for i, embedder in enumerate(shared.sd_model.conditioner.embedders):
+ if not hasattr(embedder, 'wrapped'):
+ continue
+
+ for name, module in embedder.wrapped.named_modules():
+ network_name = f'{i}_{name.replace(".", "_")}'
+ network_layer_mapping[network_name] = module
+ module.network_layer_name = network_name
+ else:
+ for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
+ network_name = name.replace(".", "_")
+ network_layer_mapping[network_name] = module
+ module.network_layer_name = network_name
+
+ for name, module in shared.sd_model.model.named_modules():
+ network_name = name.replace(".", "_")
+ network_layer_mapping[network_name] = module
+ module.network_layer_name = network_name
+
+ sd_model.network_layer_mapping = network_layer_mapping
+
+
+def load_network(name, network_on_disk):
+ net = network.Network(name, network_on_disk)
+ net.mtime = os.path.getmtime(network_on_disk.filename)
+
+ sd = sd_models.read_state_dict(network_on_disk.filename)
+
+ # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
+ if not hasattr(shared.sd_model, 'network_layer_mapping'):
+ assign_network_names_to_compvis_modules(shared.sd_model)
+
+ keys_failed_to_match = {}
+ is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
+
+ matched_networks = {}
+
+ for key_network, weight in sd.items():
+ key_network_without_network_parts, network_part = key_network.split(".", 1)
+
+ key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
+ sd_module = shared.sd_model.network_layer_mapping.get(key, None)
+
+ if sd_module is None:
+ m = re_x_proj.match(key)
+ if m:
+ sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None)
+
+ # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
+ if sd_module is None and "lora_unet" in key_network_without_network_parts:
+ key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
+ sd_module = shared.sd_model.network_layer_mapping.get(key, None)
+ elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts:
+ key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
+ sd_module = shared.sd_model.network_layer_mapping.get(key, None)
+
+ if sd_module is None:
+ keys_failed_to_match[key_network] = key
+ continue
+
+ if key not in matched_networks:
+ matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module)
+
+ matched_networks[key].w[network_part] = weight
+
+ for key, weights in matched_networks.items():
+ net_module = None
+ for nettype in module_types:
+ net_module = nettype.create_module(net, weights)
+ if net_module is not None:
+ break
+
+ if net_module is None:
+ raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}")
+
+ net.modules[key] = net_module
+
+ if keys_failed_to_match:
+ print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}")
+
+ return net
+
+
+def load_networks(names, multipliers=None):
+ already_loaded = {}
+
+ for net in loaded_networks:
+ if net.name in names:
+ already_loaded[net.name] = net
+
+ loaded_networks.clear()
+
+ networks_on_disk = [available_network_aliases.get(name, None) for name in names]
+ if any(x is None for x in networks_on_disk):
+ list_available_networks()
+
+ networks_on_disk = [available_network_aliases.get(name, None) for name in names]
+
+ failed_to_load_networks = []
+
+ for i, name in enumerate(names):
+ net = already_loaded.get(name, None)
+
+ network_on_disk = networks_on_disk[i]
+
+ if network_on_disk is not None:
+ if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
+ try:
+ net = load_network(name, network_on_disk)
+ except Exception as e:
+ errors.display(e, f"loading network {network_on_disk.filename}")
+ continue
+
+ net.mentioned_name = name
+
+ network_on_disk.read_hash()
+
+ if net is None:
+ failed_to_load_networks.append(name)
+ print(f"Couldn't find network with name {name}")
+ continue
+
+ net.multiplier = multipliers[i] if multipliers else 1.0
+ loaded_networks.append(net)
+
+ if failed_to_load_networks:
+ sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks))
+
+
+def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
+ weights_backup = getattr(self, "network_weights_backup", None)
+
+ if weights_backup is None:
+ return
+
+ if isinstance(self, torch.nn.MultiheadAttention):
+ self.in_proj_weight.copy_(weights_backup[0])
+ self.out_proj.weight.copy_(weights_backup[1])
+ else:
+ self.weight.copy_(weights_backup)
+
+
+def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
+ """
+ Applies the currently selected set of networks to the weights of torch layer self.
+ If weights already have this particular set of networks applied, does nothing.
+ If not, restores orginal weights from backup and alters weights according to networks.
+ """
+
+ network_layer_name = getattr(self, 'network_layer_name', None)
+ if network_layer_name is None:
+ return
+
+ current_names = getattr(self, "network_current_names", ())
+ wanted_names = tuple((x.name, x.multiplier) for x in loaded_networks)
+
+ weights_backup = getattr(self, "network_weights_backup", None)
+ if weights_backup is None:
+ if isinstance(self, torch.nn.MultiheadAttention):
+ weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
+ else:
+ weights_backup = self.weight.to(devices.cpu, copy=True)
+
+ self.network_weights_backup = weights_backup
+
+ if current_names != wanted_names:
+ network_restore_weights_from_backup(self)
+
+ for net in loaded_networks:
+ module = net.modules.get(network_layer_name, None)
+ if module is not None and hasattr(self, 'weight'):
+ with torch.no_grad():
+ updown = module.calc_updown(self.weight)
+
+ if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
+ # inpainting model. zero pad updown to make channel[1] 4 to 9
+ updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
+
+ self.weight += updown
+
+ module_q = net.modules.get(network_layer_name + "_q_proj", None)
+ module_k = net.modules.get(network_layer_name + "_k_proj", None)
+ module_v = net.modules.get(network_layer_name + "_v_proj", None)
+ module_out = net.modules.get(network_layer_name + "_out_proj", None)
+
+ if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
+ with torch.no_grad():
+ updown_q = module_q.calc_updown(self.in_proj_weight)
+ updown_k = module_k.calc_updown(self.in_proj_weight)
+ updown_v = module_v.calc_updown(self.in_proj_weight)
+ updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
+
+ self.in_proj_weight += updown_qkv
+ self.out_proj.weight += module_out.calc_updown(self.out_proj.weight)
+ continue
+
+ if module is None:
+ continue
+
+ print(f'failed to calculate network weights for layer {network_layer_name}')
+
+ self.network_current_names = wanted_names
+
+
+def network_forward(module, input, original_forward):
+ """
+ Old way of applying Lora by executing operations during layer's forward.
+ Stacking many loras this way results in big performance degradation.
+ """
+
+ if len(loaded_networks) == 0:
+ return original_forward(module, input)
+
+ input = devices.cond_cast_unet(input)
+
+ network_restore_weights_from_backup(module)
+ network_reset_cached_weight(module)
+
+ y = original_forward(module, input)
+
+ network_layer_name = getattr(module, 'network_layer_name', None)
+ for lora in loaded_networks:
+ module = lora.modules.get(network_layer_name, None)
+ if module is None:
+ continue
+
+ y = module.forward(y, input)
+
+ return y
+
+
+def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
+ self.network_current_names = ()
+ self.network_weights_backup = None
+
+
+def network_Linear_forward(self, input):
+ if shared.opts.lora_functional:
+ return network_forward(self, input, torch.nn.Linear_forward_before_network)
+
+ network_apply_weights(self)
+
+ return torch.nn.Linear_forward_before_network(self, input)
+
+
+def network_Linear_load_state_dict(self, *args, **kwargs):
+ network_reset_cached_weight(self)
+
+ return torch.nn.Linear_load_state_dict_before_network(self, *args, **kwargs)
+
+
+def network_Conv2d_forward(self, input):
+ if shared.opts.lora_functional:
+ return network_forward(self, input, torch.nn.Conv2d_forward_before_network)
+
+ network_apply_weights(self)
+
+ return torch.nn.Conv2d_forward_before_network(self, input)
+
+
+def network_Conv2d_load_state_dict(self, *args, **kwargs):
+ network_reset_cached_weight(self)
+
+ return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)
+
+
+def network_MultiheadAttention_forward(self, *args, **kwargs):
+ network_apply_weights(self)
+
+ return torch.nn.MultiheadAttention_forward_before_network(self, *args, **kwargs)
+
+
+def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
+ network_reset_cached_weight(self)
+
+ return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs)
+
+
+def list_available_networks():
+ available_networks.clear()
+ available_network_aliases.clear()
+ forbidden_network_aliases.clear()
+ available_network_hash_lookup.clear()
+ forbidden_network_aliases.update({"none": 1, "Addams": 1})
+
+ os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
+
+ candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
+ for filename in candidates:
+ if os.path.isdir(filename):
+ continue
+
+ name = os.path.splitext(os.path.basename(filename))[0]
+ try:
+ entry = network.NetworkOnDisk(name, filename)
+ except OSError: # should catch FileNotFoundError and PermissionError etc.
+ errors.report(f"Failed to load network {name} from {filename}", exc_info=True)
+ continue
+
+ available_networks[name] = entry
+
+ if entry.alias in available_network_aliases:
+ forbidden_network_aliases[entry.alias.lower()] = 1
+
+ available_network_aliases[name] = entry
+ available_network_aliases[entry.alias] = entry
+
+
+re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
+
+
+def infotext_pasted(infotext, params):
+ if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
+ return # if the other extension is active, it will handle those fields, no need to do anything
+
+ added = []
+
+ for k in params:
+ if not k.startswith("AddNet Model "):
+ continue
+
+ num = k[13:]
+
+ if params.get("AddNet Module " + num) != "LoRA":
+ continue
+
+ name = params.get("AddNet Model " + num)
+ if name is None:
+ continue
+
+ m = re_network_name.match(name)
+ if m:
+ name = m.group(1)
+
+ multiplier = params.get("AddNet Weight A " + num, "1.0")
+
+ added.append(f"<lora:{name}:{multiplier}>")
+
+ if added:
+ params["Prompt"] += "\n" + "".join(added)
+
+
+available_networks = {}
+available_network_aliases = {}
+loaded_networks = []
+available_network_hash_lookup = {}
+forbidden_network_aliases = {}
+
+list_available_networks()