aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJC-Array <44535867+JC-Array@users.noreply.github.com>2022-10-10 18:11:02 -0500
committerGitHub <noreply@github.com>2022-10-10 18:11:02 -0500
commitd66bc86159d415005f0745fdb5724bcd95576352 (patch)
tree2544c33a8f443f226c9cf4bea7df7e3a30369812
parent76ef3d75f61253516c024553335d9083d9660a8a (diff)
parent47f5e216da2af4b1faf232a620572f8b357855d5 (diff)
Merge pull request #2 from JC-Array/master
resolve merge conflicts
-rw-r--r--.github/PULL_REQUEST_TEMPLATE/pull_request_template.md28
-rw-r--r--javascript/contextMenus.js4
-rw-r--r--javascript/hints.js2
-rw-r--r--launch.py2
-rw-r--r--modules/devices.py6
-rw-r--r--modules/processing.py20
-rw-r--r--modules/safe.py6
-rw-r--r--modules/sd_hijack.py76
-rw-r--r--modules/sd_hijack_optimizations.py2
-rw-r--r--modules/sd_models.py7
-rw-r--r--modules/sd_samplers.py4
-rw-r--r--modules/shared.py4
-rw-r--r--modules/swinir_model.py35
-rw-r--r--modules/swinir_model_arch_v2.py1017
-rw-r--r--modules/textual_inversion/dataset.py3
-rw-r--r--modules/textual_inversion/preprocess.py19
-rw-r--r--modules/textual_inversion/textual_inversion.py14
-rw-r--r--modules/ui.py18
-rw-r--r--requirements.txt1
-rw-r--r--requirements_versions.txt1
-rw-r--r--script.js16
-rw-r--r--scripts/xy_grid.py5
-rw-r--r--style.css11
23 files changed, 1233 insertions, 68 deletions
diff --git a/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md b/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md
new file mode 100644
index 00000000..86009613
--- /dev/null
+++ b/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md
@@ -0,0 +1,28 @@
+# Please read the [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) before submitting a pull request!
+
+If you have a large change, pay special attention to this paragraph:
+
+> Before making changes, if you think that your feature will result in more than 100 lines changing, find me and talk to me about the feature you are proposing. It pains me to reject the hard work someone else did, but I won't add everything to the repo, and it's better if the rejection happens before you have to waste time working on the feature.
+
+Otherwise, after making sure you're following the rules described in wiki page, remove this section and continue on.
+
+**Describe what this pull request is trying to achieve.**
+
+A clear and concise description of what you're trying to accomplish with this, so your intent doesn't have to be extracted from your code.
+
+**Additional notes and description of your changes**
+
+More technical discussion about your changes go here, plus anything that a maintainer might have to specifically take a look at, or be wary of.
+
+**Environment this was tested in**
+
+List the environment you have developed / tested this on. As per the contributing page, changes should be able to work on Windows out of the box.
+ - OS: [e.g. Windows, Linux]
+ - Browser [e.g. chrome, safari]
+ - Graphics card [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB]
+
+**Screenshots or videos of your changes**
+
+If applicable, screenshots or a video showing off your changes. If it edits an existing UI, it should ideally contain a comparison of what used to be there, before your changes were made.
+
+This is **required** for anything that touches the user interface. \ No newline at end of file
diff --git a/javascript/contextMenus.js b/javascript/contextMenus.js
index 2d82269f..7852793c 100644
--- a/javascript/contextMenus.js
+++ b/javascript/contextMenus.js
@@ -147,10 +147,6 @@ generateOnRepeatId = appendContextMenuOption('#txt2img_generate','Generate forev
cancelGenerateForever = function(){
clearInterval(window.generateOnRepeatInterval)
- let interruptbutton = gradioApp().querySelector('#txt2img_interrupt');
- if(interruptbutton.offsetParent){
- interruptbutton.click();
- }
}
appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever)
diff --git a/javascript/hints.js b/javascript/hints.js
index 8e352e94..045f2d3c 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -79,6 +79,8 @@ titles = {
"Highres. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition",
"Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.",
+ "Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
+ "Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be bevaing in an unethical manner.",
}
diff --git a/launch.py b/launch.py
index f42f557d..e1000f55 100644
--- a/launch.py
+++ b/launch.py
@@ -127,7 +127,7 @@ def prepare_enviroment():
if not is_installed("xformers") and xformers and platform.python_version().startswith("3.10"):
if platform.system() == "Windows":
- run_pip("install https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/a/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers")
+ run_pip("install https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/c/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers")
elif platform.system() == "Linux":
run_pip("install xformers", "xformers")
diff --git a/modules/devices.py b/modules/devices.py
index 0158b11f..03ef58f1 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -36,6 +36,7 @@ errors.run(enable_tf32, "Enabling TF32")
device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
dtype = torch.float16
+dtype_vae = torch.float16
def randn(seed, shape):
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
@@ -59,9 +60,12 @@ def randn_without_seed(shape):
return torch.randn(shape, device=device)
-def autocast():
+def autocast(disable=False):
from modules import shared
+ if disable:
+ return contextlib.nullcontext()
+
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
return contextlib.nullcontext()
diff --git a/modules/processing.py b/modules/processing.py
index 94d2dd62..698b3069 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -207,7 +207,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
# enables the generation of additional tensors with noise that the sampler will use during its processing.
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
# produce the same images as with two batches [100], [101].
- if p is not None and p.sampler is not None and len(seeds) > 1 and opts.enable_batch_seeds:
+ if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0):
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
else:
sampler_noises = None
@@ -247,6 +247,9 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
if sampler_noises is not None:
cnt = p.sampler.number_of_needed_noises(p)
+ if opts.eta_noise_seed_delta > 0:
+ torch.manual_seed(seed + opts.eta_noise_seed_delta)
+
for j in range(cnt):
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
@@ -259,6 +262,13 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
return x
+def decode_first_stage(model, x):
+ with devices.autocast(disable=x.dtype == devices.dtype_vae):
+ x = model.decode_first_stage(x)
+
+ return x
+
+
def get_fixed_seed(seed):
if seed is None or seed == '' or seed == -1:
return int(random.randrange(4294967294))
@@ -294,6 +304,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Denoising strength": getattr(p, 'denoising_strength', None),
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
"Clip skip": None if clip_skip <= 1 else clip_skip,
+ "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
}
generation_params.update(p.extra_generation_params)
@@ -398,9 +409,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
# use the image collected previously in sampler loop
samples_ddim = shared.state.current_latent
- samples_ddim = samples_ddim.to(devices.dtype)
-
- x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
+ samples_ddim = samples_ddim.to(devices.dtype_vae)
+ x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
del samples_ddim
@@ -533,7 +543,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if self.scale_latent:
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
else:
- decoded_samples = self.sd_model.decode_first_stage(samples)
+ decoded_samples = decode_first_stage(self.sd_model, samples)
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
diff --git a/modules/safe.py b/modules/safe.py
index 4d06f2a5..05917463 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -12,6 +12,10 @@ import _codecs
import zipfile
+# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
+TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
+
+
def encode(*args):
out = _codecs.encode(*args)
return out
@@ -20,7 +24,7 @@ def encode(*args):
class RestrictedUnpickler(pickle.Unpickler):
def persistent_load(self, saved_id):
assert saved_id[0] == 'storage'
- return torch.storage._TypedStorage()
+ return TypedStorage()
def find_class(self, module, name):
if module == 'collections' and name == 'OrderedDict':
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 437acce4..827bf304 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -23,7 +23,7 @@ def apply_optimizations():
ldm.modules.diffusionmodules.model.nonlinearity = silu
- if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and torch.cuda.get_device_capability(shared.device) == (8, 6)):
+ if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)):
print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
@@ -43,10 +43,7 @@ def undo_optimizations():
def get_target_prompt_token_count(token_count):
- if token_count < 75:
- return 75
-
- return math.ceil(token_count / 10) * 10
+ return math.ceil(max(token_count, 1) / 75) * 75
class StableDiffusionModelHijack:
@@ -127,7 +124,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.token_mults[ident] = mult
def tokenize_line(self, line, used_custom_terms, hijack_comments):
- id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
if opts.enable_emphasis:
@@ -154,7 +150,13 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
i += 1
else:
emb_len = int(embedding.vec.shape[0])
- fixes.append((len(remade_tokens), embedding))
+ iteration = len(remade_tokens) // 75
+ if (len(remade_tokens) + emb_len) // 75 != iteration:
+ rem = (75 * (iteration + 1) - len(remade_tokens))
+ remade_tokens += [id_end] * rem
+ multipliers += [1.0] * rem
+ iteration += 1
+ fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
remade_tokens += [0] * emb_len
multipliers += [weight] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum()))
@@ -162,10 +164,10 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
token_count = len(remade_tokens)
prompt_target_length = get_target_prompt_token_count(token_count)
- tokens_to_add = prompt_target_length - len(remade_tokens) + 1
+ tokens_to_add = prompt_target_length - len(remade_tokens)
- remade_tokens = [id_start] + remade_tokens + [id_end] * tokens_to_add
- multipliers = [1.0] + multipliers + [1.0] * tokens_to_add
+ remade_tokens = remade_tokens + [id_end] * tokens_to_add
+ multipliers = multipliers + [1.0] * tokens_to_add
return remade_tokens, fixes, multipliers, token_count
@@ -260,29 +262,55 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
-
+
def forward(self, text):
-
- if opts.use_old_emphasis_implementation:
+ use_old = opts.use_old_emphasis_implementation
+ if use_old:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
else:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
- self.hijack.fixes = hijack_fixes
self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0:
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
+
+ if use_old:
+ self.hijack.fixes = hijack_fixes
+ return self.process_tokens(remade_batch_tokens, batch_multipliers)
+
+ z = None
+ i = 0
+ while max(map(len, remade_batch_tokens)) != 0:
+ rem_tokens = [x[75:] for x in remade_batch_tokens]
+ rem_multipliers = [x[75:] for x in batch_multipliers]
+
+ self.hijack.fixes = []
+ for unfiltered in hijack_fixes:
+ fixes = []
+ for fix in unfiltered:
+ if fix[0] == i:
+ fixes.append(fix[1])
+ self.hijack.fixes.append(fixes)
+
+ z1 = self.process_tokens([x[:75] for x in remade_batch_tokens], [x[:75] for x in batch_multipliers])
+ z = z1 if z is None else torch.cat((z, z1), axis=-2)
+
+ remade_batch_tokens = rem_tokens
+ batch_multipliers = rem_multipliers
+ i += 1
+
+ return z
+
+
+ def process_tokens(self, remade_batch_tokens, batch_multipliers):
+ if not opts.use_old_emphasis_implementation:
+ remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
+ batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
+
+ tokens = torch.asarray(remade_batch_tokens).to(device)
+ outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
- target_token_count = get_target_prompt_token_count(token_count) + 2
-
- position_ids_array = [min(x, 75) for x in range(target_token_count-1)] + [76]
- position_ids = torch.asarray(position_ids_array, device=devices.device).expand((1, -1))
-
- remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens]
- tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device)
-
- outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=-opts.CLIP_stop_at_last_layers)
if opts.CLIP_stop_at_last_layers > 1:
z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
z = self.wrapped.transformer.text_model.final_layer_norm(z)
@@ -290,7 +318,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
z = outputs.last_hidden_state
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
- batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers]
+ batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device)
original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 634fb4b2..18408e62 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -13,8 +13,6 @@ from modules import shared
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
try:
import xformers.ops
- import functorch
- xformers._is_functorch_available = True
shared.xformers_available = True
except Exception:
print("Cannot import xformers", file=sys.stderr)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index e63d3c29..0a55b4c3 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -149,8 +149,13 @@ def load_model_weights(model, checkpoint_info):
model.half()
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
+ devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
+
+ if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None:
+ vae_file = shared.cmd_opts.vae_path
+
if os.path.exists(vae_file):
print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location="cpu")
@@ -158,6 +163,8 @@ def load_model_weights(model, checkpoint_info):
model.first_stage_model.load_state_dict(vae_dict)
+ model.first_stage_model.to(devices.dtype_vae)
+
model.sd_model_hash = sd_model_hash
model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 6e743f7e..d168b938 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -7,7 +7,7 @@ import inspect
import k_diffusion.sampling
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
-from modules import prompt_parser
+from modules import prompt_parser, devices, processing
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -83,7 +83,7 @@ def setup_img2img_steps(p, steps=None):
def sample_to_image(samples):
- x_sample = shared.sd_model.decode_first_stage(samples[0:1].type(shared.sd_model.dtype))[0]
+ x_sample = processing.decode_first_stage(shared.sd_model, samples[0:1])[0]
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
diff --git a/modules/shared.py b/modules/shared.py
index 2e307809..99a0264c 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -25,6 +25,7 @@ parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to director
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
+parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
@@ -65,6 +66,7 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
+parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
@@ -171,6 +173,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
+ "do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
}))
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
@@ -259,6 +262,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
}))
if cmd_opts.deepdanbooru:
diff --git a/modules/swinir_model.py b/modules/swinir_model.py
index fbd11f84..baa02e3d 100644
--- a/modules/swinir_model.py
+++ b/modules/swinir_model.py
@@ -10,6 +10,7 @@ from tqdm import tqdm
from modules import modelloader
from modules.shared import cmd_opts, opts, device
from modules.swinir_model_arch import SwinIR as net
+from modules.swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData
precision_scope = (
@@ -57,22 +58,42 @@ class UpscalerSwinIR(Upscaler):
filename = path
if filename is None or not os.path.exists(filename):
return None
- model = net(
+ if filename.endswith(".v2.pth"):
+ model = net2(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
- depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
- embed_dim=240,
- num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
+ depths=[6, 6, 6, 6, 6, 6],
+ embed_dim=180,
+ num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2,
upsampler="nearest+conv",
- resi_connection="3conv",
- )
+ resi_connection="1conv",
+ )
+ params = None
+ else:
+ model = net(
+ upscale=scale,
+ in_chans=3,
+ img_size=64,
+ window_size=8,
+ img_range=1.0,
+ depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
+ embed_dim=240,
+ num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
+ mlp_ratio=2,
+ upsampler="nearest+conv",
+ resi_connection="3conv",
+ )
+ params = "params_ema"
pretrained_model = torch.load(filename)
- model.load_state_dict(pretrained_model["params_ema"], strict=True)
+ if params is not None:
+ model.load_state_dict(pretrained_model[params], strict=True)
+ else:
+ model.load_state_dict(pretrained_model, strict=True)
if not cmd_opts.no_half:
model = model.half()
return model
diff --git a/modules/swinir_model_arch_v2.py b/modules/swinir_model_arch_v2.py
new file mode 100644
index 00000000..0e28ae6e
--- /dev/null
+++ b/modules/swinir_model_arch_v2.py
@@ -0,0 +1,1017 @@
+# -----------------------------------------------------------------------------------
+# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/
+# Written by Conde and Choi et al.
+# -----------------------------------------------------------------------------------
+
+import math
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+class WindowAttention(nn.Module):
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
+ """
+
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
+ pretrained_window_size=[0, 0]):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.pretrained_window_size = pretrained_window_size
+ self.num_heads = num_heads
+
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
+
+ # mlp to generate continuous relative position bias
+ self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
+ nn.ReLU(inplace=True),
+ nn.Linear(512, num_heads, bias=False))
+
+ # get relative_coords_table
+ relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
+ relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
+ relative_coords_table = torch.stack(
+ torch.meshgrid([relative_coords_h,
+ relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
+ if pretrained_window_size[0] > 0:
+ relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
+ relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
+ else:
+ relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
+ relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
+ relative_coords_table *= 8 # normalize to -8, 8
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8)
+
+ self.register_buffer("relative_coords_table", relative_coords_table)
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(dim))
+ self.v_bias = nn.Parameter(torch.zeros(dim))
+ else:
+ self.q_bias = None
+ self.v_bias = None
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ B_, N, C = x.shape
+ qkv_bias = None
+ if self.q_bias is not None:
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ # cosine attention
+ attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
+ logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp()
+ attn = attn * logit_scale
+
+ relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
+ relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f'dim={self.dim}, window_size={self.window_size}, ' \
+ f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
+
+ def flops(self, N):
+ # calculate flops for 1 window with token length of N
+ flops = 0
+ # qkv = self.qkv(x)
+ flops += N * self.dim * 3 * self.dim
+ # attn = (q @ k.transpose(-2, -1))
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
+ # x = (attn @ v)
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
+ # x = self.proj(x)
+ flops += N * self.dim * self.dim
+ return flops
+
+class SwinTransformerBlock(nn.Module):
+ r""" Swin Transformer Block.
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resulotion.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ pretrained_window_size (int): Window size in pre-training.
+ """
+
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ if min(self.input_resolution) <= self.window_size:
+ # if window size is larger than input resolution, we don't partition windows
+ self.shift_size = 0
+ self.window_size = min(self.input_resolution)
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
+ qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
+ pretrained_window_size=to_2tuple(pretrained_window_size))
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ if self.shift_size > 0:
+ attn_mask = self.calculate_mask(self.input_resolution)
+ else:
+ attn_mask = None
+
+ self.register_buffer("attn_mask", attn_mask)
+
+ def calculate_mask(self, x_size):
+ # calculate attention mask for SW-MSA
+ H, W = x_size
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+ return attn_mask
+
+ def forward(self, x, x_size):
+ H, W = x_size
+ B, L, C = x.shape
+ #assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = x.view(B, H, W, C)
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ else:
+ shifted_x = x
+
+ # partition windows
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
+ if self.input_resolution == x_size:
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
+ else:
+ attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+ x = x.view(B, H * W, C)
+ x = shortcut + self.drop_path(self.norm1(x))
+
+ # FFN
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
+
+ def flops(self):
+ flops = 0
+ H, W = self.input_resolution
+ # norm1
+ flops += self.dim * H * W
+ # W-MSA/SW-MSA
+ nW = H * W / self.window_size / self.window_size
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
+ # mlp
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
+ # norm2
+ flops += self.dim * H * W
+ return flops
+
+class PatchMerging(nn.Module):
+ r""" Patch Merging Layer.
+ Args:
+ input_resolution (tuple[int]): Resolution of input feature.
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(2 * dim)
+
+ def forward(self, x):
+ """
+ x: B, H*W, C
+ """
+ H, W = self.input_resolution
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
+
+ x = x.view(B, H, W, C)
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ x = self.reduction(x)
+ x = self.norm(x)
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
+
+ def flops(self):
+ H, W = self.input_resolution
+ flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
+ flops += H * W * self.dim // 2
+ return flops
+
+class BasicLayer(nn.Module):
+ """ A basic Swin Transformer layer for one stage.
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ pretrained_window_size (int): Local window size in pre-training.
+ """
+
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
+ pretrained_window_size=0):
+
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
+ num_heads=num_heads, window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop, attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer,
+ pretrained_window_size=pretrained_window_size)
+ for i in range(depth)])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x, x_size):
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x, x_size)
+ else:
+ x = blk(x, x_size)
+ if self.downsample is not None:
+ x = self.downsample(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+ def flops(self):
+ flops = 0
+ for blk in self.blocks:
+ flops += blk.flops()
+ if self.downsample is not None:
+ flops += self.downsample.flops()
+ return flops
+
+ def _init_respostnorm(self):
+ for blk in self.blocks:
+ nn.init.constant_(blk.norm1.bias, 0)
+ nn.init.constant_(blk.norm1.weight, 0)
+ nn.init.constant_(blk.norm2.bias, 0)
+ nn.init.constant_(blk.norm2.weight, 0)
+
+class PatchEmbed(nn.Module):
+ r""" Image to Patch Embedding
+ Args:
+ img_size (int): Image size. Default: 224.
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patches_resolution = patches_resolution
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ # FIXME look at relaxing size constraints
+ # assert H == self.img_size[0] and W == self.img_size[1],
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
+ if self.norm is not None:
+ x = self.norm(x)
+ return x
+
+ def flops(self):
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
+
+class RSTB(nn.Module):
+ """Residual Swin Transformer Block (RSTB).
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ img_size: Input image size.
+ patch_size: Patch size.
+ resi_connection: The convolutional block before residual connection.
+ """
+
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
+ img_size=224, patch_size=4, resi_connection='1conv'):
+ super(RSTB, self).__init__()
+
+ self.dim = dim
+ self.input_resolution = input_resolution
+
+ self.residual_group = BasicLayer(dim=dim,
+ input_resolution=input_resolution,
+ depth=depth,
+ num_heads=num_heads,
+ window_size=window_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop, attn_drop=attn_drop,
+ drop_path=drop_path,
+ norm_layer=norm_layer,
+ downsample=downsample,
+ use_checkpoint=use_checkpoint)
+
+ if resi_connection == '1conv':
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
+ elif resi_connection == '3conv':
+ # to save parameters and memory
+ self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
+ norm_layer=None)
+
+ self.patch_unembed = PatchUnEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
+ norm_layer=None)
+
+ def forward(self, x, x_size):
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
+
+ def flops(self):
+ flops = 0
+ flops += self.residual_group.flops()
+ H, W = self.input_resolution
+ flops += H * W * self.dim * self.dim * 9
+ flops += self.patch_embed.flops()
+ flops += self.patch_unembed.flops()
+
+ return flops
+
+class PatchUnEmbed(nn.Module):
+ r""" Image to Patch Unembedding
+
+ Args:
+ img_size (int): Image size. Default: 224.
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patches_resolution = patches_resolution
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ def forward(self, x, x_size):
+ B, HW, C = x.shape
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
+ return x
+
+ def flops(self):
+ flops = 0
+ return flops
+
+
+class Upsample(nn.Sequential):
+ """Upsample module.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+ """
+
+ def __init__(self, scale, num_feat):
+ m = []
+ if (scale & (scale - 1)) == 0: # scale = 2^n
+ for _ in range(int(math.log(scale, 2))):
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(2))
+ elif scale == 3:
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(3))
+ else:
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
+ super(Upsample, self).__init__(*m)
+
+class Upsample_hf(nn.Sequential):
+ """Upsample module.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+ """
+
+ def __init__(self, scale, num_feat):
+ m = []
+ if (scale & (scale - 1)) == 0: # scale = 2^n
+ for _ in range(int(math.log(scale, 2))):
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(2))
+ elif scale == 3:
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(3))
+ else:
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
+ super(Upsample_hf, self).__init__(*m)
+
+
+class UpsampleOneStep(nn.Sequential):
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
+ Used in lightweight SR to save parameters.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+
+ """
+
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
+ self.num_feat = num_feat
+ self.input_resolution = input_resolution
+ m = []
+ m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
+ m.append(nn.PixelShuffle(scale))
+ super(UpsampleOneStep, self).__init__(*m)
+
+ def flops(self):
+ H, W = self.input_resolution
+ flops = H * W * self.num_feat * 3 * 9
+ return flops
+
+
+
+class Swin2SR(nn.Module):
+ r""" Swin2SR
+ A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
+
+ Args:
+ img_size (int | tuple(int)): Input image size. Default 64
+ patch_size (int | tuple(int)): Patch size. Default: 1
+ in_chans (int): Number of input image channels. Default: 3
+ embed_dim (int): Patch embedding dimension. Default: 96
+ depths (tuple(int)): Depth of each Swin Transformer layer.
+ num_heads (tuple(int)): Number of attention heads in different layers.
+ window_size (int): Window size. Default: 7
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ drop_rate (float): Dropout rate. Default: 0
+ attn_drop_rate (float): Attention dropout rate. Default: 0
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
+ img_range: Image range. 1. or 255.
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
+ """
+
+ def __init__(self, img_size=64, patch_size=1, in_chans=3,
+ embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
+ window_size=7, mlp_ratio=4., qkv_bias=True,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
+ use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
+ **kwargs):
+ super(Swin2SR, self).__init__()
+ num_in_ch = in_chans
+ num_out_ch = in_chans
+ num_feat = 64
+ self.img_range = img_range
+ if in_chans == 3:
+ rgb_mean = (0.4488, 0.4371, 0.4040)
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
+ else:
+ self.mean = torch.zeros(1, 1, 1, 1)
+ self.upscale = upscale
+ self.upsampler = upsampler
+ self.window_size = window_size
+
+ #####################################################################################################
+ ################################### 1, shallow feature extraction ###################################
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
+
+ #####################################################################################################
+ ################################### 2, deep feature extraction ######################################
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.ape = ape
+ self.patch_norm = patch_norm
+ self.num_features = embed_dim
+ self.mlp_ratio = mlp_ratio
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+ num_patches = self.patch_embed.num_patches
+ patches_resolution = self.patch_embed.patches_resolution
+ self.patches_resolution = patches_resolution
+
+ # merge non-overlapping patches into image
+ self.patch_unembed = PatchUnEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+
+ # absolute position embedding
+ if self.ape:
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+ trunc_normal_(self.absolute_pos_embed, std=.02)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+
+ # build Residual Swin Transformer blocks (RSTB)
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = RSTB(dim=embed_dim,
+ input_resolution=(patches_resolution[0],
+ patches_resolution[1]),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
+ norm_layer=norm_layer,
+ downsample=None,
+ use_checkpoint=use_checkpoint,
+ img_size=img_size,
+ patch_size=patch_size,
+ resi_connection=resi_connection
+
+ )
+ self.layers.append(layer)
+
+ if self.upsampler == 'pixelshuffle_hf':
+ self.layers_hf = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = RSTB(dim=embed_dim,
+ input_resolution=(patches_resolution[0],
+ patches_resolution[1]),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
+ norm_layer=norm_layer,
+ downsample=None,
+ use_checkpoint=use_checkpoint,
+ img_size=img_size,
+ patch_size=patch_size,
+ resi_connection=resi_connection
+
+ )
+ self.layers_hf.append(layer)
+
+ self.norm = norm_layer(self.num_features)
+
+ # build the last conv layer in deep feature extraction
+ if resi_connection == '1conv':
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
+ elif resi_connection == '3conv':
+ # to save parameters and memory
+ self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
+
+ #####################################################################################################
+ ################################ 3, high quality image reconstruction ################################
+ if self.upsampler == 'pixelshuffle':
+ # for classical SR
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.upsample = Upsample(upscale, num_feat)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ elif self.upsampler == 'pixelshuffle_aux':
+ self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+ self.conv_before_upsample = nn.Sequential(
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ self.conv_after_aux = nn.Sequential(
+ nn.Conv2d(3, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.upsample = Upsample(upscale, num_feat)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+ elif self.upsampler == 'pixelshuffle_hf':
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.upsample = Upsample(upscale, num_feat)
+ self.upsample_hf = Upsample_hf(upscale, num_feat)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
+ self.conv_before_upsample_hf = nn.Sequential(
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+ elif self.upsampler == 'pixelshuffledirect':
+ # for lightweight SR (to save parameters)
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
+ (patches_resolution[0], patches_resolution[1]))
+ elif self.upsampler == 'nearest+conv':
+ # for real-world SR (less artifacts)
+ assert self.upscale == 4, 'only support x4 now.'
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+ else:
+ # for image denoising and JPEG compression artifact reduction
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'absolute_pos_embed'}
+
+ @torch.jit.ignore
+ def no_weight_decay_keywords(self):
+ return {'relative_position_bias_table'}
+
+ def check_image_size(self, x):
+ _, _, h, w = x.size()
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
+ return x
+
+ def forward_features(self, x):
+ x_size = (x.shape[2], x.shape[3])
+ x = self.patch_embed(x)
+ if self.ape:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+
+ for layer in self.layers:
+ x = layer(x, x_size)
+
+ x = self.norm(x) # B L C
+ x = self.patch_unembed(x, x_size)
+
+ return x
+
+ def forward_features_hf(self, x):
+ x_size = (x.shape[2], x.shape[3])
+ x = self.patch_embed(x)
+ if self.ape:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+
+ for layer in self.layers_hf:
+ x = layer(x, x_size)
+
+ x = self.norm(x) # B L C
+ x = self.patch_unembed(x, x_size)
+
+ return x
+
+ def forward(self, x):
+ H, W = x.shape[2:]
+ x = self.check_image_size(x)
+
+ self.mean = self.mean.type_as(x)
+ x = (x - self.mean) * self.img_range
+
+ if self.upsampler == 'pixelshuffle':
+ # for classical SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.conv_before_upsample(x)
+ x = self.conv_last(self.upsample(x))
+ elif self.upsampler == 'pixelshuffle_aux':
+ bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False)
+ bicubic = self.conv_bicubic(bicubic)
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.conv_before_upsample(x)
+ aux = self.conv_aux(x) # b, 3, LR_H, LR_W
+ x = self.conv_after_aux(aux)
+ x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale]
+ x = self.conv_last(x)
+ aux = aux / self.img_range + self.mean
+ elif self.upsampler == 'pixelshuffle_hf':
+ # for classical SR with HF
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x_before = self.conv_before_upsample(x)
+ x_out = self.conv_last(self.upsample(x_before))
+
+ x_hf = self.conv_first_hf(x_before)
+ x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
+ x_hf = self.conv_before_upsample_hf(x_hf)
+ x_hf = self.conv_last_hf(self.upsample_hf(x_hf))
+ x = x_out + x_hf
+ x_hf = x_hf / self.img_range + self.mean
+
+ elif self.upsampler == 'pixelshuffledirect':
+ # for lightweight SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.upsample(x)
+ elif self.upsampler == 'nearest+conv':
+ # for real-world SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.conv_before_upsample(x)
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
+ else:
+ # for image denoising and JPEG compression artifact reduction
+ x_first = self.conv_first(x)
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
+ x = x + self.conv_last(res)
+
+ x = x / self.img_range + self.mean
+ if self.upsampler == "pixelshuffle_aux":
+ return x[:, :, :H*self.upscale, :W*self.upscale], aux
+
+ elif self.upsampler == "pixelshuffle_hf":
+ x_out = x_out / self.img_range + self.mean
+ return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
+
+ else:
+ return x[:, :, :H*self.upscale, :W*self.upscale]
+
+ def flops(self):
+ flops = 0
+ H, W = self.patches_resolution
+ flops += H * W * 3 * self.embed_dim * 9
+ flops += self.patch_embed.flops()
+ for i, layer in enumerate(self.layers):
+ flops += layer.flops()
+ flops += H * W * 3 * self.embed_dim * self.embed_dim
+ flops += self.upsample.flops()
+ return flops
+
+
+if __name__ == '__main__':
+ upscale = 4
+ window_size = 8
+ height = (1024 // upscale // window_size + 1) * window_size
+ width = (720 // upscale // window_size + 1) * window_size
+ model = Swin2SR(upscale=2, img_size=(height, width),
+ window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
+ embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
+ print(model)
+ print(height, width, model.flops() / 1e9)
+
+ x = torch.randn((1, 3, height, width))
+ x = model(x)
+ print(x.shape) \ No newline at end of file
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 7c44ea5b..bcf772d2 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -15,11 +15,10 @@ re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
class PersonalizedBase(Dataset):
- def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None):
self.placeholder_token = placeholder_token
- self.size = size
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 9f63c9a4..4a2194da 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -10,8 +10,9 @@ from modules.shared import opts, cmd_opts
if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
-def preprocess(process_src, process_dst, process_flip, process_split, process_caption, process_caption_deepbooru=False):
- size = 512
+def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
+ width = process_width
+ height = process_height
src = os.path.abspath(process_src)
dst = os.path.abspath(process_dst)
@@ -69,23 +70,23 @@ def preprocess(process_src, process_dst, process_flip, process_split, process_ca
is_wide = ratio < 1 / 1.35
if process_split and is_tall:
- img = img.resize((size, size * img.height // img.width))
+ img = img.resize((width, height * img.height // img.width))
- top = img.crop((0, 0, size, size))
+ top = img.crop((0, 0, width, height))
save_pic(top, index)
- bot = img.crop((0, img.height - size, size, img.height))
+ bot = img.crop((0, img.height - height, width, img.height))
save_pic(bot, index)
elif process_split and is_wide:
- img = img.resize((size * img.width // img.height, size))
+ img = img.resize((width * img.width // img.height, height))
- left = img.crop((0, 0, size, size))
+ left = img.crop((0, 0, width, height))
save_pic(left, index)
- right = img.crop((img.width - size, 0, img.width, size))
+ right = img.crop((img.width - width, 0, img.width, height))
save_pic(right, index)
else:
- img = images.resize_image(1, img, size, size)
+ img = images.resize_image(1, img, width, height)
save_pic(img, index)
shared.state.nextjob()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index cd9f3498..5965c5a0 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -156,7 +156,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -182,7 +182,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
hijack = sd_hijack.model_hijack
@@ -200,6 +200,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
if ititial_step > steps:
return embedding, filename
+ tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)])
+ epoch_len = (tr_img_len * num_repeats) + tr_img_len
+
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, (x, text) in pbar:
embedding.step = i + ititial_step
@@ -223,7 +226,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
loss.backward()
optimizer.step()
- pbar.set_description(f"loss: {losses.mean():.7f}")
+ epoch_num = embedding.step // epoch_len
+ epoch_step = embedding.step - (epoch_num * epoch_len) + 1
+
+ pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}")
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
@@ -236,6 +242,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
sd_model=shared.sd_model,
prompt=text,
steps=20,
+ height=training_height,
+ width=training_width,
do_not_save_grid=True,
do_not_save_samples=True,
)
diff --git a/modules/ui.py b/modules/ui.py
index f8adafb3..2ad7d864 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -524,7 +524,7 @@ def create_ui(wrap_gradio_gpu_call):
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)
with gr.Row():
- batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
@@ -710,7 +710,7 @@ def create_ui(wrap_gradio_gpu_call):
tiling = gr.Checkbox(label='Tiling', value=False)
with gr.Row():
- batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
with gr.Group():
@@ -961,7 +961,7 @@ def create_ui(wrap_gradio_gpu_call):
extras_send_to_inpaint.click(
fn=lambda x: image_from_url_text(x),
- _js="extract_image_from_gallery_img2img",
+ _js="extract_image_from_gallery_inpaint",
inputs=[result_images],
outputs=[init_img_with_mask],
)
@@ -1029,6 +1029,8 @@ def create_ui(wrap_gradio_gpu_call):
process_src = gr.Textbox(label='Source directory')
process_dst = gr.Textbox(label='Destination directory')
+ process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
+ process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies')
@@ -1047,13 +1049,16 @@ def create_ui(wrap_gradio_gpu_call):
run_preprocess = gr.Button(value="Preprocess", variant='primary')
with gr.Group():
- gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 512x512 images</p>")
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
+ training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
+ training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0)
+ num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
@@ -1096,6 +1101,8 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[
process_src,
process_dst,
+ process_width,
+ process_height,
process_flip,
process_split,
process_caption,
@@ -1115,7 +1122,10 @@ def create_ui(wrap_gradio_gpu_call):
learn_rate,
dataset_directory,
log_directory,
+ training_width,
+ training_height,
steps,
+ num_repeats,
create_image_every,
save_embedding_every,
template_file,
diff --git a/requirements.txt b/requirements.txt
index 81641d68..631fe616 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -23,4 +23,3 @@ resize-right
torchdiffeq
kornia
lark
-functorch
diff --git a/requirements_versions.txt b/requirements_versions.txt
index fec3e9d5..fdff2687 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -22,4 +22,3 @@ resize-right==0.0.2
torchdiffeq==0.2.3
kornia==0.6.7
lark==1.1.2
-functorch==0.2.1
diff --git a/script.js b/script.js
index cf989605..a92c0f77 100644
--- a/script.js
+++ b/script.js
@@ -41,6 +41,22 @@ document.addEventListener("DOMContentLoaded", function() {
});
/**
+ * Add a ctrl+enter as a shortcut to start a generation
+ */
+ document.addEventListener('keydown', function(e) {
+ var handled = false;
+ if (e.key !== undefined) {
+ if((e.key == "Enter" && (e.metaKey || e.ctrlKey))) handled = true;
+ } else if (e.keyCode !== undefined) {
+ if((e.keyCode == 13 && (e.metaKey || e.ctrlKey))) handled = true;
+ }
+ if (handled) {
+ gradioApp().querySelector("#txt2img_generate").click();
+ e.preventDefault();
+ }
+})
+
+/**
* checks that a UI element is not in another hidden element or tab content
*/
function uiElementIsVisible(el) {
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
index 771eb8e4..42e1489c 100644
--- a/scripts/xy_grid.py
+++ b/scripts/xy_grid.py
@@ -205,7 +205,10 @@ class Script(scripts.Script):
if not no_fixed_seeds:
modules.processing.fix_seed(p)
- p.batch_size = 1
+ if not opts.return_grid:
+ p.batch_size = 1
+
+
CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
def process_axis(opt, vals):
diff --git a/style.css b/style.css
index c0c3f2bb..00a3d07f 100644
--- a/style.css
+++ b/style.css
@@ -1,3 +1,7 @@
+.container {
+ max-width: 100%;
+}
+
.output-html p {margin: 0 0.5em;}
.row > *,
@@ -463,3 +467,10 @@ input[type="range"]{
max-width: 32em;
padding: 0;
}
+
+canvas[key="mask"] {
+ z-index: 12 !important;
+ filter: invert();
+ mix-blend-mode: multiply;
+ pointer-events: none;
+}