aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/deepbooru.py99
-rw-r--r--modules/devices.py6
-rw-r--r--modules/hypernetwork.py98
-rw-r--r--modules/hypernetworks/hypernetwork.py305
-rw-r--r--modules/hypernetworks/ui.py47
-rw-r--r--modules/ngrok.py15
-rw-r--r--modules/processing.py20
-rw-r--r--modules/safe.py23
-rw-r--r--modules/sd_hijack.py108
-rw-r--r--modules/sd_hijack_optimizations.py142
-rw-r--r--modules/sd_models.py7
-rw-r--r--modules/sd_samplers.py28
-rw-r--r--modules/shared.py32
-rw-r--r--modules/swinir_model.py35
-rw-r--r--modules/swinir_model_arch_v2.py1017
-rw-r--r--modules/textual_inversion/dataset.py37
-rw-r--r--modules/textual_inversion/learn_schedule.py34
-rw-r--r--modules/textual_inversion/preprocess.py44
-rw-r--r--modules/textual_inversion/textual_inversion.py39
-rw-r--r--modules/textual_inversion/ui.py2
-rw-r--r--modules/ui.py124
21 files changed, 2003 insertions, 259 deletions
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
index 7e3c0618..29529949 100644
--- a/modules/deepbooru.py
+++ b/modules/deepbooru.py
@@ -1,21 +1,75 @@
import os.path
from concurrent.futures import ProcessPoolExecutor
-from multiprocessing import get_context
+import multiprocessing
+import time
+def get_deepbooru_tags(pil_image):
+ """
+ This method is for running only one image at a time for simple use. Used to the img2img interrogate.
+ """
+ from modules import shared # prevents circular reference
+ create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, shared.opts.deepbooru_sort_alpha)
+ shared.deepbooru_process_return["value"] = -1
+ shared.deepbooru_process_queue.put(pil_image)
+ while shared.deepbooru_process_return["value"] == -1:
+ time.sleep(0.2)
+ tags = shared.deepbooru_process_return["value"]
+ release_process()
+ return tags
-def _load_tf_and_return_tags(pil_image, threshold):
+
+def deepbooru_process(queue, deepbooru_process_return, threshold, alpha_sort):
+ model, tags = get_deepbooru_tags_model()
+ while True: # while process is running, keep monitoring queue for new image
+ pil_image = queue.get()
+ if pil_image == "QUIT":
+ break
+ else:
+ deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort)
+
+
+def create_deepbooru_process(threshold, alpha_sort):
+ """
+ Creates deepbooru process. A queue is created to send images into the process. This enables multiple images
+ to be processed in a row without reloading the model or creating a new process. To return the data, a shared
+ dictionary is created to hold the tags created. To wait for tags to be returned, a value of -1 is assigned
+ to the dictionary and the method adding the image to the queue should wait for this value to be updated with
+ the tags.
+ """
+ from modules import shared # prevents circular reference
+ shared.deepbooru_process_manager = multiprocessing.Manager()
+ shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
+ shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
+ shared.deepbooru_process_return["value"] = -1
+ shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, alpha_sort))
+ shared.deepbooru_process.start()
+
+
+def release_process():
+ """
+ Stops the deepbooru process to return used memory
+ """
+ from modules import shared # prevents circular reference
+ shared.deepbooru_process_queue.put("QUIT")
+ shared.deepbooru_process.join()
+ shared.deepbooru_process_queue = None
+ shared.deepbooru_process = None
+ shared.deepbooru_process_return = None
+ shared.deepbooru_process_manager = None
+
+def get_deepbooru_tags_model():
import deepdanbooru as dd
import tensorflow as tf
import numpy as np
-
this_folder = os.path.dirname(__file__)
model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru'))
if not os.path.exists(os.path.join(model_path, 'project.json')):
# there is no point importing these every time
import zipfile
from basicsr.utils.download_util import load_file_from_url
- load_file_from_url(r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
- model_path)
+ load_file_from_url(
+ r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
+ model_path)
with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref:
zip_ref.extractall(model_path)
os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"))
@@ -24,7 +78,13 @@ def _load_tf_and_return_tags(pil_image, threshold):
model = dd.project.load_model_from_project(
model_path, compile_model=True
)
+ return model, tags
+
+def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort):
+ import deepdanbooru as dd
+ import tensorflow as tf
+ import numpy as np
width = model.input_shape[2]
height = model.input_shape[1]
image = np.array(pil_image)
@@ -46,28 +106,27 @@ def _load_tf_and_return_tags(pil_image, threshold):
for i, tag in enumerate(tags):
result_dict[tag] = y[i]
- result_tags_out = []
+
+ unsorted_tags_in_theshold = []
result_tags_print = []
for tag in tags:
if result_dict[tag] >= threshold:
if tag.startswith("rating:"):
continue
- result_tags_out.append(tag)
+ unsorted_tags_in_theshold.append((result_dict[tag], tag))
result_tags_print.append(f'{result_dict[tag]} {tag}')
- print('\n'.join(sorted(result_tags_print, reverse=True)))
-
- return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ')
-
+ # sort tags
+ result_tags_out = []
+ sort_ndx = 0
+ if alpha_sort:
+ sort_ndx = 1
-def subprocess_init_no_cuda():
- import os
- os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
+ # sort by reverse by likelihood and normal for alpha
+ unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
+ for weight, tag in unsorted_tags_in_theshold:
+ result_tags_out.append(tag)
+ print('\n'.join(sorted(result_tags_print, reverse=True)))
-def get_deepbooru_tags(pil_image, threshold=0.5):
- context = get_context('spawn')
- with ProcessPoolExecutor(initializer=subprocess_init_no_cuda, mp_context=context) as executor:
- f = executor.submit(_load_tf_and_return_tags, pil_image, threshold, )
- ret = f.result() # will rethrow any exceptions
- return ret \ No newline at end of file
+ return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ')
diff --git a/modules/devices.py b/modules/devices.py
index 0158b11f..03ef58f1 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -36,6 +36,7 @@ errors.run(enable_tf32, "Enabling TF32")
device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
dtype = torch.float16
+dtype_vae = torch.float16
def randn(seed, shape):
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
@@ -59,9 +60,12 @@ def randn_without_seed(shape):
return torch.randn(shape, device=device)
-def autocast():
+def autocast(disable=False):
from modules import shared
+ if disable:
+ return contextlib.nullcontext()
+
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
return contextlib.nullcontext()
diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py
deleted file mode 100644
index 498bc9d8..00000000
--- a/modules/hypernetwork.py
+++ /dev/null
@@ -1,98 +0,0 @@
-import glob
-import os
-import sys
-import traceback
-
-import torch
-
-from ldm.util import default
-from modules import devices, shared
-import torch
-from torch import einsum
-from einops import rearrange, repeat
-
-
-class HypernetworkModule(torch.nn.Module):
- def __init__(self, dim, state_dict):
- super().__init__()
-
- self.linear1 = torch.nn.Linear(dim, dim * 2)
- self.linear2 = torch.nn.Linear(dim * 2, dim)
-
- self.load_state_dict(state_dict, strict=True)
- self.to(devices.device)
-
- def forward(self, x):
- return x + (self.linear2(self.linear1(x)))
-
-
-class Hypernetwork:
- filename = None
- name = None
-
- def __init__(self, filename):
- self.filename = filename
- self.name = os.path.splitext(os.path.basename(filename))[0]
- self.layers = {}
-
- state_dict = torch.load(filename, map_location='cpu')
- for size, sd in state_dict.items():
- self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
-
-
-def list_hypernetworks(path):
- res = {}
- for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
- name = os.path.splitext(os.path.basename(filename))[0]
- res[name] = filename
- return res
-
-
-def load_hypernetwork(filename):
- path = shared.hypernetworks.get(filename, None)
- if path is not None:
- print(f"Loading hypernetwork {filename}")
- try:
- shared.loaded_hypernetwork = Hypernetwork(path)
- except Exception:
- print(f"Error loading hypernetwork {path}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- else:
- if shared.loaded_hypernetwork is not None:
- print(f"Unloading hypernetwork")
-
- shared.loaded_hypernetwork = None
-
-
-def attention_CrossAttention_forward(self, x, context=None, mask=None):
- h = self.heads
-
- q = self.to_q(x)
- context = default(context, x)
-
- hypernetwork = shared.loaded_hypernetwork
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
-
- if hypernetwork_layers is not None:
- k = self.to_k(hypernetwork_layers[0](context))
- v = self.to_v(hypernetwork_layers[1](context))
- else:
- k = self.to_k(context)
- v = self.to_v(context)
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
-
- sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
-
- if mask is not None:
- mask = rearrange(mask, 'b ... -> b (...)')
- max_neg_value = -torch.finfo(sim.dtype).max
- mask = repeat(mask, 'b j -> (b h) () j', h=h)
- sim.masked_fill_(~mask, max_neg_value)
-
- # attention, what we cannot get enough of
- attn = sim.softmax(dim=-1)
-
- out = einsum('b i j, b j d -> b i d', attn, v)
- out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
- return self.to_out(out)
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
new file mode 100644
index 00000000..470659df
--- /dev/null
+++ b/modules/hypernetworks/hypernetwork.py
@@ -0,0 +1,305 @@
+import datetime
+import glob
+import html
+import os
+import sys
+import traceback
+import tqdm
+
+import torch
+
+from ldm.util import default
+from modules import devices, shared, processing, sd_models
+import torch
+from torch import einsum
+from einops import rearrange, repeat
+import modules.textual_inversion.dataset
+from modules.textual_inversion.learn_schedule import LearnSchedule
+
+
+class HypernetworkModule(torch.nn.Module):
+ def __init__(self, dim, state_dict=None):
+ super().__init__()
+
+ self.linear1 = torch.nn.Linear(dim, dim * 2)
+ self.linear2 = torch.nn.Linear(dim * 2, dim)
+
+ if state_dict is not None:
+ self.load_state_dict(state_dict, strict=True)
+ else:
+
+ self.linear1.weight.data.normal_(mean=0.0, std=0.01)
+ self.linear1.bias.data.zero_()
+ self.linear2.weight.data.normal_(mean=0.0, std=0.01)
+ self.linear2.bias.data.zero_()
+
+ self.to(devices.device)
+
+ def forward(self, x):
+ return x + (self.linear2(self.linear1(x)))
+
+
+class Hypernetwork:
+ filename = None
+ name = None
+
+ def __init__(self, name=None, enable_sizes=None):
+ self.filename = None
+ self.name = name
+ self.layers = {}
+ self.step = 0
+ self.sd_checkpoint = None
+ self.sd_checkpoint_name = None
+
+ for size in enable_sizes or []:
+ self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))
+
+ def weights(self):
+ res = []
+
+ for k, layers in self.layers.items():
+ for layer in layers:
+ layer.train()
+ res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias]
+
+ return res
+
+ def save(self, filename):
+ state_dict = {}
+
+ for k, v in self.layers.items():
+ state_dict[k] = (v[0].state_dict(), v[1].state_dict())
+
+ state_dict['step'] = self.step
+ state_dict['name'] = self.name
+ state_dict['sd_checkpoint'] = self.sd_checkpoint
+ state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
+
+ torch.save(state_dict, filename)
+
+ def load(self, filename):
+ self.filename = filename
+ if self.name is None:
+ self.name = os.path.splitext(os.path.basename(filename))[0]
+
+ state_dict = torch.load(filename, map_location='cpu')
+
+ for size, sd in state_dict.items():
+ if type(size) == int:
+ self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
+
+ self.name = state_dict.get('name', self.name)
+ self.step = state_dict.get('step', 0)
+ self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
+ self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
+
+
+def list_hypernetworks(path):
+ res = {}
+ for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
+ name = os.path.splitext(os.path.basename(filename))[0]
+ res[name] = filename
+ return res
+
+
+def load_hypernetwork(filename):
+ path = shared.hypernetworks.get(filename, None)
+ if path is not None:
+ print(f"Loading hypernetwork {filename}")
+ try:
+ shared.loaded_hypernetwork = Hypernetwork()
+ shared.loaded_hypernetwork.load(path)
+
+ except Exception:
+ print(f"Error loading hypernetwork {path}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ else:
+ if shared.loaded_hypernetwork is not None:
+ print(f"Unloading hypernetwork")
+
+ shared.loaded_hypernetwork = None
+
+
+def apply_hypernetwork(hypernetwork, context, layer=None):
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
+
+ if hypernetwork_layers is None:
+ return context, context
+
+ if layer is not None:
+ layer.hyper_k = hypernetwork_layers[0]
+ layer.hyper_v = hypernetwork_layers[1]
+
+ context_k = hypernetwork_layers[0](context)
+ context_v = hypernetwork_layers[1](context)
+ return context_k, context_v
+
+
+def attention_CrossAttention_forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+
+ context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
+ k = self.to_k(context_k)
+ v = self.to_v(context_v)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+ if mask is not None:
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ attn = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', attn, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out)
+
+
+def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt):
+ assert hypernetwork_name, 'embedding not selected'
+
+ path = shared.hypernetworks.get(hypernetwork_name, None)
+ shared.loaded_hypernetwork = Hypernetwork()
+ shared.loaded_hypernetwork.load(path)
+
+ shared.state.textinfo = "Initializing hypernetwork training..."
+ shared.state.job_count = steps
+
+ filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
+
+ log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
+ unload = shared.opts.unload_models_when_training
+
+ if save_hypernetwork_every > 0:
+ hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
+ os.makedirs(hypernetwork_dir, exist_ok=True)
+ else:
+ hypernetwork_dir = None
+
+ if create_image_every > 0:
+ images_dir = os.path.join(log_directory, "images")
+ os.makedirs(images_dir, exist_ok=True)
+ else:
+ images_dir = None
+
+ shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
+ with torch.autocast("cuda"):
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True)
+
+ if unload:
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ shared.sd_model.first_stage_model.to(devices.cpu)
+
+ hypernetwork = shared.loaded_hypernetwork
+ weights = hypernetwork.weights()
+ for weight in weights:
+ weight.requires_grad = True
+
+ losses = torch.zeros((32,))
+
+ last_saved_file = "<none>"
+ last_saved_image = "<none>"
+
+ ititial_step = hypernetwork.step or 0
+ if ititial_step > steps:
+ return hypernetwork, filename
+
+ schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
+ (learn_rate, end_step) = next(schedules)
+ print(f'Training at rate of {learn_rate} until step {end_step}')
+
+ optimizer = torch.optim.AdamW(weights, lr=learn_rate)
+
+ pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
+ for i, (x, text, cond) in pbar:
+ hypernetwork.step = i + ititial_step
+
+ if hypernetwork.step > end_step:
+ try:
+ (learn_rate, end_step) = next(schedules)
+ except Exception:
+ break
+ tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
+ for pg in optimizer.param_groups:
+ pg['lr'] = learn_rate
+
+ if shared.state.interrupted:
+ break
+
+ with torch.autocast("cuda"):
+ cond = cond.to(devices.device)
+ x = x.to(devices.device)
+ loss = shared.sd_model(x.unsqueeze(0), cond)[0]
+ del x
+ del cond
+
+ losses[hypernetwork.step % losses.shape[0]] = loss.item()
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ pbar.set_description(f"loss: {losses.mean():.7f}")
+
+ if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
+ last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
+ hypernetwork.save(last_saved_file)
+
+ if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
+ last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
+
+ preview_text = text if preview_image_prompt == "" else preview_image_prompt
+
+ optimizer.zero_grad()
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
+
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ prompt=preview_text,
+ steps=20,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ )
+
+ processed = processing.process_images(p)
+ image = processed.images[0]
+
+ if unload:
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ shared.sd_model.first_stage_model.to(devices.cpu)
+
+ shared.state.current_image = image
+ image.save(last_saved_image)
+
+ last_saved_image += f", prompt: {preview_text}"
+
+ shared.state.job_no = hypernetwork.step
+
+ shared.state.textinfo = f"""
+<p>
+Loss: {losses.mean():.7f}<br/>
+Step: {hypernetwork.step}<br/>
+Last prompt: {html.escape(text)}<br/>
+Last saved embedding: {html.escape(last_saved_file)}<br/>
+Last saved image: {html.escape(last_saved_image)}<br/>
+</p>
+"""
+
+ checkpoint = sd_models.select_checkpoint()
+
+ hypernetwork.sd_checkpoint = checkpoint.hash
+ hypernetwork.sd_checkpoint_name = checkpoint.model_name
+ hypernetwork.save(filename)
+
+ return hypernetwork, filename
+
+
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
new file mode 100644
index 00000000..dfa599af
--- /dev/null
+++ b/modules/hypernetworks/ui.py
@@ -0,0 +1,47 @@
+import html
+import os
+
+import gradio as gr
+
+import modules.textual_inversion.textual_inversion
+import modules.textual_inversion.preprocess
+from modules import sd_hijack, shared, devices
+from modules.hypernetworks import hypernetwork
+
+
+def create_hypernetwork(name, enable_sizes):
+ fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
+ assert not os.path.exists(fn), f"file {fn} already exists"
+
+ hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name, enable_sizes=[int(x) for x in enable_sizes])
+ hypernet.save(fn)
+
+ shared.reload_hypernetworks()
+
+ return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", ""
+
+
+def train_hypernetwork(*args):
+
+ initial_hypernetwork = shared.loaded_hypernetwork
+
+ assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
+
+ try:
+ sd_hijack.undo_optimizations()
+
+ hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args)
+
+ res = f"""
+Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
+Hypernetwork saved to {html.escape(filename)}
+"""
+ return res, ""
+ except Exception:
+ raise
+ finally:
+ shared.loaded_hypernetwork = initial_hypernetwork
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
+ sd_hijack.apply_optimizations()
+
diff --git a/modules/ngrok.py b/modules/ngrok.py
new file mode 100644
index 00000000..7d03a6df
--- /dev/null
+++ b/modules/ngrok.py
@@ -0,0 +1,15 @@
+from pyngrok import ngrok, conf, exception
+
+
+def connect(token, port):
+ if token == None:
+ token = 'None'
+ conf.get_default().auth_token = token
+ try:
+ public_url = ngrok.connect(port).public_url
+ except exception.PyngrokNgrokError:
+ print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
+ f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
+ else:
+ print(f'ngrok connected to localhost:{port}! URL: {public_url}\n'
+ 'You can use this link after the launch is complete.')
diff --git a/modules/processing.py b/modules/processing.py
index 94d2dd62..698b3069 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -207,7 +207,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
# enables the generation of additional tensors with noise that the sampler will use during its processing.
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
# produce the same images as with two batches [100], [101].
- if p is not None and p.sampler is not None and len(seeds) > 1 and opts.enable_batch_seeds:
+ if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0):
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
else:
sampler_noises = None
@@ -247,6 +247,9 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
if sampler_noises is not None:
cnt = p.sampler.number_of_needed_noises(p)
+ if opts.eta_noise_seed_delta > 0:
+ torch.manual_seed(seed + opts.eta_noise_seed_delta)
+
for j in range(cnt):
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
@@ -259,6 +262,13 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
return x
+def decode_first_stage(model, x):
+ with devices.autocast(disable=x.dtype == devices.dtype_vae):
+ x = model.decode_first_stage(x)
+
+ return x
+
+
def get_fixed_seed(seed):
if seed is None or seed == '' or seed == -1:
return int(random.randrange(4294967294))
@@ -294,6 +304,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Denoising strength": getattr(p, 'denoising_strength', None),
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
"Clip skip": None if clip_skip <= 1 else clip_skip,
+ "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
}
generation_params.update(p.extra_generation_params)
@@ -398,9 +409,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
# use the image collected previously in sampler loop
samples_ddim = shared.state.current_latent
- samples_ddim = samples_ddim.to(devices.dtype)
-
- x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
+ samples_ddim = samples_ddim.to(devices.dtype_vae)
+ x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
del samples_ddim
@@ -533,7 +543,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if self.scale_latent:
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
else:
- decoded_samples = self.sd_model.decode_first_stage(samples)
+ decoded_samples = decode_first_stage(self.sd_model, samples)
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
diff --git a/modules/safe.py b/modules/safe.py
index 4d06f2a5..20be16a5 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -10,6 +10,11 @@ import torch
import numpy
import _codecs
import zipfile
+import re
+
+
+# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
+TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
def encode(*args):
@@ -20,7 +25,7 @@ def encode(*args):
class RestrictedUnpickler(pickle.Unpickler):
def persistent_load(self, saved_id):
assert saved_id[0] == 'storage'
- return torch.storage._TypedStorage()
+ return TypedStorage()
def find_class(self, module, name):
if module == 'collections' and name == 'OrderedDict':
@@ -50,11 +55,27 @@ class RestrictedUnpickler(pickle.Unpickler):
raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
+allowed_zip_names = ["archive/data.pkl", "archive/version"]
+allowed_zip_names_re = re.compile(r"^archive/data/\d+$")
+
+
+def check_zip_filenames(filename, names):
+ for name in names:
+ if name in allowed_zip_names:
+ continue
+ if allowed_zip_names_re.match(name):
+ continue
+
+ raise Exception(f"bad file inside {filename}: {name}")
+
+
def check_pt(filename):
try:
# new pytorch format is a zip file
with zipfile.ZipFile(filename) as z:
+ check_zip_filenames(filename, z.namelist())
+
with z.open('archive/data.pkl') as file:
unpickler = RestrictedUnpickler(file)
unpickler.load()
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 437acce4..ac70f876 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -8,8 +8,9 @@ from torch import einsum
from torch.nn.functional import silu
import modules.textual_inversion.textual_inversion
-from modules import prompt_parser, devices, sd_hijack_optimizations, shared, hypernetwork
+from modules import prompt_parser, devices, sd_hijack_optimizations, shared
from modules.shared import opts, device, cmd_opts
+from modules.sd_hijack_optimizations import invokeAI_mps_available
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
@@ -23,30 +24,37 @@ def apply_optimizations():
ldm.modules.diffusionmodules.model.nonlinearity = silu
- if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and torch.cuda.get_device_capability(shared.device) == (8, 6)):
+ if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)):
print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
elif cmd_opts.opt_split_attention_v1:
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
+ elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
+ if not invokeAI_mps_available and shared.device.type == 'mps':
+ print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
+ print("Applying v1 cross attention optimization.")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
+ else:
+ print("Applying cross attention optimization (InvokeAI).")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
- print("Applying cross attention optimization.")
+ print("Applying cross attention optimization (Doggettx).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
def undo_optimizations():
+ from modules.hypernetworks import hypernetwork
+
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
def get_target_prompt_token_count(token_count):
- if token_count < 75:
- return 75
-
- return math.ceil(token_count / 10) * 10
+ return math.ceil(max(token_count, 1) / 75) * 75
class StableDiffusionModelHijack:
@@ -110,6 +118,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.tokenizer = wrapped.tokenizer
self.token_mults = {}
+ self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
+
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
@@ -127,7 +137,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.token_mults[ident] = mult
def tokenize_line(self, line, used_custom_terms, hijack_comments):
- id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
if opts.enable_emphasis:
@@ -140,6 +149,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
fixes = []
remade_tokens = []
multipliers = []
+ last_comma = -1
for tokens, (text, weight) in zip(tokenized, parsed):
i = 0
@@ -148,13 +158,33 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+ if token == self.comma_token:
+ last_comma = len(remade_tokens)
+ elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
+ last_comma += 1
+ reloc_tokens = remade_tokens[last_comma:]
+ reloc_mults = multipliers[last_comma:]
+
+ remade_tokens = remade_tokens[:last_comma]
+ length = len(remade_tokens)
+
+ rem = int(math.ceil(length / 75)) * 75 - length
+ remade_tokens += [id_end] * rem + reloc_tokens
+ multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
+
if embedding is None:
remade_tokens.append(token)
multipliers.append(weight)
i += 1
else:
emb_len = int(embedding.vec.shape[0])
- fixes.append((len(remade_tokens), embedding))
+ iteration = len(remade_tokens) // 75
+ if (len(remade_tokens) + emb_len) // 75 != iteration:
+ rem = (75 * (iteration + 1) - len(remade_tokens))
+ remade_tokens += [id_end] * rem
+ multipliers += [1.0] * rem
+ iteration += 1
+ fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
remade_tokens += [0] * emb_len
multipliers += [weight] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum()))
@@ -162,10 +192,10 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
token_count = len(remade_tokens)
prompt_target_length = get_target_prompt_token_count(token_count)
- tokens_to_add = prompt_target_length - len(remade_tokens) + 1
+ tokens_to_add = prompt_target_length - len(remade_tokens)
- remade_tokens = [id_start] + remade_tokens + [id_end] * tokens_to_add
- multipliers = [1.0] + multipliers + [1.0] * tokens_to_add
+ remade_tokens = remade_tokens + [id_end] * tokens_to_add
+ multipliers = multipliers + [1.0] * tokens_to_add
return remade_tokens, fixes, multipliers, token_count
@@ -260,29 +290,55 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
-
+
def forward(self, text):
-
- if opts.use_old_emphasis_implementation:
+ use_old = opts.use_old_emphasis_implementation
+ if use_old:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
else:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
- self.hijack.fixes = hijack_fixes
self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0:
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
+
+ if use_old:
+ self.hijack.fixes = hijack_fixes
+ return self.process_tokens(remade_batch_tokens, batch_multipliers)
+
+ z = None
+ i = 0
+ while max(map(len, remade_batch_tokens)) != 0:
+ rem_tokens = [x[75:] for x in remade_batch_tokens]
+ rem_multipliers = [x[75:] for x in batch_multipliers]
+
+ self.hijack.fixes = []
+ for unfiltered in hijack_fixes:
+ fixes = []
+ for fix in unfiltered:
+ if fix[0] == i:
+ fixes.append(fix[1])
+ self.hijack.fixes.append(fixes)
+
+ z1 = self.process_tokens([x[:75] for x in remade_batch_tokens], [x[:75] for x in batch_multipliers])
+ z = z1 if z is None else torch.cat((z, z1), axis=-2)
+
+ remade_batch_tokens = rem_tokens
+ batch_multipliers = rem_multipliers
+ i += 1
+
+ return z
+
+
+ def process_tokens(self, remade_batch_tokens, batch_multipliers):
+ if not opts.use_old_emphasis_implementation:
+ remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
+ batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
+
+ tokens = torch.asarray(remade_batch_tokens).to(device)
+ outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
- target_token_count = get_target_prompt_token_count(token_count) + 2
-
- position_ids_array = [min(x, 75) for x in range(target_token_count-1)] + [76]
- position_ids = torch.asarray(position_ids_array, device=devices.device).expand((1, -1))
-
- remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens]
- tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device)
-
- outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=-opts.CLIP_stop_at_last_layers)
if opts.CLIP_stop_at_last_layers > 1:
z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
z = self.wrapped.transformer.text_model.final_layer_norm(z)
@@ -290,7 +346,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
z = outputs.last_hidden_state
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
- batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers]
+ batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device)
original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 634fb4b2..79405525 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -1,6 +1,7 @@
import math
import sys
import traceback
+import importlib
import torch
from torch import einsum
@@ -9,12 +10,12 @@ from ldm.util import default
from einops import rearrange
from modules import shared
+from modules.hypernetworks import hypernetwork
+
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
try:
import xformers.ops
- import functorch
- xformers._is_functorch_available = True
shared.xformers_available = True
except Exception:
print("Cannot import xformers", file=sys.stderr)
@@ -28,16 +29,10 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
q_in = self.to_q(x)
context = default(context, x)
- hypernetwork = shared.loaded_hypernetwork
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
-
- if hypernetwork_layers is not None:
- k_in = self.to_k(hypernetwork_layers[0](context))
- v_in = self.to_v(hypernetwork_layers[1](context))
- else:
- k_in = self.to_k(context)
- v_in = self.to_v(context)
- del context, x
+ context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ k_in = self.to_k(context_k)
+ v_in = self.to_v(context_v)
+ del context, context_k, context_v, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
@@ -61,22 +56,16 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
return self.to_out(r2)
-# taken from https://github.com/Doggettx/stable-diffusion
+# taken from https://github.com/Doggettx/stable-diffusion and modified
def split_cross_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
- hypernetwork = shared.loaded_hypernetwork
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
-
- if hypernetwork_layers is not None:
- k_in = self.to_k(hypernetwork_layers[0](context))
- v_in = self.to_v(hypernetwork_layers[1](context))
- else:
- k_in = self.to_k(context)
- v_in = self.to_v(context)
+ context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ k_in = self.to_k(context_k)
+ v_in = self.to_v(context_v)
k_in *= self.scale
@@ -128,18 +117,111 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
return self.to_out(r2)
+
+def check_for_psutil():
+ try:
+ spec = importlib.util.find_spec('psutil')
+ return spec is not None
+ except ModuleNotFoundError:
+ return False
+
+invokeAI_mps_available = check_for_psutil()
+
+# -- Taken from https://github.com/invoke-ai/InvokeAI --
+if invokeAI_mps_available:
+ import psutil
+ mem_total_gb = psutil.virtual_memory().total // (1 << 30)
+
+def einsum_op_compvis(q, k, v):
+ s = einsum('b i d, b j d -> b i j', q, k)
+ s = s.softmax(dim=-1, dtype=s.dtype)
+ return einsum('b i j, b j d -> b i d', s, v)
+
+def einsum_op_slice_0(q, k, v, slice_size):
+ r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
+ for i in range(0, q.shape[0], slice_size):
+ end = i + slice_size
+ r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
+ return r
+
+def einsum_op_slice_1(q, k, v, slice_size):
+ r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
+ for i in range(0, q.shape[1], slice_size):
+ end = i + slice_size
+ r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
+ return r
+
+def einsum_op_mps_v1(q, k, v):
+ if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
+ return einsum_op_compvis(q, k, v)
+ else:
+ slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
+ return einsum_op_slice_1(q, k, v, slice_size)
+
+def einsum_op_mps_v2(q, k, v):
+ if mem_total_gb > 8 and q.shape[1] <= 4096:
+ return einsum_op_compvis(q, k, v)
+ else:
+ return einsum_op_slice_0(q, k, v, 1)
+
+def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
+ size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
+ if size_mb <= max_tensor_mb:
+ return einsum_op_compvis(q, k, v)
+ div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
+ if div <= q.shape[0]:
+ return einsum_op_slice_0(q, k, v, q.shape[0] // div)
+ return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
+
+def einsum_op_cuda(q, k, v):
+ stats = torch.cuda.memory_stats(q.device)
+ mem_active = stats['active_bytes.all.current']
+ mem_reserved = stats['reserved_bytes.all.current']
+ mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
+ mem_free_torch = mem_reserved - mem_active
+ mem_free_total = mem_free_cuda + mem_free_torch
+ # Divide factor of safety as there's copying and fragmentation
+ return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
+
+def einsum_op(q, k, v):
+ if q.device.type == 'cuda':
+ return einsum_op_cuda(q, k, v)
+
+ if q.device.type == 'mps':
+ if mem_total_gb >= 32:
+ return einsum_op_mps_v1(q, k, v)
+ return einsum_op_mps_v2(q, k, v)
+
+ # Smaller slices are faster due to L2/L3/SLC caches.
+ # Tested on i7 with 8MB L3 cache.
+ return einsum_op_tensor_mem(q, k, v, 32)
+
+def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+
+ context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ k = self.to_k(context_k) * self.scale
+ v = self.to_v(context_v)
+ del context, context_k, context_v, x
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+ r = einsum_op(q, k, v)
+ return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
+
+# -- End of code from https://github.com/invoke-ai/InvokeAI --
+
def xformers_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
- hypernetwork = shared.loaded_hypernetwork
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
- if hypernetwork_layers is not None:
- k_in = self.to_k(hypernetwork_layers[0](context))
- v_in = self.to_v(hypernetwork_layers[1](context))
- else:
- k_in = self.to_k(context)
- v_in = self.to_v(context)
+
+ context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ k_in = self.to_k(context_k)
+ v_in = self.to_v(context_v)
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index e63d3c29..0a55b4c3 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -149,8 +149,13 @@ def load_model_weights(model, checkpoint_info):
model.half()
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
+ devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
+
+ if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None:
+ vae_file = shared.cmd_opts.vae_path
+
if os.path.exists(vae_file):
print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location="cpu")
@@ -158,6 +163,8 @@ def load_model_weights(model, checkpoint_info):
model.first_stage_model.load_state_dict(vae_dict)
+ model.first_stage_model.to(devices.dtype_vae)
+
model.sd_model_hash = sd_model_hash
model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 6e743f7e..20309e06 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -7,7 +7,7 @@ import inspect
import k_diffusion.sampling
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
-from modules import prompt_parser
+from modules import prompt_parser, devices, processing
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -57,7 +57,7 @@ def set_samplers():
global samplers, samplers_for_img2img
hidden = set(opts.hide_samplers)
- hidden_img2img = set(opts.hide_samplers + ['PLMS', 'DPM fast', 'DPM adaptive'])
+ hidden_img2img = set(opts.hide_samplers + ['PLMS'])
samplers = [x for x in all_samplers if x.name not in hidden]
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
@@ -83,7 +83,7 @@ def setup_img2img_steps(p, steps=None):
def sample_to_image(samples):
- x_sample = shared.sd_model.decode_first_stage(samples[0:1].type(shared.sd_model.dtype))[0]
+ x_sample = processing.decode_first_stage(shared.sd_model, samples[0:1])[0]
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
@@ -365,16 +365,26 @@ class KDiffusionSampler:
else:
sigmas = self.model_wrap.get_sigmas(steps)
- noise = noise * sigmas[steps - t_enc - 1]
- xi = x + noise
-
- extra_params_kwargs = self.initialize(p)
-
sigma_sched = sigmas[steps - t_enc - 1:]
+ xi = x + noise * sigma_sched[0]
+
+ extra_params_kwargs = self.initialize(p)
+ if 'sigma_min' in inspect.signature(self.func).parameters:
+ ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
+ extra_params_kwargs['sigma_min'] = sigma_sched[-2]
+ if 'sigma_max' in inspect.signature(self.func).parameters:
+ extra_params_kwargs['sigma_max'] = sigma_sched[0]
+ if 'n' in inspect.signature(self.func).parameters:
+ extra_params_kwargs['n'] = len(sigma_sched) - 1
+ if 'sigma_sched' in inspect.signature(self.func).parameters:
+ extra_params_kwargs['sigma_sched'] = sigma_sched
+ if 'sigmas' in inspect.signature(self.func).parameters:
+ extra_params_kwargs['sigmas'] = sigma_sched
self.model_wrap_cfg.init_latent = x
- return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
+ return self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
+
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
steps = steps or p.steps
diff --git a/modules/shared.py b/modules/shared.py
index 1995a99a..42e99741 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -13,7 +13,8 @@ import modules.memmon
import modules.sd_models
import modules.styles
import modules.devices as devices
-from modules import sd_samplers, hypernetwork
+from modules import sd_samplers
+from modules.hypernetworks import hypernetwork
from modules.paths import models_path, script_path, sd_path
sd_model_file = os.path.join(script_path, 'model.ckpt')
@@ -25,9 +26,11 @@ parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to director
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
+parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
+parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
@@ -35,6 +38,7 @@ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="dis
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
+parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
@@ -46,9 +50,10 @@ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator")
-parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
-parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
+parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
+parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
+parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[])
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
@@ -65,6 +70,7 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
+parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
@@ -80,10 +86,18 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
xformers_available = False
config_filename = cmd_opts.ui_settings_file
-hypernetworks = hypernetwork.list_hypernetworks(os.path.join(models_path, 'hypernetworks'))
+os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
+hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
loaded_hypernetwork = None
+def reload_hypernetworks():
+ global hypernetworks
+
+ hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
+ hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
+
+
class State:
skipped = False
interrupted = False
@@ -171,6 +185,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
+ "do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
}))
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
@@ -214,6 +229,10 @@ options_templates.update(options_section(('system', "System"), {
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
}))
+options_templates.update(options_section(('training', "Training"), {
+ "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"),
+}))
+
options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, show_on_main_page=True),
"sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}),
@@ -224,6 +243,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
+ "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
@@ -235,7 +255,8 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
- "interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"),
+ "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
+ "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
}))
options_templates.update(options_section(('ui', "User interface"), {
@@ -259,6 +280,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
}))
diff --git a/modules/swinir_model.py b/modules/swinir_model.py
index fbd11f84..baa02e3d 100644
--- a/modules/swinir_model.py
+++ b/modules/swinir_model.py
@@ -10,6 +10,7 @@ from tqdm import tqdm
from modules import modelloader
from modules.shared import cmd_opts, opts, device
from modules.swinir_model_arch import SwinIR as net
+from modules.swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData
precision_scope = (
@@ -57,22 +58,42 @@ class UpscalerSwinIR(Upscaler):
filename = path
if filename is None or not os.path.exists(filename):
return None
- model = net(
+ if filename.endswith(".v2.pth"):
+ model = net2(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
- depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
- embed_dim=240,
- num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
+ depths=[6, 6, 6, 6, 6, 6],
+ embed_dim=180,
+ num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2,
upsampler="nearest+conv",
- resi_connection="3conv",
- )
+ resi_connection="1conv",
+ )
+ params = None
+ else:
+ model = net(
+ upscale=scale,
+ in_chans=3,
+ img_size=64,
+ window_size=8,
+ img_range=1.0,
+ depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
+ embed_dim=240,
+ num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
+ mlp_ratio=2,
+ upsampler="nearest+conv",
+ resi_connection="3conv",
+ )
+ params = "params_ema"
pretrained_model = torch.load(filename)
- model.load_state_dict(pretrained_model["params_ema"], strict=True)
+ if params is not None:
+ model.load_state_dict(pretrained_model[params], strict=True)
+ else:
+ model.load_state_dict(pretrained_model, strict=True)
if not cmd_opts.no_half:
model = model.half()
return model
diff --git a/modules/swinir_model_arch_v2.py b/modules/swinir_model_arch_v2.py
new file mode 100644
index 00000000..0e28ae6e
--- /dev/null
+++ b/modules/swinir_model_arch_v2.py
@@ -0,0 +1,1017 @@
+# -----------------------------------------------------------------------------------
+# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/
+# Written by Conde and Choi et al.
+# -----------------------------------------------------------------------------------
+
+import math
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+class WindowAttention(nn.Module):
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
+ """
+
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
+ pretrained_window_size=[0, 0]):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.pretrained_window_size = pretrained_window_size
+ self.num_heads = num_heads
+
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
+
+ # mlp to generate continuous relative position bias
+ self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
+ nn.ReLU(inplace=True),
+ nn.Linear(512, num_heads, bias=False))
+
+ # get relative_coords_table
+ relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
+ relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
+ relative_coords_table = torch.stack(
+ torch.meshgrid([relative_coords_h,
+ relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
+ if pretrained_window_size[0] > 0:
+ relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
+ relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
+ else:
+ relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
+ relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
+ relative_coords_table *= 8 # normalize to -8, 8
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8)
+
+ self.register_buffer("relative_coords_table", relative_coords_table)
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(dim))
+ self.v_bias = nn.Parameter(torch.zeros(dim))
+ else:
+ self.q_bias = None
+ self.v_bias = None
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ B_, N, C = x.shape
+ qkv_bias = None
+ if self.q_bias is not None:
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ # cosine attention
+ attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
+ logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp()
+ attn = attn * logit_scale
+
+ relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
+ relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f'dim={self.dim}, window_size={self.window_size}, ' \
+ f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
+
+ def flops(self, N):
+ # calculate flops for 1 window with token length of N
+ flops = 0
+ # qkv = self.qkv(x)
+ flops += N * self.dim * 3 * self.dim
+ # attn = (q @ k.transpose(-2, -1))
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
+ # x = (attn @ v)
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
+ # x = self.proj(x)
+ flops += N * self.dim * self.dim
+ return flops
+
+class SwinTransformerBlock(nn.Module):
+ r""" Swin Transformer Block.
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resulotion.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ pretrained_window_size (int): Window size in pre-training.
+ """
+
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ if min(self.input_resolution) <= self.window_size:
+ # if window size is larger than input resolution, we don't partition windows
+ self.shift_size = 0
+ self.window_size = min(self.input_resolution)
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
+ qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
+ pretrained_window_size=to_2tuple(pretrained_window_size))
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ if self.shift_size > 0:
+ attn_mask = self.calculate_mask(self.input_resolution)
+ else:
+ attn_mask = None
+
+ self.register_buffer("attn_mask", attn_mask)
+
+ def calculate_mask(self, x_size):
+ # calculate attention mask for SW-MSA
+ H, W = x_size
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+ return attn_mask
+
+ def forward(self, x, x_size):
+ H, W = x_size
+ B, L, C = x.shape
+ #assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = x.view(B, H, W, C)
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ else:
+ shifted_x = x
+
+ # partition windows
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
+ if self.input_resolution == x_size:
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
+ else:
+ attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+ x = x.view(B, H * W, C)
+ x = shortcut + self.drop_path(self.norm1(x))
+
+ # FFN
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
+
+ def flops(self):
+ flops = 0
+ H, W = self.input_resolution
+ # norm1
+ flops += self.dim * H * W
+ # W-MSA/SW-MSA
+ nW = H * W / self.window_size / self.window_size
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
+ # mlp
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
+ # norm2
+ flops += self.dim * H * W
+ return flops
+
+class PatchMerging(nn.Module):
+ r""" Patch Merging Layer.
+ Args:
+ input_resolution (tuple[int]): Resolution of input feature.
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(2 * dim)
+
+ def forward(self, x):
+ """
+ x: B, H*W, C
+ """
+ H, W = self.input_resolution
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
+
+ x = x.view(B, H, W, C)
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ x = self.reduction(x)
+ x = self.norm(x)
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
+
+ def flops(self):
+ H, W = self.input_resolution
+ flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
+ flops += H * W * self.dim // 2
+ return flops
+
+class BasicLayer(nn.Module):
+ """ A basic Swin Transformer layer for one stage.
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ pretrained_window_size (int): Local window size in pre-training.
+ """
+
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
+ pretrained_window_size=0):
+
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
+ num_heads=num_heads, window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop, attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer,
+ pretrained_window_size=pretrained_window_size)
+ for i in range(depth)])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x, x_size):
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x, x_size)
+ else:
+ x = blk(x, x_size)
+ if self.downsample is not None:
+ x = self.downsample(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+ def flops(self):
+ flops = 0
+ for blk in self.blocks:
+ flops += blk.flops()
+ if self.downsample is not None:
+ flops += self.downsample.flops()
+ return flops
+
+ def _init_respostnorm(self):
+ for blk in self.blocks:
+ nn.init.constant_(blk.norm1.bias, 0)
+ nn.init.constant_(blk.norm1.weight, 0)
+ nn.init.constant_(blk.norm2.bias, 0)
+ nn.init.constant_(blk.norm2.weight, 0)
+
+class PatchEmbed(nn.Module):
+ r""" Image to Patch Embedding
+ Args:
+ img_size (int): Image size. Default: 224.
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patches_resolution = patches_resolution
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ # FIXME look at relaxing size constraints
+ # assert H == self.img_size[0] and W == self.img_size[1],
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
+ if self.norm is not None:
+ x = self.norm(x)
+ return x
+
+ def flops(self):
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
+
+class RSTB(nn.Module):
+ """Residual Swin Transformer Block (RSTB).
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ img_size: Input image size.
+ patch_size: Patch size.
+ resi_connection: The convolutional block before residual connection.
+ """
+
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
+ img_size=224, patch_size=4, resi_connection='1conv'):
+ super(RSTB, self).__init__()
+
+ self.dim = dim
+ self.input_resolution = input_resolution
+
+ self.residual_group = BasicLayer(dim=dim,
+ input_resolution=input_resolution,
+ depth=depth,
+ num_heads=num_heads,
+ window_size=window_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop, attn_drop=attn_drop,
+ drop_path=drop_path,
+ norm_layer=norm_layer,
+ downsample=downsample,
+ use_checkpoint=use_checkpoint)
+
+ if resi_connection == '1conv':
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
+ elif resi_connection == '3conv':
+ # to save parameters and memory
+ self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
+ norm_layer=None)
+
+ self.patch_unembed = PatchUnEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
+ norm_layer=None)
+
+ def forward(self, x, x_size):
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
+
+ def flops(self):
+ flops = 0
+ flops += self.residual_group.flops()
+ H, W = self.input_resolution
+ flops += H * W * self.dim * self.dim * 9
+ flops += self.patch_embed.flops()
+ flops += self.patch_unembed.flops()
+
+ return flops
+
+class PatchUnEmbed(nn.Module):
+ r""" Image to Patch Unembedding
+
+ Args:
+ img_size (int): Image size. Default: 224.
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patches_resolution = patches_resolution
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ def forward(self, x, x_size):
+ B, HW, C = x.shape
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
+ return x
+
+ def flops(self):
+ flops = 0
+ return flops
+
+
+class Upsample(nn.Sequential):
+ """Upsample module.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+ """
+
+ def __init__(self, scale, num_feat):
+ m = []
+ if (scale & (scale - 1)) == 0: # scale = 2^n
+ for _ in range(int(math.log(scale, 2))):
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(2))
+ elif scale == 3:
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(3))
+ else:
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
+ super(Upsample, self).__init__(*m)
+
+class Upsample_hf(nn.Sequential):
+ """Upsample module.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+ """
+
+ def __init__(self, scale, num_feat):
+ m = []
+ if (scale & (scale - 1)) == 0: # scale = 2^n
+ for _ in range(int(math.log(scale, 2))):
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(2))
+ elif scale == 3:
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(3))
+ else:
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
+ super(Upsample_hf, self).__init__(*m)
+
+
+class UpsampleOneStep(nn.Sequential):
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
+ Used in lightweight SR to save parameters.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+
+ """
+
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
+ self.num_feat = num_feat
+ self.input_resolution = input_resolution
+ m = []
+ m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
+ m.append(nn.PixelShuffle(scale))
+ super(UpsampleOneStep, self).__init__(*m)
+
+ def flops(self):
+ H, W = self.input_resolution
+ flops = H * W * self.num_feat * 3 * 9
+ return flops
+
+
+
+class Swin2SR(nn.Module):
+ r""" Swin2SR
+ A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
+
+ Args:
+ img_size (int | tuple(int)): Input image size. Default 64
+ patch_size (int | tuple(int)): Patch size. Default: 1
+ in_chans (int): Number of input image channels. Default: 3
+ embed_dim (int): Patch embedding dimension. Default: 96
+ depths (tuple(int)): Depth of each Swin Transformer layer.
+ num_heads (tuple(int)): Number of attention heads in different layers.
+ window_size (int): Window size. Default: 7
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ drop_rate (float): Dropout rate. Default: 0
+ attn_drop_rate (float): Attention dropout rate. Default: 0
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
+ img_range: Image range. 1. or 255.
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
+ """
+
+ def __init__(self, img_size=64, patch_size=1, in_chans=3,
+ embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
+ window_size=7, mlp_ratio=4., qkv_bias=True,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
+ use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
+ **kwargs):
+ super(Swin2SR, self).__init__()
+ num_in_ch = in_chans
+ num_out_ch = in_chans
+ num_feat = 64
+ self.img_range = img_range
+ if in_chans == 3:
+ rgb_mean = (0.4488, 0.4371, 0.4040)
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
+ else:
+ self.mean = torch.zeros(1, 1, 1, 1)
+ self.upscale = upscale
+ self.upsampler = upsampler
+ self.window_size = window_size
+
+ #####################################################################################################
+ ################################### 1, shallow feature extraction ###################################
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
+
+ #####################################################################################################
+ ################################### 2, deep feature extraction ######################################
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.ape = ape
+ self.patch_norm = patch_norm
+ self.num_features = embed_dim
+ self.mlp_ratio = mlp_ratio
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+ num_patches = self.patch_embed.num_patches
+ patches_resolution = self.patch_embed.patches_resolution
+ self.patches_resolution = patches_resolution
+
+ # merge non-overlapping patches into image
+ self.patch_unembed = PatchUnEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+
+ # absolute position embedding
+ if self.ape:
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+ trunc_normal_(self.absolute_pos_embed, std=.02)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+
+ # build Residual Swin Transformer blocks (RSTB)
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = RSTB(dim=embed_dim,
+ input_resolution=(patches_resolution[0],
+ patches_resolution[1]),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
+ norm_layer=norm_layer,
+ downsample=None,
+ use_checkpoint=use_checkpoint,
+ img_size=img_size,
+ patch_size=patch_size,
+ resi_connection=resi_connection
+
+ )
+ self.layers.append(layer)
+
+ if self.upsampler == 'pixelshuffle_hf':
+ self.layers_hf = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = RSTB(dim=embed_dim,
+ input_resolution=(patches_resolution[0],
+ patches_resolution[1]),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
+ norm_layer=norm_layer,
+ downsample=None,
+ use_checkpoint=use_checkpoint,
+ img_size=img_size,
+ patch_size=patch_size,
+ resi_connection=resi_connection
+
+ )
+ self.layers_hf.append(layer)
+
+ self.norm = norm_layer(self.num_features)
+
+ # build the last conv layer in deep feature extraction
+ if resi_connection == '1conv':
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
+ elif resi_connection == '3conv':
+ # to save parameters and memory
+ self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
+
+ #####################################################################################################
+ ################################ 3, high quality image reconstruction ################################
+ if self.upsampler == 'pixelshuffle':
+ # for classical SR
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.upsample = Upsample(upscale, num_feat)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ elif self.upsampler == 'pixelshuffle_aux':
+ self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+ self.conv_before_upsample = nn.Sequential(
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ self.conv_after_aux = nn.Sequential(
+ nn.Conv2d(3, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.upsample = Upsample(upscale, num_feat)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+ elif self.upsampler == 'pixelshuffle_hf':
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.upsample = Upsample(upscale, num_feat)
+ self.upsample_hf = Upsample_hf(upscale, num_feat)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
+ self.conv_before_upsample_hf = nn.Sequential(
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+ elif self.upsampler == 'pixelshuffledirect':
+ # for lightweight SR (to save parameters)
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
+ (patches_resolution[0], patches_resolution[1]))
+ elif self.upsampler == 'nearest+conv':
+ # for real-world SR (less artifacts)
+ assert self.upscale == 4, 'only support x4 now.'
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+ else:
+ # for image denoising and JPEG compression artifact reduction
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'absolute_pos_embed'}
+
+ @torch.jit.ignore
+ def no_weight_decay_keywords(self):
+ return {'relative_position_bias_table'}
+
+ def check_image_size(self, x):
+ _, _, h, w = x.size()
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
+ return x
+
+ def forward_features(self, x):
+ x_size = (x.shape[2], x.shape[3])
+ x = self.patch_embed(x)
+ if self.ape:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+
+ for layer in self.layers:
+ x = layer(x, x_size)
+
+ x = self.norm(x) # B L C
+ x = self.patch_unembed(x, x_size)
+
+ return x
+
+ def forward_features_hf(self, x):
+ x_size = (x.shape[2], x.shape[3])
+ x = self.patch_embed(x)
+ if self.ape:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+
+ for layer in self.layers_hf:
+ x = layer(x, x_size)
+
+ x = self.norm(x) # B L C
+ x = self.patch_unembed(x, x_size)
+
+ return x
+
+ def forward(self, x):
+ H, W = x.shape[2:]
+ x = self.check_image_size(x)
+
+ self.mean = self.mean.type_as(x)
+ x = (x - self.mean) * self.img_range
+
+ if self.upsampler == 'pixelshuffle':
+ # for classical SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.conv_before_upsample(x)
+ x = self.conv_last(self.upsample(x))
+ elif self.upsampler == 'pixelshuffle_aux':
+ bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False)
+ bicubic = self.conv_bicubic(bicubic)
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.conv_before_upsample(x)
+ aux = self.conv_aux(x) # b, 3, LR_H, LR_W
+ x = self.conv_after_aux(aux)
+ x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale]
+ x = self.conv_last(x)
+ aux = aux / self.img_range + self.mean
+ elif self.upsampler == 'pixelshuffle_hf':
+ # for classical SR with HF
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x_before = self.conv_before_upsample(x)
+ x_out = self.conv_last(self.upsample(x_before))
+
+ x_hf = self.conv_first_hf(x_before)
+ x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
+ x_hf = self.conv_before_upsample_hf(x_hf)
+ x_hf = self.conv_last_hf(self.upsample_hf(x_hf))
+ x = x_out + x_hf
+ x_hf = x_hf / self.img_range + self.mean
+
+ elif self.upsampler == 'pixelshuffledirect':
+ # for lightweight SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.upsample(x)
+ elif self.upsampler == 'nearest+conv':
+ # for real-world SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.conv_before_upsample(x)
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
+ else:
+ # for image denoising and JPEG compression artifact reduction
+ x_first = self.conv_first(x)
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
+ x = x + self.conv_last(res)
+
+ x = x / self.img_range + self.mean
+ if self.upsampler == "pixelshuffle_aux":
+ return x[:, :, :H*self.upscale, :W*self.upscale], aux
+
+ elif self.upsampler == "pixelshuffle_hf":
+ x_out = x_out / self.img_range + self.mean
+ return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
+
+ else:
+ return x[:, :, :H*self.upscale, :W*self.upscale]
+
+ def flops(self):
+ flops = 0
+ H, W = self.patches_resolution
+ flops += H * W * 3 * self.embed_dim * 9
+ flops += self.patch_embed.flops()
+ for i, layer in enumerate(self.layers):
+ flops += layer.flops()
+ flops += H * W * 3 * self.embed_dim * self.embed_dim
+ flops += self.upsample.flops()
+ return flops
+
+
+if __name__ == '__main__':
+ upscale = 4
+ window_size = 8
+ height = (1024 // upscale // window_size + 1) * window_size
+ width = (720 // upscale // window_size + 1) * window_size
+ model = Swin2SR(upscale=2, img_size=(height, width),
+ window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
+ embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
+ print(model)
+ print(height, width, model.flops() / 1e9)
+
+ x = torch.randn((1, 3, height, width))
+ x = model(x)
+ print(x.shape) \ No newline at end of file
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 7c44ea5b..f61f40d3 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -8,18 +8,17 @@ from torchvision import transforms
import random
import tqdm
-from modules import devices
+from modules import devices, shared
import re
re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
class PersonalizedBase(Dataset):
- def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):
self.placeholder_token = placeholder_token
- self.size = size
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
@@ -33,12 +32,15 @@ class PersonalizedBase(Dataset):
assert data_root, 'dataset directory not specified'
+ cond_model = shared.sd_model.cond_stage_model
+
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
- image = Image.open(path)
- image = image.convert('RGB')
- image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
+ try:
+ image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
+ except Exception:
+ continue
filename = os.path.basename(path)
filename_tokens = os.path.splitext(filename)[0]
@@ -53,7 +55,13 @@ class PersonalizedBase(Dataset):
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
init_latent = init_latent.to(devices.cpu)
- self.dataset.append((init_latent, filename_tokens))
+ if include_cond:
+ text = self.create_text(filename_tokens)
+ cond = cond_model([text]).to(devices.cpu)
+ else:
+ cond = None
+
+ self.dataset.append((init_latent, filename_tokens, cond))
self.length = len(self.dataset) * repeats
@@ -64,6 +72,12 @@ class PersonalizedBase(Dataset):
def shuffle(self):
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
+ def create_text(self, filename_tokens):
+ text = random.choice(self.lines)
+ text = text.replace("[name]", self.placeholder_token)
+ text = text.replace("[filewords]", ' '.join(filename_tokens))
+ return text
+
def __len__(self):
return self.length
@@ -72,10 +86,7 @@ class PersonalizedBase(Dataset):
self.shuffle()
index = self.indexes[i % len(self.indexes)]
- x, filename_tokens = self.dataset[index]
-
- text = random.choice(self.lines)
- text = text.replace("[name]", self.placeholder_token)
- text = text.replace("[filewords]", ' '.join(filename_tokens))
+ x, filename_tokens, cond = self.dataset[index]
- return x, text
+ text = self.create_text(filename_tokens)
+ return x, text, cond
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
new file mode 100644
index 00000000..db720271
--- /dev/null
+++ b/modules/textual_inversion/learn_schedule.py
@@ -0,0 +1,34 @@
+
+class LearnSchedule:
+ def __init__(self, learn_rate, max_steps, cur_step=0):
+ pairs = learn_rate.split(',')
+ self.rates = []
+ self.it = 0
+ self.maxit = 0
+ for i, pair in enumerate(pairs):
+ tmp = pair.split(':')
+ if len(tmp) == 2:
+ step = int(tmp[1])
+ if step > cur_step:
+ self.rates.append((float(tmp[0]), min(step, max_steps)))
+ self.maxit += 1
+ if step > max_steps:
+ return
+ elif step == -1:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
+ return
+ else:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
+ return
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.it < self.maxit:
+ self.it += 1
+ return self.rates[self.it - 1]
+ else:
+ raise StopIteration
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index f1c002a2..113cecf1 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -3,12 +3,16 @@ from PIL import Image, ImageOps
import platform
import sys
import tqdm
+import time
from modules import shared, images
+from modules.shared import opts, cmd_opts
+if cmd_opts.deepdanbooru:
+ import modules.deepbooru as deepbooru
-
-def preprocess(process_src, process_dst, process_flip, process_split, process_caption):
- size = 512
+def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
+ width = process_width
+ height = process_height
src = os.path.abspath(process_src)
dst = os.path.abspath(process_dst)
@@ -24,10 +28,21 @@ def preprocess(process_src, process_dst, process_flip, process_split, process_ca
if process_caption:
shared.interrogator.load()
+ if process_caption_deepbooru:
+ deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, opts.deepbooru_sort_alpha)
+
def save_pic_with_caption(image, index):
if process_caption:
caption = "-" + shared.interrogator.generate_caption(image)
caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png")
+ elif process_caption_deepbooru:
+ shared.deepbooru_process_return["value"] = -1
+ shared.deepbooru_process_queue.put(image)
+ while shared.deepbooru_process_return["value"] == -1:
+ time.sleep(0.2)
+ caption = "-" + shared.deepbooru_process_return["value"]
+ caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png")
+ shared.deepbooru_process_return["value"] = -1
else:
caption = filename
caption = os.path.splitext(caption)[0]
@@ -45,7 +60,10 @@ def preprocess(process_src, process_dst, process_flip, process_split, process_ca
for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0]
filename = os.path.join(src, imagefile)
- img = Image.open(filename).convert("RGB")
+ try:
+ img = Image.open(filename).convert("RGB")
+ except Exception:
+ continue
if shared.state.interrupted:
break
@@ -55,23 +73,23 @@ def preprocess(process_src, process_dst, process_flip, process_split, process_ca
is_wide = ratio < 1 / 1.35
if process_split and is_tall:
- img = img.resize((size, size * img.height // img.width))
+ img = img.resize((width, height * img.height // img.width))
- top = img.crop((0, 0, size, size))
+ top = img.crop((0, 0, width, height))
save_pic(top, index)
- bot = img.crop((0, img.height - size, size, img.height))
+ bot = img.crop((0, img.height - height, width, img.height))
save_pic(bot, index)
elif process_split and is_wide:
- img = img.resize((size * img.width // img.height, size))
+ img = img.resize((width * img.width // img.height, height))
- left = img.crop((0, 0, size, size))
+ left = img.crop((0, 0, width, height))
save_pic(left, index)
- right = img.crop((img.width - size, 0, img.width, size))
+ right = img.crop((img.width - width, 0, img.width, height))
save_pic(right, index)
else:
- img = images.resize_image(1, img, size, size)
+ img = images.resize_image(1, img, width, height)
save_pic(img, index)
shared.state.nextjob()
@@ -79,6 +97,10 @@ def preprocess(process_src, process_dst, process_flip, process_split, process_ca
if process_caption:
shared.interrogator.send_blip_to_ram()
+ if process_caption_deepbooru:
+ deepbooru.release_process()
+
+
def sanitize_caption(base_path, original_caption, suffix):
operating_system = platform.system().lower()
if (operating_system == "windows"):
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index cd9f3498..7717837d 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -10,6 +10,7 @@ import datetime
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
+from modules.textual_inversion.learn_schedule import LearnSchedule
class Embedding:
@@ -156,7 +157,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, preview_image_prompt):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -182,15 +183,13 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
hijack = sd_hijack.model_hijack
embedding = hijack.embedding_db.word_embeddings[embedding_name]
embedding.vec.requires_grad = True
- optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
-
losses = torch.zeros((32,))
last_saved_file = "<none>"
@@ -200,12 +199,24 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
if ititial_step > steps:
return embedding, filename
+ schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
+ (learn_rate, end_step) = next(schedules)
+ print(f'Training at rate of {learn_rate} until step {end_step}')
+
+ optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
+
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
- for i, (x, text) in pbar:
+ for i, (x, text, _) in pbar:
embedding.step = i + ititial_step
- if embedding.step > steps:
- break
+ if embedding.step > end_step:
+ try:
+ (learn_rate, end_step) = next(schedules)
+ except:
+ break
+ tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
+ for pg in optimizer.param_groups:
+ pg['lr'] = learn_rate
if shared.state.interrupted:
break
@@ -223,7 +234,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
loss.backward()
optimizer.step()
- pbar.set_description(f"loss: {losses.mean():.7f}")
+ epoch_num = embedding.step // len(ds)
+ epoch_step = embedding.step - (epoch_num * len(ds)) + 1
+
+ pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
@@ -232,10 +246,14 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
+ preview_text = text if preview_image_prompt == "" else preview_image_prompt
+
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
- prompt=text,
+ prompt=preview_text,
steps=20,
+ height=training_height,
+ width=training_width,
do_not_save_grid=True,
do_not_save_samples=True,
)
@@ -246,7 +264,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
shared.state.current_image = image
image.save(last_saved_image)
- last_saved_image += f", prompt: {text}"
+ last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = embedding.step
@@ -268,4 +286,3 @@ Last saved image: {html.escape(last_saved_image)}<br/>
embedding.save(filename)
return embedding, filename
-
diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
index f19ac5e0..36881e7a 100644
--- a/modules/textual_inversion/ui.py
+++ b/modules/textual_inversion/ui.py
@@ -23,6 +23,8 @@ def preprocess(*args):
def train_embedding(*args):
+ assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
+
try:
sd_hijack.undo_optimizations()
diff --git a/modules/ui.py b/modules/ui.py
index 1aabe18d..4bfdd275 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -39,6 +39,7 @@ import modules.generation_parameters_copypaste
from modules import prompt_parser
from modules.images import save_image
import modules.textual_inversion.ui
+import modules.hypernetworks.ui
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
@@ -50,6 +51,11 @@ if not cmd_opts.share and not cmd_opts.listen:
gradio.utils.version_check = lambda: None
gradio.utils.get_local_ip_address = lambda: '127.0.0.1'
+if cmd_opts.ngrok != None:
+ import modules.ngrok as ngrok
+ print('ngrok authtoken detected, trying to connect...')
+ ngrok.connect(cmd_opts.ngrok, cmd_opts.port if cmd_opts.port != None else 7860)
+
def gr_show(visible=True):
return {"visible": visible, "__type__": "update"}
@@ -125,6 +131,8 @@ def save_files(js_data, images, do_make_zip, index):
images = [images[index]]
start_index = index
+ os.makedirs(opts.outdir_save, exist_ok=True)
+
with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
at_start = file.tell() == 0
writer = csv.writer(file)
@@ -175,8 +183,15 @@ def wrap_gradio_call(func, extra_outputs=None):
try:
res = list(func(*args, **kwargs))
except Exception as e:
+ # When printing out our debug argument list, do not print out more than a MB of text
+ max_debug_str_len = 131072 # (1024*1024)/8
+
print("Error completing request", file=sys.stderr)
- print("Arguments:", args, kwargs, file=sys.stderr)
+ argStr = f"Arguments: {str(args)} {str(kwargs)}"
+ print(argStr[:max_debug_str_len], file=sys.stderr)
+ if len(argStr) > max_debug_str_len:
+ print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
+
print(traceback.format_exc(), file=sys.stderr)
shared.state.job = ""
@@ -428,7 +443,10 @@ def create_toprow(is_img2img):
with gr.Row():
with gr.Column(scale=8):
- negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2)
+ with gr.Row():
+ negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2)
+ with gr.Column(scale=1, elem_id="roll_col"):
+ sh = gr.Button(elem_id="sh", visible=True)
with gr.Column(scale=1, elem_id="style_neg_col"):
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
@@ -524,7 +542,7 @@ def create_ui(wrap_gradio_gpu_call):
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)
with gr.Row():
- batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
@@ -555,9 +573,9 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Row():
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
- with gr.Group():
- html_info = gr.HTML()
- generation_info = gr.Textbox(visible=False)
+ with gr.Group():
+ html_info = gr.HTML()
+ generation_info = gr.Textbox(visible=False)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
@@ -710,7 +728,7 @@ def create_ui(wrap_gradio_gpu_call):
tiling = gr.Checkbox(label='Tiling', value=False)
with gr.Row():
- batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
with gr.Group():
@@ -743,9 +761,9 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Row():
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
- with gr.Group():
- html_info = gr.HTML()
- generation_info = gr.Textbox(visible=False)
+ with gr.Group():
+ html_info = gr.HTML()
+ generation_info = gr.Textbox(visible=False)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
@@ -973,7 +991,7 @@ def create_ui(wrap_gradio_gpu_call):
extras_send_to_inpaint.click(
fn=lambda x: image_from_url_text(x),
- _js="extract_image_from_gallery_img2img",
+ _js="extract_image_from_gallery_inpaint",
inputs=[result_images],
outputs=[init_img_with_mask],
)
@@ -1034,18 +1052,37 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="")
with gr.Column():
- create_embedding = gr.Button(value="Create", variant='primary')
+ create_embedding = gr.Button(value="Create embedding", variant='primary')
+
+ with gr.Group():
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new hypernetwork</p>")
+
+ new_hypernetwork_name = gr.Textbox(label="Name")
+ new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ gr.HTML(value="")
+
+ with gr.Column():
+ create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary')
with gr.Group():
gr.HTML(value="<p style='margin-bottom: 0.7em'>Preprocess images</p>")
process_src = gr.Textbox(label='Source directory')
process_dst = gr.Textbox(label='Destination directory')
+ process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
+ process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies')
process_split = gr.Checkbox(label='Split oversized images into two')
process_caption = gr.Checkbox(label='Use BLIP caption as filename')
+ if cmd_opts.deepdanbooru:
+ process_caption_deepbooru = gr.Checkbox(label='Use deepbooru caption as filename')
+ else:
+ process_caption_deepbooru = gr.Checkbox(label='Use deepbooru caption as filename', visible=False)
with gr.Row():
with gr.Column(scale=3):
@@ -1055,24 +1092,25 @@ def create_ui(wrap_gradio_gpu_call):
run_preprocess = gr.Button(value="Preprocess", variant='primary')
with gr.Group():
- gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 512x512 images</p>")
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
- learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
+ train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
+ learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
+ training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
+ training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0)
+ num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
+ preview_image_prompt = gr.Textbox(label='Preview prompt', value="")
with gr.Row():
- with gr.Column(scale=2):
- gr.HTML(value="")
-
- with gr.Column():
- with gr.Row():
- interrupt_training = gr.Button(value="Interrupt")
- train_embedding = gr.Button(value="Train", variant='primary')
+ interrupt_training = gr.Button(value="Interrupt")
+ train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
+ train_embedding = gr.Button(value="Train Embedding", variant='primary')
with gr.Column():
progressbar = gr.HTML(elem_id="ti_progressbar")
@@ -1098,15 +1136,31 @@ def create_ui(wrap_gradio_gpu_call):
]
)
+ create_hypernetwork.click(
+ fn=modules.hypernetworks.ui.create_hypernetwork,
+ inputs=[
+ new_hypernetwork_name,
+ new_hypernetwork_sizes,
+ ],
+ outputs=[
+ train_hypernetwork_name,
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
run_preprocess.click(
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
process_src,
process_dst,
+ process_width,
+ process_height,
process_flip,
process_split,
process_caption,
+ process_caption_deepbooru
],
outputs=[
ti_output,
@@ -1122,10 +1176,34 @@ def create_ui(wrap_gradio_gpu_call):
learn_rate,
dataset_directory,
log_directory,
+ training_width,
+ training_height,
+ steps,
+ num_repeats,
+ create_image_every,
+ save_embedding_every,
+ template_file,
+ preview_image_prompt,
+ ],
+ outputs=[
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
+ train_hypernetwork.click(
+ fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]),
+ _js="start_training_textual_inversion",
+ inputs=[
+ train_hypernetwork_name,
+ learn_rate,
+ dataset_directory,
+ log_directory,
steps,
create_image_every,
save_embedding_every,
template_file,
+ preview_image_prompt,
],
outputs=[
ti_output,
@@ -1139,6 +1217,7 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[],
)
+
def create_setting_component(key):
def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default
@@ -1292,6 +1371,7 @@ Requested path was: {f}
shared.state.interrupt()
settings_interface.gradio_ref.do_restart = True
+
restart_gradio.click(
fn=request_restart,
inputs=[],
@@ -1333,7 +1413,7 @@ Requested path was: {f}
with gr.Tabs() as tabs:
for interface, label, ifid in interfaces:
- with gr.TabItem(label, id=ifid):
+ with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
interface.render()
if os.path.exists(os.path.join(script_path, "notification.mp3")):