aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin
diff options
context:
space:
mode:
Diffstat (limited to 'extensions-builtin')
-rw-r--r--extensions-builtin/Lora/network.py35
-rw-r--r--extensions-builtin/Lora/network_full.py4
-rw-r--r--extensions-builtin/Lora/network_glora.py10
-rw-r--r--extensions-builtin/Lora/network_hada.py12
-rw-r--r--extensions-builtin/Lora/network_ia3.py2
-rw-r--r--extensions-builtin/Lora/network_lokr.py18
-rw-r--r--extensions-builtin/Lora/network_lora.py6
-rw-r--r--extensions-builtin/Lora/network_norm.py4
-rw-r--r--extensions-builtin/Lora/network_oft.py100
-rw-r--r--extensions-builtin/Lora/networks.py42
-rw-r--r--extensions-builtin/Lora/preload.py5
-rw-r--r--extensions-builtin/Lora/scripts/lora_script.py2
-rw-r--r--extensions-builtin/Lora/ui_edit_user_metadata.py9
-rw-r--r--extensions-builtin/Lora/ui_extra_networks_lora.py10
-rw-r--r--extensions-builtin/ScuNET/scripts/scunet_model.py94
-rw-r--r--extensions-builtin/ScuNET/scunet_model_arch.py268
-rw-r--r--extensions-builtin/SwinIR/scripts/swinir_model.py163
-rw-r--r--extensions-builtin/SwinIR/swinir_model_arch.py867
-rw-r--r--extensions-builtin/SwinIR/swinir_model_arch_v2.py1017
-rw-r--r--extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js8
-rw-r--r--extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py4
-rw-r--r--extensions-builtin/extra-options-section/scripts/extra_options_section.py4
-rw-r--r--extensions-builtin/soft-inpainting/scripts/soft_inpainting.py747
23 files changed, 982 insertions, 2449 deletions
diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py
index 6021fd8d..b8fd9194 100644
--- a/extensions-builtin/Lora/network.py
+++ b/extensions-builtin/Lora/network.py
@@ -3,6 +3,9 @@ import os
from collections import namedtuple
import enum
+import torch.nn as nn
+import torch.nn.functional as F
+
from modules import sd_models, cache, errors, hashes, shared
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
@@ -115,6 +118,29 @@ class NetworkModule:
if hasattr(self.sd_module, 'weight'):
self.shape = self.sd_module.weight.shape
+ self.ops = None
+ self.extra_kwargs = {}
+ if isinstance(self.sd_module, nn.Conv2d):
+ self.ops = F.conv2d
+ self.extra_kwargs = {
+ 'stride': self.sd_module.stride,
+ 'padding': self.sd_module.padding
+ }
+ elif isinstance(self.sd_module, nn.Linear):
+ self.ops = F.linear
+ elif isinstance(self.sd_module, nn.LayerNorm):
+ self.ops = F.layer_norm
+ self.extra_kwargs = {
+ 'normalized_shape': self.sd_module.normalized_shape,
+ 'eps': self.sd_module.eps
+ }
+ elif isinstance(self.sd_module, nn.GroupNorm):
+ self.ops = F.group_norm
+ self.extra_kwargs = {
+ 'num_groups': self.sd_module.num_groups,
+ 'eps': self.sd_module.eps
+ }
+
self.dim = None
self.bias = weights.w.get("bias")
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
@@ -137,7 +163,7 @@ class NetworkModule:
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None:
updown = updown.reshape(self.bias.shape)
- updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
+ updown += self.bias.to(orig_weight.device, dtype=updown.dtype)
updown = updown.reshape(output_shape)
if len(output_shape) == 4:
@@ -155,5 +181,10 @@ class NetworkModule:
raise NotImplementedError()
def forward(self, x, y):
- raise NotImplementedError()
+ """A general forward implementation for all modules"""
+ if self.ops is None:
+ raise NotImplementedError()
+ else:
+ updown, ex_bias = self.calc_updown(self.sd_module.weight)
+ return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs)
diff --git a/extensions-builtin/Lora/network_full.py b/extensions-builtin/Lora/network_full.py
index bf6930e9..f221c95f 100644
--- a/extensions-builtin/Lora/network_full.py
+++ b/extensions-builtin/Lora/network_full.py
@@ -18,9 +18,9 @@ class NetworkModuleFull(network.NetworkModule):
def calc_updown(self, orig_weight):
output_shape = self.weight.shape
- updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype)
+ updown = self.weight.to(orig_weight.device)
if self.ex_bias is not None:
- ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype)
+ ex_bias = self.ex_bias.to(orig_weight.device)
else:
ex_bias = None
diff --git a/extensions-builtin/Lora/network_glora.py b/extensions-builtin/Lora/network_glora.py
index 492d4870..efe5c681 100644
--- a/extensions-builtin/Lora/network_glora.py
+++ b/extensions-builtin/Lora/network_glora.py
@@ -22,12 +22,12 @@ class NetworkModuleGLora(network.NetworkModule):
self.w2b = weights.w["b2.weight"]
def calc_updown(self, orig_weight):
- w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
- w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
- w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
- w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w1a = self.w1a.to(orig_weight.device)
+ w1b = self.w1b.to(orig_weight.device)
+ w2a = self.w2a.to(orig_weight.device)
+ w2b = self.w2b.to(orig_weight.device)
output_shape = [w1a.size(0), w1b.size(1)]
- updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a))
+ updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a))
return self.finalize_updown(updown, orig_weight, output_shape)
diff --git a/extensions-builtin/Lora/network_hada.py b/extensions-builtin/Lora/network_hada.py
index 5fcb0695..d95a0fd1 100644
--- a/extensions-builtin/Lora/network_hada.py
+++ b/extensions-builtin/Lora/network_hada.py
@@ -27,16 +27,16 @@ class NetworkModuleHada(network.NetworkModule):
self.t2 = weights.w.get("hada_t2")
def calc_updown(self, orig_weight):
- w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
- w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
- w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
- w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w1a = self.w1a.to(orig_weight.device)
+ w1b = self.w1b.to(orig_weight.device)
+ w2a = self.w2a.to(orig_weight.device)
+ w2b = self.w2b.to(orig_weight.device)
output_shape = [w1a.size(0), w1b.size(1)]
if self.t1 is not None:
output_shape = [w1a.size(1), w1b.size(1)]
- t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype)
+ t1 = self.t1.to(orig_weight.device)
updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
output_shape += t1.shape[2:]
else:
@@ -45,7 +45,7 @@ class NetworkModuleHada(network.NetworkModule):
updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
if self.t2 is not None:
- t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
+ t2 = self.t2.to(orig_weight.device)
updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
else:
updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
diff --git a/extensions-builtin/Lora/network_ia3.py b/extensions-builtin/Lora/network_ia3.py
index 7edc4249..96faeaf3 100644
--- a/extensions-builtin/Lora/network_ia3.py
+++ b/extensions-builtin/Lora/network_ia3.py
@@ -17,7 +17,7 @@ class NetworkModuleIa3(network.NetworkModule):
self.on_input = weights.w["on_input"].item()
def calc_updown(self, orig_weight):
- w = self.w.to(orig_weight.device, dtype=orig_weight.dtype)
+ w = self.w.to(orig_weight.device)
output_shape = [w.size(0), orig_weight.size(1)]
if self.on_input:
diff --git a/extensions-builtin/Lora/network_lokr.py b/extensions-builtin/Lora/network_lokr.py
index 340acdab..fcdaeafd 100644
--- a/extensions-builtin/Lora/network_lokr.py
+++ b/extensions-builtin/Lora/network_lokr.py
@@ -37,22 +37,22 @@ class NetworkModuleLokr(network.NetworkModule):
def calc_updown(self, orig_weight):
if self.w1 is not None:
- w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype)
+ w1 = self.w1.to(orig_weight.device)
else:
- w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
- w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w1a = self.w1a.to(orig_weight.device)
+ w1b = self.w1b.to(orig_weight.device)
w1 = w1a @ w1b
if self.w2 is not None:
- w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype)
+ w2 = self.w2.to(orig_weight.device)
elif self.t2 is None:
- w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
- w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w2a = self.w2a.to(orig_weight.device)
+ w2b = self.w2b.to(orig_weight.device)
w2 = w2a @ w2b
else:
- t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
- w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
- w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
+ t2 = self.t2.to(orig_weight.device)
+ w2a = self.w2a.to(orig_weight.device)
+ w2b = self.w2b.to(orig_weight.device)
w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py
index 26c0a72c..4cc40295 100644
--- a/extensions-builtin/Lora/network_lora.py
+++ b/extensions-builtin/Lora/network_lora.py
@@ -61,13 +61,13 @@ class NetworkModuleLora(network.NetworkModule):
return module
def calc_updown(self, orig_weight):
- up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
- down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
+ up = self.up_model.weight.to(orig_weight.device)
+ down = self.down_model.weight.to(orig_weight.device)
output_shape = [up.size(0), down.size(1)]
if self.mid_model is not None:
# cp-decomposition
- mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
+ mid = self.mid_model.weight.to(orig_weight.device)
updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
output_shape += mid.shape[2:]
else:
diff --git a/extensions-builtin/Lora/network_norm.py b/extensions-builtin/Lora/network_norm.py
index ce450158..d25afcbb 100644
--- a/extensions-builtin/Lora/network_norm.py
+++ b/extensions-builtin/Lora/network_norm.py
@@ -18,10 +18,10 @@ class NetworkModuleNorm(network.NetworkModule):
def calc_updown(self, orig_weight):
output_shape = self.w_norm.shape
- updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype)
+ updown = self.w_norm.to(orig_weight.device)
if self.b_norm is not None:
- ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype)
+ ex_bias = self.b_norm.to(orig_weight.device)
else:
ex_bias = None
diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py
index fa647020..7821a8a7 100644
--- a/extensions-builtin/Lora/network_oft.py
+++ b/extensions-builtin/Lora/network_oft.py
@@ -1,6 +1,5 @@
import torch
import network
-from lyco_helpers import factorization
from einops import rearrange
@@ -22,20 +21,28 @@ class NetworkModuleOFT(network.NetworkModule):
self.org_module: list[torch.Module] = [self.sd_module]
self.scale = 1.0
+ self.is_R = False
+ self.is_boft = False
- # kohya-ss
+ # kohya-ss/New LyCORIS OFT/BOFT
if "oft_blocks" in weights.w.keys():
- self.is_kohya = True
self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
- self.alpha = weights.w["alpha"] # alpha is constraint
+ self.alpha = weights.w.get("alpha", None) # alpha is constraint
self.dim = self.oft_blocks.shape[0] # lora dim
- # LyCORIS
+ # Old LyCORIS OFT
elif "oft_diag" in weights.w.keys():
- self.is_kohya = False
+ self.is_R = True
self.oft_blocks = weights.w["oft_diag"]
# self.alpha is unused
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
+ # LyCORIS BOFT
+ if self.oft_blocks.dim() == 4:
+ self.is_boft = True
+ self.rescale = weights.w.get('rescale', None)
+ if self.rescale is not None:
+ self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1))
+
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
@@ -47,36 +54,65 @@ class NetworkModuleOFT(network.NetworkModule):
elif is_other_linear:
self.out_dim = self.sd_module.embed_dim
- if self.is_kohya:
- self.constraint = self.alpha * self.out_dim
- self.num_blocks = self.dim
- self.block_size = self.out_dim // self.dim
- else:
+ self.num_blocks = self.dim
+ self.block_size = self.out_dim // self.dim
+ self.constraint = (0 if self.alpha is None else self.alpha) * self.out_dim
+ if self.is_R:
self.constraint = None
- self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
+ self.block_size = self.dim
+ self.num_blocks = self.out_dim // self.dim
+ elif self.is_boft:
+ self.boft_m = self.oft_blocks.shape[0]
+ self.num_blocks = self.oft_blocks.shape[1]
+ self.block_size = self.oft_blocks.shape[2]
+ self.boft_b = self.block_size
def calc_updown(self, orig_weight):
- oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
- eye = torch.eye(self.block_size, device=self.oft_blocks.device)
-
- if self.is_kohya:
- block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
- norm_Q = torch.norm(block_Q.flatten())
- new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
- block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
+ oft_blocks = self.oft_blocks.to(orig_weight.device)
+ eye = torch.eye(self.block_size, device=oft_blocks.device)
+
+ if not self.is_R:
+ block_Q = oft_blocks - oft_blocks.transpose(-1, -2) # ensure skew-symmetric orthogonal matrix
+ if self.constraint != 0:
+ norm_Q = torch.norm(block_Q.flatten())
+ new_norm_Q = torch.clamp(norm_Q, max=self.constraint.to(oft_blocks.device))
+ block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())
- R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
-
- # This errors out for MultiheadAttention, might need to be handled up-stream
- merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
- merged_weight = torch.einsum(
- 'k n m, k n ... -> k m ...',
- R,
- merged_weight
- )
- merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
-
- updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
+ R = oft_blocks.to(orig_weight.device)
+
+ if not self.is_boft:
+ # This errors out for MultiheadAttention, might need to be handled up-stream
+ merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
+ merged_weight = torch.einsum(
+ 'k n m, k n ... -> k m ...',
+ R,
+ merged_weight
+ )
+ merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
+ else:
+ # TODO: determine correct value for scale
+ scale = 1.0
+ m = self.boft_m
+ b = self.boft_b
+ r_b = b // 2
+ inp = orig_weight
+ for i in range(m):
+ bi = R[i] # b_num, b_size, b_size
+ if i == 0:
+ # Apply multiplier/scale and rescale into first weight
+ bi = bi * scale + (1 - scale) * eye
+ inp = rearrange(inp, "(c g k) ... -> (c k g) ...", g=2, k=2**i * r_b)
+ inp = rearrange(inp, "(d b) ... -> d b ...", b=b)
+ inp = torch.einsum("b i j, b j ... -> b i ...", bi, inp)
+ inp = rearrange(inp, "d b ... -> (d b) ...")
+ inp = rearrange(inp, "(c k g) ... -> (c g k) ...", g=2, k=2**i * r_b)
+ merged_weight = inp
+
+ # Rescale mechanism
+ if self.rescale is not None:
+ merged_weight = self.rescale.to(merged_weight) * merged_weight
+
+ updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype)
output_shape = orig_weight.shape
return self.finalize_updown(updown, orig_weight, output_shape)
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index 629bf853..83ea2802 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -1,3 +1,4 @@
+import gradio as gr
import logging
import os
import re
@@ -259,11 +260,11 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
loaded_networks.clear()
- networks_on_disk = [available_network_aliases.get(name, None) for name in names]
+ networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
if any(x is None for x in networks_on_disk):
list_available_networks()
- networks_on_disk = [available_network_aliases.get(name, None) for name in names]
+ networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
failed_to_load_networks = []
@@ -314,7 +315,12 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
emb_db.skipped_embeddings[name] = embedding
if failed_to_load_networks:
- sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
+ lora_not_found_message = f'Lora not found: {", ".join(failed_to_load_networks)}'
+ sd_hijack.model_hijack.comments.append(lora_not_found_message)
+ if shared.opts.lora_not_found_warning_console:
+ print(f'\n{lora_not_found_message}\n')
+ if shared.opts.lora_not_found_gradio_warning:
+ gr.Warning(lora_not_found_message)
purge_networks_from_memory()
@@ -389,18 +395,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
if module is not None and hasattr(self, 'weight'):
try:
with torch.no_grad():
- updown, ex_bias = module.calc_updown(self.weight)
+ if getattr(self, 'fp16_weight', None) is None:
+ weight = self.weight
+ bias = self.bias
+ else:
+ weight = self.fp16_weight.clone().to(self.weight.device)
+ bias = getattr(self, 'fp16_bias', None)
+ if bias is not None:
+ bias = bias.clone().to(self.bias.device)
+ updown, ex_bias = module.calc_updown(weight)
- if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
+ if len(weight.shape) == 4 and weight.shape[1] == 9:
# inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
- self.weight += updown
+ self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype))
if ex_bias is not None and hasattr(self, 'bias'):
if self.bias is None:
- self.bias = torch.nn.Parameter(ex_bias)
+ self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
else:
- self.bias += ex_bias
+ self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype))
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
@@ -444,23 +458,23 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.network_current_names = wanted_names
-def network_forward(module, input, original_forward):
+def network_forward(org_module, input, original_forward):
"""
Old way of applying Lora by executing operations during layer's forward.
Stacking many loras this way results in big performance degradation.
"""
if len(loaded_networks) == 0:
- return original_forward(module, input)
+ return original_forward(org_module, input)
input = devices.cond_cast_unet(input)
- network_restore_weights_from_backup(module)
- network_reset_cached_weight(module)
+ network_restore_weights_from_backup(org_module)
+ network_reset_cached_weight(org_module)
- y = original_forward(module, input)
+ y = original_forward(org_module, input)
- network_layer_name = getattr(module, 'network_layer_name', None)
+ network_layer_name = getattr(org_module, 'network_layer_name', None)
for lora in loaded_networks:
module = lora.modules.get(network_layer_name, None)
if module is None:
diff --git a/extensions-builtin/Lora/preload.py b/extensions-builtin/Lora/preload.py
index 50961be3..52fab29b 100644
--- a/extensions-builtin/Lora/preload.py
+++ b/extensions-builtin/Lora/preload.py
@@ -1,7 +1,8 @@
import os
from modules import paths
+from modules.paths_internal import normalized_filepath
def preload(parser):
- parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))
- parser.add_argument("--lyco-dir-backcompat", type=str, help="Path to directory with LyCORIS networks (for backawards compatibility; can also use --lyco-dir).", default=os.path.join(paths.models_path, 'LyCORIS'))
+ parser.add_argument("--lora-dir", type=normalized_filepath, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))
+ parser.add_argument("--lyco-dir-backcompat", type=normalized_filepath, help="Path to directory with LyCORIS networks (for backawards compatibility; can also use --lyco-dir).", default=os.path.join(paths.models_path, 'LyCORIS'))
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py
index ef23968c..1518f7e5 100644
--- a/extensions-builtin/Lora/scripts/lora_script.py
+++ b/extensions-builtin/Lora/scripts/lora_script.py
@@ -39,6 +39,8 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra
"lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
+ "lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"),
+ "lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"),
}))
diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py
index c7011909..3160aecf 100644
--- a/extensions-builtin/Lora/ui_edit_user_metadata.py
+++ b/extensions-builtin/Lora/ui_edit_user_metadata.py
@@ -54,12 +54,13 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
self.slider_preferred_weight = None
self.edit_notes = None
- def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, notes):
+ def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, notes):
user_metadata = self.get_user_metadata(name)
user_metadata["description"] = desc
user_metadata["sd version"] = sd_version
user_metadata["activation text"] = activation_text
user_metadata["preferred weight"] = preferred_weight
+ user_metadata["negative text"] = negative_text
user_metadata["notes"] = notes
self.write_user_metadata(name, user_metadata)
@@ -127,6 +128,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False),
user_metadata.get('activation text', ''),
float(user_metadata.get('preferred weight', 0.0)),
+ user_metadata.get('negative text', ''),
gr.update(visible=True if tags else False),
gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False),
]
@@ -162,7 +164,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
self.taginfo = gr.HighlightedText(label="Training dataset tags")
self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora")
self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01)
-
+ self.edit_negative_text = gr.Text(label='Negative prompt', info="Will be added to negative prompts")
with gr.Row() as row_random_prompt:
with gr.Column(scale=8):
random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
@@ -198,6 +200,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
self.taginfo,
self.edit_activation_text,
self.slider_preferred_weight,
+ self.edit_negative_text,
row_random_prompt,
random_prompt,
]
@@ -211,7 +214,9 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
self.select_sd_version,
self.edit_activation_text,
self.slider_preferred_weight,
+ self.edit_negative_text,
self.edit_notes,
]
+
self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components)
diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py
index df02c663..66d15dd0 100644
--- a/extensions-builtin/Lora/ui_extra_networks_lora.py
+++ b/extensions-builtin/Lora/ui_extra_networks_lora.py
@@ -24,13 +24,16 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
alias = lora_on_disk.get_alias()
+ search_terms = [self.search_terms_from_path(lora_on_disk.filename)]
+ if lora_on_disk.hash:
+ search_terms.append(lora_on_disk.hash)
item = {
"name": name,
"filename": lora_on_disk.filename,
"shorthash": lora_on_disk.shorthash,
"preview": self.find_preview(path),
"description": self.find_description(path),
- "search_term": self.search_terms_from_path(lora_on_disk.filename) + " " + (lora_on_disk.hash or ""),
+ "search_terms": search_terms,
"local_preview": f"{path}.{shared.opts.samples_format}",
"metadata": lora_on_disk.metadata,
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
@@ -45,6 +48,11 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
if activation_text:
item["prompt"] += " + " + quote_js(" " + activation_text)
+ negative_prompt = item["user_metadata"].get("negative text")
+ item["negative_prompt"] = quote_js("")
+ if negative_prompt:
+ item["negative_prompt"] = quote_js('(' + negative_prompt + ':1)')
+
sd_version = item["user_metadata"].get("sd version")
if sd_version in network.SdVersion.__members__:
item["sd_version"] = sd_version
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py
index 167d2f64..fe5e5a19 100644
--- a/extensions-builtin/ScuNET/scripts/scunet_model.py
+++ b/extensions-builtin/ScuNET/scripts/scunet_model.py
@@ -1,16 +1,9 @@
import sys
import PIL.Image
-import numpy as np
-import torch
-from tqdm import tqdm
import modules.upscaler
-from modules import devices, modelloader, script_callbacks, errors
-from scunet_model_arch import SCUNet
-
-from modules.modelloader import load_file_from_url
-from modules.shared import opts
+from modules import devices, errors, modelloader, script_callbacks, shared, upscaler_utils
class UpscalerScuNET(modules.upscaler.Upscaler):
@@ -42,100 +35,37 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scalers.append(scaler_data2)
self.scalers = scalers
- @staticmethod
- @torch.no_grad()
- def tiled_inference(img, model):
- # test the image tile by tile
- h, w = img.shape[2:]
- tile = opts.SCUNET_tile
- tile_overlap = opts.SCUNET_tile_overlap
- if tile == 0:
- return model(img)
-
- device = devices.get_device_for('scunet')
- assert tile % 8 == 0, "tile size should be a multiple of window_size"
- sf = 1
-
- stride = tile - tile_overlap
- h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
- w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
- E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device)
- W = torch.zeros_like(E, dtype=devices.dtype, device=device)
-
- with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar:
- for h_idx in h_idx_list:
-
- for w_idx in w_idx_list:
-
- in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
-
- out_patch = model(in_patch)
- out_patch_mask = torch.ones_like(out_patch)
-
- E[
- ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
- ].add_(out_patch)
- W[
- ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
- ].add_(out_patch_mask)
- pbar.update(1)
- output = E.div_(W)
-
- return output
-
def do_upscale(self, img: PIL.Image.Image, selected_file):
-
devices.torch_gc()
-
try:
model = self.load_model(selected_file)
except Exception as e:
print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
return img
- device = devices.get_device_for('scunet')
- tile = opts.SCUNET_tile
- h, w = img.height, img.width
- np_img = np.array(img)
- np_img = np_img[:, :, ::-1] # RGB to BGR
- np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW
- torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore
-
- if tile > h or tile > w:
- _img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device)
- _img[:, :, :h, :w] = torch_img # pad image
- torch_img = _img
-
- torch_output = self.tiled_inference(torch_img, model).squeeze(0)
- torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
- np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
- del torch_img, torch_output
+ img = upscaler_utils.upscale_2(
+ img,
+ model,
+ tile_size=shared.opts.SCUNET_tile,
+ tile_overlap=shared.opts.SCUNET_tile_overlap,
+ scale=1, # ScuNET is a denoising model, not an upscaler
+ desc='ScuNET',
+ )
devices.torch_gc()
-
- output = np_output.transpose((1, 2, 0)) # CHW to HWC
- output = output[:, :, ::-1] # BGR to RGB
- return PIL.Image.fromarray((output * 255).astype(np.uint8))
+ return img
def load_model(self, path: str):
device = devices.get_device_for('scunet')
if path.startswith("http"):
# TODO: this doesn't use `path` at all?
- filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
+ filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
else:
filename = path
- model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
- model.load_state_dict(torch.load(filename), strict=True)
- model.eval()
- for _, v in model.named_parameters():
- v.requires_grad = False
- model = model.to(device)
-
- return model
+ return modelloader.load_spandrel_model(filename, device=device, expected_architecture='SCUNet')
def on_ui_settings():
import gradio as gr
- from modules import shared
shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling"))
shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam"))
diff --git a/extensions-builtin/ScuNET/scunet_model_arch.py b/extensions-builtin/ScuNET/scunet_model_arch.py
deleted file mode 100644
index b51a8806..00000000
--- a/extensions-builtin/ScuNET/scunet_model_arch.py
+++ /dev/null
@@ -1,268 +0,0 @@
-# -*- coding: utf-8 -*-
-import numpy as np
-import torch
-import torch.nn as nn
-from einops import rearrange
-from einops.layers.torch import Rearrange
-from timm.models.layers import trunc_normal_, DropPath
-
-
-class WMSA(nn.Module):
- """ Self-attention module in Swin Transformer
- """
-
- def __init__(self, input_dim, output_dim, head_dim, window_size, type):
- super(WMSA, self).__init__()
- self.input_dim = input_dim
- self.output_dim = output_dim
- self.head_dim = head_dim
- self.scale = self.head_dim ** -0.5
- self.n_heads = input_dim // head_dim
- self.window_size = window_size
- self.type = type
- self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
-
- self.relative_position_params = nn.Parameter(
- torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads))
-
- self.linear = nn.Linear(self.input_dim, self.output_dim)
-
- trunc_normal_(self.relative_position_params, std=.02)
- self.relative_position_params = torch.nn.Parameter(
- self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1,
- 2).transpose(
- 0, 1))
-
- def generate_mask(self, h, w, p, shift):
- """ generating the mask of SW-MSA
- Args:
- shift: shift parameters in CyclicShift.
- Returns:
- attn_mask: should be (1 1 w p p),
- """
- # supporting square.
- attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device)
- if self.type == 'W':
- return attn_mask
-
- s = p - shift
- attn_mask[-1, :, :s, :, s:, :] = True
- attn_mask[-1, :, s:, :, :s, :] = True
- attn_mask[:, -1, :, :s, :, s:] = True
- attn_mask[:, -1, :, s:, :, :s] = True
- attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)')
- return attn_mask
-
- def forward(self, x):
- """ Forward pass of Window Multi-head Self-attention module.
- Args:
- x: input tensor with shape of [b h w c];
- attn_mask: attention mask, fill -inf where the value is True;
- Returns:
- output: tensor shape [b h w c]
- """
- if self.type != 'W':
- x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
-
- x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
- h_windows = x.size(1)
- w_windows = x.size(2)
- # square validation
- # assert h_windows == w_windows
-
- x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size)
- qkv = self.embedding_layer(x)
- q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0)
- sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale
- # Adding learnable relative embedding
- sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q')
- # Using Attn Mask to distinguish different subwindows.
- if self.type != 'W':
- attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2)
- sim = sim.masked_fill_(attn_mask, float("-inf"))
-
- probs = nn.functional.softmax(sim, dim=-1)
- output = torch.einsum('hbwij,hbwjc->hbwic', probs, v)
- output = rearrange(output, 'h b w p c -> b w p (h c)')
- output = self.linear(output)
- output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
-
- if self.type != 'W':
- output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2))
-
- return output
-
- def relative_embedding(self):
- cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)]))
- relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
- # negative is allowed
- return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()]
-
-
-class Block(nn.Module):
- def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
- """ SwinTransformer Block
- """
- super(Block, self).__init__()
- self.input_dim = input_dim
- self.output_dim = output_dim
- assert type in ['W', 'SW']
- self.type = type
- if input_resolution <= window_size:
- self.type = 'W'
-
- self.ln1 = nn.LayerNorm(input_dim)
- self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.ln2 = nn.LayerNorm(input_dim)
- self.mlp = nn.Sequential(
- nn.Linear(input_dim, 4 * input_dim),
- nn.GELU(),
- nn.Linear(4 * input_dim, output_dim),
- )
-
- def forward(self, x):
- x = x + self.drop_path(self.msa(self.ln1(x)))
- x = x + self.drop_path(self.mlp(self.ln2(x)))
- return x
-
-
-class ConvTransBlock(nn.Module):
- def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
- """ SwinTransformer and Conv Block
- """
- super(ConvTransBlock, self).__init__()
- self.conv_dim = conv_dim
- self.trans_dim = trans_dim
- self.head_dim = head_dim
- self.window_size = window_size
- self.drop_path = drop_path
- self.type = type
- self.input_resolution = input_resolution
-
- assert self.type in ['W', 'SW']
- if self.input_resolution <= self.window_size:
- self.type = 'W'
-
- self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path,
- self.type, self.input_resolution)
- self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
- self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
-
- self.conv_block = nn.Sequential(
- nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
- nn.ReLU(True),
- nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False)
- )
-
- def forward(self, x):
- conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
- conv_x = self.conv_block(conv_x) + conv_x
- trans_x = Rearrange('b c h w -> b h w c')(trans_x)
- trans_x = self.trans_block(trans_x)
- trans_x = Rearrange('b h w c -> b c h w')(trans_x)
- res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
- x = x + res
-
- return x
-
-
-class SCUNet(nn.Module):
- # def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256):
- def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256):
- super(SCUNet, self).__init__()
- if config is None:
- config = [2, 2, 2, 2, 2, 2, 2]
- self.config = config
- self.dim = dim
- self.head_dim = 32
- self.window_size = 8
-
- # drop path rate for each layer
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
-
- self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
-
- begin = 0
- self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
- 'W' if not i % 2 else 'SW', input_resolution)
- for i in range(config[0])] + \
- [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
-
- begin += config[0]
- self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
- 'W' if not i % 2 else 'SW', input_resolution // 2)
- for i in range(config[1])] + \
- [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
-
- begin += config[1]
- self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
- 'W' if not i % 2 else 'SW', input_resolution // 4)
- for i in range(config[2])] + \
- [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
-
- begin += config[2]
- self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin],
- 'W' if not i % 2 else 'SW', input_resolution // 8)
- for i in range(config[3])]
-
- begin += config[3]
- self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \
- [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
- 'W' if not i % 2 else 'SW', input_resolution // 4)
- for i in range(config[4])]
-
- begin += config[4]
- self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \
- [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
- 'W' if not i % 2 else 'SW', input_resolution // 2)
- for i in range(config[5])]
-
- begin += config[5]
- self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \
- [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
- 'W' if not i % 2 else 'SW', input_resolution)
- for i in range(config[6])]
-
- self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
-
- self.m_head = nn.Sequential(*self.m_head)
- self.m_down1 = nn.Sequential(*self.m_down1)
- self.m_down2 = nn.Sequential(*self.m_down2)
- self.m_down3 = nn.Sequential(*self.m_down3)
- self.m_body = nn.Sequential(*self.m_body)
- self.m_up3 = nn.Sequential(*self.m_up3)
- self.m_up2 = nn.Sequential(*self.m_up2)
- self.m_up1 = nn.Sequential(*self.m_up1)
- self.m_tail = nn.Sequential(*self.m_tail)
- # self.apply(self._init_weights)
-
- def forward(self, x0):
-
- h, w = x0.size()[-2:]
- paddingBottom = int(np.ceil(h / 64) * 64 - h)
- paddingRight = int(np.ceil(w / 64) * 64 - w)
- x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
-
- x1 = self.m_head(x0)
- x2 = self.m_down1(x1)
- x3 = self.m_down2(x2)
- x4 = self.m_down3(x3)
- x = self.m_body(x4)
- x = self.m_up3(x + x4)
- x = self.m_up2(x + x3)
- x = self.m_up1(x + x2)
- x = self.m_tail(x + x1)
-
- x = x[..., :h, :w]
-
- return x
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py
index ae0d0e6a..16bf9b79 100644
--- a/extensions-builtin/SwinIR/scripts/swinir_model.py
+++ b/extensions-builtin/SwinIR/scripts/swinir_model.py
@@ -1,20 +1,15 @@
+import logging
import sys
-import platform
-import numpy as np
import torch
from PIL import Image
-from tqdm import tqdm
-from modules import modelloader, devices, script_callbacks, shared
-from modules.shared import opts, state
-from swinir_model_arch import SwinIR
-from swinir_model_arch_v2 import Swin2SR
+from modules import devices, modelloader, script_callbacks, shared, upscaler_utils
from modules.upscaler import Upscaler, UpscalerData
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
-device_swinir = devices.get_device_for('swinir')
+logger = logging.getLogger(__name__)
class UpscalerSwinIR(Upscaler):
@@ -37,26 +32,28 @@ class UpscalerSwinIR(Upscaler):
scalers.append(model_data)
self.scalers = scalers
- def do_upscale(self, img, model_file):
- use_compile = hasattr(opts, 'SWIN_torch_compile') and opts.SWIN_torch_compile \
- and int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows"
- current_config = (model_file, opts.SWIN_tile)
+ def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image:
+ current_config = (model_file, shared.opts.SWIN_tile)
- if use_compile and self._cached_model_config == current_config:
+ if self._cached_model_config == current_config:
model = self._cached_model
else:
- self._cached_model = None
try:
model = self.load_model(model_file)
except Exception as e:
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
return img
- model = model.to(device_swinir, dtype=devices.dtype)
- if use_compile:
- model = torch.compile(model)
- self._cached_model = model
- self._cached_model_config = current_config
- img = upscale(img, model)
+ self._cached_model = model
+ self._cached_model_config = current_config
+
+ img = upscaler_utils.upscale_2(
+ img,
+ model,
+ tile_size=shared.opts.SWIN_tile,
+ tile_overlap=shared.opts.SWIN_tile_overlap,
+ scale=model.scale,
+ desc="SwinIR",
+ )
devices.torch_gc()
return img
@@ -69,115 +66,22 @@ class UpscalerSwinIR(Upscaler):
)
else:
filename = path
- if filename.endswith(".v2.pth"):
- model = Swin2SR(
- upscale=scale,
- in_chans=3,
- img_size=64,
- window_size=8,
- img_range=1.0,
- depths=[6, 6, 6, 6, 6, 6],
- embed_dim=180,
- num_heads=[6, 6, 6, 6, 6, 6],
- mlp_ratio=2,
- upsampler="nearest+conv",
- resi_connection="1conv",
- )
- params = None
- else:
- model = SwinIR(
- upscale=scale,
- in_chans=3,
- img_size=64,
- window_size=8,
- img_range=1.0,
- depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
- embed_dim=240,
- num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
- mlp_ratio=2,
- upsampler="nearest+conv",
- resi_connection="3conv",
- )
- params = "params_ema"
- pretrained_model = torch.load(filename)
- if params is not None:
- model.load_state_dict(pretrained_model[params], strict=True)
- else:
- model.load_state_dict(pretrained_model, strict=True)
- return model
-
-
-def upscale(
- img,
- model,
- tile=None,
- tile_overlap=None,
- window_size=8,
- scale=4,
-):
- tile = tile or opts.SWIN_tile
- tile_overlap = tile_overlap or opts.SWIN_tile_overlap
-
-
- img = np.array(img)
- img = img[:, :, ::-1]
- img = np.moveaxis(img, 2, 0) / 255
- img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype)
- with torch.no_grad(), devices.autocast():
- _, _, h_old, w_old = img.size()
- h_pad = (h_old // window_size + 1) * window_size - h_old
- w_pad = (w_old // window_size + 1) * window_size - w_old
- img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
- img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
- output = inference(img, model, tile, tile_overlap, window_size, scale)
- output = output[..., : h_old * scale, : w_old * scale]
- output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
- if output.ndim == 3:
- output = np.transpose(
- output[[2, 1, 0], :, :], (1, 2, 0)
- ) # CHW-RGB to HCW-BGR
- output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
- return Image.fromarray(output, "RGB")
-
-
-def inference(img, model, tile, tile_overlap, window_size, scale):
- # test the image tile by tile
- b, c, h, w = img.size()
- tile = min(tile, h, w)
- assert tile % window_size == 0, "tile size should be a multiple of window_size"
- sf = scale
-
- stride = tile - tile_overlap
- h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
- w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
- E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img)
- W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir)
-
- with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
- for h_idx in h_idx_list:
- if state.interrupted or state.skipped:
- break
-
- for w_idx in w_idx_list:
- if state.interrupted or state.skipped:
- break
-
- in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
- out_patch = model(in_patch)
- out_patch_mask = torch.ones_like(out_patch)
-
- E[
- ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
- ].add_(out_patch)
- W[
- ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
- ].add_(out_patch_mask)
- pbar.update(1)
- output = E.div_(W)
-
- return output
+ model_descriptor = modelloader.load_spandrel_model(
+ filename,
+ device=self._get_device(),
+ prefer_half=(devices.dtype == torch.float16),
+ expected_architecture="SwinIR",
+ )
+ if getattr(shared.opts, 'SWIN_torch_compile', False):
+ try:
+ model_descriptor.model.compile()
+ except Exception:
+ logger.warning("Failed to compile SwinIR model, fallback to JIT", exc_info=True)
+ return model_descriptor
+
+ def _get_device(self):
+ return devices.get_device_for('swinir')
def on_ui_settings():
@@ -185,8 +89,7 @@ def on_ui_settings():
shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
- if int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows": # torch.compile() require pytorch 2.0 or above, and not on Windows
- shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
+ shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
script_callbacks.on_ui_settings(on_ui_settings)
diff --git a/extensions-builtin/SwinIR/swinir_model_arch.py b/extensions-builtin/SwinIR/swinir_model_arch.py
deleted file mode 100644
index 93b93274..00000000
--- a/extensions-builtin/SwinIR/swinir_model_arch.py
+++ /dev/null
@@ -1,867 +0,0 @@
-# -----------------------------------------------------------------------------------
-# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
-# Originally Written by Ze Liu, Modified by Jingyun Liang.
-# -----------------------------------------------------------------------------------
-
-import math
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.utils.checkpoint as checkpoint
-from timm.models.layers import DropPath, to_2tuple, trunc_normal_
-
-
-class Mlp(nn.Module):
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
-
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
-
-
-def window_partition(x, window_size):
- """
- Args:
- x: (B, H, W, C)
- window_size (int): window size
-
- Returns:
- windows: (num_windows*B, window_size, window_size, C)
- """
- B, H, W, C = x.shape
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
- return windows
-
-
-def window_reverse(windows, window_size, H, W):
- """
- Args:
- windows: (num_windows*B, window_size, window_size, C)
- window_size (int): Window size
- H (int): Height of image
- W (int): Width of image
-
- Returns:
- x: (B, H, W, C)
- """
- B = int(windows.shape[0] / (H * W / window_size / window_size))
- x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
- return x
-
-
-class WindowAttention(nn.Module):
- r""" Window based multi-head self attention (W-MSA) module with relative position bias.
- It supports both of shifted and non-shifted window.
-
- Args:
- dim (int): Number of input channels.
- window_size (tuple[int]): The height and width of the window.
- num_heads (int): Number of attention heads.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
- attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
- proj_drop (float, optional): Dropout ratio of output. Default: 0.0
- """
-
- def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
-
- super().__init__()
- self.dim = dim
- self.window_size = window_size # Wh, Ww
- self.num_heads = num_heads
- head_dim = dim // num_heads
- self.scale = qk_scale or head_dim ** -0.5
-
- # define a parameter table of relative position bias
- self.relative_position_bias_table = nn.Parameter(
- torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
-
- # get pair-wise relative position index for each token inside the window
- coords_h = torch.arange(self.window_size[0])
- coords_w = torch.arange(self.window_size[1])
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
- relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
- relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
- relative_coords[:, :, 1] += self.window_size[1] - 1
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
- relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
- self.register_buffer("relative_position_index", relative_position_index)
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
-
- self.proj_drop = nn.Dropout(proj_drop)
-
- trunc_normal_(self.relative_position_bias_table, std=.02)
- self.softmax = nn.Softmax(dim=-1)
-
- def forward(self, x, mask=None):
- """
- Args:
- x: input features with shape of (num_windows*B, N, C)
- mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
- """
- B_, N, C = x.shape
- qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
-
- q = q * self.scale
- attn = (q @ k.transpose(-2, -1))
-
- relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
- self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
- attn = attn + relative_position_bias.unsqueeze(0)
-
- if mask is not None:
- nW = mask.shape[0]
- attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
- attn = attn.view(-1, self.num_heads, N, N)
- attn = self.softmax(attn)
- else:
- attn = self.softmax(attn)
-
- attn = self.attn_drop(attn)
-
- x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
- def extra_repr(self) -> str:
- return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
-
- def flops(self, N):
- # calculate flops for 1 window with token length of N
- flops = 0
- # qkv = self.qkv(x)
- flops += N * self.dim * 3 * self.dim
- # attn = (q @ k.transpose(-2, -1))
- flops += self.num_heads * N * (self.dim // self.num_heads) * N
- # x = (attn @ v)
- flops += self.num_heads * N * N * (self.dim // self.num_heads)
- # x = self.proj(x)
- flops += N * self.dim * self.dim
- return flops
-
-
-class SwinTransformerBlock(nn.Module):
- r""" Swin Transformer Block.
-
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- num_heads (int): Number of attention heads.
- window_size (int): Window size.
- shift_size (int): Shift size for SW-MSA.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
- act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
-
- def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
- mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
- act_layer=nn.GELU, norm_layer=nn.LayerNorm):
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.num_heads = num_heads
- self.window_size = window_size
- self.shift_size = shift_size
- self.mlp_ratio = mlp_ratio
- if min(self.input_resolution) <= self.window_size:
- # if window size is larger than input resolution, we don't partition windows
- self.shift_size = 0
- self.window_size = min(self.input_resolution)
- assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
-
- self.norm1 = norm_layer(dim)
- self.attn = WindowAttention(
- dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
- qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
-
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
-
- if self.shift_size > 0:
- attn_mask = self.calculate_mask(self.input_resolution)
- else:
- attn_mask = None
-
- self.register_buffer("attn_mask", attn_mask)
-
- def calculate_mask(self, x_size):
- # calculate attention mask for SW-MSA
- H, W = x_size
- img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
- h_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- w_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- cnt = 0
- for h in h_slices:
- for w in w_slices:
- img_mask[:, h, w, :] = cnt
- cnt += 1
-
- mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
-
- return attn_mask
-
- def forward(self, x, x_size):
- H, W = x_size
- B, L, C = x.shape
- # assert L == H * W, "input feature has wrong size"
-
- shortcut = x
- x = self.norm1(x)
- x = x.view(B, H, W, C)
-
- # cyclic shift
- if self.shift_size > 0:
- shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
- else:
- shifted_x = x
-
- # partition windows
- x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
- x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
-
- # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
- if self.input_resolution == x_size:
- attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
- else:
- attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
-
- # merge windows
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
- shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
-
- # reverse cyclic shift
- if self.shift_size > 0:
- x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
- else:
- x = shifted_x
- x = x.view(B, H * W, C)
-
- # FFN
- x = shortcut + self.drop_path(x)
- x = x + self.drop_path(self.mlp(self.norm2(x)))
-
- return x
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
- f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
-
- def flops(self):
- flops = 0
- H, W = self.input_resolution
- # norm1
- flops += self.dim * H * W
- # W-MSA/SW-MSA
- nW = H * W / self.window_size / self.window_size
- flops += nW * self.attn.flops(self.window_size * self.window_size)
- # mlp
- flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
- # norm2
- flops += self.dim * H * W
- return flops
-
-
-class PatchMerging(nn.Module):
- r""" Patch Merging Layer.
-
- Args:
- input_resolution (tuple[int]): Resolution of input feature.
- dim (int): Number of input channels.
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
-
- def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
- super().__init__()
- self.input_resolution = input_resolution
- self.dim = dim
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
- self.norm = norm_layer(4 * dim)
-
- def forward(self, x):
- """
- x: B, H*W, C
- """
- H, W = self.input_resolution
- B, L, C = x.shape
- assert L == H * W, "input feature has wrong size"
- assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
-
- x = x.view(B, H, W, C)
-
- x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
- x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
- x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
- x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
- x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
- x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
-
- x = self.norm(x)
- x = self.reduction(x)
-
- return x
-
- def extra_repr(self) -> str:
- return f"input_resolution={self.input_resolution}, dim={self.dim}"
-
- def flops(self):
- H, W = self.input_resolution
- flops = H * W * self.dim
- flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
- return flops
-
-
-class BasicLayer(nn.Module):
- """ A basic Swin Transformer layer for one stage.
-
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- """
-
- def __init__(self, dim, input_resolution, depth, num_heads, window_size,
- mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
- drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
-
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.depth = depth
- self.use_checkpoint = use_checkpoint
-
- # build blocks
- self.blocks = nn.ModuleList([
- SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
- num_heads=num_heads, window_size=window_size,
- shift_size=0 if (i % 2 == 0) else window_size // 2,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop, attn_drop=attn_drop,
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
- norm_layer=norm_layer)
- for i in range(depth)])
-
- # patch merging layer
- if downsample is not None:
- self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
- else:
- self.downsample = None
-
- def forward(self, x, x_size):
- for blk in self.blocks:
- if self.use_checkpoint:
- x = checkpoint.checkpoint(blk, x, x_size)
- else:
- x = blk(x, x_size)
- if self.downsample is not None:
- x = self.downsample(x)
- return x
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
-
- def flops(self):
- flops = 0
- for blk in self.blocks:
- flops += blk.flops()
- if self.downsample is not None:
- flops += self.downsample.flops()
- return flops
-
-
-class RSTB(nn.Module):
- """Residual Swin Transformer Block (RSTB).
-
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- img_size: Input image size.
- patch_size: Patch size.
- resi_connection: The convolutional block before residual connection.
- """
-
- def __init__(self, dim, input_resolution, depth, num_heads, window_size,
- mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
- drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
- img_size=224, patch_size=4, resi_connection='1conv'):
- super(RSTB, self).__init__()
-
- self.dim = dim
- self.input_resolution = input_resolution
-
- self.residual_group = BasicLayer(dim=dim,
- input_resolution=input_resolution,
- depth=depth,
- num_heads=num_heads,
- window_size=window_size,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop, attn_drop=attn_drop,
- drop_path=drop_path,
- norm_layer=norm_layer,
- downsample=downsample,
- use_checkpoint=use_checkpoint)
-
- if resi_connection == '1conv':
- self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
- elif resi_connection == '3conv':
- # to save parameters and memory
- self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim, 3, 1, 1))
-
- self.patch_embed = PatchEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
- norm_layer=None)
-
- self.patch_unembed = PatchUnEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
- norm_layer=None)
-
- def forward(self, x, x_size):
- return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
-
- def flops(self):
- flops = 0
- flops += self.residual_group.flops()
- H, W = self.input_resolution
- flops += H * W * self.dim * self.dim * 9
- flops += self.patch_embed.flops()
- flops += self.patch_unembed.flops()
-
- return flops
-
-
-class PatchEmbed(nn.Module):
- r""" Image to Patch Embedding
-
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- if norm_layer is not None:
- self.norm = norm_layer(embed_dim)
- else:
- self.norm = None
-
- def forward(self, x):
- x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
- if self.norm is not None:
- x = self.norm(x)
- return x
-
- def flops(self):
- flops = 0
- H, W = self.img_size
- if self.norm is not None:
- flops += H * W * self.embed_dim
- return flops
-
-
-class PatchUnEmbed(nn.Module):
- r""" Image to Patch Unembedding
-
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- def forward(self, x, x_size):
- B, HW, C = x.shape
- x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
- return x
-
- def flops(self):
- flops = 0
- return flops
-
-
-class Upsample(nn.Sequential):
- """Upsample module.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
- """
-
- def __init__(self, scale, num_feat):
- m = []
- if (scale & (scale - 1)) == 0: # scale = 2^n
- for _ in range(int(math.log(scale, 2))):
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(2))
- elif scale == 3:
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(3))
- else:
- raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
- super(Upsample, self).__init__(*m)
-
-
-class UpsampleOneStep(nn.Sequential):
- """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
- Used in lightweight SR to save parameters.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
-
- """
-
- def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
- self.num_feat = num_feat
- self.input_resolution = input_resolution
- m = []
- m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
- m.append(nn.PixelShuffle(scale))
- super(UpsampleOneStep, self).__init__(*m)
-
- def flops(self):
- H, W = self.input_resolution
- flops = H * W * self.num_feat * 3 * 9
- return flops
-
-
-class SwinIR(nn.Module):
- r""" SwinIR
- A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
-
- Args:
- img_size (int | tuple(int)): Input image size. Default 64
- patch_size (int | tuple(int)): Patch size. Default: 1
- in_chans (int): Number of input image channels. Default: 3
- embed_dim (int): Patch embedding dimension. Default: 96
- depths (tuple(int)): Depth of each Swin Transformer layer.
- num_heads (tuple(int)): Number of attention heads in different layers.
- window_size (int): Window size. Default: 7
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
- drop_rate (float): Dropout rate. Default: 0
- attn_drop_rate (float): Attention dropout rate. Default: 0
- drop_path_rate (float): Stochastic depth rate. Default: 0.1
- norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
- ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
- patch_norm (bool): If True, add normalization after patch embedding. Default: True
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
- upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
- img_range: Image range. 1. or 255.
- upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
- resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
- """
-
- def __init__(self, img_size=64, patch_size=1, in_chans=3,
- embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
- window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
- drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
- norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
- use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
- **kwargs):
- super(SwinIR, self).__init__()
- num_in_ch = in_chans
- num_out_ch = in_chans
- num_feat = 64
- self.img_range = img_range
- if in_chans == 3:
- rgb_mean = (0.4488, 0.4371, 0.4040)
- self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
- else:
- self.mean = torch.zeros(1, 1, 1, 1)
- self.upscale = upscale
- self.upsampler = upsampler
- self.window_size = window_size
-
- #####################################################################################################
- ################################### 1, shallow feature extraction ###################################
- self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
-
- #####################################################################################################
- ################################### 2, deep feature extraction ######################################
- self.num_layers = len(depths)
- self.embed_dim = embed_dim
- self.ape = ape
- self.patch_norm = patch_norm
- self.num_features = embed_dim
- self.mlp_ratio = mlp_ratio
-
- # split image into non-overlapping patches
- self.patch_embed = PatchEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None)
- num_patches = self.patch_embed.num_patches
- patches_resolution = self.patch_embed.patches_resolution
- self.patches_resolution = patches_resolution
-
- # merge non-overlapping patches into image
- self.patch_unembed = PatchUnEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None)
-
- # absolute position embedding
- if self.ape:
- self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
- trunc_normal_(self.absolute_pos_embed, std=.02)
-
- self.pos_drop = nn.Dropout(p=drop_rate)
-
- # stochastic depth
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
-
- # build Residual Swin Transformer blocks (RSTB)
- self.layers = nn.ModuleList()
- for i_layer in range(self.num_layers):
- layer = RSTB(dim=embed_dim,
- input_resolution=(patches_resolution[0],
- patches_resolution[1]),
- depth=depths[i_layer],
- num_heads=num_heads[i_layer],
- window_size=window_size,
- mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias, qk_scale=qk_scale,
- drop=drop_rate, attn_drop=attn_drop_rate,
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
- norm_layer=norm_layer,
- downsample=None,
- use_checkpoint=use_checkpoint,
- img_size=img_size,
- patch_size=patch_size,
- resi_connection=resi_connection
-
- )
- self.layers.append(layer)
- self.norm = norm_layer(self.num_features)
-
- # build the last conv layer in deep feature extraction
- if resi_connection == '1conv':
- self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
- elif resi_connection == '3conv':
- # to save parameters and memory
- self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
-
- #####################################################################################################
- ################################ 3, high quality image reconstruction ################################
- if self.upsampler == 'pixelshuffle':
- # for classical SR
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.upsample = Upsample(upscale, num_feat)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- elif self.upsampler == 'pixelshuffledirect':
- # for lightweight SR (to save parameters)
- self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
- (patches_resolution[0], patches_resolution[1]))
- elif self.upsampler == 'nearest+conv':
- # for real-world SR (less artifacts)
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- if self.upscale == 4:
- self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
- else:
- # for image denoising and JPEG compression artifact reduction
- self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
-
- self.apply(self._init_weights)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
- @torch.jit.ignore
- def no_weight_decay(self):
- return {'absolute_pos_embed'}
-
- @torch.jit.ignore
- def no_weight_decay_keywords(self):
- return {'relative_position_bias_table'}
-
- def check_image_size(self, x):
- _, _, h, w = x.size()
- mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
- mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
- x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
- return x
-
- def forward_features(self, x):
- x_size = (x.shape[2], x.shape[3])
- x = self.patch_embed(x)
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
-
- for layer in self.layers:
- x = layer(x, x_size)
-
- x = self.norm(x) # B L C
- x = self.patch_unembed(x, x_size)
-
- return x
-
- def forward(self, x):
- H, W = x.shape[2:]
- x = self.check_image_size(x)
-
- self.mean = self.mean.type_as(x)
- x = (x - self.mean) * self.img_range
-
- if self.upsampler == 'pixelshuffle':
- # for classical SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- x = self.conv_last(self.upsample(x))
- elif self.upsampler == 'pixelshuffledirect':
- # for lightweight SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.upsample(x)
- elif self.upsampler == 'nearest+conv':
- # for real-world SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
- if self.upscale == 4:
- x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
- x = self.conv_last(self.lrelu(self.conv_hr(x)))
- else:
- # for image denoising and JPEG compression artifact reduction
- x_first = self.conv_first(x)
- res = self.conv_after_body(self.forward_features(x_first)) + x_first
- x = x + self.conv_last(res)
-
- x = x / self.img_range + self.mean
-
- return x[:, :, :H*self.upscale, :W*self.upscale]
-
- def flops(self):
- flops = 0
- H, W = self.patches_resolution
- flops += H * W * 3 * self.embed_dim * 9
- flops += self.patch_embed.flops()
- for layer in self.layers:
- flops += layer.flops()
- flops += H * W * 3 * self.embed_dim * self.embed_dim
- flops += self.upsample.flops()
- return flops
-
-
-if __name__ == '__main__':
- upscale = 4
- window_size = 8
- height = (1024 // upscale // window_size + 1) * window_size
- width = (720 // upscale // window_size + 1) * window_size
- model = SwinIR(upscale=2, img_size=(height, width),
- window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
- embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
- print(model)
- print(height, width, model.flops() / 1e9)
-
- x = torch.randn((1, 3, height, width))
- x = model(x)
- print(x.shape)
diff --git a/extensions-builtin/SwinIR/swinir_model_arch_v2.py b/extensions-builtin/SwinIR/swinir_model_arch_v2.py
deleted file mode 100644
index dad22cca..00000000
--- a/extensions-builtin/SwinIR/swinir_model_arch_v2.py
+++ /dev/null
@@ -1,1017 +0,0 @@
-# -----------------------------------------------------------------------------------
-# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/
-# Written by Conde and Choi et al.
-# -----------------------------------------------------------------------------------
-
-import math
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.utils.checkpoint as checkpoint
-from timm.models.layers import DropPath, to_2tuple, trunc_normal_
-
-
-class Mlp(nn.Module):
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
-
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
-
-
-def window_partition(x, window_size):
- """
- Args:
- x: (B, H, W, C)
- window_size (int): window size
- Returns:
- windows: (num_windows*B, window_size, window_size, C)
- """
- B, H, W, C = x.shape
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
- return windows
-
-
-def window_reverse(windows, window_size, H, W):
- """
- Args:
- windows: (num_windows*B, window_size, window_size, C)
- window_size (int): Window size
- H (int): Height of image
- W (int): Width of image
- Returns:
- x: (B, H, W, C)
- """
- B = int(windows.shape[0] / (H * W / window_size / window_size))
- x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
- return x
-
-class WindowAttention(nn.Module):
- r""" Window based multi-head self attention (W-MSA) module with relative position bias.
- It supports both of shifted and non-shifted window.
- Args:
- dim (int): Number of input channels.
- window_size (tuple[int]): The height and width of the window.
- num_heads (int): Number of attention heads.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
- proj_drop (float, optional): Dropout ratio of output. Default: 0.0
- pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
- """
-
- def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
- pretrained_window_size=(0, 0)):
-
- super().__init__()
- self.dim = dim
- self.window_size = window_size # Wh, Ww
- self.pretrained_window_size = pretrained_window_size
- self.num_heads = num_heads
-
- self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
-
- # mlp to generate continuous relative position bias
- self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
- nn.ReLU(inplace=True),
- nn.Linear(512, num_heads, bias=False))
-
- # get relative_coords_table
- relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
- relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
- relative_coords_table = torch.stack(
- torch.meshgrid([relative_coords_h,
- relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
- if pretrained_window_size[0] > 0:
- relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
- relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
- else:
- relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
- relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
- relative_coords_table *= 8 # normalize to -8, 8
- relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
- torch.abs(relative_coords_table) + 1.0) / np.log2(8)
-
- self.register_buffer("relative_coords_table", relative_coords_table)
-
- # get pair-wise relative position index for each token inside the window
- coords_h = torch.arange(self.window_size[0])
- coords_w = torch.arange(self.window_size[1])
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
- relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
- relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
- relative_coords[:, :, 1] += self.window_size[1] - 1
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
- relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
- self.register_buffer("relative_position_index", relative_position_index)
-
- self.qkv = nn.Linear(dim, dim * 3, bias=False)
- if qkv_bias:
- self.q_bias = nn.Parameter(torch.zeros(dim))
- self.v_bias = nn.Parameter(torch.zeros(dim))
- else:
- self.q_bias = None
- self.v_bias = None
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
- self.softmax = nn.Softmax(dim=-1)
-
- def forward(self, x, mask=None):
- """
- Args:
- x: input features with shape of (num_windows*B, N, C)
- mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
- """
- B_, N, C = x.shape
- qkv_bias = None
- if self.q_bias is not None:
- qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
- qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
- qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
-
- # cosine attention
- attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
- logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp()
- attn = attn * logit_scale
-
- relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
- relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
- self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
- relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
- attn = attn + relative_position_bias.unsqueeze(0)
-
- if mask is not None:
- nW = mask.shape[0]
- attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
- attn = attn.view(-1, self.num_heads, N, N)
- attn = self.softmax(attn)
- else:
- attn = self.softmax(attn)
-
- attn = self.attn_drop(attn)
-
- x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
- def extra_repr(self) -> str:
- return f'dim={self.dim}, window_size={self.window_size}, ' \
- f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
-
- def flops(self, N):
- # calculate flops for 1 window with token length of N
- flops = 0
- # qkv = self.qkv(x)
- flops += N * self.dim * 3 * self.dim
- # attn = (q @ k.transpose(-2, -1))
- flops += self.num_heads * N * (self.dim // self.num_heads) * N
- # x = (attn @ v)
- flops += self.num_heads * N * N * (self.dim // self.num_heads)
- # x = self.proj(x)
- flops += N * self.dim * self.dim
- return flops
-
-class SwinTransformerBlock(nn.Module):
- r""" Swin Transformer Block.
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resulotion.
- num_heads (int): Number of attention heads.
- window_size (int): Window size.
- shift_size (int): Shift size for SW-MSA.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
- act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- pretrained_window_size (int): Window size in pre-training.
- """
-
- def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
- mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
- act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.num_heads = num_heads
- self.window_size = window_size
- self.shift_size = shift_size
- self.mlp_ratio = mlp_ratio
- if min(self.input_resolution) <= self.window_size:
- # if window size is larger than input resolution, we don't partition windows
- self.shift_size = 0
- self.window_size = min(self.input_resolution)
- assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
-
- self.norm1 = norm_layer(dim)
- self.attn = WindowAttention(
- dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
- qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
- pretrained_window_size=to_2tuple(pretrained_window_size))
-
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
-
- if self.shift_size > 0:
- attn_mask = self.calculate_mask(self.input_resolution)
- else:
- attn_mask = None
-
- self.register_buffer("attn_mask", attn_mask)
-
- def calculate_mask(self, x_size):
- # calculate attention mask for SW-MSA
- H, W = x_size
- img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
- h_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- w_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- cnt = 0
- for h in h_slices:
- for w in w_slices:
- img_mask[:, h, w, :] = cnt
- cnt += 1
-
- mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
-
- return attn_mask
-
- def forward(self, x, x_size):
- H, W = x_size
- B, L, C = x.shape
- #assert L == H * W, "input feature has wrong size"
-
- shortcut = x
- x = x.view(B, H, W, C)
-
- # cyclic shift
- if self.shift_size > 0:
- shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
- else:
- shifted_x = x
-
- # partition windows
- x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
- x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
-
- # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
- if self.input_resolution == x_size:
- attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
- else:
- attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
-
- # merge windows
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
- shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
-
- # reverse cyclic shift
- if self.shift_size > 0:
- x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
- else:
- x = shifted_x
- x = x.view(B, H * W, C)
- x = shortcut + self.drop_path(self.norm1(x))
-
- # FFN
- x = x + self.drop_path(self.norm2(self.mlp(x)))
-
- return x
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
- f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
-
- def flops(self):
- flops = 0
- H, W = self.input_resolution
- # norm1
- flops += self.dim * H * W
- # W-MSA/SW-MSA
- nW = H * W / self.window_size / self.window_size
- flops += nW * self.attn.flops(self.window_size * self.window_size)
- # mlp
- flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
- # norm2
- flops += self.dim * H * W
- return flops
-
-class PatchMerging(nn.Module):
- r""" Patch Merging Layer.
- Args:
- input_resolution (tuple[int]): Resolution of input feature.
- dim (int): Number of input channels.
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
-
- def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
- super().__init__()
- self.input_resolution = input_resolution
- self.dim = dim
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
- self.norm = norm_layer(2 * dim)
-
- def forward(self, x):
- """
- x: B, H*W, C
- """
- H, W = self.input_resolution
- B, L, C = x.shape
- assert L == H * W, "input feature has wrong size"
- assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
-
- x = x.view(B, H, W, C)
-
- x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
- x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
- x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
- x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
- x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
- x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
-
- x = self.reduction(x)
- x = self.norm(x)
-
- return x
-
- def extra_repr(self) -> str:
- return f"input_resolution={self.input_resolution}, dim={self.dim}"
-
- def flops(self):
- H, W = self.input_resolution
- flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
- flops += H * W * self.dim // 2
- return flops
-
-class BasicLayer(nn.Module):
- """ A basic Swin Transformer layer for one stage.
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- pretrained_window_size (int): Local window size in pre-training.
- """
-
- def __init__(self, dim, input_resolution, depth, num_heads, window_size,
- mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
- drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
- pretrained_window_size=0):
-
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.depth = depth
- self.use_checkpoint = use_checkpoint
-
- # build blocks
- self.blocks = nn.ModuleList([
- SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
- num_heads=num_heads, window_size=window_size,
- shift_size=0 if (i % 2 == 0) else window_size // 2,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- drop=drop, attn_drop=attn_drop,
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
- norm_layer=norm_layer,
- pretrained_window_size=pretrained_window_size)
- for i in range(depth)])
-
- # patch merging layer
- if downsample is not None:
- self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
- else:
- self.downsample = None
-
- def forward(self, x, x_size):
- for blk in self.blocks:
- if self.use_checkpoint:
- x = checkpoint.checkpoint(blk, x, x_size)
- else:
- x = blk(x, x_size)
- if self.downsample is not None:
- x = self.downsample(x)
- return x
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
-
- def flops(self):
- flops = 0
- for blk in self.blocks:
- flops += blk.flops()
- if self.downsample is not None:
- flops += self.downsample.flops()
- return flops
-
- def _init_respostnorm(self):
- for blk in self.blocks:
- nn.init.constant_(blk.norm1.bias, 0)
- nn.init.constant_(blk.norm1.weight, 0)
- nn.init.constant_(blk.norm2.bias, 0)
- nn.init.constant_(blk.norm2.weight, 0)
-
-class PatchEmbed(nn.Module):
- r""" Image to Patch Embedding
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
- if norm_layer is not None:
- self.norm = norm_layer(embed_dim)
- else:
- self.norm = None
-
- def forward(self, x):
- B, C, H, W = x.shape
- # FIXME look at relaxing size constraints
- # assert H == self.img_size[0] and W == self.img_size[1],
- # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
- x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
- if self.norm is not None:
- x = self.norm(x)
- return x
-
- def flops(self):
- Ho, Wo = self.patches_resolution
- flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
- if self.norm is not None:
- flops += Ho * Wo * self.embed_dim
- return flops
-
-class RSTB(nn.Module):
- """Residual Swin Transformer Block (RSTB).
-
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- img_size: Input image size.
- patch_size: Patch size.
- resi_connection: The convolutional block before residual connection.
- """
-
- def __init__(self, dim, input_resolution, depth, num_heads, window_size,
- mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
- drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
- img_size=224, patch_size=4, resi_connection='1conv'):
- super(RSTB, self).__init__()
-
- self.dim = dim
- self.input_resolution = input_resolution
-
- self.residual_group = BasicLayer(dim=dim,
- input_resolution=input_resolution,
- depth=depth,
- num_heads=num_heads,
- window_size=window_size,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- drop=drop, attn_drop=attn_drop,
- drop_path=drop_path,
- norm_layer=norm_layer,
- downsample=downsample,
- use_checkpoint=use_checkpoint)
-
- if resi_connection == '1conv':
- self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
- elif resi_connection == '3conv':
- # to save parameters and memory
- self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim, 3, 1, 1))
-
- self.patch_embed = PatchEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
- norm_layer=None)
-
- self.patch_unembed = PatchUnEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
- norm_layer=None)
-
- def forward(self, x, x_size):
- return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
-
- def flops(self):
- flops = 0
- flops += self.residual_group.flops()
- H, W = self.input_resolution
- flops += H * W * self.dim * self.dim * 9
- flops += self.patch_embed.flops()
- flops += self.patch_unembed.flops()
-
- return flops
-
-class PatchUnEmbed(nn.Module):
- r""" Image to Patch Unembedding
-
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- def forward(self, x, x_size):
- B, HW, C = x.shape
- x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
- return x
-
- def flops(self):
- flops = 0
- return flops
-
-
-class Upsample(nn.Sequential):
- """Upsample module.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
- """
-
- def __init__(self, scale, num_feat):
- m = []
- if (scale & (scale - 1)) == 0: # scale = 2^n
- for _ in range(int(math.log(scale, 2))):
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(2))
- elif scale == 3:
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(3))
- else:
- raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
- super(Upsample, self).__init__(*m)
-
-class Upsample_hf(nn.Sequential):
- """Upsample module.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
- """
-
- def __init__(self, scale, num_feat):
- m = []
- if (scale & (scale - 1)) == 0: # scale = 2^n
- for _ in range(int(math.log(scale, 2))):
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(2))
- elif scale == 3:
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(3))
- else:
- raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
- super(Upsample_hf, self).__init__(*m)
-
-
-class UpsampleOneStep(nn.Sequential):
- """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
- Used in lightweight SR to save parameters.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
-
- """
-
- def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
- self.num_feat = num_feat
- self.input_resolution = input_resolution
- m = []
- m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
- m.append(nn.PixelShuffle(scale))
- super(UpsampleOneStep, self).__init__(*m)
-
- def flops(self):
- H, W = self.input_resolution
- flops = H * W * self.num_feat * 3 * 9
- return flops
-
-
-
-class Swin2SR(nn.Module):
- r""" Swin2SR
- A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
-
- Args:
- img_size (int | tuple(int)): Input image size. Default 64
- patch_size (int | tuple(int)): Patch size. Default: 1
- in_chans (int): Number of input image channels. Default: 3
- embed_dim (int): Patch embedding dimension. Default: 96
- depths (tuple(int)): Depth of each Swin Transformer layer.
- num_heads (tuple(int)): Number of attention heads in different layers.
- window_size (int): Window size. Default: 7
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
- drop_rate (float): Dropout rate. Default: 0
- attn_drop_rate (float): Attention dropout rate. Default: 0
- drop_path_rate (float): Stochastic depth rate. Default: 0.1
- norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
- ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
- patch_norm (bool): If True, add normalization after patch embedding. Default: True
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
- upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
- img_range: Image range. 1. or 255.
- upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
- resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
- """
-
- def __init__(self, img_size=64, patch_size=1, in_chans=3,
- embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
- window_size=7, mlp_ratio=4., qkv_bias=True,
- drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
- norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
- use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
- **kwargs):
- super(Swin2SR, self).__init__()
- num_in_ch = in_chans
- num_out_ch = in_chans
- num_feat = 64
- self.img_range = img_range
- if in_chans == 3:
- rgb_mean = (0.4488, 0.4371, 0.4040)
- self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
- else:
- self.mean = torch.zeros(1, 1, 1, 1)
- self.upscale = upscale
- self.upsampler = upsampler
- self.window_size = window_size
-
- #####################################################################################################
- ################################### 1, shallow feature extraction ###################################
- self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
-
- #####################################################################################################
- ################################### 2, deep feature extraction ######################################
- self.num_layers = len(depths)
- self.embed_dim = embed_dim
- self.ape = ape
- self.patch_norm = patch_norm
- self.num_features = embed_dim
- self.mlp_ratio = mlp_ratio
-
- # split image into non-overlapping patches
- self.patch_embed = PatchEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None)
- num_patches = self.patch_embed.num_patches
- patches_resolution = self.patch_embed.patches_resolution
- self.patches_resolution = patches_resolution
-
- # merge non-overlapping patches into image
- self.patch_unembed = PatchUnEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None)
-
- # absolute position embedding
- if self.ape:
- self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
- trunc_normal_(self.absolute_pos_embed, std=.02)
-
- self.pos_drop = nn.Dropout(p=drop_rate)
-
- # stochastic depth
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
-
- # build Residual Swin Transformer blocks (RSTB)
- self.layers = nn.ModuleList()
- for i_layer in range(self.num_layers):
- layer = RSTB(dim=embed_dim,
- input_resolution=(patches_resolution[0],
- patches_resolution[1]),
- depth=depths[i_layer],
- num_heads=num_heads[i_layer],
- window_size=window_size,
- mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias,
- drop=drop_rate, attn_drop=attn_drop_rate,
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
- norm_layer=norm_layer,
- downsample=None,
- use_checkpoint=use_checkpoint,
- img_size=img_size,
- patch_size=patch_size,
- resi_connection=resi_connection
-
- )
- self.layers.append(layer)
-
- if self.upsampler == 'pixelshuffle_hf':
- self.layers_hf = nn.ModuleList()
- for i_layer in range(self.num_layers):
- layer = RSTB(dim=embed_dim,
- input_resolution=(patches_resolution[0],
- patches_resolution[1]),
- depth=depths[i_layer],
- num_heads=num_heads[i_layer],
- window_size=window_size,
- mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias,
- drop=drop_rate, attn_drop=attn_drop_rate,
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
- norm_layer=norm_layer,
- downsample=None,
- use_checkpoint=use_checkpoint,
- img_size=img_size,
- patch_size=patch_size,
- resi_connection=resi_connection
-
- )
- self.layers_hf.append(layer)
-
- self.norm = norm_layer(self.num_features)
-
- # build the last conv layer in deep feature extraction
- if resi_connection == '1conv':
- self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
- elif resi_connection == '3conv':
- # to save parameters and memory
- self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
-
- #####################################################################################################
- ################################ 3, high quality image reconstruction ################################
- if self.upsampler == 'pixelshuffle':
- # for classical SR
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.upsample = Upsample(upscale, num_feat)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- elif self.upsampler == 'pixelshuffle_aux':
- self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
- self.conv_before_upsample = nn.Sequential(
- nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- self.conv_after_aux = nn.Sequential(
- nn.Conv2d(3, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.upsample = Upsample(upscale, num_feat)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
-
- elif self.upsampler == 'pixelshuffle_hf':
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.upsample = Upsample(upscale, num_feat)
- self.upsample_hf = Upsample_hf(upscale, num_feat)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
- self.conv_before_upsample_hf = nn.Sequential(
- nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
-
- elif self.upsampler == 'pixelshuffledirect':
- # for lightweight SR (to save parameters)
- self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
- (patches_resolution[0], patches_resolution[1]))
- elif self.upsampler == 'nearest+conv':
- # for real-world SR (less artifacts)
- assert self.upscale == 4, 'only support x4 now.'
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
- else:
- # for image denoising and JPEG compression artifact reduction
- self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
-
- self.apply(self._init_weights)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
- @torch.jit.ignore
- def no_weight_decay(self):
- return {'absolute_pos_embed'}
-
- @torch.jit.ignore
- def no_weight_decay_keywords(self):
- return {'relative_position_bias_table'}
-
- def check_image_size(self, x):
- _, _, h, w = x.size()
- mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
- mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
- x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
- return x
-
- def forward_features(self, x):
- x_size = (x.shape[2], x.shape[3])
- x = self.patch_embed(x)
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
-
- for layer in self.layers:
- x = layer(x, x_size)
-
- x = self.norm(x) # B L C
- x = self.patch_unembed(x, x_size)
-
- return x
-
- def forward_features_hf(self, x):
- x_size = (x.shape[2], x.shape[3])
- x = self.patch_embed(x)
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
-
- for layer in self.layers_hf:
- x = layer(x, x_size)
-
- x = self.norm(x) # B L C
- x = self.patch_unembed(x, x_size)
-
- return x
-
- def forward(self, x):
- H, W = x.shape[2:]
- x = self.check_image_size(x)
-
- self.mean = self.mean.type_as(x)
- x = (x - self.mean) * self.img_range
-
- if self.upsampler == 'pixelshuffle':
- # for classical SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- x = self.conv_last(self.upsample(x))
- elif self.upsampler == 'pixelshuffle_aux':
- bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False)
- bicubic = self.conv_bicubic(bicubic)
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- aux = self.conv_aux(x) # b, 3, LR_H, LR_W
- x = self.conv_after_aux(aux)
- x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale]
- x = self.conv_last(x)
- aux = aux / self.img_range + self.mean
- elif self.upsampler == 'pixelshuffle_hf':
- # for classical SR with HF
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x_before = self.conv_before_upsample(x)
- x_out = self.conv_last(self.upsample(x_before))
-
- x_hf = self.conv_first_hf(x_before)
- x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
- x_hf = self.conv_before_upsample_hf(x_hf)
- x_hf = self.conv_last_hf(self.upsample_hf(x_hf))
- x = x_out + x_hf
- x_hf = x_hf / self.img_range + self.mean
-
- elif self.upsampler == 'pixelshuffledirect':
- # for lightweight SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.upsample(x)
- elif self.upsampler == 'nearest+conv':
- # for real-world SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
- x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
- x = self.conv_last(self.lrelu(self.conv_hr(x)))
- else:
- # for image denoising and JPEG compression artifact reduction
- x_first = self.conv_first(x)
- res = self.conv_after_body(self.forward_features(x_first)) + x_first
- x = x + self.conv_last(res)
-
- x = x / self.img_range + self.mean
- if self.upsampler == "pixelshuffle_aux":
- return x[:, :, :H*self.upscale, :W*self.upscale], aux
-
- elif self.upsampler == "pixelshuffle_hf":
- x_out = x_out / self.img_range + self.mean
- return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
-
- else:
- return x[:, :, :H*self.upscale, :W*self.upscale]
-
- def flops(self):
- flops = 0
- H, W = self.patches_resolution
- flops += H * W * 3 * self.embed_dim * 9
- flops += self.patch_embed.flops()
- for layer in self.layers:
- flops += layer.flops()
- flops += H * W * 3 * self.embed_dim * self.embed_dim
- flops += self.upsample.flops()
- return flops
-
-
-if __name__ == '__main__':
- upscale = 4
- window_size = 8
- height = (1024 // upscale // window_size + 1) * window_size
- width = (720 // upscale // window_size + 1) * window_size
- model = Swin2SR(upscale=2, img_size=(height, width),
- window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
- embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
- print(model)
- print(height, width, model.flops() / 1e9)
-
- x = torch.randn((1, 3, height, width))
- x = model(x)
- print(x.shape)
diff --git a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js
index 45c7600a..df60c1a1 100644
--- a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js
+++ b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js
@@ -218,6 +218,8 @@ onUiLoaded(async() => {
canvas_hotkey_fullscreen: "KeyS",
canvas_hotkey_move: "KeyF",
canvas_hotkey_overlap: "KeyO",
+ canvas_hotkey_shrink_brush: "KeyQ",
+ canvas_hotkey_grow_brush: "KeyW",
canvas_disabled_functions: [],
canvas_show_tooltip: true,
canvas_auto_expand: true,
@@ -227,6 +229,8 @@ onUiLoaded(async() => {
const functionMap = {
"Zoom": "canvas_hotkey_zoom",
"Adjust brush size": "canvas_hotkey_adjust",
+ "Hotkey shrink brush": "canvas_hotkey_shrink_brush",
+ "Hotkey enlarge brush": "canvas_hotkey_grow_brush",
"Moving canvas": "canvas_hotkey_move",
"Fullscreen": "canvas_hotkey_fullscreen",
"Reset Zoom": "canvas_hotkey_reset",
@@ -686,7 +690,9 @@ onUiLoaded(async() => {
const hotkeyActions = {
[hotkeysConfig.canvas_hotkey_reset]: resetZoom,
[hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap,
- [hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen
+ [hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen,
+ [hotkeysConfig.canvas_hotkey_shrink_brush]: () => adjustBrushSize(elemId, 10),
+ [hotkeysConfig.canvas_hotkey_grow_brush]: () => adjustBrushSize(elemId, -10)
};
const action = hotkeyActions[event.code];
diff --git a/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py b/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py
index 2d8d2d1c..89b7c31f 100644
--- a/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py
+++ b/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py
@@ -4,6 +4,8 @@ from modules import shared
shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas Hotkeys"), {
"canvas_hotkey_zoom": shared.OptionInfo("Alt", "Zoom canvas", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"),
"canvas_hotkey_adjust": shared.OptionInfo("Ctrl", "Adjust brush size", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"),
+ "canvas_hotkey_shrink_brush": shared.OptionInfo("Q", "Shrink the brush size"),
+ "canvas_hotkey_grow_brush": shared.OptionInfo("W", "Enlarge the brush size"),
"canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas").info("To work correctly in firefox, turn off 'Automatically search the page text when typing' in the browser settings"),
"canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "),
"canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
@@ -11,5 +13,5 @@ shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas
"canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
"canvas_auto_expand": shared.OptionInfo(True, "Automatically expands an image that does not fit completely in the canvas area, similar to manually pressing the S and R buttons"),
"canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"),
- "canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
+ "canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size","Hotkey enlarge brush","Hotkey shrink brush","Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
}))
diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py
index ac2c3de4..4c10d9c7 100644
--- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py
+++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py
@@ -1,7 +1,7 @@
import math
import gradio as gr
-from modules import scripts, shared, ui_components, ui_settings, generation_parameters_copypaste
+from modules import scripts, shared, ui_components, ui_settings, infotext_utils
from modules.ui_components import FormColumn
@@ -25,7 +25,7 @@ class ExtraOptionsSection(scripts.Script):
extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img
elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img")
- mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping}
+ mapping = {k: v for v, k in infotext_utils.infotext_to_setting_name_mapping}
with gr.Blocks() as interface:
with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname):
diff --git a/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py b/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py
new file mode 100644
index 00000000..d9024344
--- /dev/null
+++ b/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py
@@ -0,0 +1,747 @@
+import numpy as np
+import gradio as gr
+import math
+from modules.ui_components import InputAccordion
+import modules.scripts as scripts
+
+
+class SoftInpaintingSettings:
+ def __init__(self,
+ mask_blend_power,
+ mask_blend_scale,
+ inpaint_detail_preservation,
+ composite_mask_influence,
+ composite_difference_threshold,
+ composite_difference_contrast):
+ self.mask_blend_power = mask_blend_power
+ self.mask_blend_scale = mask_blend_scale
+ self.inpaint_detail_preservation = inpaint_detail_preservation
+ self.composite_mask_influence = composite_mask_influence
+ self.composite_difference_threshold = composite_difference_threshold
+ self.composite_difference_contrast = composite_difference_contrast
+
+ def add_generation_params(self, dest):
+ dest[enabled_gen_param_label] = True
+ dest[gen_param_labels.mask_blend_power] = self.mask_blend_power
+ dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale
+ dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation
+ dest[gen_param_labels.composite_mask_influence] = self.composite_mask_influence
+ dest[gen_param_labels.composite_difference_threshold] = self.composite_difference_threshold
+ dest[gen_param_labels.composite_difference_contrast] = self.composite_difference_contrast
+
+
+# ------------------- Methods -------------------
+
+def processing_uses_inpainting(p):
+ # TODO: Figure out a better way to determine if inpainting is being used by p
+ if getattr(p, "image_mask", None) is not None:
+ return True
+
+ if getattr(p, "mask", None) is not None:
+ return True
+
+ if getattr(p, "nmask", None) is not None:
+ return True
+
+ return False
+
+
+def latent_blend(settings, a, b, t):
+ """
+ Interpolates two latent image representations according to the parameter t,
+ where the interpolated vectors' magnitudes are also interpolated separately.
+ The "detail_preservation" factor biases the magnitude interpolation towards
+ the larger of the two magnitudes.
+ """
+ import torch
+
+ # NOTE: We use inplace operations wherever possible.
+
+ # [4][w][h] to [1][4][w][h]
+ t2 = t.unsqueeze(0)
+ # [4][w][h] to [1][1][w][h] - the [4] seem redundant.
+ t3 = t[0].unsqueeze(0).unsqueeze(0)
+
+ one_minus_t2 = 1 - t2
+ one_minus_t3 = 1 - t3
+
+ # Linearly interpolate the image vectors.
+ a_scaled = a * one_minus_t2
+ b_scaled = b * t2
+ image_interp = a_scaled
+ image_interp.add_(b_scaled)
+ result_type = image_interp.dtype
+ del a_scaled, b_scaled, t2, one_minus_t2
+
+ # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.)
+ # 64-bit operations are used here to allow large exponents.
+ current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001)
+
+ # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
+ a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_(
+ settings.inpaint_detail_preservation) * one_minus_t3
+ b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_(
+ settings.inpaint_detail_preservation) * t3
+ desired_magnitude = a_magnitude
+ desired_magnitude.add_(b_magnitude).pow_(1 / settings.inpaint_detail_preservation)
+ del a_magnitude, b_magnitude, t3, one_minus_t3
+
+ # Change the linearly interpolated image vectors' magnitudes to the value we want.
+ # This is the last 64-bit operation.
+ image_interp_scaling_factor = desired_magnitude
+ image_interp_scaling_factor.div_(current_magnitude)
+ image_interp_scaling_factor = image_interp_scaling_factor.to(result_type)
+ image_interp_scaled = image_interp
+ image_interp_scaled.mul_(image_interp_scaling_factor)
+ del current_magnitude
+ del desired_magnitude
+ del image_interp
+ del image_interp_scaling_factor
+ del result_type
+
+ return image_interp_scaled
+
+
+def get_modified_nmask(settings, nmask, sigma):
+ """
+ Converts a negative mask representing the transparency of the original latent vectors being overlayed
+ to a mask that is scaled according to the denoising strength for this step.
+
+ Where:
+ 0 = fully opaque, infinite density, fully masked
+ 1 = fully transparent, zero density, fully unmasked
+
+ We bring this transparency to a power, as this allows one to simulate N number of blending operations
+ where N can be any positive real value. Using this one can control the balance of influence between
+ the denoiser and the original latents according to the sigma value.
+
+ NOTE: "mask" is not used
+ """
+ import torch
+ return torch.pow(nmask, (sigma ** settings.mask_blend_power) * settings.mask_blend_scale)
+
+
+def apply_adaptive_masks(
+ settings: SoftInpaintingSettings,
+ nmask,
+ latent_orig,
+ latent_processed,
+ overlay_images,
+ width, height,
+ paste_to):
+ import torch
+ import modules.processing as proc
+ import modules.images as images
+ from PIL import Image, ImageOps, ImageFilter
+
+ # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control.
+ latent_mask = nmask[0].float()
+ # convert the original mask into a form we use to scale distances for thresholding
+ mask_scalar = 1 - (torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2))
+ mask_scalar = (0.5 * (1 - settings.composite_mask_influence)
+ + mask_scalar * settings.composite_mask_influence)
+ mask_scalar = mask_scalar / (1.00001 - mask_scalar)
+ mask_scalar = mask_scalar.cpu().numpy()
+
+ latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1)
+
+ kernel, kernel_center = get_gaussian_kernel(stddev_radius=1.5, max_radius=2)
+
+ masks_for_overlay = []
+
+ for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)):
+ converted_mask = distance_map.float().cpu().numpy()
+ converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center,
+ percentile_min=0.9, percentile_max=1, min_width=1)
+ converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center,
+ percentile_min=0.25, percentile_max=0.75, min_width=1)
+
+ # The distance at which opacity of original decreases to 50%
+ half_weighted_distance = settings.composite_difference_threshold * mask_scalar
+ converted_mask = converted_mask / half_weighted_distance
+
+ converted_mask = 1 / (1 + converted_mask ** settings.composite_difference_contrast)
+ converted_mask = smootherstep(converted_mask)
+ converted_mask = 1 - converted_mask
+ converted_mask = 255. * converted_mask
+ converted_mask = converted_mask.astype(np.uint8)
+ converted_mask = Image.fromarray(converted_mask)
+ converted_mask = images.resize_image(2, converted_mask, width, height)
+ converted_mask = proc.create_binary_mask(converted_mask, round=False)
+
+ # Remove aliasing artifacts using a gaussian blur.
+ converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))
+
+ # Expand the mask to fit the whole image if needed.
+ if paste_to is not None:
+ converted_mask = proc.uncrop(converted_mask,
+ (overlay_image.width, overlay_image.height),
+ paste_to)
+
+ masks_for_overlay.append(converted_mask)
+
+ image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))
+ image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"),
+ mask=ImageOps.invert(converted_mask.convert('L')))
+
+ overlay_images[i] = image_masked.convert('RGBA')
+
+ return masks_for_overlay
+
+
+def apply_masks(
+ settings,
+ nmask,
+ overlay_images,
+ width, height,
+ paste_to):
+ import torch
+ import modules.processing as proc
+ import modules.images as images
+ from PIL import Image, ImageOps, ImageFilter
+
+ converted_mask = nmask[0].float()
+ converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(settings.mask_blend_scale / 2)
+ converted_mask = 255. * converted_mask
+ converted_mask = converted_mask.cpu().numpy().astype(np.uint8)
+ converted_mask = Image.fromarray(converted_mask)
+ converted_mask = images.resize_image(2, converted_mask, width, height)
+ converted_mask = proc.create_binary_mask(converted_mask, round=False)
+
+ # Remove aliasing artifacts using a gaussian blur.
+ converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))
+
+ # Expand the mask to fit the whole image if needed.
+ if paste_to is not None:
+ converted_mask = proc.uncrop(converted_mask,
+ (width, height),
+ paste_to)
+
+ masks_for_overlay = []
+
+ for i, overlay_image in enumerate(overlay_images):
+ masks_for_overlay[i] = converted_mask
+
+ image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))
+ image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"),
+ mask=ImageOps.invert(converted_mask.convert('L')))
+
+ overlay_images[i] = image_masked.convert('RGBA')
+
+ return masks_for_overlay
+
+
+def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0):
+ """
+ Generalization convolution filter capable of applying
+ weighted mean, median, maximum, and minimum filters
+ parametrically using an arbitrary kernel.
+
+ Args:
+ img (nparray):
+ The image, a 2-D array of floats, to which the filter is being applied.
+ kernel (nparray):
+ The kernel, a 2-D array of floats.
+ kernel_center (nparray):
+ The kernel center coordinate, a 1-D array with two elements.
+ percentile_min (float):
+ The lower bound of the histogram window used by the filter,
+ from 0 to 1.
+ percentile_max (float):
+ The upper bound of the histogram window used by the filter,
+ from 0 to 1.
+ min_width (float):
+ The minimum size of the histogram window bounds, in weight units.
+ Must be greater than 0.
+
+ Returns:
+ (nparray): A filtered copy of the input image "img", a 2-D array of floats.
+ """
+
+ # Converts an index tuple into a vector.
+ def vec(x):
+ return np.array(x)
+
+ kernel_min = -kernel_center
+ kernel_max = vec(kernel.shape) - kernel_center
+
+ def weighted_histogram_filter_single(idx):
+ idx = vec(idx)
+ min_index = np.maximum(0, idx + kernel_min)
+ max_index = np.minimum(vec(img.shape), idx + kernel_max)
+ window_shape = max_index - min_index
+
+ class WeightedElement:
+ """
+ An element of the histogram, its weight
+ and bounds.
+ """
+
+ def __init__(self, value, weight):
+ self.value: float = value
+ self.weight: float = weight
+ self.window_min: float = 0.0
+ self.window_max: float = 1.0
+
+ # Collect the values in the image as WeightedElements,
+ # weighted by their corresponding kernel values.
+ values = []
+ for window_tup in np.ndindex(tuple(window_shape)):
+ window_index = vec(window_tup)
+ image_index = window_index + min_index
+ centered_kernel_index = image_index - idx
+ kernel_index = centered_kernel_index + kernel_center
+ element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)])
+ values.append(element)
+
+ def sort_key(x: WeightedElement):
+ return x.value
+
+ values.sort(key=sort_key)
+
+ # Calculate the height of the stack (sum)
+ # and each sample's range they occupy in the stack
+ sum = 0
+ for i in range(len(values)):
+ values[i].window_min = sum
+ sum += values[i].weight
+ values[i].window_max = sum
+
+ # Calculate what range of this stack ("window")
+ # we want to get the weighted average across.
+ window_min = sum * percentile_min
+ window_max = sum * percentile_max
+ window_width = window_max - window_min
+
+ # Ensure the window is within the stack and at least a certain size.
+ if window_width < min_width:
+ window_center = (window_min + window_max) / 2
+ window_min = window_center - min_width / 2
+ window_max = window_center + min_width / 2
+
+ if window_max > sum:
+ window_max = sum
+ window_min = sum - min_width
+
+ if window_min < 0:
+ window_min = 0
+ window_max = min_width
+
+ value = 0
+ value_weight = 0
+
+ # Get the weighted average of all the samples
+ # that overlap with the window, weighted
+ # by the size of their overlap.
+ for i in range(len(values)):
+ if window_min >= values[i].window_max:
+ continue
+ if window_max <= values[i].window_min:
+ break
+
+ s = max(window_min, values[i].window_min)
+ e = min(window_max, values[i].window_max)
+ w = e - s
+
+ value += values[i].value * w
+ value_weight += w
+
+ return value / value_weight if value_weight != 0 else 0
+
+ img_out = img.copy()
+
+ # Apply the kernel operation over each pixel.
+ for index in np.ndindex(img.shape):
+ img_out[index] = weighted_histogram_filter_single(index)
+
+ return img_out
+
+
+def smoothstep(x):
+ """
+ The smoothstep function, input should be clamped to 0-1 range.
+ Turns a diagonal line (f(x) = x) into a sigmoid-like curve.
+ """
+ return x * x * (3 - 2 * x)
+
+
+def smootherstep(x):
+ """
+ The smootherstep function, input should be clamped to 0-1 range.
+ Turns a diagonal line (f(x) = x) into a sigmoid-like curve.
+ """
+ return x * x * x * (x * (6 * x - 15) + 10)
+
+
+def get_gaussian_kernel(stddev_radius=1.0, max_radius=2):
+ """
+ Creates a Gaussian kernel with thresholded edges.
+
+ Args:
+ stddev_radius (float):
+ Standard deviation of the gaussian kernel, in pixels.
+ max_radius (int):
+ The size of the filter kernel. The number of pixels is (max_radius*2+1) ** 2.
+ The kernel is thresholded so that any values one pixel beyond this radius
+ is weighted at 0.
+
+ Returns:
+ (nparray, nparray): A kernel array (shape: (N, N)), its center coordinate (shape: (2))
+ """
+
+ # Evaluates a 0-1 normalized gaussian function for a given square distance from the mean.
+ def gaussian(sqr_mag):
+ return math.exp(-sqr_mag / (stddev_radius * stddev_radius))
+
+ # Helper function for converting a tuple to an array.
+ def vec(x):
+ return np.array(x)
+
+ """
+ Since a gaussian is unbounded, we need to limit ourselves
+ to a finite range.
+ We taper the ends off at the end of that range so they equal zero
+ while preserving the maximum value of 1 at the mean.
+ """
+ zero_radius = max_radius + 1.0
+ gauss_zero = gaussian(zero_radius * zero_radius)
+ gauss_kernel_scale = 1 / (1 - gauss_zero)
+
+ def gaussian_kernel_func(coordinate):
+ x = coordinate[0] ** 2.0 + coordinate[1] ** 2.0
+ x = gaussian(x)
+ x -= gauss_zero
+ x *= gauss_kernel_scale
+ x = max(0.0, x)
+ return x
+
+ size = max_radius * 2 + 1
+ kernel_center = max_radius
+ kernel = np.zeros((size, size))
+
+ for index in np.ndindex(kernel.shape):
+ kernel[index] = gaussian_kernel_func(vec(index) - kernel_center)
+
+ return kernel, kernel_center
+
+
+# ------------------- Constants -------------------
+
+
+default = SoftInpaintingSettings(1, 0.5, 4, 0, 0.5, 2)
+
+enabled_ui_label = "Soft inpainting"
+enabled_gen_param_label = "Soft inpainting enabled"
+enabled_el_id = "soft_inpainting_enabled"
+
+ui_labels = SoftInpaintingSettings(
+ "Schedule bias",
+ "Preservation strength",
+ "Transition contrast boost",
+ "Mask influence",
+ "Difference threshold",
+ "Difference contrast")
+
+ui_info = SoftInpaintingSettings(
+ "Shifts when preservation of original content occurs during denoising.",
+ "How strongly partially masked content should be preserved.",
+ "Amplifies the contrast that may be lost in partially masked regions.",
+ "How strongly the original mask should bias the difference threshold.",
+ "How much an image region can change before the original pixels are not blended in anymore.",
+ "How sharp the transition should be between blended and not blended.")
+
+gen_param_labels = SoftInpaintingSettings(
+ "Soft inpainting schedule bias",
+ "Soft inpainting preservation strength",
+ "Soft inpainting transition contrast boost",
+ "Soft inpainting mask influence",
+ "Soft inpainting difference threshold",
+ "Soft inpainting difference contrast")
+
+el_ids = SoftInpaintingSettings(
+ "mask_blend_power",
+ "mask_blend_scale",
+ "inpaint_detail_preservation",
+ "composite_mask_influence",
+ "composite_difference_threshold",
+ "composite_difference_contrast")
+
+
+# ------------------- Script -------------------
+
+
+class Script(scripts.Script):
+ def __init__(self):
+ self.section = "inpaint"
+ self.masks_for_overlay = None
+ self.overlay_images = None
+
+ def title(self):
+ return "Soft Inpainting"
+
+ def show(self, is_img2img):
+ return scripts.AlwaysVisible if is_img2img else False
+
+ def ui(self, is_img2img):
+ if not is_img2img:
+ return
+
+ with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled:
+ with gr.Group():
+ gr.Markdown(
+ """
+ Soft inpainting allows you to **seamlessly blend original content with inpainted content** according to the mask opacity.
+ **High _Mask blur_** values are recommended!
+ """)
+
+ power = \
+ gr.Slider(label=ui_labels.mask_blend_power,
+ info=ui_info.mask_blend_power,
+ minimum=0,
+ maximum=8,
+ step=0.1,
+ value=default.mask_blend_power,
+ elem_id=el_ids.mask_blend_power)
+ scale = \
+ gr.Slider(label=ui_labels.mask_blend_scale,
+ info=ui_info.mask_blend_scale,
+ minimum=0,
+ maximum=8,
+ step=0.05,
+ value=default.mask_blend_scale,
+ elem_id=el_ids.mask_blend_scale)
+ detail = \
+ gr.Slider(label=ui_labels.inpaint_detail_preservation,
+ info=ui_info.inpaint_detail_preservation,
+ minimum=1,
+ maximum=32,
+ step=0.5,
+ value=default.inpaint_detail_preservation,
+ elem_id=el_ids.inpaint_detail_preservation)
+
+ gr.Markdown(
+ """
+ ### Pixel Composite Settings
+ """)
+
+ mask_inf = \
+ gr.Slider(label=ui_labels.composite_mask_influence,
+ info=ui_info.composite_mask_influence,
+ minimum=0,
+ maximum=1,
+ step=0.05,
+ value=default.composite_mask_influence,
+ elem_id=el_ids.composite_mask_influence)
+
+ dif_thresh = \
+ gr.Slider(label=ui_labels.composite_difference_threshold,
+ info=ui_info.composite_difference_threshold,
+ minimum=0,
+ maximum=8,
+ step=0.25,
+ value=default.composite_difference_threshold,
+ elem_id=el_ids.composite_difference_threshold)
+
+ dif_contr = \
+ gr.Slider(label=ui_labels.composite_difference_contrast,
+ info=ui_info.composite_difference_contrast,
+ minimum=0,
+ maximum=8,
+ step=0.25,
+ value=default.composite_difference_contrast,
+ elem_id=el_ids.composite_difference_contrast)
+
+ with gr.Accordion("Help", open=False):
+ gr.Markdown(
+ f"""
+ ### {ui_labels.mask_blend_power}
+
+ The blending strength of original content is scaled proportionally with the decreasing noise level values at each step (sigmas).
+ This ensures that the influence of the denoiser and original content preservation is roughly balanced at each step.
+ This balance can be shifted using this parameter, controlling whether earlier or later steps have stronger preservation.
+
+ - **Below 1**: Stronger preservation near the end (with low sigma)
+ - **1**: Balanced (proportional to sigma)
+ - **Above 1**: Stronger preservation in the beginning (with high sigma)
+ """)
+ gr.Markdown(
+ f"""
+ ### {ui_labels.mask_blend_scale}
+
+ Skews whether partially masked image regions should be more likely to preserve the original content or favor inpainted content.
+ This may need to be adjusted depending on the {ui_labels.mask_blend_power}, CFG Scale, prompt and Denoising strength.
+
+ - **Low values**: Favors generated content.
+ - **High values**: Favors original content.
+ """)
+ gr.Markdown(
+ f"""
+ ### {ui_labels.inpaint_detail_preservation}
+
+ This parameter controls how the original latent vectors and denoised latent vectors are interpolated.
+ With higher values, the magnitude of the resulting blended vector will be closer to the maximum of the two interpolated vectors.
+ This can prevent the loss of contrast that occurs with linear interpolation.
+
+ - **Low values**: Softer blending, details may fade.
+ - **High values**: Stronger contrast, may over-saturate colors.
+ """)
+
+ gr.Markdown(
+ """
+ ## Pixel Composite Settings
+
+ Masks are generated based on how much a part of the image changed after denoising.
+ These masks are used to blend the original and final images together.
+ If the difference is low, the original pixels are used instead of the pixels returned by the inpainting process.
+ """)
+
+ gr.Markdown(
+ f"""
+ ### {ui_labels.composite_mask_influence}
+
+ This parameter controls how much the mask should bias this sensitivity to difference.
+
+ - **0**: Ignore the mask, only consider differences in image content.
+ - **1**: Follow the mask closely despite image content changes.
+ """)
+
+ gr.Markdown(
+ f"""
+ ### {ui_labels.composite_difference_threshold}
+
+ This value represents the difference at which the original pixels will have less than 50% opacity.
+
+ - **Low values**: Two images patches must be almost the same in order to retain original pixels.
+ - **High values**: Two images patches can be very different and still retain original pixels.
+ """)
+
+ gr.Markdown(
+ f"""
+ ### {ui_labels.composite_difference_contrast}
+
+ This value represents the contrast between the opacity of the original and inpainted content.
+
+ - **Low values**: The blend will be more gradual and have longer transitions, but may cause ghosting.
+ - **High values**: Ghosting will be less common, but transitions may be very sudden.
+ """)
+
+ self.infotext_fields = [(soft_inpainting_enabled, enabled_gen_param_label),
+ (power, gen_param_labels.mask_blend_power),
+ (scale, gen_param_labels.mask_blend_scale),
+ (detail, gen_param_labels.inpaint_detail_preservation),
+ (mask_inf, gen_param_labels.composite_mask_influence),
+ (dif_thresh, gen_param_labels.composite_difference_threshold),
+ (dif_contr, gen_param_labels.composite_difference_contrast)]
+
+ self.paste_field_names = []
+ for _, field_name in self.infotext_fields:
+ self.paste_field_names.append(field_name)
+
+ return [soft_inpainting_enabled,
+ power,
+ scale,
+ detail,
+ mask_inf,
+ dif_thresh,
+ dif_contr]
+
+ def process(self, p, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr):
+ if not enabled:
+ return
+
+ if not processing_uses_inpainting(p):
+ return
+
+ # Shut off the rounding it normally does.
+ p.mask_round = False
+
+ settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)
+
+ # p.extra_generation_params["Mask rounding"] = False
+ settings.add_generation_params(p.extra_generation_params)
+
+ def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation, mask_inf,
+ dif_thresh, dif_contr):
+ if not enabled:
+ return
+
+ if not processing_uses_inpainting(p):
+ return
+
+ if mba.is_final_blend:
+ mba.blended_latent = mba.current_latent
+ return
+
+ settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)
+
+ # todo: Why is sigma 2D? Both values are the same.
+ mba.blended_latent = latent_blend(settings,
+ mba.init_latent,
+ mba.current_latent,
+ get_modified_nmask(settings, mba.nmask, mba.sigma[0]))
+
+ def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation, mask_inf,
+ dif_thresh, dif_contr):
+ if not enabled:
+ return
+
+ if not processing_uses_inpainting(p):
+ return
+
+ nmask = getattr(p, "nmask", None)
+ if nmask is None:
+ return
+
+ from modules import images
+ from modules.shared import opts
+
+ settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)
+
+ # since the original code puts holes in the existing overlay images,
+ # we have to rebuild them.
+ self.overlay_images = []
+ for img in p.init_images:
+
+ image = images.flatten(img, opts.img2img_background_color)
+
+ if p.paste_to is None and p.resize_mode != 3:
+ image = images.resize_image(p.resize_mode, image, p.width, p.height)
+
+ self.overlay_images.append(image.convert('RGBA'))
+
+ if len(p.init_images) == 1:
+ self.overlay_images = self.overlay_images * p.batch_size
+
+ if getattr(ps.samples, 'already_decoded', False):
+ self.masks_for_overlay = apply_masks(settings=settings,
+ nmask=nmask,
+ overlay_images=self.overlay_images,
+ width=p.width,
+ height=p.height,
+ paste_to=p.paste_to)
+ else:
+ self.masks_for_overlay = apply_adaptive_masks(settings=settings,
+ nmask=nmask,
+ latent_orig=p.init_latent,
+ latent_processed=ps.samples,
+ overlay_images=self.overlay_images,
+ width=p.width,
+ height=p.height,
+ paste_to=p.paste_to)
+
+ def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale,
+ detail_preservation, mask_inf, dif_thresh, dif_contr):
+ if not enabled:
+ return
+
+ if not processing_uses_inpainting(p):
+ return
+
+ if self.masks_for_overlay is None:
+ return
+
+ if self.overlay_images is None:
+ return
+
+ ppmo.mask_for_overlay = self.masks_for_overlay[ppmo.index]
+ ppmo.overlay_image = self.overlay_images[ppmo.index]