aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin
diff options
context:
space:
mode:
Diffstat (limited to 'extensions-builtin')
-rw-r--r--extensions-builtin/Lora/extra_networks_lora.py10
-rw-r--r--extensions-builtin/Lora/lora_logger.py33
-rw-r--r--extensions-builtin/Lora/lora_patches.py31
-rw-r--r--extensions-builtin/Lora/lyco_helpers.py47
-rw-r--r--extensions-builtin/Lora/network.py8
-rw-r--r--extensions-builtin/Lora/network_full.py7
-rw-r--r--extensions-builtin/Lora/network_glora.py33
-rw-r--r--extensions-builtin/Lora/network_norm.py28
-rw-r--r--extensions-builtin/Lora/network_oft.py97
-rw-r--r--extensions-builtin/Lora/networks.py224
-rw-r--r--extensions-builtin/Lora/scripts/lora_script.py41
-rw-r--r--extensions-builtin/Lora/ui_edit_user_metadata.py1
-rw-r--r--extensions-builtin/Lora/ui_extra_networks_lora.py10
-rw-r--r--extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js194
-rw-r--r--extensions-builtin/canvas-zoom-and-pan/style.css3
-rw-r--r--extensions-builtin/extra-options-section/scripts/extra_options_section.py12
-rw-r--r--extensions-builtin/hypertile/hypertile.py348
-rw-r--r--extensions-builtin/hypertile/scripts/hypertile_script.py73
-rw-r--r--extensions-builtin/mobile/javascript/mobile.js8
19 files changed, 1103 insertions, 105 deletions
diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py
index ba2945c6..005ff32c 100644
--- a/extensions-builtin/Lora/extra_networks_lora.py
+++ b/extensions-builtin/Lora/extra_networks_lora.py
@@ -6,9 +6,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
def __init__(self):
super().__init__('lora')
+ self.errors = {}
+ """mapping of network names to the number of errors the network had during operation"""
+
def activate(self, p, params_list):
additional = shared.opts.sd_lora
+ self.errors.clear()
+
if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
@@ -56,4 +61,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)
def deactivate(self, p):
- pass
+ if self.errors:
+ p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items()))
+
+ self.errors.clear()
diff --git a/extensions-builtin/Lora/lora_logger.py b/extensions-builtin/Lora/lora_logger.py
new file mode 100644
index 00000000..d51de297
--- /dev/null
+++ b/extensions-builtin/Lora/lora_logger.py
@@ -0,0 +1,33 @@
+import sys
+import copy
+import logging
+
+
+class ColoredFormatter(logging.Formatter):
+ COLORS = {
+ "DEBUG": "\033[0;36m", # CYAN
+ "INFO": "\033[0;32m", # GREEN
+ "WARNING": "\033[0;33m", # YELLOW
+ "ERROR": "\033[0;31m", # RED
+ "CRITICAL": "\033[0;37;41m", # WHITE ON RED
+ "RESET": "\033[0m", # RESET COLOR
+ }
+
+ def format(self, record):
+ colored_record = copy.copy(record)
+ levelname = colored_record.levelname
+ seq = self.COLORS.get(levelname, self.COLORS["RESET"])
+ colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}"
+ return super().format(colored_record)
+
+
+logger = logging.getLogger("lora")
+logger.propagate = False
+
+
+if not logger.handlers:
+ handler = logging.StreamHandler(sys.stdout)
+ handler.setFormatter(
+ ColoredFormatter("[%(name)s]-%(levelname)s: %(message)s")
+ )
+ logger.addHandler(handler)
diff --git a/extensions-builtin/Lora/lora_patches.py b/extensions-builtin/Lora/lora_patches.py
new file mode 100644
index 00000000..b394d8e9
--- /dev/null
+++ b/extensions-builtin/Lora/lora_patches.py
@@ -0,0 +1,31 @@
+import torch
+
+import networks
+from modules import patches
+
+
+class LoraPatches:
+ def __init__(self):
+ self.Linear_forward = patches.patch(__name__, torch.nn.Linear, 'forward', networks.network_Linear_forward)
+ self.Linear_load_state_dict = patches.patch(__name__, torch.nn.Linear, '_load_from_state_dict', networks.network_Linear_load_state_dict)
+ self.Conv2d_forward = patches.patch(__name__, torch.nn.Conv2d, 'forward', networks.network_Conv2d_forward)
+ self.Conv2d_load_state_dict = patches.patch(__name__, torch.nn.Conv2d, '_load_from_state_dict', networks.network_Conv2d_load_state_dict)
+ self.GroupNorm_forward = patches.patch(__name__, torch.nn.GroupNorm, 'forward', networks.network_GroupNorm_forward)
+ self.GroupNorm_load_state_dict = patches.patch(__name__, torch.nn.GroupNorm, '_load_from_state_dict', networks.network_GroupNorm_load_state_dict)
+ self.LayerNorm_forward = patches.patch(__name__, torch.nn.LayerNorm, 'forward', networks.network_LayerNorm_forward)
+ self.LayerNorm_load_state_dict = patches.patch(__name__, torch.nn.LayerNorm, '_load_from_state_dict', networks.network_LayerNorm_load_state_dict)
+ self.MultiheadAttention_forward = patches.patch(__name__, torch.nn.MultiheadAttention, 'forward', networks.network_MultiheadAttention_forward)
+ self.MultiheadAttention_load_state_dict = patches.patch(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict', networks.network_MultiheadAttention_load_state_dict)
+
+ def undo(self):
+ self.Linear_forward = patches.undo(__name__, torch.nn.Linear, 'forward')
+ self.Linear_load_state_dict = patches.undo(__name__, torch.nn.Linear, '_load_from_state_dict')
+ self.Conv2d_forward = patches.undo(__name__, torch.nn.Conv2d, 'forward')
+ self.Conv2d_load_state_dict = patches.undo(__name__, torch.nn.Conv2d, '_load_from_state_dict')
+ self.GroupNorm_forward = patches.undo(__name__, torch.nn.GroupNorm, 'forward')
+ self.GroupNorm_load_state_dict = patches.undo(__name__, torch.nn.GroupNorm, '_load_from_state_dict')
+ self.LayerNorm_forward = patches.undo(__name__, torch.nn.LayerNorm, 'forward')
+ self.LayerNorm_load_state_dict = patches.undo(__name__, torch.nn.LayerNorm, '_load_from_state_dict')
+ self.MultiheadAttention_forward = patches.undo(__name__, torch.nn.MultiheadAttention, 'forward')
+ self.MultiheadAttention_load_state_dict = patches.undo(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict')
+
diff --git a/extensions-builtin/Lora/lyco_helpers.py b/extensions-builtin/Lora/lyco_helpers.py
index 279b34bc..1679a0ce 100644
--- a/extensions-builtin/Lora/lyco_helpers.py
+++ b/extensions-builtin/Lora/lyco_helpers.py
@@ -19,3 +19,50 @@ def rebuild_cp_decomposition(up, down, mid):
up = up.reshape(up.size(0), -1)
down = down.reshape(down.size(0), -1)
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
+
+
+# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py
+def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
+ '''
+ return a tuple of two value of input dimension decomposed by the number closest to factor
+ second value is higher or equal than first value.
+
+ In LoRA with Kroneckor Product, first value is a value for weight scale.
+ secon value is a value for weight.
+
+ Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
+
+ examples)
+ factor
+ -1 2 4 8 16 ...
+ 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
+ 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
+ 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
+ 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
+ 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
+ 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
+ '''
+
+ if factor > 0 and (dimension % factor) == 0:
+ m = factor
+ n = dimension // factor
+ if m > n:
+ n, m = m, n
+ return m, n
+ if factor < 0:
+ factor = dimension
+ m, n = 1, dimension
+ length = m + n
+ while m<n:
+ new_m = m + 1
+ while dimension%new_m != 0:
+ new_m += 1
+ new_n = dimension // new_m
+ if new_m + new_n > length or new_m>factor:
+ break
+ else:
+ m, n = new_m, new_n
+ if m > n:
+ n, m = m, n
+ return m, n
+
diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py
index 0a18d69e..6021fd8d 100644
--- a/extensions-builtin/Lora/network.py
+++ b/extensions-builtin/Lora/network.py
@@ -93,6 +93,7 @@ class Network: # LoraModule
self.unet_multiplier = 1.0
self.dyn_dim = None
self.modules = {}
+ self.bundle_embeddings = {}
self.mtime = None
self.mentioned_name = None
@@ -133,7 +134,7 @@ class NetworkModule:
return 1.0
- def finalize_updown(self, updown, orig_weight, output_shape):
+ def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None:
updown = updown.reshape(self.bias.shape)
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
@@ -145,7 +146,10 @@ class NetworkModule:
if orig_weight.size().numel() == updown.size().numel():
updown = updown.reshape(orig_weight.shape)
- return updown * self.calc_scale() * self.multiplier()
+ if ex_bias is not None:
+ ex_bias = ex_bias * self.multiplier()
+
+ return updown * self.calc_scale() * self.multiplier(), ex_bias
def calc_updown(self, target):
raise NotImplementedError()
diff --git a/extensions-builtin/Lora/network_full.py b/extensions-builtin/Lora/network_full.py
index 109b4c2c..bf6930e9 100644
--- a/extensions-builtin/Lora/network_full.py
+++ b/extensions-builtin/Lora/network_full.py
@@ -14,9 +14,14 @@ class NetworkModuleFull(network.NetworkModule):
super().__init__(net, weights)
self.weight = weights.w.get("diff")
+ self.ex_bias = weights.w.get("diff_b")
def calc_updown(self, orig_weight):
output_shape = self.weight.shape
updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype)
+ if self.ex_bias is not None:
+ ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype)
+ else:
+ ex_bias = None
- return self.finalize_updown(updown, orig_weight, output_shape)
+ return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)
diff --git a/extensions-builtin/Lora/network_glora.py b/extensions-builtin/Lora/network_glora.py
new file mode 100644
index 00000000..492d4870
--- /dev/null
+++ b/extensions-builtin/Lora/network_glora.py
@@ -0,0 +1,33 @@
+
+import network
+
+class ModuleTypeGLora(network.ModuleType):
+ def create_module(self, net: network.Network, weights: network.NetworkWeights):
+ if all(x in weights.w for x in ["a1.weight", "a2.weight", "alpha", "b1.weight", "b2.weight"]):
+ return NetworkModuleGLora(net, weights)
+
+ return None
+
+# adapted from https://github.com/KohakuBlueleaf/LyCORIS
+class NetworkModuleGLora(network.NetworkModule):
+ def __init__(self, net: network.Network, weights: network.NetworkWeights):
+ super().__init__(net, weights)
+
+ if hasattr(self.sd_module, 'weight'):
+ self.shape = self.sd_module.weight.shape
+
+ self.w1a = weights.w["a1.weight"]
+ self.w1b = weights.w["b1.weight"]
+ self.w2a = weights.w["a2.weight"]
+ self.w2b = weights.w["b2.weight"]
+
+ def calc_updown(self, orig_weight):
+ w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
+ w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
+ w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
+
+ output_shape = [w1a.size(0), w1b.size(1)]
+ updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a))
+
+ return self.finalize_updown(updown, orig_weight, output_shape)
diff --git a/extensions-builtin/Lora/network_norm.py b/extensions-builtin/Lora/network_norm.py
new file mode 100644
index 00000000..ce450158
--- /dev/null
+++ b/extensions-builtin/Lora/network_norm.py
@@ -0,0 +1,28 @@
+import network
+
+
+class ModuleTypeNorm(network.ModuleType):
+ def create_module(self, net: network.Network, weights: network.NetworkWeights):
+ if all(x in weights.w for x in ["w_norm", "b_norm"]):
+ return NetworkModuleNorm(net, weights)
+
+ return None
+
+
+class NetworkModuleNorm(network.NetworkModule):
+ def __init__(self, net: network.Network, weights: network.NetworkWeights):
+ super().__init__(net, weights)
+
+ self.w_norm = weights.w.get("w_norm")
+ self.b_norm = weights.w.get("b_norm")
+
+ def calc_updown(self, orig_weight):
+ output_shape = self.w_norm.shape
+ updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype)
+
+ if self.b_norm is not None:
+ ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype)
+ else:
+ ex_bias = None
+
+ return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)
diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py
new file mode 100644
index 00000000..05c37811
--- /dev/null
+++ b/extensions-builtin/Lora/network_oft.py
@@ -0,0 +1,97 @@
+import torch
+import network
+from lyco_helpers import factorization
+from einops import rearrange
+
+
+class ModuleTypeOFT(network.ModuleType):
+ def create_module(self, net: network.Network, weights: network.NetworkWeights):
+ if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]):
+ return NetworkModuleOFT(net, weights)
+
+ return None
+
+# Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py
+# and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py
+class NetworkModuleOFT(network.NetworkModule):
+ def __init__(self, net: network.Network, weights: network.NetworkWeights):
+
+ super().__init__(net, weights)
+
+ self.lin_module = None
+ self.org_module: list[torch.Module] = [self.sd_module]
+
+ # kohya-ss
+ if "oft_blocks" in weights.w.keys():
+ self.is_kohya = True
+ self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
+ self.alpha = weights.w["alpha"] # alpha is constraint
+ self.dim = self.oft_blocks.shape[0] # lora dim
+ # LyCORIS
+ elif "oft_diag" in weights.w.keys():
+ self.is_kohya = False
+ self.oft_blocks = weights.w["oft_diag"]
+ # self.alpha is unused
+ self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
+
+ is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
+ is_conv = type(self.sd_module) in [torch.nn.Conv2d]
+ is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
+
+ if is_linear:
+ self.out_dim = self.sd_module.out_features
+ elif is_conv:
+ self.out_dim = self.sd_module.out_channels
+ elif is_other_linear:
+ self.out_dim = self.sd_module.embed_dim
+
+ if self.is_kohya:
+ self.constraint = self.alpha * self.out_dim
+ self.num_blocks = self.dim
+ self.block_size = self.out_dim // self.dim
+ else:
+ self.constraint = None
+ self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
+
+ def calc_updown_kb(self, orig_weight, multiplier):
+ oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
+ oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
+
+ R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
+ R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device)
+
+ # This errors out for MultiheadAttention, might need to be handled up-stream
+ merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
+ merged_weight = torch.einsum(
+ 'k n m, k n ... -> k m ...',
+ R,
+ merged_weight
+ )
+ merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
+
+ updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
+ output_shape = orig_weight.shape
+ return self.finalize_updown(updown, orig_weight, output_shape)
+
+ def calc_updown(self, orig_weight):
+ # if alpha is a very small number as in coft, calc_scale() will return a almost zero number so we ignore it
+ multiplier = self.multiplier()
+ return self.calc_updown_kb(orig_weight, multiplier)
+
+ # override to remove the multiplier/scale factor; it's already multiplied in get_weight
+ def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
+ if self.bias is not None:
+ updown = updown.reshape(self.bias.shape)
+ updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
+ updown = updown.reshape(output_shape)
+
+ if len(output_shape) == 4:
+ updown = updown.reshape(output_shape)
+
+ if orig_weight.size().numel() == updown.size().numel():
+ updown = updown.reshape(orig_weight.shape)
+
+ if ex_bias is not None:
+ ex_bias = ex_bias * self.multiplier()
+
+ return updown, ex_bias
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index bc722e90..7f814706 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -1,17 +1,25 @@
+import logging
import os
import re
+import lora_patches
import network
import network_lora
+import network_glora
import network_hada
import network_ia3
import network_lokr
import network_full
+import network_norm
+import network_oft
import torch
from typing import Union
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
+import modules.textual_inversion.textual_inversion as textual_inversion
+
+from lora_logger import logger
module_types = [
network_lora.ModuleTypeLora(),
@@ -19,6 +27,9 @@ module_types = [
network_ia3.ModuleTypeIa3(),
network_lokr.ModuleTypeLokr(),
network_full.ModuleTypeFull(),
+ network_norm.ModuleTypeNorm(),
+ network_glora.ModuleTypeGLora(),
+ network_oft.ModuleTypeOFT(),
]
@@ -31,6 +42,8 @@ suffix_conversion = {
"resnets": {
"conv1": "in_layers_2",
"conv2": "out_layers_3",
+ "norm1": "in_layers_0",
+ "norm2": "out_layers_0",
"time_emb_proj": "emb_layers_1",
"conv_shortcut": "skip_connection",
}
@@ -143,9 +156,19 @@ def load_network(name, network_on_disk):
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
matched_networks = {}
+ bundle_embeddings = {}
for key_network, weight in sd.items():
key_network_without_network_parts, network_part = key_network.split(".", 1)
+ if key_network_without_network_parts == "bundle_emb":
+ emb_name, vec_name = network_part.split(".", 1)
+ emb_dict = bundle_embeddings.get(emb_name, {})
+ if vec_name.split('.')[0] == 'string_to_param':
+ _, k2 = vec_name.split('.', 1)
+ emb_dict['string_to_param'] = {k2: weight}
+ else:
+ emb_dict[vec_name] = weight
+ bundle_embeddings[emb_name] = emb_dict
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
@@ -168,6 +191,17 @@ def load_network(name, network_on_disk):
key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
+ # kohya_ss OFT module
+ elif sd_module is None and "oft_unet" in key_network_without_network_parts:
+ key = key_network_without_network_parts.replace("oft_unet", "diffusion_model")
+ sd_module = shared.sd_model.network_layer_mapping.get(key, None)
+
+ # KohakuBlueLeaf OFT module
+ if sd_module is None and "oft_diag" in key:
+ key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
+ key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
+ sd_module = shared.sd_model.network_layer_mapping.get(key, None)
+
if sd_module is None:
keys_failed_to_match[key_network] = key
continue
@@ -189,8 +223,16 @@ def load_network(name, network_on_disk):
net.modules[key] = net_module
+ embeddings = {}
+ for emb_name, data in bundle_embeddings.items():
+ embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name)
+ embedding.loaded = None
+ embeddings[emb_name] = embedding
+
+ net.bundle_embeddings = embeddings
+
if keys_failed_to_match:
- print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}")
+ logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
return net
@@ -203,13 +245,16 @@ def purge_networks_from_memory():
devices.torch_gc()
-
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
+ emb_db = sd_hijack.model_hijack.embedding_db
already_loaded = {}
for net in loaded_networks:
if net.name in names:
already_loaded[net.name] = net
+ for emb_name, embedding in net.bundle_embeddings.items():
+ if embedding.loaded:
+ emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
loaded_networks.clear()
@@ -244,7 +289,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
if net is None:
failed_to_load_networks.append(name)
- print(f"Couldn't find network with name {name}")
+ logging.info(f"Couldn't find network with name {name}")
continue
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
@@ -252,26 +297,54 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
loaded_networks.append(net)
+ for emb_name, embedding in net.bundle_embeddings.items():
+ if embedding.loaded is None and emb_name in emb_db.word_embeddings:
+ logger.warning(
+ f'Skip bundle embedding: "{emb_name}"'
+ ' as it was already loaded from embeddings folder'
+ )
+ continue
+
+ embedding.loaded = False
+ if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape:
+ embedding.loaded = True
+ emb_db.register_embedding(embedding, shared.sd_model)
+ else:
+ emb_db.skipped_embeddings[name] = embedding
+
if failed_to_load_networks:
- sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks))
+ sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
purge_networks_from_memory()
-def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
+def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
weights_backup = getattr(self, "network_weights_backup", None)
+ bias_backup = getattr(self, "network_bias_backup", None)
- if weights_backup is None:
+ if weights_backup is None and bias_backup is None:
return
- if isinstance(self, torch.nn.MultiheadAttention):
- self.in_proj_weight.copy_(weights_backup[0])
- self.out_proj.weight.copy_(weights_backup[1])
+ if weights_backup is not None:
+ if isinstance(self, torch.nn.MultiheadAttention):
+ self.in_proj_weight.copy_(weights_backup[0])
+ self.out_proj.weight.copy_(weights_backup[1])
+ else:
+ self.weight.copy_(weights_backup)
+
+ if bias_backup is not None:
+ if isinstance(self, torch.nn.MultiheadAttention):
+ self.out_proj.bias.copy_(bias_backup)
+ else:
+ self.bias.copy_(bias_backup)
else:
- self.weight.copy_(weights_backup)
+ if isinstance(self, torch.nn.MultiheadAttention):
+ self.out_proj.bias = None
+ else:
+ self.bias = None
-def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
+def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
"""
Applies the currently selected set of networks to the weights of torch layer self.
If weights already have this particular set of networks applied, does nothing.
@@ -286,7 +359,10 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
weights_backup = getattr(self, "network_weights_backup", None)
- if weights_backup is None:
+ if weights_backup is None and wanted_names != ():
+ if current_names != ():
+ raise RuntimeError("no backup weights found and current weights are not unchanged")
+
if isinstance(self, torch.nn.MultiheadAttention):
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
else:
@@ -294,21 +370,41 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.network_weights_backup = weights_backup
+ bias_backup = getattr(self, "network_bias_backup", None)
+ if bias_backup is None:
+ if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
+ bias_backup = self.out_proj.bias.to(devices.cpu, copy=True)
+ elif getattr(self, 'bias', None) is not None:
+ bias_backup = self.bias.to(devices.cpu, copy=True)
+ else:
+ bias_backup = None
+ self.network_bias_backup = bias_backup
+
if current_names != wanted_names:
network_restore_weights_from_backup(self)
for net in loaded_networks:
module = net.modules.get(network_layer_name, None)
if module is not None and hasattr(self, 'weight'):
- with torch.no_grad():
- updown = module.calc_updown(self.weight)
-
- if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
- # inpainting model. zero pad updown to make channel[1] 4 to 9
- updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
+ try:
+ with torch.no_grad():
+ updown, ex_bias = module.calc_updown(self.weight)
+
+ if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
+ # inpainting model. zero pad updown to make channel[1] 4 to 9
+ updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
+
+ self.weight += updown
+ if ex_bias is not None and hasattr(self, 'bias'):
+ if self.bias is None:
+ self.bias = torch.nn.Parameter(ex_bias)
+ else:
+ self.bias += ex_bias
+ except RuntimeError as e:
+ logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
+ extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
- self.weight += updown
- continue
+ continue
module_q = net.modules.get(network_layer_name + "_q_proj", None)
module_k = net.modules.get(network_layer_name + "_k_proj", None)
@@ -316,21 +412,33 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
module_out = net.modules.get(network_layer_name + "_out_proj", None)
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
- with torch.no_grad():
- updown_q = module_q.calc_updown(self.in_proj_weight)
- updown_k = module_k.calc_updown(self.in_proj_weight)
- updown_v = module_v.calc_updown(self.in_proj_weight)
- updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
- updown_out = module_out.calc_updown(self.out_proj.weight)
-
- self.in_proj_weight += updown_qkv
- self.out_proj.weight += updown_out
- continue
+ try:
+ with torch.no_grad():
+ updown_q, _ = module_q.calc_updown(self.in_proj_weight)
+ updown_k, _ = module_k.calc_updown(self.in_proj_weight)
+ updown_v, _ = module_v.calc_updown(self.in_proj_weight)
+ updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
+ updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)
+
+ self.in_proj_weight += updown_qkv
+ self.out_proj.weight += updown_out
+ if ex_bias is not None:
+ if self.out_proj.bias is None:
+ self.out_proj.bias = torch.nn.Parameter(ex_bias)
+ else:
+ self.out_proj.bias += ex_bias
+
+ except RuntimeError as e:
+ logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
+ extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
+
+ continue
if module is None:
continue
- print(f'failed to calculate network weights for layer {network_layer_name}')
+ logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation")
+ extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
self.network_current_names = wanted_names
@@ -357,7 +465,7 @@ def network_forward(module, input, original_forward):
if module is None:
continue
- y = module.forward(y, input)
+ y = module.forward(input, y)
return y
@@ -365,48 +473,79 @@ def network_forward(module, input, original_forward):
def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
self.network_current_names = ()
self.network_weights_backup = None
+ self.network_bias_backup = None
def network_Linear_forward(self, input):
if shared.opts.lora_functional:
- return network_forward(self, input, torch.nn.Linear_forward_before_network)
+ return network_forward(self, input, originals.Linear_forward)
network_apply_weights(self)
- return torch.nn.Linear_forward_before_network(self, input)
+ return originals.Linear_forward(self, input)
def network_Linear_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
- return torch.nn.Linear_load_state_dict_before_network(self, *args, **kwargs)
+ return originals.Linear_load_state_dict(self, *args, **kwargs)
def network_Conv2d_forward(self, input):
if shared.opts.lora_functional:
- return network_forward(self, input, torch.nn.Conv2d_forward_before_network)
+ return network_forward(self, input, originals.Conv2d_forward)
network_apply_weights(self)
- return torch.nn.Conv2d_forward_before_network(self, input)
+ return originals.Conv2d_forward(self, input)
def network_Conv2d_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
- return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)
+ return originals.Conv2d_load_state_dict(self, *args, **kwargs)
+
+
+def network_GroupNorm_forward(self, input):
+ if shared.opts.lora_functional:
+ return network_forward(self, input, originals.GroupNorm_forward)
+
+ network_apply_weights(self)
+
+ return originals.GroupNorm_forward(self, input)
+
+
+def network_GroupNorm_load_state_dict(self, *args, **kwargs):
+ network_reset_cached_weight(self)
+
+ return originals.GroupNorm_load_state_dict(self, *args, **kwargs)
+
+
+def network_LayerNorm_forward(self, input):
+ if shared.opts.lora_functional:
+ return network_forward(self, input, originals.LayerNorm_forward)
+
+ network_apply_weights(self)
+
+ return originals.LayerNorm_forward(self, input)
+
+
+def network_LayerNorm_load_state_dict(self, *args, **kwargs):
+ network_reset_cached_weight(self)
+
+ return originals.LayerNorm_load_state_dict(self, *args, **kwargs)
def network_MultiheadAttention_forward(self, *args, **kwargs):
network_apply_weights(self)
- return torch.nn.MultiheadAttention_forward_before_network(self, *args, **kwargs)
+ return originals.MultiheadAttention_forward(self, *args, **kwargs)
def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
- return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs)
+ return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs)
def list_available_networks():
@@ -474,9 +613,14 @@ def infotext_pasted(infotext, params):
params["Prompt"] += "\n" + "".join(added)
+originals: lora_patches.LoraPatches = None
+
+extra_network_lora = None
+
available_networks = {}
available_network_aliases = {}
loaded_networks = []
+loaded_bundle_embeddings = {}
networks_in_memory = {}
available_network_hash_lookup = {}
forbidden_network_aliases = {}
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py
index 6ab8b6e7..ef23968c 100644
--- a/extensions-builtin/Lora/scripts/lora_script.py
+++ b/extensions-builtin/Lora/scripts/lora_script.py
@@ -1,57 +1,30 @@
import re
-import torch
import gradio as gr
from fastapi import FastAPI
import network
import networks
import lora # noqa:F401
+import lora_patches
import extra_networks_lora
import ui_extra_networks_lora
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
+
def unload():
- torch.nn.Linear.forward = torch.nn.Linear_forward_before_network
- torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_network
- torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_network
- torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_network
- torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_network
- torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_network
+ networks.originals.undo()
def before_ui():
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
- extra_network = extra_networks_lora.ExtraNetworkLora()
- extra_networks.register_extra_network(extra_network)
- extra_networks.register_extra_network_alias(extra_network, "lyco")
-
-
-if not hasattr(torch.nn, 'Linear_forward_before_network'):
- torch.nn.Linear_forward_before_network = torch.nn.Linear.forward
-
-if not hasattr(torch.nn, 'Linear_load_state_dict_before_network'):
- torch.nn.Linear_load_state_dict_before_network = torch.nn.Linear._load_from_state_dict
-
-if not hasattr(torch.nn, 'Conv2d_forward_before_network'):
- torch.nn.Conv2d_forward_before_network = torch.nn.Conv2d.forward
-
-if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
- torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
-
-if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
- torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
+ networks.extra_network_lora = extra_networks_lora.ExtraNetworkLora()
+ extra_networks.register_extra_network(networks.extra_network_lora)
+ extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco")
-if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_network'):
- torch.nn.MultiheadAttention_load_state_dict_before_network = torch.nn.MultiheadAttention._load_from_state_dict
-torch.nn.Linear.forward = networks.network_Linear_forward
-torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
-torch.nn.Conv2d.forward = networks.network_Conv2d_forward
-torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
-torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
-torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict
+networks.originals = lora_patches.LoraPatches()
script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload)
diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py
index 390d9dde..c7011909 100644
--- a/extensions-builtin/Lora/ui_edit_user_metadata.py
+++ b/extensions-builtin/Lora/ui_edit_user_metadata.py
@@ -70,6 +70,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
metadata = item.get("metadata") or {}
keys = {
+ 'ss_output_name': "Output name:",
'ss_sd_model_name': "Model:",
'ss_clip_skip': "Clip skip:",
'ss_network_module': "Kohya module:",
diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py
index 3629e5c0..df02c663 100644
--- a/extensions-builtin/Lora/ui_extra_networks_lora.py
+++ b/extensions-builtin/Lora/ui_extra_networks_lora.py
@@ -17,6 +17,8 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
def create_item(self, name, index=None, enable_filter=True):
lora_on_disk = networks.available_networks.get(name)
+ if lora_on_disk is None:
+ return
path, ext = os.path.splitext(lora_on_disk.filename)
@@ -25,9 +27,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
item = {
"name": name,
"filename": lora_on_disk.filename,
+ "shorthash": lora_on_disk.shorthash,
"preview": self.find_preview(path),
"description": self.find_description(path),
- "search_term": self.search_terms_from_path(lora_on_disk.filename),
+ "search_term": self.search_terms_from_path(lora_on_disk.filename) + " " + (lora_on_disk.hash or ""),
"local_preview": f"{path}.{shared.opts.samples_format}",
"metadata": lora_on_disk.metadata,
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
@@ -65,9 +68,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
return item
def list_items(self):
- for index, name in enumerate(networks.available_networks):
+ # instantiate a list to protect against concurrent modification
+ names = list(networks.available_networks)
+ for index, name in enumerate(names):
item = self.create_item(name, index)
-
if item is not None:
yield item
diff --git a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js
index e7616b98..45c7600a 100644
--- a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js
+++ b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js
@@ -12,8 +12,22 @@ onUiLoaded(async() => {
"Sketch": elementIDs.sketch
};
+
// Helper functions
// Get active tab
+
+ /**
+ * Waits for an element to be present in the DOM.
+ */
+ const waitForElement = (id) => new Promise(resolve => {
+ const checkForElement = () => {
+ const element = document.querySelector(id);
+ if (element) return resolve(element);
+ setTimeout(checkForElement, 100);
+ };
+ checkForElement();
+ });
+
function getActiveTab(elements, all = false) {
const tabs = elements.img2imgTabs.querySelectorAll("button");
@@ -34,7 +48,7 @@ onUiLoaded(async() => {
// Wait until opts loaded
async function waitForOpts() {
- for (;;) {
+ for (; ;) {
if (window.opts && Object.keys(window.opts).length) {
return window.opts;
}
@@ -255,7 +269,7 @@ onUiLoaded(async() => {
input?.addEventListener("input", () => restoreImgRedMask(elements));
}
- function applyZoomAndPan(elemId) {
+ function applyZoomAndPan(elemId, isExtension = true) {
const targetElement = gradioApp().querySelector(elemId);
if (!targetElement) {
@@ -367,6 +381,12 @@ onUiLoaded(async() => {
panY: 0
};
+ if (isExtension) {
+ targetElement.style.overflow = "hidden";
+ }
+
+ targetElement.isZoomed = false;
+
fixCanvas();
targetElement.style.transform = `scale(${elemData[elemId].zoomLevel}) translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px)`;
@@ -377,8 +397,27 @@ onUiLoaded(async() => {
toggleOverlap("off");
fullScreenMode = false;
+ const closeBtn = targetElement.querySelector("button[aria-label='Remove Image']");
+ if (closeBtn) {
+ closeBtn.addEventListener("click", resetZoom);
+ }
+
+ if (canvas && isExtension) {
+ const parentElement = targetElement.closest('[id^="component-"]');
+ if (
+ canvas &&
+ parseFloat(canvas.style.width) > parentElement.offsetWidth &&
+ parseFloat(targetElement.style.width) > parentElement.offsetWidth
+ ) {
+ fitToElement();
+ return;
+ }
+
+ }
+
if (
canvas &&
+ !isExtension &&
parseFloat(canvas.style.width) > 865 &&
parseFloat(targetElement.style.width) > 865
) {
@@ -387,9 +426,6 @@ onUiLoaded(async() => {
}
targetElement.style.width = "";
- if (canvas) {
- targetElement.style.height = canvas.style.height;
- }
}
// Toggle the zIndex of the target element between two values, allowing it to overlap or be overlapped by other elements
@@ -445,7 +481,7 @@ onUiLoaded(async() => {
// Update the zoom level and pan position of the target element based on the values of the zoomLevel, panX and panY variables
function updateZoom(newZoomLevel, mouseX, mouseY) {
- newZoomLevel = Math.max(0.5, Math.min(newZoomLevel, 15));
+ newZoomLevel = Math.max(0.1, Math.min(newZoomLevel, 15));
elemData[elemId].panX +=
mouseX - (mouseX * newZoomLevel) / elemData[elemId].zoomLevel;
@@ -456,6 +492,10 @@ onUiLoaded(async() => {
targetElement.style.transform = `translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px) scale(${newZoomLevel})`;
toggleOverlap("on");
+ if (isExtension) {
+ targetElement.style.overflow = "visible";
+ }
+
return newZoomLevel;
}
@@ -478,10 +518,12 @@ onUiLoaded(async() => {
fullScreenMode = false;
elemData[elemId].zoomLevel = updateZoom(
elemData[elemId].zoomLevel +
- (operation === "+" ? delta : -delta),
+ (operation === "+" ? delta : -delta),
zoomPosX - targetElement.getBoundingClientRect().left,
zoomPosY - targetElement.getBoundingClientRect().top
);
+
+ targetElement.isZoomed = true;
}
}
@@ -495,10 +537,19 @@ onUiLoaded(async() => {
//Reset Zoom
targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`;
+ let parentElement;
+
+ if (isExtension) {
+ parentElement = targetElement.closest('[id^="component-"]');
+ } else {
+ parentElement = targetElement.parentElement;
+ }
+
+
// Get element and screen dimensions
const elementWidth = targetElement.offsetWidth;
const elementHeight = targetElement.offsetHeight;
- const parentElement = targetElement.parentElement;
+
const screenWidth = parentElement.clientWidth;
const screenHeight = parentElement.clientHeight;
@@ -551,8 +602,12 @@ onUiLoaded(async() => {
if (!canvas) return;
- if (canvas.offsetWidth > 862) {
- targetElement.style.width = canvas.offsetWidth + "px";
+ if (canvas.offsetWidth > 862 || isExtension) {
+ targetElement.style.width = (canvas.offsetWidth + 2) + "px";
+ }
+
+ if (isExtension) {
+ targetElement.style.overflow = "visible";
}
if (fullScreenMode) {
@@ -657,17 +712,18 @@ onUiLoaded(async() => {
// Simulation of the function to put a long image into the screen.
// We detect if an image has a scroll bar or not, make a fullscreen to reveal the image, then reduce it to fit into the element.
// We hide the image and show it to the user when it is ready.
- function autoExpand(e) {
- const canvas = document.querySelector(`${elemId} canvas[key="interface"]`);
- const isMainTab = activeElement === elementIDs.inpaint || activeElement === elementIDs.inpaintSketch || activeElement === elementIDs.sketch;
- if (canvas && isMainTab) {
- if (hasHorizontalScrollbar(targetElement)) {
+ targetElement.isExpanded = false;
+ function autoExpand() {
+ const canvas = document.querySelector(`${elemId} canvas[key="interface"]`);
+ if (canvas) {
+ if (hasHorizontalScrollbar(targetElement) && targetElement.isExpanded === false) {
targetElement.style.visibility = "hidden";
setTimeout(() => {
fitToScreen();
resetZoom();
targetElement.style.visibility = "visible";
+ targetElement.isExpanded = true;
}, 10);
}
}
@@ -675,9 +731,24 @@ onUiLoaded(async() => {
targetElement.addEventListener("mousemove", getMousePosition);
+ //observers
+ // Creating an observer with a callback function to handle DOM changes
+ const observer = new MutationObserver((mutationsList, observer) => {
+ for (let mutation of mutationsList) {
+ // If the style attribute of the canvas has changed, by observation it happens only when the picture changes
+ if (mutation.type === 'attributes' && mutation.attributeName === 'style' &&
+ mutation.target.tagName.toLowerCase() === 'canvas') {
+ targetElement.isExpanded = false;
+ setTimeout(resetZoom, 10);
+ }
+ }
+ });
+
// Apply auto expand if enabled
if (hotkeysConfig.canvas_auto_expand) {
targetElement.addEventListener("mousemove", autoExpand);
+ // Set up an observer to track attribute changes
+ observer.observe(targetElement, {attributes: true, childList: true, subtree: true});
}
// Handle events only inside the targetElement
@@ -784,6 +855,11 @@ onUiLoaded(async() => {
if (isMoving && elemId === activeElement) {
updatePanPosition(e.movementX, e.movementY);
targetElement.style.pointerEvents = "none";
+
+ if (isExtension) {
+ targetElement.style.overflow = "visible";
+ }
+
} else {
targetElement.style.pointerEvents = "auto";
}
@@ -794,13 +870,93 @@ onUiLoaded(async() => {
isMoving = false;
};
+ // Checks for extension
+ function checkForOutBox() {
+ const parentElement = targetElement.closest('[id^="component-"]');
+ if (parentElement.offsetWidth < targetElement.offsetWidth && !targetElement.isExpanded) {
+ resetZoom();
+ targetElement.isExpanded = true;
+ }
+
+ if (parentElement.offsetWidth < targetElement.offsetWidth && elemData[elemId].zoomLevel == 1) {
+ resetZoom();
+ }
+
+ if (parentElement.offsetWidth < targetElement.offsetWidth && targetElement.offsetWidth * elemData[elemId].zoomLevel > parentElement.offsetWidth && elemData[elemId].zoomLevel < 1 && !targetElement.isZoomed) {
+ resetZoom();
+ }
+ }
+
+ if (isExtension) {
+ targetElement.addEventListener("mousemove", checkForOutBox);
+ }
+
+
+ window.addEventListener('resize', (e) => {
+ resetZoom();
+
+ if (isExtension) {
+ targetElement.isExpanded = false;
+ targetElement.isZoomed = false;
+ }
+ });
+
gradioApp().addEventListener("mousemove", handleMoveByKey);
+
+
}
- applyZoomAndPan(elementIDs.sketch);
- applyZoomAndPan(elementIDs.inpaint);
- applyZoomAndPan(elementIDs.inpaintSketch);
+ applyZoomAndPan(elementIDs.sketch, false);
+ applyZoomAndPan(elementIDs.inpaint, false);
+ applyZoomAndPan(elementIDs.inpaintSketch, false);
// Make the function global so that other extensions can take advantage of this solution
- window.applyZoomAndPan = applyZoomAndPan;
+ const applyZoomAndPanIntegration = async(id, elementIDs) => {
+ const mainEl = document.querySelector(id);
+ if (id.toLocaleLowerCase() === "none") {
+ for (const elementID of elementIDs) {
+ const el = await waitForElement(elementID);
+ if (!el) break;
+ applyZoomAndPan(elementID);
+ }
+ return;
+ }
+
+ if (!mainEl) return;
+ mainEl.addEventListener("click", async() => {
+ for (const elementID of elementIDs) {
+ const el = await waitForElement(elementID);
+ if (!el) break;
+ applyZoomAndPan(elementID);
+ }
+ }, {once: true});
+ };
+
+ window.applyZoomAndPan = applyZoomAndPan; // Only 1 elements, argument elementID, for example applyZoomAndPan("#txt2img_controlnet_ControlNet_input_image")
+
+ window.applyZoomAndPanIntegration = applyZoomAndPanIntegration; // for any extension
+
+ /*
+ The function `applyZoomAndPanIntegration` takes two arguments:
+
+ 1. `id`: A string identifier for the element to which zoom and pan functionality will be applied on click.
+ If the `id` value is "none", the functionality will be applied to all elements specified in the second argument without a click event.
+
+ 2. `elementIDs`: An array of string identifiers for elements. Zoom and pan functionality will be applied to each of these elements on click of the element specified by the first argument.
+ If "none" is specified in the first argument, the functionality will be applied to each of these elements without a click event.
+
+ Example usage:
+ applyZoomAndPanIntegration("#txt2img_controlnet", ["#txt2img_controlnet_ControlNet_input_image"]);
+ In this example, zoom and pan functionality will be applied to the element with the identifier "txt2img_controlnet_ControlNet_input_image" upon clicking the element with the identifier "txt2img_controlnet".
+ */
+
+ // More examples
+ // Add integration with ControlNet txt2img One TAB
+ // applyZoomAndPanIntegration("#txt2img_controlnet", ["#txt2img_controlnet_ControlNet_input_image"]);
+
+ // Add integration with ControlNet txt2img Tabs
+ // applyZoomAndPanIntegration("#txt2img_controlnet",Array.from({ length: 10 }, (_, i) => `#txt2img_controlnet_ControlNet-${i}_input_image`));
+
+ // Add integration with Inpaint Anything
+ // applyZoomAndPanIntegration("None", ["#ia_sam_image", "#ia_sel_mask"]);
});
diff --git a/extensions-builtin/canvas-zoom-and-pan/style.css b/extensions-builtin/canvas-zoom-and-pan/style.css
index 6bcc9570..5d8054e6 100644
--- a/extensions-builtin/canvas-zoom-and-pan/style.css
+++ b/extensions-builtin/canvas-zoom-and-pan/style.css
@@ -61,3 +61,6 @@
to {opacity: 1;}
}
+.styler {
+ overflow:inherit !important;
+} \ No newline at end of file
diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py
index 588b64d2..983f87ff 100644
--- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py
+++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py
@@ -22,22 +22,23 @@ class ExtraOptionsSection(scripts.Script):
self.comps = []
self.setting_names = []
self.infotext_fields = []
+ extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img
mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping}
with gr.Blocks() as interface:
- with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group():
+ with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and extra_options else gr.Group():
- row_count = math.ceil(len(shared.opts.extra_options) / shared.opts.extra_options_cols)
+ row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols)
for row in range(row_count):
with gr.Row():
for col in range(shared.opts.extra_options_cols):
index = row * shared.opts.extra_options_cols + col
- if index >= len(shared.opts.extra_options):
+ if index >= len(extra_options):
break
- setting_name = shared.opts.extra_options[index]
+ setting_name = extra_options[index]
with FormColumn():
comp = ui_settings.create_setting_component(setting_name)
@@ -64,7 +65,8 @@ class ExtraOptionsSection(scripts.Script):
shared.options_templates.update(shared.options_section(('ui', "User interface"), {
- "extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_reload_ui(),
+ "extra_options_txt2img": shared.OptionInfo([], "Options in main UI - txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(),
+ "extra_options_img2img": shared.OptionInfo([], "Options in main UI - img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(),
"extra_options_cols": shared.OptionInfo(1, "Options in main UI - number of columns", gr.Number, {"precision": 0}).needs_reload_ui(),
"extra_options_accordion": shared.OptionInfo(False, "Options in main UI - place into an accordion").needs_reload_ui()
}))
diff --git a/extensions-builtin/hypertile/hypertile.py b/extensions-builtin/hypertile/hypertile.py
new file mode 100644
index 00000000..a40c1311
--- /dev/null
+++ b/extensions-builtin/hypertile/hypertile.py
@@ -0,0 +1,348 @@
+"""
+Hypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE
+Warn: The patch works well only if the input image has a width and height that are multiples of 128
+Original author: @tfernd Github: https://github.com/tfernd/HyperTile
+"""
+
+from __future__ import annotations
+
+import functools
+from dataclasses import dataclass
+from typing import Callable
+from typing_extensions import Literal
+
+import logging
+from functools import wraps, cache
+from contextlib import contextmanager
+
+import math
+import torch.nn as nn
+import random
+
+from einops import rearrange
+
+
+@dataclass
+class HypertileParams:
+ depth = 0
+ layer_name = ""
+ tile_size: int = 0
+ swap_size: int = 0
+ aspect_ratio: float = 1.0
+ forward = None
+ enabled = False
+
+
+
+# TODO add SD-XL layers
+DEPTH_LAYERS = {
+ 0: [
+ # SD 1.5 U-Net (diffusers)
+ "down_blocks.0.attentions.0.transformer_blocks.0.attn1",
+ "down_blocks.0.attentions.1.transformer_blocks.0.attn1",
+ "up_blocks.3.attentions.0.transformer_blocks.0.attn1",
+ "up_blocks.3.attentions.1.transformer_blocks.0.attn1",
+ "up_blocks.3.attentions.2.transformer_blocks.0.attn1",
+ # SD 1.5 U-Net (ldm)
+ "input_blocks.1.1.transformer_blocks.0.attn1",
+ "input_blocks.2.1.transformer_blocks.0.attn1",
+ "output_blocks.9.1.transformer_blocks.0.attn1",
+ "output_blocks.10.1.transformer_blocks.0.attn1",
+ "output_blocks.11.1.transformer_blocks.0.attn1",
+ # SD 1.5 VAE
+ "decoder.mid_block.attentions.0",
+ "decoder.mid.attn_1",
+ ],
+ 1: [
+ # SD 1.5 U-Net (diffusers)
+ "down_blocks.1.attentions.0.transformer_blocks.0.attn1",
+ "down_blocks.1.attentions.1.transformer_blocks.0.attn1",
+ "up_blocks.2.attentions.0.transformer_blocks.0.attn1",
+ "up_blocks.2.attentions.1.transformer_blocks.0.attn1",
+ "up_blocks.2.attentions.2.transformer_blocks.0.attn1",
+ # SD 1.5 U-Net (ldm)
+ "input_blocks.4.1.transformer_blocks.0.attn1",
+ "input_blocks.5.1.transformer_blocks.0.attn1",
+ "output_blocks.6.1.transformer_blocks.0.attn1",
+ "output_blocks.7.1.transformer_blocks.0.attn1",
+ "output_blocks.8.1.transformer_blocks.0.attn1",
+ ],
+ 2: [
+ # SD 1.5 U-Net (diffusers)
+ "down_blocks.2.attentions.0.transformer_blocks.0.attn1",
+ "down_blocks.2.attentions.1.transformer_blocks.0.attn1",
+ "up_blocks.1.attentions.0.transformer_blocks.0.attn1",
+ "up_blocks.1.attentions.1.transformer_blocks.0.attn1",
+ "up_blocks.1.attentions.2.transformer_blocks.0.attn1",
+ # SD 1.5 U-Net (ldm)
+ "input_blocks.7.1.transformer_blocks.0.attn1",
+ "input_blocks.8.1.transformer_blocks.0.attn1",
+ "output_blocks.3.1.transformer_blocks.0.attn1",
+ "output_blocks.4.1.transformer_blocks.0.attn1",
+ "output_blocks.5.1.transformer_blocks.0.attn1",
+ ],
+ 3: [
+ # SD 1.5 U-Net (diffusers)
+ "mid_block.attentions.0.transformer_blocks.0.attn1",
+ # SD 1.5 U-Net (ldm)
+ "middle_block.1.transformer_blocks.0.attn1",
+ ],
+}
+# XL layers, thanks for GitHub@gel-crabs for the help
+DEPTH_LAYERS_XL = {
+ 0: [
+ # SD 1.5 U-Net (diffusers)
+ "down_blocks.0.attentions.0.transformer_blocks.0.attn1",
+ "down_blocks.0.attentions.1.transformer_blocks.0.attn1",
+ "up_blocks.3.attentions.0.transformer_blocks.0.attn1",
+ "up_blocks.3.attentions.1.transformer_blocks.0.attn1",
+ "up_blocks.3.attentions.2.transformer_blocks.0.attn1",
+ # SD 1.5 U-Net (ldm)
+ "input_blocks.4.1.transformer_blocks.0.attn1",
+ "input_blocks.5.1.transformer_blocks.0.attn1",
+ "output_blocks.3.1.transformer_blocks.0.attn1",
+ "output_blocks.4.1.transformer_blocks.0.attn1",
+ "output_blocks.5.1.transformer_blocks.0.attn1",
+ # SD 1.5 VAE
+ "decoder.mid_block.attentions.0",
+ "decoder.mid.attn_1",
+ ],
+ 1: [
+ # SD 1.5 U-Net (diffusers)
+ #"down_blocks.1.attentions.0.transformer_blocks.0.attn1",
+ #"down_blocks.1.attentions.1.transformer_blocks.0.attn1",
+ #"up_blocks.2.attentions.0.transformer_blocks.0.attn1",
+ #"up_blocks.2.attentions.1.transformer_blocks.0.attn1",
+ #"up_blocks.2.attentions.2.transformer_blocks.0.attn1",
+ # SD 1.5 U-Net (ldm)
+ "input_blocks.4.1.transformer_blocks.1.attn1",
+ "input_blocks.5.1.transformer_blocks.1.attn1",
+ "output_blocks.3.1.transformer_blocks.1.attn1",
+ "output_blocks.4.1.transformer_blocks.1.attn1",
+ "output_blocks.5.1.transformer_blocks.1.attn1",
+ "input_blocks.7.1.transformer_blocks.0.attn1",
+ "input_blocks.8.1.transformer_blocks.0.attn1",
+ "output_blocks.0.1.transformer_blocks.0.attn1",
+ "output_blocks.1.1.transformer_blocks.0.attn1",
+ "output_blocks.2.1.transformer_blocks.0.attn1",
+ "input_blocks.7.1.transformer_blocks.1.attn1",
+ "input_blocks.8.1.transformer_blocks.1.attn1",
+ "output_blocks.0.1.transformer_blocks.1.attn1",
+ "output_blocks.1.1.transformer_blocks.1.attn1",
+ "output_blocks.2.1.transformer_blocks.1.attn1",
+ "input_blocks.7.1.transformer_blocks.2.attn1",
+ "input_blocks.8.1.transformer_blocks.2.attn1",
+ "output_blocks.0.1.transformer_blocks.2.attn1",
+ "output_blocks.1.1.transformer_blocks.2.attn1",
+ "output_blocks.2.1.transformer_blocks.2.attn1",
+ "input_blocks.7.1.transformer_blocks.3.attn1",
+ "input_blocks.8.1.transformer_blocks.3.attn1",
+ "output_blocks.0.1.transformer_blocks.3.attn1",
+ "output_blocks.1.1.transformer_blocks.3.attn1",
+ "output_blocks.2.1.transformer_blocks.3.attn1",
+ "input_blocks.7.1.transformer_blocks.4.attn1",
+ "input_blocks.8.1.transformer_blocks.4.attn1",
+ "output_blocks.0.1.transformer_blocks.4.attn1",
+ "output_blocks.1.1.transformer_blocks.4.attn1",
+ "output_blocks.2.1.transformer_blocks.4.attn1",
+ "input_blocks.7.1.transformer_blocks.5.attn1",
+ "input_blocks.8.1.transformer_blocks.5.attn1",
+ "output_blocks.0.1.transformer_blocks.5.attn1",
+ "output_blocks.1.1.transformer_blocks.5.attn1",
+ "output_blocks.2.1.transformer_blocks.5.attn1",
+ "input_blocks.7.1.transformer_blocks.6.attn1",
+ "input_blocks.8.1.transformer_blocks.6.attn1",
+ "output_blocks.0.1.transformer_blocks.6.attn1",
+ "output_blocks.1.1.transformer_blocks.6.attn1",
+ "output_blocks.2.1.transformer_blocks.6.attn1",
+ "input_blocks.7.1.transformer_blocks.7.attn1",
+ "input_blocks.8.1.transformer_blocks.7.attn1",
+ "output_blocks.0.1.transformer_blocks.7.attn1",
+ "output_blocks.1.1.transformer_blocks.7.attn1",
+ "output_blocks.2.1.transformer_blocks.7.attn1",
+ "input_blocks.7.1.transformer_blocks.8.attn1",
+ "input_blocks.8.1.transformer_blocks.8.attn1",
+ "output_blocks.0.1.transformer_blocks.8.attn1",
+ "output_blocks.1.1.transformer_blocks.8.attn1",
+ "output_blocks.2.1.transformer_blocks.8.attn1",
+ "input_blocks.7.1.transformer_blocks.9.attn1",
+ "input_blocks.8.1.transformer_blocks.9.attn1",
+ "output_blocks.0.1.transformer_blocks.9.attn1",
+ "output_blocks.1.1.transformer_blocks.9.attn1",
+ "output_blocks.2.1.transformer_blocks.9.attn1",
+ ],
+ 2: [
+ # SD 1.5 U-Net (diffusers)
+ "mid_block.attentions.0.transformer_blocks.0.attn1",
+ # SD 1.5 U-Net (ldm)
+ "middle_block.1.transformer_blocks.0.attn1",
+ "middle_block.1.transformer_blocks.1.attn1",
+ "middle_block.1.transformer_blocks.2.attn1",
+ "middle_block.1.transformer_blocks.3.attn1",
+ "middle_block.1.transformer_blocks.4.attn1",
+ "middle_block.1.transformer_blocks.5.attn1",
+ "middle_block.1.transformer_blocks.6.attn1",
+ "middle_block.1.transformer_blocks.7.attn1",
+ "middle_block.1.transformer_blocks.8.attn1",
+ "middle_block.1.transformer_blocks.9.attn1",
+ ],
+ 3 : [] # TODO - separate layers for SD-XL
+}
+
+
+RNG_INSTANCE = random.Random()
+
+
+def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
+ """
+ Returns a random divisor of value that
+ x * min_value <= value
+ if max_options is 1, the behavior is deterministic
+ """
+ min_value = min(min_value, value)
+
+ # All big divisors of value (inclusive)
+ divisors = [i for i in range(min_value, value + 1) if value % i == 0] # divisors in small -> big order
+
+ ns = [value // i for i in divisors[:max_options]] # has at least 1 element # big -> small order
+
+ idx = RNG_INSTANCE.randint(0, len(ns) - 1)
+
+ return ns[idx]
+
+
+def set_hypertile_seed(seed: int) -> None:
+ RNG_INSTANCE.seed(seed)
+
+
+@functools.cache
+def largest_tile_size_available(width: int, height: int) -> int:
+ """
+ Calculates the largest tile size available for a given width and height
+ Tile size is always a power of 2
+ """
+ gcd = math.gcd(width, height)
+ largest_tile_size_available = 1
+ while gcd % (largest_tile_size_available * 2) == 0:
+ largest_tile_size_available *= 2
+ return largest_tile_size_available
+
+
+def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]:
+ """
+ Finds h and w such that h*w = hw and h/w = aspect_ratio
+ We check all possible divisors of hw and return the closest to the aspect ratio
+ """
+ divisors = [i for i in range(2, hw + 1) if hw % i == 0] # all divisors of hw
+ pairs = [(i, hw // i) for i in divisors] # all pairs of divisors of hw
+ ratios = [w/h for h, w in pairs] # all ratios of pairs of divisors of hw
+ closest_ratio = min(ratios, key=lambda x: abs(x - aspect_ratio)) # closest ratio to aspect_ratio
+ closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio
+ return closest_pair
+
+
+@cache
+def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]:
+ """
+ Finds h and w such that h*w = hw and h/w = aspect_ratio
+ """
+ h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
+ # find h and w such that h*w = hw and h/w = aspect_ratio
+ if h * w != hw:
+ w_candidate = hw / h
+ # check if w is an integer
+ if not w_candidate.is_integer():
+ h_candidate = hw / w
+ # check if h is an integer
+ if not h_candidate.is_integer():
+ return iterative_closest_divisors(hw, aspect_ratio)
+ else:
+ h = int(h_candidate)
+ else:
+ w = int(w_candidate)
+ return h, w
+
+
+def self_attn_forward(params: HypertileParams, scale_depth=True) -> Callable:
+
+ @wraps(params.forward)
+ def wrapper(*args, **kwargs):
+ if not params.enabled:
+ return params.forward(*args, **kwargs)
+
+ latent_tile_size = max(128, params.tile_size) // 8
+ x = args[0]
+
+ # VAE
+ if x.ndim == 4:
+ b, c, h, w = x.shape
+
+ nh = random_divisor(h, latent_tile_size, params.swap_size)
+ nw = random_divisor(w, latent_tile_size, params.swap_size)
+
+ if nh * nw > 1:
+ x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) # split into nh * nw tiles
+
+ out = params.forward(x, *args[1:], **kwargs)
+
+ if nh * nw > 1:
+ out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw)
+
+ # U-Net
+ else:
+ hw: int = x.size(1)
+ h, w = find_hw_candidates(hw, params.aspect_ratio)
+ assert h * w == hw, f"Invalid aspect ratio {params.aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}"
+
+ factor = 2 ** params.depth if scale_depth else 1
+ nh = random_divisor(h, latent_tile_size * factor, params.swap_size)
+ nw = random_divisor(w, latent_tile_size * factor, params.swap_size)
+
+ if nh * nw > 1:
+ x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
+
+ out = params.forward(x, *args[1:], **kwargs)
+
+ if nh * nw > 1:
+ out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
+ out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
+
+ return out
+
+ return wrapper
+
+
+def hypertile_hook_model(model: nn.Module, width, height, *, enable=False, tile_size_max=128, swap_size=1, max_depth=3, is_sdxl=False):
+ hypertile_layers = getattr(model, "__webui_hypertile_layers", None)
+ if hypertile_layers is None:
+ if not enable:
+ return
+
+ hypertile_layers = {}
+ layers = DEPTH_LAYERS_XL if is_sdxl else DEPTH_LAYERS
+
+ for depth in range(4):
+ for layer_name, module in model.named_modules():
+ if any(layer_name.endswith(try_name) for try_name in layers[depth]):
+ params = HypertileParams()
+ module.__webui_hypertile_params = params
+ params.forward = module.forward
+ params.depth = depth
+ params.layer_name = layer_name
+ module.forward = self_attn_forward(params)
+
+ hypertile_layers[layer_name] = 1
+
+ model.__webui_hypertile_layers = hypertile_layers
+
+ aspect_ratio = width / height
+ tile_size = min(largest_tile_size_available(width, height), tile_size_max)
+
+ for layer_name, module in model.named_modules():
+ if layer_name in hypertile_layers:
+ params = module.__webui_hypertile_params
+
+ params.tile_size = tile_size
+ params.swap_size = swap_size
+ params.aspect_ratio = aspect_ratio
+ params.enabled = enable and params.depth <= max_depth
diff --git a/extensions-builtin/hypertile/scripts/hypertile_script.py b/extensions-builtin/hypertile/scripts/hypertile_script.py
new file mode 100644
index 00000000..3cc29cd1
--- /dev/null
+++ b/extensions-builtin/hypertile/scripts/hypertile_script.py
@@ -0,0 +1,73 @@
+import hypertile
+from modules import scripts, script_callbacks, shared
+
+
+class ScriptHypertile(scripts.Script):
+ name = "Hypertile"
+
+ def title(self):
+ return self.name
+
+ def show(self, is_img2img):
+ return scripts.AlwaysVisible
+
+ def process(self, p, *args):
+ hypertile.set_hypertile_seed(p.all_seeds[0])
+
+ configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet)
+
+ def before_hr(self, p, *args):
+ configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=shared.opts.hypertile_enable_unet_secondpass or shared.opts.hypertile_enable_unet)
+
+
+def configure_hypertile(width, height, enable_unet=True):
+ hypertile.hypertile_hook_model(
+ shared.sd_model.first_stage_model,
+ width,
+ height,
+ swap_size=shared.opts.hypertile_swap_size_vae,
+ max_depth=shared.opts.hypertile_max_depth_vae,
+ tile_size_max=shared.opts.hypertile_max_tile_vae,
+ enable=shared.opts.hypertile_enable_vae,
+ )
+
+ hypertile.hypertile_hook_model(
+ shared.sd_model.model,
+ width,
+ height,
+ swap_size=shared.opts.hypertile_swap_size_unet,
+ max_depth=shared.opts.hypertile_max_depth_unet,
+ tile_size_max=shared.opts.hypertile_max_tile_unet,
+ enable=enable_unet,
+ is_sdxl=shared.sd_model.is_sdxl
+ )
+
+
+def on_ui_settings():
+ import gradio as gr
+
+ options = {
+ "hypertile_explanation": shared.OptionHTML("""
+ <a href='https://github.com/tfernd/HyperTile'>Hypertile</a> optimizes the self-attention layer within U-Net and VAE models,
+ resulting in a reduction in computation time ranging from 1 to 4 times. The larger the generated image is, the greater the
+ benefit.
+ """),
+
+ "hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net").info("noticeable change in details of the generated picture; if enabled, overrides the setting below"),
+ "hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass"),
+ "hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}),
+ "hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
+ "hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-net swap size", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}),
+
+ "hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE").info("minimal change in the generated picture"),
+ "hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}),
+ "hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
+ "hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}),
+ }
+
+ for name, opt in options.items():
+ opt.section = ('hypertile', "Hypertile")
+ shared.opts.add_option(name, opt)
+
+
+script_callbacks.on_ui_settings(on_ui_settings)
diff --git a/extensions-builtin/mobile/javascript/mobile.js b/extensions-builtin/mobile/javascript/mobile.js
index 12cae4b7..bff1aced 100644
--- a/extensions-builtin/mobile/javascript/mobile.js
+++ b/extensions-builtin/mobile/javascript/mobile.js
@@ -12,6 +12,8 @@ function isMobile() {
}
function reportWindowSize() {
+ if (gradioApp().querySelector('.toprow-compact-tools')) return; // not applicable for compact prompt layout
+
var currentlyMobile = isMobile();
if (currentlyMobile == isSetupForMobile) return;
isSetupForMobile = currentlyMobile;
@@ -20,7 +22,13 @@ function reportWindowSize() {
var button = gradioApp().getElementById(tab + '_generate_box');
var target = gradioApp().getElementById(currentlyMobile ? tab + '_results' : tab + '_actions_column');
target.insertBefore(button, target.firstElementChild);
+
+ gradioApp().getElementById(tab + '_results').classList.toggle('mobile', currentlyMobile);
}
}
window.addEventListener("resize", reportWindowSize);
+
+onUiLoaded(function() {
+ reportWindowSize();
+});