aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2022-10-15 10:47:26 +0300
committerGitHub <noreply@github.com>2022-10-15 10:47:26 +0300
commitf42e0aae6de6b9a7f8da4eaf13594a13502b4fa9 (patch)
tree472025101577ff5cbd45a3bcb524e6e4accb75ec /modules
parent0e77ee24b0b651d6a564245243850e4fb9831e31 (diff)
parentd13ce89e203d76ab2b54a3406a93a5e4304f529e (diff)
Merge branch 'master' into master
Diffstat (limited to 'modules')
-rw-r--r--modules/bsrgan_model.py8
-rw-r--r--modules/codeformer_model.py12
-rw-r--r--modules/deepbooru.py173
-rw-r--r--modules/devices.py22
-rw-r--r--modules/esrgan_model.py17
-rw-r--r--modules/esrgan_model_arch.py (renamed from modules/esrgam_model_arch.py)0
-rw-r--r--modules/extras.py76
-rw-r--r--modules/generation_parameters_copypaste.py9
-rw-r--r--modules/gfpgan_model.py22
-rw-r--r--modules/hypernetworks/hypernetwork.py347
-rw-r--r--modules/hypernetworks/ui.py47
-rw-r--r--modules/images.py97
-rw-r--r--modules/images_history.py181
-rw-r--r--modules/img2img.py16
-rw-r--r--modules/interrogate.py24
-rw-r--r--modules/ldsr_model.py2
-rw-r--r--modules/modelloader.py21
-rw-r--r--modules/ngrok.py15
-rw-r--r--modules/paths.py3
-rw-r--r--modules/processing.py233
-rw-r--r--modules/prompt_parser.py336
-rw-r--r--modules/realesrgan_model.py2
-rw-r--r--modules/safe.py117
-rw-r--r--modules/scripts.py34
-rw-r--r--modules/scunet_model.py88
-rw-r--r--modules/scunet_model_arch.py265
-rw-r--r--modules/sd_hijack.py466
-rw-r--r--modules/sd_hijack_optimizations.py306
-rw-r--r--modules/sd_models.py118
-rw-r--r--modules/sd_samplers.py173
-rw-r--r--modules/shared.py141
-rw-r--r--modules/swinir_model.py64
-rw-r--r--modules/swinir_model_arch.py2
-rw-r--r--modules/swinir_model_arch_v2.py1017
-rw-r--r--modules/textual_inversion/dataset.py121
-rw-r--r--modules/textual_inversion/image_embedding.py219
-rw-r--r--modules/textual_inversion/learn_schedule.py69
-rw-r--r--modules/textual_inversion/preprocess.py116
-rw-r--r--modules/textual_inversion/test_embedding.pngbin0 -> 489220 bytes
-rw-r--r--modules/textual_inversion/textual_inversion.py363
-rw-r--r--modules/textual_inversion/ui.py42
-rw-r--r--modules/txt2img.py12
-rw-r--r--modules/ui.py740
-rw-r--r--modules/upscaler.py7
44 files changed, 5320 insertions, 823 deletions
diff --git a/modules/bsrgan_model.py b/modules/bsrgan_model.py
index e62c6657..737e1a76 100644
--- a/modules/bsrgan_model.py
+++ b/modules/bsrgan_model.py
@@ -8,15 +8,13 @@ import torch
from basicsr.utils.download_util import load_file_from_url
import modules.upscaler
-from modules import shared, modelloader
+from modules import devices, modelloader
from modules.bsrgan_model_arch import RRDBNet
-from modules.paths import models_path
class UpscalerBSRGAN(modules.upscaler.Upscaler):
def __init__(self, dirname):
self.name = "BSRGAN"
- self.model_path = os.path.join(models_path, self.name)
self.model_name = "BSRGAN 4x"
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
self.user_path = dirname
@@ -44,13 +42,13 @@ class UpscalerBSRGAN(modules.upscaler.Upscaler):
model = self.load_model(selected_file)
if model is None:
return img
- model.to(shared.device)
+ model.to(devices.device_bsrgan)
torch.cuda.empty_cache()
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(shared.device)
+ img = img.unsqueeze(0).to(devices.device_bsrgan)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index a29f3855..e6d9fa4f 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -69,10 +69,14 @@ def setup_model(dirname):
self.net = net
self.face_helper = face_helper
- self.net.to(devices.device_codeformer)
return net, face_helper
+ def send_model_to(self, device):
+ self.net.to(device)
+ self.face_helper.face_det.to(device)
+ self.face_helper.face_parse.to(device)
+
def restore(self, np_image, w=None):
np_image = np_image[:, :, ::-1]
@@ -82,6 +86,8 @@ def setup_model(dirname):
if self.net is None or self.face_helper is None:
return np_image
+ self.send_model_to(devices.device_codeformer)
+
self.face_helper.clean_all()
self.face_helper.read_image(np_image)
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
@@ -113,8 +119,10 @@ def setup_model(dirname):
if original_resolution != restored_img.shape[0:2]:
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
+ self.face_helper.clean_all()
+
if shared.opts.face_restoration_unload:
- self.net.to(devices.cpu)
+ self.send_model_to(devices.cpu)
return restored_img
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
new file mode 100644
index 00000000..4ad334a1
--- /dev/null
+++ b/modules/deepbooru.py
@@ -0,0 +1,173 @@
+import os.path
+from concurrent.futures import ProcessPoolExecutor
+import multiprocessing
+import time
+import re
+
+re_special = re.compile(r'([\\()])')
+
+def get_deepbooru_tags(pil_image):
+ """
+ This method is for running only one image at a time for simple use. Used to the img2img interrogate.
+ """
+ from modules import shared # prevents circular reference
+
+ try:
+ create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, create_deepbooru_opts())
+ return get_tags_from_process(pil_image)
+ finally:
+ release_process()
+
+
+OPT_INCLUDE_RANKS = "include_ranks"
+def create_deepbooru_opts():
+ from modules import shared
+
+ return {
+ "use_spaces": shared.opts.deepbooru_use_spaces,
+ "use_escape": shared.opts.deepbooru_escape,
+ "alpha_sort": shared.opts.deepbooru_sort_alpha,
+ OPT_INCLUDE_RANKS: shared.opts.interrogate_return_ranks,
+ }
+
+
+def deepbooru_process(queue, deepbooru_process_return, threshold, deepbooru_opts):
+ model, tags = get_deepbooru_tags_model()
+ while True: # while process is running, keep monitoring queue for new image
+ pil_image = queue.get()
+ if pil_image == "QUIT":
+ break
+ else:
+ deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts)
+
+
+def create_deepbooru_process(threshold, deepbooru_opts):
+ """
+ Creates deepbooru process. A queue is created to send images into the process. This enables multiple images
+ to be processed in a row without reloading the model or creating a new process. To return the data, a shared
+ dictionary is created to hold the tags created. To wait for tags to be returned, a value of -1 is assigned
+ to the dictionary and the method adding the image to the queue should wait for this value to be updated with
+ the tags.
+ """
+ from modules import shared # prevents circular reference
+ shared.deepbooru_process_manager = multiprocessing.Manager()
+ shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
+ shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
+ shared.deepbooru_process_return["value"] = -1
+ shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts))
+ shared.deepbooru_process.start()
+
+
+def get_tags_from_process(image):
+ from modules import shared
+
+ shared.deepbooru_process_return["value"] = -1
+ shared.deepbooru_process_queue.put(image)
+ while shared.deepbooru_process_return["value"] == -1:
+ time.sleep(0.2)
+ caption = shared.deepbooru_process_return["value"]
+ shared.deepbooru_process_return["value"] = -1
+
+ return caption
+
+
+def release_process():
+ """
+ Stops the deepbooru process to return used memory
+ """
+ from modules import shared # prevents circular reference
+ shared.deepbooru_process_queue.put("QUIT")
+ shared.deepbooru_process.join()
+ shared.deepbooru_process_queue = None
+ shared.deepbooru_process = None
+ shared.deepbooru_process_return = None
+ shared.deepbooru_process_manager = None
+
+def get_deepbooru_tags_model():
+ import deepdanbooru as dd
+ import tensorflow as tf
+ import numpy as np
+ this_folder = os.path.dirname(__file__)
+ model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru'))
+ if not os.path.exists(os.path.join(model_path, 'project.json')):
+ # there is no point importing these every time
+ import zipfile
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(
+ r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
+ model_path)
+ with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref:
+ zip_ref.extractall(model_path)
+ os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"))
+
+ tags = dd.project.load_tags_from_project(model_path)
+ model = dd.project.load_model_from_project(
+ model_path, compile_model=False
+ )
+ return model, tags
+
+
+def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts):
+ import deepdanbooru as dd
+ import tensorflow as tf
+ import numpy as np
+
+ alpha_sort = deepbooru_opts['alpha_sort']
+ use_spaces = deepbooru_opts['use_spaces']
+ use_escape = deepbooru_opts['use_escape']
+ include_ranks = deepbooru_opts['include_ranks']
+
+ width = model.input_shape[2]
+ height = model.input_shape[1]
+ image = np.array(pil_image)
+ image = tf.image.resize(
+ image,
+ size=(height, width),
+ method=tf.image.ResizeMethod.AREA,
+ preserve_aspect_ratio=True,
+ )
+ image = image.numpy() # EagerTensor to np.array
+ image = dd.image.transform_and_pad_image(image, width, height)
+ image = image / 255.0
+ image_shape = image.shape
+ image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2]))
+
+ y = model.predict(image)[0]
+
+ result_dict = {}
+
+ for i, tag in enumerate(tags):
+ result_dict[tag] = y[i]
+
+ unsorted_tags_in_theshold = []
+ result_tags_print = []
+ for tag in tags:
+ if result_dict[tag] >= threshold:
+ if tag.startswith("rating:"):
+ continue
+ unsorted_tags_in_theshold.append((result_dict[tag], tag))
+ result_tags_print.append(f'{result_dict[tag]} {tag}')
+
+ # sort tags
+ result_tags_out = []
+ sort_ndx = 0
+ if alpha_sort:
+ sort_ndx = 1
+
+ # sort by reverse by likelihood and normal for alpha, and format tag text as requested
+ unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
+ for weight, tag in unsorted_tags_in_theshold:
+ # note: tag_outformat will still have a colon if include_ranks is True
+ tag_outformat = tag.replace(':', ' ')
+ if use_spaces:
+ tag_outformat = tag_outformat.replace('_', ' ')
+ if use_escape:
+ tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
+ if include_ranks:
+ tag_outformat = f"({tag_outformat}:{weight:.3f})"
+
+ result_tags_out.append(tag_outformat)
+
+ print('\n'.join(sorted(result_tags_print, reverse=True)))
+
+ return ', '.join(result_tags_out)
diff --git a/modules/devices.py b/modules/devices.py
index 07bb2339..eb422583 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -1,8 +1,10 @@
+import contextlib
+
import torch
-# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
from modules import errors
+# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
has_mps = getattr(torch, 'has_mps', False)
cpu = torch.device("cpu")
@@ -32,10 +34,9 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
-
-device = get_optimal_device()
-device_codeformer = cpu if has_mps else device
-
+device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
+dtype = torch.float16
+dtype_vae = torch.float16
def randn(seed, shape):
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
@@ -58,3 +59,14 @@ def randn_without_seed(shape):
return torch.randn(shape, device=device)
+
+def autocast(disable=False):
+ from modules import shared
+
+ if disable:
+ return contextlib.nullcontext()
+
+ if dtype == torch.float32 or shared.cmd_opts.precision == "full":
+ return contextlib.nullcontext()
+
+ return torch.autocast("cuda")
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index ea91abfe..46ad0da3 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -5,10 +5,8 @@ import torch
from PIL import Image
from basicsr.utils.download_util import load_file_from_url
-import modules.esrgam_model_arch as arch
-from modules import shared, modelloader, images
-from modules.devices import has_mps
-from modules.paths import models_path
+import modules.esrgan_model_arch as arch
+from modules import shared, modelloader, images, devices
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts
@@ -73,11 +71,10 @@ def fix_model_layers(crt_model, pretrained_net):
class UpscalerESRGAN(Upscaler):
def __init__(self, dirname):
self.name = "ESRGAN"
- self.model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
- self.model_name = "ESRGAN 4x"
+ self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
+ self.model_name = "ESRGAN_4x"
self.scalers = []
self.user_path = dirname
- self.model_path = os.path.join(models_path, self.name)
super().__init__()
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
scalers = []
@@ -97,7 +94,7 @@ class UpscalerESRGAN(Upscaler):
model = self.load_model(selected_model)
if model is None:
return img
- model.to(shared.device)
+ model.to(devices.device_esrgan)
img = esrgan_upscale(model, img)
return img
@@ -112,7 +109,7 @@ class UpscalerESRGAN(Upscaler):
print("Unable to load %s from %s" % (self.model_path, filename))
return None
- pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
+ pretrained_net = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
pretrained_net = fix_model_layers(crt_model, pretrained_net)
@@ -127,7 +124,7 @@ def upscale_without_tiling(model, img):
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(shared.device)
+ img = img.unsqueeze(0).to(devices.device_esrgan)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
diff --git a/modules/esrgam_model_arch.py b/modules/esrgan_model_arch.py
index e413d36e..e413d36e 100644
--- a/modules/esrgam_model_arch.py
+++ b/modules/esrgan_model_arch.py
diff --git a/modules/extras.py b/modules/extras.py
index 6a0d5cb0..f2f5a7b0 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -1,3 +1,4 @@
+import math
import os
import numpy as np
@@ -19,7 +20,7 @@ import gradio as gr
cached_images = {}
-def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
+def run_extras(extras_mode, resize_mode, image, image_folder, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
devices.torch_gc()
imageArr = []
@@ -29,7 +30,7 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
if extras_mode == 1:
#convert file to pillow image
for img in image_folder:
- image = Image.fromarray(np.array(Image.open(img)))
+ image = Image.open(img)
imageArr.append(image)
imageNameArr.append(os.path.splitext(img.orig_name)[0])
else:
@@ -67,8 +68,13 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
image = res
+ if resize_mode == 1:
+ upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
+ crop_info = " (crop)" if upscaling_crop else ""
+ info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
+
if upscaling_resize != 1.0:
- def upscale(image, scaler_index, resize):
+ def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
pixels = tuple(np.array(small).flatten().tolist())
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
@@ -77,15 +83,19 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
if c is None:
upscaler = shared.sd_upscalers[scaler_index]
c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
+ if mode == 1 and crop:
+ cropped = Image.new("RGB", (resize_w, resize_h))
+ cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2))
+ c = cropped
cached_images[key] = c
return c
info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
- res = upscale(image, extras_upscaler_1, upscaling_resize)
+ res = upscale(image, extras_upscaler_1, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
- res2 = upscale(image, extras_upscaler_2, upscaling_resize)
+ res2 = upscale(image, extras_upscaler_2, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
res = Image.blend(res, res2, extras_upscaler_2_visibility)
@@ -98,8 +108,14 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo,
forced_filename=image_name if opts.use_original_name_batch else None)
+ if opts.enable_pnginfo:
+ image.info = existing_pnginfo
+ image.info["extras"] = info
+
outputs.append(image)
+ devices.torch_gc()
+
return outputs, plaintext_to_html(info), ''
@@ -143,48 +159,52 @@ def run_pnginfo(image):
return '', geninfo, info
-def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount, save_as_half, custom_name):
- # Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation)
- def weighted_sum(theta0, theta1, alpha):
+def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name):
+ def weighted_sum(theta0, theta1, theta2, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1)
- # Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
- def sigmoid(theta0, theta1, alpha):
- alpha = alpha * alpha * (3 - (2 * alpha))
- return theta0 + ((theta1 - theta0) * alpha)
-
- # Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
- def inv_sigmoid(theta0, theta1, alpha):
- import math
- alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0)
- return theta0 + ((theta1 - theta0) * alpha)
+ def add_difference(theta0, theta1, theta2, alpha):
+ return theta0 + (theta1 - theta2) * alpha
primary_model_info = sd_models.checkpoints_list[primary_model_name]
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
+ teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None)
print(f"Loading {primary_model_info.filename}...")
primary_model = torch.load(primary_model_info.filename, map_location='cpu')
+ theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
print(f"Loading {secondary_model_info.filename}...")
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
-
- theta_0 = primary_model['state_dict']
- theta_1 = secondary_model['state_dict']
+ theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
+
+ if teritary_model_info is not None:
+ print(f"Loading {teritary_model_info.filename}...")
+ teritary_model = torch.load(teritary_model_info.filename, map_location='cpu')
+ theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model)
+ else:
+ theta_2 = None
theta_funcs = {
- "Weighted Sum": weighted_sum,
- "Sigmoid": sigmoid,
- "Inverse Sigmoid": inv_sigmoid,
+ "Weighted sum": weighted_sum,
+ "Add difference": add_difference,
}
theta_func = theta_funcs[interp_method]
print(f"Merging...")
+
for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1:
- theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
+ t2 = (theta_2 or {}).get(key)
+ if t2 is None:
+ t2 = torch.zeros_like(theta_0[key])
+
+ theta_0[key] = theta_func(theta_0[key], theta_1[key], t2, multiplier)
+
if save_as_half:
theta_0[key] = theta_0[key].half()
-
+
+ # I believe this part should be discarded, but I'll leave it for now until I am sure
for key in theta_1.keys():
if 'model' in key and key not in theta_0:
theta_0[key] = theta_1[key]
@@ -193,7 +213,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
- filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
+ filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
filename = filename if custom_name == '' else (custom_name + '.ckpt')
output_modelname = os.path.join(ckpt_dir, filename)
@@ -203,4 +223,4 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
sd_models.list_models()
print(f"Checkpoint saved.")
- return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(3)]
+ return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index ac1ba7f4..c27826b6 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -1,5 +1,8 @@
+import os
import re
import gradio as gr
+from modules.shared import script_path
+from modules import shared
re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)"
re_param = re.compile(re_param_code)
@@ -61,6 +64,12 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
def connect_paste(button, paste_fields, input_comp, js=None):
def paste_func(prompt):
+ if not prompt and not shared.cmd_opts.hide_ui_dir_config:
+ filename = os.path.join(script_path, "params.txt")
+ if os.path.exists(filename):
+ with open(filename, "r", encoding="utf8") as file:
+ prompt = file.read()
+
params = parse_generation_parameters(prompt)
res = []
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index bb30d733..a9452dce 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -21,7 +21,7 @@ def gfpgann():
global loaded_gfpgan_model
global model_path
if loaded_gfpgan_model is not None:
- loaded_gfpgan_model.gfpgan.to(shared.device)
+ loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
return loaded_gfpgan_model
if gfpgan_constructor is None:
@@ -37,22 +37,32 @@ def gfpgann():
print("Unable to load gfpgan model!")
return None
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
- model.gfpgan.to(shared.device)
loaded_gfpgan_model = model
return model
+def send_model_to(model, device):
+ model.gfpgan.to(device)
+ model.face_helper.face_det.to(device)
+ model.face_helper.face_parse.to(device)
+
+
def gfpgan_fix_faces(np_image):
model = gfpgann()
if model is None:
return np_image
+
+ send_model_to(model, devices.device_gfpgan)
+
np_image_bgr = np_image[:, :, ::-1]
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1]
+ model.face_helper.clean_all()
+
if shared.opts.face_restoration_unload:
- model.gfpgan.to(devices.cpu)
+ send_model_to(model, devices.cpu)
return np_image
@@ -97,11 +107,7 @@ def setup_model(dirname):
return "GFPGAN"
def restore(self, np_image):
- np_image_bgr = np_image[:, :, ::-1]
- cropped_faces, restored_faces, gfpgan_output_bgr = gfpgann().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
- np_image = gfpgan_output_bgr[:, :, ::-1]
-
- return np_image
+ return gfpgan_fix_faces(np_image)
shared.face_restorers.append(FaceRestorerGFPGAN())
except Exception:
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
new file mode 100644
index 00000000..a2b3bc0a
--- /dev/null
+++ b/modules/hypernetworks/hypernetwork.py
@@ -0,0 +1,347 @@
+import datetime
+import glob
+import html
+import os
+import sys
+import traceback
+import tqdm
+import csv
+
+import torch
+
+from ldm.util import default
+from modules import devices, shared, processing, sd_models
+import torch
+from torch import einsum
+from einops import rearrange, repeat
+import modules.textual_inversion.dataset
+from modules.textual_inversion import textual_inversion
+from modules.textual_inversion.learn_schedule import LearnRateScheduler
+
+
+class HypernetworkModule(torch.nn.Module):
+ multiplier = 1.0
+
+ def __init__(self, dim, state_dict=None):
+ super().__init__()
+
+ self.linear1 = torch.nn.Linear(dim, dim * 2)
+ self.linear2 = torch.nn.Linear(dim * 2, dim)
+
+ if state_dict is not None:
+ self.load_state_dict(state_dict, strict=True)
+ else:
+
+ self.linear1.weight.data.normal_(mean=0.0, std=0.01)
+ self.linear1.bias.data.zero_()
+ self.linear2.weight.data.normal_(mean=0.0, std=0.01)
+ self.linear2.bias.data.zero_()
+
+ self.to(devices.device)
+
+ def forward(self, x):
+ return x + (self.linear2(self.linear1(x))) * self.multiplier
+
+
+def apply_strength(value=None):
+ HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
+
+
+class Hypernetwork:
+ filename = None
+ name = None
+
+ def __init__(self, name=None, enable_sizes=None):
+ self.filename = None
+ self.name = name
+ self.layers = {}
+ self.step = 0
+ self.sd_checkpoint = None
+ self.sd_checkpoint_name = None
+
+ for size in enable_sizes or []:
+ self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))
+
+ def weights(self):
+ res = []
+
+ for k, layers in self.layers.items():
+ for layer in layers:
+ layer.train()
+ res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias]
+
+ return res
+
+ def save(self, filename):
+ state_dict = {}
+
+ for k, v in self.layers.items():
+ state_dict[k] = (v[0].state_dict(), v[1].state_dict())
+
+ state_dict['step'] = self.step
+ state_dict['name'] = self.name
+ state_dict['sd_checkpoint'] = self.sd_checkpoint
+ state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
+
+ torch.save(state_dict, filename)
+
+ def load(self, filename):
+ self.filename = filename
+ if self.name is None:
+ self.name = os.path.splitext(os.path.basename(filename))[0]
+
+ state_dict = torch.load(filename, map_location='cpu')
+
+ for size, sd in state_dict.items():
+ if type(size) == int:
+ self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
+
+ self.name = state_dict.get('name', self.name)
+ self.step = state_dict.get('step', 0)
+ self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
+ self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
+
+
+def list_hypernetworks(path):
+ res = {}
+ for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
+ name = os.path.splitext(os.path.basename(filename))[0]
+ res[name] = filename
+ return res
+
+
+def load_hypernetwork(filename):
+ path = shared.hypernetworks.get(filename, None)
+ if path is not None:
+ print(f"Loading hypernetwork {filename}")
+ try:
+ shared.loaded_hypernetwork = Hypernetwork()
+ shared.loaded_hypernetwork.load(path)
+
+ except Exception:
+ print(f"Error loading hypernetwork {path}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ else:
+ if shared.loaded_hypernetwork is not None:
+ print(f"Unloading hypernetwork")
+
+ shared.loaded_hypernetwork = None
+
+
+def find_closest_hypernetwork_name(search: str):
+ if not search:
+ return None
+ search = search.lower()
+ applicable = [name for name in shared.hypernetworks if search in name.lower()]
+ if not applicable:
+ return None
+ applicable = sorted(applicable, key=lambda name: len(name))
+ return applicable[0]
+
+
+def apply_hypernetwork(hypernetwork, context, layer=None):
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
+
+ if hypernetwork_layers is None:
+ return context, context
+
+ if layer is not None:
+ layer.hyper_k = hypernetwork_layers[0]
+ layer.hyper_v = hypernetwork_layers[1]
+
+ context_k = hypernetwork_layers[0](context)
+ context_v = hypernetwork_layers[1](context)
+ return context_k, context_v
+
+
+def attention_CrossAttention_forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+
+ context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
+ k = self.to_k(context_k)
+ v = self.to_v(context_v)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+ if mask is not None:
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ attn = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', attn, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out)
+
+
+def stack_conds(conds):
+ if len(conds) == 1:
+ return torch.stack(conds)
+
+ # same as in reconstruct_multicond_batch
+ token_count = max([x.shape[0] for x in conds])
+ for i in range(len(conds)):
+ if conds[i].shape[0] != token_count:
+ last_vector = conds[i][-1:]
+ last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
+ conds[i] = torch.vstack([conds[i], last_vector_repeated])
+
+ return torch.stack(conds)
+
+def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+ assert hypernetwork_name, 'hypernetwork not selected'
+
+ path = shared.hypernetworks.get(hypernetwork_name, None)
+ shared.loaded_hypernetwork = Hypernetwork()
+ shared.loaded_hypernetwork.load(path)
+
+ shared.state.textinfo = "Initializing hypernetwork training..."
+ shared.state.job_count = steps
+
+ filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
+
+ log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
+ unload = shared.opts.unload_models_when_training
+
+ if save_hypernetwork_every > 0:
+ hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
+ os.makedirs(hypernetwork_dir, exist_ok=True)
+ else:
+ hypernetwork_dir = None
+
+ if create_image_every > 0:
+ images_dir = os.path.join(log_directory, "images")
+ os.makedirs(images_dir, exist_ok=True)
+ else:
+ images_dir = None
+
+ shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
+ with torch.autocast("cuda"):
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
+
+ if unload:
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ shared.sd_model.first_stage_model.to(devices.cpu)
+
+ hypernetwork = shared.loaded_hypernetwork
+ weights = hypernetwork.weights()
+ for weight in weights:
+ weight.requires_grad = True
+
+ losses = torch.zeros((32,))
+
+ last_saved_file = "<none>"
+ last_saved_image = "<none>"
+
+ ititial_step = hypernetwork.step or 0
+ if ititial_step > steps:
+ return hypernetwork, filename
+
+ scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+ optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
+
+ pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
+ for i, entries in pbar:
+ hypernetwork.step = i + ititial_step
+
+ scheduler.apply(optimizer, hypernetwork.step)
+ if scheduler.finished:
+ break
+
+ if shared.state.interrupted:
+ break
+
+ with torch.autocast("cuda"):
+ c = stack_conds([entry.cond for entry in entries]).to(devices.device)
+# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
+ x = torch.stack([entry.latent for entry in entries]).to(devices.device)
+ loss = shared.sd_model(x, c)[0]
+ del x
+ del c
+
+ losses[hypernetwork.step % losses.shape[0]] = loss.item()
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ pbar.set_description(f"loss: {losses.mean():.7f}")
+
+ if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
+ last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
+ hypernetwork.save(last_saved_file)
+
+ textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
+ "loss": f"{losses.mean():.7f}",
+ "learn_rate": scheduler.learn_rate
+ })
+
+ if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
+ last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
+
+ optimizer.zero_grad()
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
+
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ )
+
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_index = preview_sampler_index
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = entries[0].cond_text
+ p.steps = 20
+
+ preview_text = p.prompt
+
+ processed = processing.process_images(p)
+ image = processed.images[0] if len(processed.images)>0 else None
+
+ if unload:
+ shared.sd_model.cond_stage_model.to(devices.cpu)
+ shared.sd_model.first_stage_model.to(devices.cpu)
+
+ if image is not None:
+ shared.state.current_image = image
+ image.save(last_saved_image)
+ last_saved_image += f", prompt: {preview_text}"
+
+ shared.state.job_no = hypernetwork.step
+
+ shared.state.textinfo = f"""
+<p>
+Loss: {losses.mean():.7f}<br/>
+Step: {hypernetwork.step}<br/>
+Last prompt: {html.escape(entries[0].cond_text)}<br/>
+Last saved embedding: {html.escape(last_saved_file)}<br/>
+Last saved image: {html.escape(last_saved_image)}<br/>
+</p>
+"""
+
+ checkpoint = sd_models.select_checkpoint()
+
+ hypernetwork.sd_checkpoint = checkpoint.hash
+ hypernetwork.sd_checkpoint_name = checkpoint.model_name
+ hypernetwork.save(filename)
+
+ return hypernetwork, filename
+
+
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
new file mode 100644
index 00000000..dfa599af
--- /dev/null
+++ b/modules/hypernetworks/ui.py
@@ -0,0 +1,47 @@
+import html
+import os
+
+import gradio as gr
+
+import modules.textual_inversion.textual_inversion
+import modules.textual_inversion.preprocess
+from modules import sd_hijack, shared, devices
+from modules.hypernetworks import hypernetwork
+
+
+def create_hypernetwork(name, enable_sizes):
+ fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
+ assert not os.path.exists(fn), f"file {fn} already exists"
+
+ hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name, enable_sizes=[int(x) for x in enable_sizes])
+ hypernet.save(fn)
+
+ shared.reload_hypernetworks()
+
+ return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", ""
+
+
+def train_hypernetwork(*args):
+
+ initial_hypernetwork = shared.loaded_hypernetwork
+
+ assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
+
+ try:
+ sd_hijack.undo_optimizations()
+
+ hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args)
+
+ res = f"""
+Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.
+Hypernetwork saved to {html.escape(filename)}
+"""
+ return res, ""
+ except Exception:
+ raise
+ finally:
+ shared.loaded_hypernetwork = initial_hypernetwork
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
+ sd_hijack.apply_optimizations()
+
diff --git a/modules/images.py b/modules/images.py
index f1aed5d6..b9589563 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -1,4 +1,5 @@
import datetime
+import io
import math
import os
from collections import namedtuple
@@ -23,6 +24,10 @@ def image_grid(imgs, batch_size=1, rows=None):
rows = opts.n_rows
elif opts.n_rows == 0:
rows = batch_size
+ elif opts.grid_prevent_empty_spots:
+ rows = math.floor(math.sqrt(len(imgs)))
+ while len(imgs) % rows != 0:
+ rows -= 1
else:
rows = math.sqrt(len(imgs))
rows = round(rows)
@@ -287,6 +292,20 @@ def apply_filename_pattern(x, p, seed, prompt):
if seed is not None:
x = x.replace("[seed]", str(seed))
+ if p is not None:
+ x = x.replace("[steps]", str(p.steps))
+ x = x.replace("[cfg]", str(p.cfg_scale))
+ x = x.replace("[width]", str(p.width))
+ x = x.replace("[height]", str(p.height))
+ x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]) or "None", replace_spaces=False))
+ x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
+
+ x = x.replace("[model_hash]", getattr(p, "sd_model_hash", shared.sd_model.sd_model_hash))
+ x = x.replace("[date]", datetime.date.today().isoformat())
+ x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
+ x = x.replace("[job_timestamp]", getattr(p, "job_timestamp", shared.state.job_timestamp))
+
+ # Apply [prompt] at last. Because it may contain any replacement word.^M
if prompt is not None:
x = x.replace("[prompt]", sanitize_filename_part(prompt))
if "[prompt_no_styles]" in x:
@@ -295,7 +314,7 @@ def apply_filename_pattern(x, p, seed, prompt):
if len(style) > 0:
style_parts = [y for y in style.split("{prompt}")]
for part in style_parts:
- prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
+ prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
x = x.replace("[prompt_no_styles]", sanitize_filename_part(prompt_no_style, replace_spaces=False))
@@ -306,19 +325,6 @@ def apply_filename_pattern(x, p, seed, prompt):
words = ["empty"]
x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
- if p is not None:
- x = x.replace("[steps]", str(p.steps))
- x = x.replace("[cfg]", str(p.cfg_scale))
- x = x.replace("[width]", str(p.width))
- x = x.replace("[height]", str(p.height))
- x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]), replace_spaces=False))
- x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
-
- x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
- x = x.replace("[date]", datetime.date.today().isoformat())
- x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
- x = x.replace("[job_timestamp]", shared.state.job_timestamp)
-
if cmd_opts.hide_ui_dir_config:
x = re.sub(r'^[\\/]+|\.{2,}[\\/]+|[\\/]+\.{2,}', '', x)
@@ -347,7 +353,39 @@ def get_next_sequence_number(path, basename):
return result + 1
-def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix=""):
+def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
+ '''Save an image.
+
+ Args:
+ image (`PIL.Image`):
+ The image to be saved.
+ path (`str`):
+ The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory.
+ basename (`str`):
+ The base filename which will be applied to `filename pattern`.
+ seed, prompt, short_filename,
+ extension (`str`):
+ Image file extension, default is `png`.
+ pngsectionname (`str`):
+ Specify the name of the section which `info` will be saved in.
+ info (`str` or `PngImagePlugin.iTXt`):
+ PNG info chunks.
+ existing_info (`dict`):
+ Additional PNG info. `existing_info == {pngsectionname: info, ...}`
+ no_prompt:
+ TODO I don't know its meaning.
+ p (`StableDiffusionProcessing`)
+ forced_filename (`str`):
+ If specified, `basename` and filename pattern will be ignored.
+ save_to_dirs (bool):
+ If true, the image will be saved into a subdirectory of `path`.
+
+ Returns: (fullfn, txt_fullfn)
+ fullfn (`str`):
+ The full path of the saved imaged.
+ txt_fullfn (`str` or None):
+ If a text file is saved for this image, this will be its full path. Otherwise None.
+ '''
if short_filename or prompt is None or seed is None:
file_decoration = ""
elif opts.save_to_dirs:
@@ -371,10 +409,11 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
else:
pnginfo = None
- save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
+ if save_to_dirs is None:
+ save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
if save_to_dirs:
- dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt)
+ dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ /')
path = os.path.join(path, dirname)
os.makedirs(path, exist_ok=True)
@@ -422,7 +461,29 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
piexif.insert(exif_bytes(), fullfn_without_extension + ".jpg")
if opts.save_txt and info is not None:
- with open(f"{fullfn_without_extension}.txt", "w", encoding="utf8") as file:
+ txt_fullfn = f"{fullfn_without_extension}.txt"
+ with open(txt_fullfn, "w", encoding="utf8") as file:
file.write(info + "\n")
+ else:
+ txt_fullfn = None
+
+ return fullfn, txt_fullfn
+def image_data(data):
+ try:
+ image = Image.open(io.BytesIO(data))
+ textinfo = image.text["parameters"]
+ return textinfo, None
+ except Exception:
+ pass
+
+ try:
+ text = data.decode('utf8')
+ assert len(text) < 10000
+ return text, None
+
+ except Exception:
+ pass
+
+ return '', None
diff --git a/modules/images_history.py b/modules/images_history.py
new file mode 100644
index 00000000..9260df8a
--- /dev/null
+++ b/modules/images_history.py
@@ -0,0 +1,181 @@
+import os
+import shutil
+
+
+def traverse_all_files(output_dir, image_list, curr_dir=None):
+ curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir)
+ try:
+ f_list = os.listdir(curr_path)
+ except:
+ if curr_dir[-10:].rfind(".") > 0 and curr_dir[-4:] != ".txt":
+ image_list.append(curr_dir)
+ return image_list
+ for file in f_list:
+ file = file if curr_dir is None else os.path.join(curr_dir, file)
+ file_path = os.path.join(curr_path, file)
+ if file[-4:] == ".txt":
+ pass
+ elif os.path.isfile(file_path) and file[-10:].rfind(".") > 0:
+ image_list.append(file)
+ else:
+ image_list = traverse_all_files(output_dir, image_list, file)
+ return image_list
+
+
+def get_recent_images(dir_name, page_index, step, image_index, tabname):
+ page_index = int(page_index)
+ f_list = os.listdir(dir_name)
+ image_list = []
+ image_list = traverse_all_files(dir_name, image_list)
+ image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file)))
+ num = 48 if tabname != "extras" else 12
+ max_page_index = len(image_list) // num + 1
+ page_index = max_page_index if page_index == -1 else page_index + step
+ page_index = 1 if page_index < 1 else page_index
+ page_index = max_page_index if page_index > max_page_index else page_index
+ idx_frm = (page_index - 1) * num
+ image_list = image_list[idx_frm:idx_frm + num]
+ image_index = int(image_index)
+ if image_index < 0 or image_index > len(image_list) - 1:
+ current_file = None
+ hidden = None
+ else:
+ current_file = image_list[int(image_index)]
+ hidden = os.path.join(dir_name, current_file)
+ return [os.path.join(dir_name, file) for file in image_list], page_index, image_list, current_file, hidden, ""
+
+
+def first_page_click(dir_name, page_index, image_index, tabname):
+ return get_recent_images(dir_name, 1, 0, image_index, tabname)
+
+
+def end_page_click(dir_name, page_index, image_index, tabname):
+ return get_recent_images(dir_name, -1, 0, image_index, tabname)
+
+
+def prev_page_click(dir_name, page_index, image_index, tabname):
+ return get_recent_images(dir_name, page_index, -1, image_index, tabname)
+
+
+def next_page_click(dir_name, page_index, image_index, tabname):
+ return get_recent_images(dir_name, page_index, 1, image_index, tabname)
+
+
+def page_index_change(dir_name, page_index, image_index, tabname):
+ return get_recent_images(dir_name, page_index, 0, image_index, tabname)
+
+
+def show_image_info(num, image_path, filenames):
+ # print(f"select image {num}")
+ file = filenames[int(num)]
+ return file, num, os.path.join(image_path, file)
+
+
+def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, image_index):
+ if name == "":
+ return filenames, delete_num
+ else:
+ delete_num = int(delete_num)
+ index = list(filenames).index(name)
+ i = 0
+ new_file_list = []
+ for name in filenames:
+ if i >= index and i < index + delete_num:
+ path = os.path.join(dir_name, name)
+ if os.path.exists(path):
+ print(f"Delete file {path}")
+ os.remove(path)
+ txt_file = os.path.splitext(path)[0] + ".txt"
+ if os.path.exists(txt_file):
+ os.remove(txt_file)
+ else:
+ print(f"Not exists file {path}")
+ else:
+ new_file_list.append(name)
+ i += 1
+ return new_file_list, 1
+
+
+def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
+ if opts.outdir_samples != "":
+ dir_name = opts.outdir_samples
+ elif tabname == "txt2img":
+ dir_name = opts.outdir_txt2img_samples
+ elif tabname == "img2img":
+ dir_name = opts.outdir_img2img_samples
+ elif tabname == "extras":
+ dir_name = opts.outdir_extras_samples
+ d = dir_name.split("/")
+ dir_name = "/" if dir_name.startswith("/") else d[0]
+ for p in d[1:]:
+ dir_name = os.path.join(dir_name, p)
+ with gr.Row():
+ renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page")
+ first_page = gr.Button('First Page')
+ prev_page = gr.Button('Prev Page')
+ page_index = gr.Number(value=1, label="Page Index")
+ next_page = gr.Button('Next Page')
+ end_page = gr.Button('End Page')
+ with gr.Row(elem_id=tabname + "_images_history"):
+ with gr.Row():
+ with gr.Column(scale=2):
+ history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6)
+ with gr.Row():
+ delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
+ delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button")
+ with gr.Column():
+ with gr.Row():
+ pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
+ pnginfo_send_to_img2img = gr.Button('Send to img2img')
+ with gr.Row():
+ with gr.Column():
+ img_file_info = gr.Textbox(label="Generate Info", interactive=False)
+ img_file_name = gr.Textbox(label="File Name", interactive=False)
+ with gr.Row():
+ # hiden items
+
+ img_path = gr.Textbox(dir_name.rstrip("/"), visible=False)
+ tabname_box = gr.Textbox(tabname, visible=False)
+ image_index = gr.Textbox(value=-1, visible=False)
+ set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False)
+ filenames = gr.State()
+ hidden = gr.Image(type="pil", visible=False)
+ info1 = gr.Textbox(visible=False)
+ info2 = gr.Textbox(visible=False)
+
+ # turn pages
+ gallery_inputs = [img_path, page_index, image_index, tabname_box]
+ gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hidden, img_file_name]
+
+ first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
+ next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
+ prev_page.click(prev_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
+ end_page.click(end_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
+ page_index.submit(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
+ renew_page.click(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
+ # page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index])
+
+ # other funcitons
+ set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hidden])
+ img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
+ delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num])
+ hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
+
+ # pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
+ switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img')
+ switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img')
+
+
+def create_history_tabs(gr, opts, run_pnginfo, switch_dict):
+ with gr.Blocks(analytics_enabled=False) as images_history:
+ with gr.Tabs() as tabs:
+ with gr.Tab("txt2img history"):
+ with gr.Blocks(analytics_enabled=False) as images_history_txt2img:
+ show_images_history(gr, opts, "txt2img", run_pnginfo, switch_dict)
+ with gr.Tab("img2img history"):
+ with gr.Blocks(analytics_enabled=False) as images_history_img2img:
+ show_images_history(gr, opts, "img2img", run_pnginfo, switch_dict)
+ with gr.Tab("extras history"):
+ with gr.Blocks(analytics_enabled=False) as images_history_img2img:
+ show_images_history(gr, opts, "extras", run_pnginfo, switch_dict)
+ return images_history
diff --git a/modules/img2img.py b/modules/img2img.py
index 03e934e9..24126774 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -23,13 +23,17 @@ def process_batch(p, input_dir, output_dir, args):
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
+ save_normally = output_dir == ''
+
p.do_not_save_grid = True
- p.do_not_save_samples = True
+ p.do_not_save_samples = not save_normally
state.job_count = len(images) * p.n_iter
for i, image in enumerate(images):
state.job = f"{i+1} out of {len(images)}"
+ if state.skipped:
+ state.skipped = False
if state.interrupted:
break
@@ -48,7 +52,8 @@ def process_batch(p, input_dir, output_dir, args):
left, right = os.path.splitext(filename)
filename = f"{left}-{n}{right}"
- processed_image.save(os.path.join(output_dir, filename))
+ if not save_normally:
+ processed_image.save(os.path.join(output_dir, filename))
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
@@ -103,7 +108,9 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
inpaint_full_res_padding=inpaint_full_res_padding,
inpainting_mask_invert=inpainting_mask_invert,
)
- print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
+
+ if shared.cmd_opts.enable_console_prompts:
+ print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
p.extra_generation_params["Mask blur"] = mask_blur
@@ -124,4 +131,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
if opts.samples_log_stdout:
print(generation_info_js)
+ if opts.do_not_show_images:
+ processed.images = []
+
return processed.images, generation_info_js, plaintext_to_html(processed.info)
diff --git a/modules/interrogate.py b/modules/interrogate.py
index f62a4745..9263d65a 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -21,6 +21,7 @@ Category = namedtuple("Category", ["name", "topn", "items"])
re_topn = re.compile(r"\.top(\d+)\.")
+
class InterrogateModels:
blip_model = None
clip_model = None
@@ -54,7 +55,7 @@ class InterrogateModels:
model, preprocess = clip.load(clip_model_name)
model.eval()
- model = model.to(shared.device)
+ model = model.to(devices.device_interrogate)
return model, preprocess
@@ -64,14 +65,14 @@ class InterrogateModels:
if not shared.cmd_opts.no_half:
self.blip_model = self.blip_model.half()
- self.blip_model = self.blip_model.to(shared.device)
+ self.blip_model = self.blip_model.to(devices.device_interrogate)
if self.clip_model is None:
self.clip_model, self.clip_preprocess = self.load_clip_model()
if not shared.cmd_opts.no_half:
self.clip_model = self.clip_model.half()
- self.clip_model = self.clip_model.to(shared.device)
+ self.clip_model = self.clip_model.to(devices.device_interrogate)
self.dtype = next(self.clip_model.parameters()).dtype
@@ -98,11 +99,11 @@ class InterrogateModels:
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
top_count = min(top_count, len(text_array))
- text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(shared.device)
+ text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
text_features /= text_features.norm(dim=-1, keepdim=True)
- similarity = torch.zeros((1, len(text_array))).to(shared.device)
+ similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)
for i in range(image_features.shape[0]):
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
similarity /= image_features.shape[0]
@@ -115,14 +116,14 @@ class InterrogateModels:
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
- ])(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
+ ])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
with torch.no_grad():
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
return caption[0]
- def interrogate(self, pil_image):
+ def interrogate(self, pil_image, include_ranks=False):
res = None
try:
@@ -139,11 +140,11 @@ class InterrogateModels:
res = caption
- cilp_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
+ clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
with torch.no_grad(), precision_scope("cuda"):
- image_features = self.clip_model.encode_image(cilp_image).type(self.dtype)
+ image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
image_features /= image_features.norm(dim=-1, keepdim=True)
@@ -155,7 +156,10 @@ class InterrogateModels:
for name, topn, items in self.categories:
matches = self.rank(image_features, items, top_count=topn)
for match, score in matches:
- res += ", " + match
+ if include_ranks:
+ res += ", " + match
+ else:
+ res += f", ({match}:{score})"
except Exception:
print(f"Error interrogating", file=sys.stderr)
diff --git a/modules/ldsr_model.py b/modules/ldsr_model.py
index 1c1070fc..8c4db44a 100644
--- a/modules/ldsr_model.py
+++ b/modules/ldsr_model.py
@@ -7,13 +7,11 @@ from basicsr.utils.download_util import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData
from modules.ldsr_model_arch import LDSR
from modules import shared
-from modules.paths import models_path
class UpscalerLDSR(Upscaler):
def __init__(self, user_path):
self.name = "LDSR"
- self.model_path = os.path.join(models_path, self.name)
self.user_path = user_path
self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
diff --git a/modules/modelloader.py b/modules/modelloader.py
index 8c862b42..b0f2f33d 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -5,7 +5,6 @@ import importlib
from urllib.parse import urlparse
from basicsr.utils.download_util import load_file_from_url
-
from modules import shared
from modules.upscaler import Upscaler
from modules.paths import script_path, models_path
@@ -43,7 +42,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
for place in places:
if os.path.exists(place):
for file in glob.iglob(place + '**/**', recursive=True):
- full_path = os.path.join(place, file)
+ full_path = file
if os.path.isdir(full_path):
continue
if len(ext_filter) != 0:
@@ -121,16 +120,30 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
def load_upscalers():
+ sd = shared.script_path
+ # We can only do this 'magic' method to dynamically load upscalers if they are referenced,
+ # so we'll try to import any _model.py files before looking in __subclasses__
+ modules_dir = os.path.join(sd, "modules")
+ for file in os.listdir(modules_dir):
+ if "_model.py" in file:
+ model_name = file.replace("_model.py", "")
+ full_model = f"modules.{model_name}_model"
+ try:
+ importlib.import_module(full_model)
+ except:
+ pass
datas = []
+ c_o = vars(shared.cmd_opts)
for cls in Upscaler.__subclasses__():
name = cls.__name__
module_name = cls.__module__
module = importlib.import_module(module_name)
class_ = getattr(module, name)
- cmd_name = f"{name.lower().replace('upscaler', '')}-models-path"
+ cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
opt_string = None
try:
- opt_string = shared.opts.__getattr__(cmd_name)
+ if cmd_name in c_o:
+ opt_string = c_o[cmd_name]
except:
pass
scaler = class_(opt_string)
diff --git a/modules/ngrok.py b/modules/ngrok.py
new file mode 100644
index 00000000..7d03a6df
--- /dev/null
+++ b/modules/ngrok.py
@@ -0,0 +1,15 @@
+from pyngrok import ngrok, conf, exception
+
+
+def connect(token, port):
+ if token == None:
+ token = 'None'
+ conf.get_default().auth_token = token
+ try:
+ public_url = ngrok.connect(port).public_url
+ except exception.PyngrokNgrokError:
+ print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
+ f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
+ else:
+ print(f'ngrok connected to localhost:{port}! URL: {public_url}\n'
+ 'You can use this link after the launch is complete.')
diff --git a/modules/paths.py b/modules/paths.py
index ceb80417..1e7a2fbc 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -1,6 +1,7 @@
import argparse
import os
import sys
+import modules.safe
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
models_path = os.path.join(script_path, "models")
@@ -12,6 +13,7 @@ possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'),
for possible_sd_path in possible_sd_paths:
if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
sd_path = os.path.abspath(possible_sd_path)
+ break
assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths)
@@ -20,7 +22,6 @@ path_dirs = [
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
- (os.path.join(sd_path, '../latent-diffusion'), 'LDSR.py', 'LDSR', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
]
diff --git a/modules/processing.py b/modules/processing.py
index 1da753a2..7e2a416d 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -1,4 +1,3 @@
-import contextlib
import json
import math
import os
@@ -12,9 +11,8 @@ import cv2
from skimage import exposure
import modules.sd_hijack
-from modules import devices, prompt_parser, masking
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram
from modules.sd_hijack import model_hijack
-from modules.sd_samplers import samplers, samplers_for_img2img
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
import modules.face_restoration
@@ -48,6 +46,12 @@ def apply_color_correction(correction, image):
return image
+def get_correct_sampler(p):
+ if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img):
+ return sd_samplers.samplers
+ elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
+ return sd_samplers.samplers_for_img2img
+
class StableDiffusionProcessing:
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
self.sd_model = sd_model
@@ -56,7 +60,7 @@ class StableDiffusionProcessing:
self.prompt: str = prompt
self.prompt_for_display: str = None
self.negative_prompt: str = (negative_prompt or "")
- self.styles: str = styles
+ self.styles: list = styles or []
self.seed: int = seed
self.subseed: int = subseed
self.subseed_strength: float = subseed_strength
@@ -85,7 +89,7 @@ class StableDiffusionProcessing:
self.s_tmin = opts.s_tmin
self.s_tmax = float('inf') # not representable as a standard ui option
self.s_noise = opts.s_noise
-
+
if not seed_enable_extras:
self.subseed = -1
self.subseed_strength = 0
@@ -111,7 +115,7 @@ class Processed:
self.width = p.width
self.height = p.height
self.sampler_index = p.sampler_index
- self.sampler = samplers[p.sampler_index].name
+ self.sampler = sd_samplers.samplers[p.sampler_index].name
self.cfg_scale = p.cfg_scale
self.steps = p.steps
self.batch_size = p.batch_size
@@ -123,6 +127,9 @@ class Processed:
self.denoising_strength = getattr(p, 'denoising_strength', None)
self.extra_generation_params = p.extra_generation_params
self.index_of_first_image = index_of_first_image
+ self.styles = p.styles
+ self.job_timestamp = state.job_timestamp
+ self.clip_skip = opts.CLIP_stop_at_last_layers
self.eta = p.eta
self.ddim_discretize = p.ddim_discretize
@@ -133,7 +140,7 @@ class Processed:
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
- self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
+ self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
self.all_prompts = all_prompts or [self.prompt]
@@ -167,6 +174,9 @@ class Processed:
"extra_generation_params": self.extra_generation_params,
"index_of_first_image": self.index_of_first_image,
"infotexts": self.infotexts,
+ "styles": self.styles,
+ "job_timestamp": self.job_timestamp,
+ "clip_skip": self.clip_skip,
}
return json.dumps(obj)
@@ -197,7 +207,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
# enables the generation of additional tensors with noise that the sampler will use during its processing.
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
# produce the same images as with two batches [100], [101].
- if p is not None and p.sampler is not None and len(seeds) > 1 and opts.enable_batch_seeds:
+ if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0):
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
else:
sampler_noises = None
@@ -237,6 +247,9 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
if sampler_noises is not None:
cnt = p.sampler.number_of_needed_noises(p)
+ if opts.eta_noise_seed_delta > 0:
+ torch.manual_seed(seed + opts.eta_noise_seed_delta)
+
for j in range(cnt):
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
@@ -249,29 +262,49 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
return x
+def decode_first_stage(model, x):
+ with devices.autocast(disable=x.dtype == devices.dtype_vae):
+ x = model.decode_first_stage(x)
+
+ return x
+
+
+def get_fixed_seed(seed):
+ if seed is None or seed == '' or seed == -1:
+ return int(random.randrange(4294967294))
+
+ return seed
+
+
def fix_seed(p):
- p.seed = int(random.randrange(4294967294)) if p.seed is None or p.seed == '' or p.seed == -1 else p.seed
- p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == '' or p.subseed == -1 else p.subseed
+ p.seed = get_fixed_seed(p.seed)
+ p.subseed = get_fixed_seed(p.subseed)
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
index = position_in_batch + iteration * p.batch_size
+ clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
+
generation_params = {
"Steps": p.steps,
- "Sampler": samplers[p.sampler_index].name,
+ "Sampler": get_correct_sampler(p)[p.sampler_index].name,
"CFG scale": p.cfg_scale,
"Seed": all_seeds[index],
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
+ "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
+ "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(':', '')),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None),
- "Eta": (None if p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
+ "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
+ "Clip skip": None if clip_skip <= 1 else clip_skip,
+ "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
}
generation_params.update(p.extra_generation_params)
@@ -290,15 +323,24 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
assert(len(p.prompt) > 0)
else:
assert p.prompt is not None
-
+
+ with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
+ processed = Processed(p, [], p.seed, "")
+ file.write(processed.infotext(p, 0))
+
devices.torch_gc()
- fix_seed(p)
+ seed = get_fixed_seed(p.seed)
+ subseed = get_fixed_seed(p.subseed)
- os.makedirs(p.outpath_samples, exist_ok=True)
- os.makedirs(p.outpath_grids, exist_ok=True)
+ if p.outpath_samples is not None:
+ os.makedirs(p.outpath_samples, exist_ok=True)
+
+ if p.outpath_grids is not None:
+ os.makedirs(p.outpath_grids, exist_ok=True)
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
+ modules.sd_hijack.model_hijack.clear_comments()
comments = {}
@@ -309,33 +351,36 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
else:
all_prompts = p.batch_size * p.n_iter * [p.prompt]
- if type(p.seed) == list:
- all_seeds = p.seed
+ if type(seed) == list:
+ all_seeds = seed
else:
- all_seeds = [int(p.seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
+ all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
- if type(p.subseed) == list:
- all_subseeds = p.subseed
+ if type(subseed) == list:
+ all_subseeds = subseed
else:
- all_subseeds = [int(p.subseed) + x for x in range(len(all_prompts))]
+ all_subseeds = [int(subseed) + x for x in range(len(all_prompts))]
def infotext(iteration=0, position_in_batch=0):
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
if os.path.exists(cmd_opts.embeddings_dir):
- model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)
+ model_hijack.embedding_db.load_textual_inversion_embeddings()
infotexts = []
output_images = []
- precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
- ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
- with torch.no_grad(), precision_scope("cuda"), ema_scope():
- p.init(all_prompts, all_seeds, all_subseeds)
+
+ with torch.no_grad(), p.sd_model.ema_scope():
+ with devices.autocast():
+ p.init(all_prompts, all_seeds, all_subseeds)
if state.job_count == -1:
state.job_count = p.n_iter
for n in range(p.n_iter):
+ if state.skipped:
+ state.skipped = False
+
if state.interrupted:
break
@@ -348,8 +393,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
#uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
#c = p.sd_model.get_learned_conditioning(prompts)
- uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps)
- c = prompt_parser.get_learned_conditioning(prompts, p.steps)
+ with devices.autocast():
+ uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
+ c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
if len(model_hijack.comments) > 0:
for comment in model_hijack.comments:
@@ -358,16 +404,26 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
- samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
- if state.interrupted:
+ with devices.autocast():
+ samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
+
+ if state.interrupted or state.skipped:
- # if we are interruped, sample returns just noise
+ # if we are interrupted, sample returns just noise
# use the image collected previously in sampler loop
samples_ddim = shared.state.current_latent
- x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
+ samples_ddim = samples_ddim.to(devices.dtype_vae)
+ x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+ del samples_ddim
+
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ lowvram.send_everything_to_cpu()
+
+ devices.torch_gc()
+
if opts.filter_nsfw:
import modules.safety as safety
x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
@@ -383,6 +439,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc()
x_sample = modules.face_restoration.restore_faces(x_sample)
+ devices.torch_gc()
image = Image.fromarray(x_sample)
@@ -408,9 +465,16 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if opts.samples_save and not p.do_not_save_samples:
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
- infotexts.append(infotext(n, i))
+ text = infotext(n, i)
+ infotexts.append(text)
+ if opts.enable_pnginfo:
+ image.info["parameters"] = text
output_images.append(image)
+ del x_samples_ddim
+
+ devices.torch_gc()
+
state.nextjob()
p.color_corrections = None
@@ -421,7 +485,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
grid = images.image_grid(output_images, p.batch_size)
if opts.return_grid:
- infotexts.insert(0, infotext())
+ text = infotext()
+ infotexts.insert(0, text)
+ if opts.enable_pnginfo:
+ grid.info["parameters"] = text
output_images.insert(0, grid)
index_of_first_image = 1
@@ -434,16 +501,15 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None
- firstphase_width = 0
- firstphase_height = 0
- firstphase_width_truncated = 0
- firstphase_height_truncated = 0
- def __init__(self, enable_hr=False, scale_latent=True, denoising_strength=0.75, **kwargs):
+ def __init__(self, enable_hr=False, denoising_strength=0.75, firstphase_width=0, firstphase_height=0, **kwargs):
super().__init__(**kwargs)
self.enable_hr = enable_hr
- self.scale_latent = scale_latent
self.denoising_strength = denoising_strength
+ self.firstphase_width = firstphase_width
+ self.firstphase_height = firstphase_height
+ self.truncate_x = 0
+ self.truncate_y = 0
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
@@ -452,17 +518,34 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else:
state.job_count = state.job_count * 2
- desired_pixel_count = 512 * 512
- actual_pixel_count = self.width * self.height
- scale = math.sqrt(desired_pixel_count / actual_pixel_count)
+ if self.firstphase_width == 0 or self.firstphase_height == 0:
+ desired_pixel_count = 512 * 512
+ actual_pixel_count = self.width * self.height
+ scale = math.sqrt(desired_pixel_count / actual_pixel_count)
+ self.firstphase_width = math.ceil(scale * self.width / 64) * 64
+ self.firstphase_height = math.ceil(scale * self.height / 64) * 64
+ firstphase_width_truncated = int(scale * self.width)
+ firstphase_height_truncated = int(scale * self.height)
+
+ else:
+ self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
+
+ width_ratio = self.width / self.firstphase_width
+ height_ratio = self.height / self.firstphase_height
+
+ if width_ratio > height_ratio:
+ firstphase_width_truncated = self.firstphase_width
+ firstphase_height_truncated = self.firstphase_width * self.height / self.width
+ else:
+ firstphase_width_truncated = self.firstphase_height * self.width / self.height
+ firstphase_height_truncated = self.firstphase_height
+
+ self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
+ self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
- self.firstphase_width = math.ceil(scale * self.width / 64) * 64
- self.firstphase_height = math.ceil(scale * self.height / 64) * 64
- self.firstphase_width_truncated = int(scale * self.width)
- self.firstphase_height_truncated = int(scale * self.height)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
- self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
+ self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
@@ -472,46 +555,41 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
- truncate_x = (self.firstphase_width - self.firstphase_width_truncated) // opt_f
- truncate_y = (self.firstphase_height - self.firstphase_height_truncated) // opt_f
+ samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
- samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2]
+ decoded_samples = decode_first_stage(self.sd_model, samples)
- if self.scale_latent:
- samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+ if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
+ decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
else:
- decoded_samples = self.sd_model.decode_first_stage(samples)
+ lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
- if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
- decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
- else:
- lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
-
- batch_images = []
- for i, x_sample in enumerate(lowres_samples):
- x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
- x_sample = x_sample.astype(np.uint8)
- image = Image.fromarray(x_sample)
- image = images.resize_image(0, image, self.width, self.height)
- image = np.array(image).astype(np.float32) / 255.0
- image = np.moveaxis(image, 2, 0)
- batch_images.append(image)
+ batch_images = []
+ for i, x_sample in enumerate(lowres_samples):
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
+ x_sample = x_sample.astype(np.uint8)
+ image = Image.fromarray(x_sample)
+ image = images.resize_image(0, image, self.width, self.height)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = np.moveaxis(image, 2, 0)
+ batch_images.append(image)
- decoded_samples = torch.from_numpy(np.array(batch_images))
- decoded_samples = decoded_samples.to(shared.device)
- decoded_samples = 2. * decoded_samples - 1.
+ decoded_samples = torch.from_numpy(np.array(batch_images))
+ decoded_samples = decoded_samples.to(shared.device)
+ decoded_samples = 2. * decoded_samples - 1.
- samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
+ samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
shared.state.nextjob()
- self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
+ self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
+
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
# GC now before running the next img2img to prevent running out of memory
x = None
devices.torch_gc()
-
+
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
return samples
@@ -540,7 +618,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.nmask = None
def init(self, all_prompts, all_seeds, all_subseeds):
- self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model)
+ self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
crop_region = None
if self.image_mask is not None:
@@ -647,4 +725,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.mask is not None:
samples = samples * self.nmask + self.init_latent * self.mask
+ del x
+ devices.torch_gc()
+
return samples
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index e811eb9e..919d5d31 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -1,19 +1,7 @@
import re
from collections import namedtuple
-import torch
-
-import modules.shared as shared
-
-re_prompt = re.compile(r'''
-(.*?)
-\[
- ([^]:]+):
- (?:([^]:]*):)?
- ([0-9]*\.?[0-9]+)
-]
-|
-(.+)
-''', re.X)
+from typing import List
+import lark
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
# will be represented with prompt_schedule like this (assuming steps=100):
@@ -23,71 +11,117 @@ re_prompt = re.compile(r'''
# [75, 'fantasy landscape with a lake and an oak in background masterful']
# [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
+schedule_parser = lark.Lark(r"""
+!start: (prompt | /[][():]/+)*
+prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
+!emphasized: "(" prompt ")"
+ | "(" prompt ":" prompt ")"
+ | "[" prompt "]"
+scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
+alternate: "[" prompt ("|" prompt)+ "]"
+WHITESPACE: /\s+/
+plain: /([^\\\[\]():|]|\\.)+/
+%import common.SIGNED_NUMBER -> NUMBER
+""")
def get_learned_conditioning_prompt_schedules(prompts, steps):
- res = []
- cache = {}
-
- for prompt in prompts:
- prompt_schedule: list[list[str | int]] = [[steps, ""]]
-
- cached = cache.get(prompt, None)
- if cached is not None:
- res.append(cached)
- continue
-
- for m in re_prompt.finditer(prompt):
- plaintext = m.group(1) if m.group(5) is None else m.group(5)
- concept_from = m.group(2)
- concept_to = m.group(3)
- if concept_to is None:
- concept_to = concept_from
- concept_from = ""
- swap_position = float(m.group(4)) if m.group(4) is not None else None
-
- if swap_position is not None:
- if swap_position < 1:
- swap_position = swap_position * steps
- swap_position = int(min(swap_position, steps))
-
- swap_index = None
- found_exact_index = False
- for i in range(len(prompt_schedule)):
- end_step = prompt_schedule[i][0]
- prompt_schedule[i][1] += plaintext
-
- if swap_position is not None and swap_index is None:
- if swap_position == end_step:
- swap_index = i
- found_exact_index = True
-
- if swap_position < end_step:
- swap_index = i
-
- if swap_index is not None:
- if not found_exact_index:
- prompt_schedule.insert(swap_index, [swap_position, prompt_schedule[swap_index][1]])
-
- for i in range(len(prompt_schedule)):
- end_step = prompt_schedule[i][0]
- must_replace = swap_position < end_step
-
- prompt_schedule[i][1] += concept_to if must_replace else concept_from
-
- res.append(prompt_schedule)
- cache[prompt] = prompt_schedule
- #for t in prompt_schedule:
- # print(t)
+ """
+ >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
+ >>> g("test")
+ [[10, 'test']]
+ >>> g("a [b:3]")
+ [[3, 'a '], [10, 'a b']]
+ >>> g("a [b: 3]")
+ [[3, 'a '], [10, 'a b']]
+ >>> g("a [[[b]]:2]")
+ [[2, 'a '], [10, 'a [[b]]']]
+ >>> g("[(a:2):3]")
+ [[3, ''], [10, '(a:2)']]
+ >>> g("a [b : c : 1] d")
+ [[1, 'a b d'], [10, 'a c d']]
+ >>> g("a[b:[c:d:2]:1]e")
+ [[1, 'abe'], [2, 'ace'], [10, 'ade']]
+ >>> g("a [unbalanced")
+ [[10, 'a [unbalanced']]
+ >>> g("a [b:.5] c")
+ [[5, 'a c'], [10, 'a b c']]
+ >>> g("a [{b|d{:.5] c") # not handling this right now
+ [[5, 'a c'], [10, 'a {b|d{ c']]
+ >>> g("((a][:b:c [d:3]")
+ [[3, '((a][:b:c '], [10, '((a][:b:c d']]
+ """
- return res
+ def collect_steps(steps, tree):
+ l = [steps]
+ class CollectSteps(lark.Visitor):
+ def scheduled(self, tree):
+ tree.children[-1] = float(tree.children[-1])
+ if tree.children[-1] < 1:
+ tree.children[-1] *= steps
+ tree.children[-1] = min(steps, int(tree.children[-1]))
+ l.append(tree.children[-1])
+ def alternate(self, tree):
+ l.extend(range(1, steps+1))
+ CollectSteps().visit(tree)
+ return sorted(set(l))
+
+ def at_step(step, tree):
+ class AtStep(lark.Transformer):
+ def scheduled(self, args):
+ before, after, _, when = args
+ yield before or () if step <= when else after
+ def alternate(self, args):
+ yield next(args[(step - 1)%len(args)])
+ def start(self, args):
+ def flatten(x):
+ if type(x) == str:
+ yield x
+ else:
+ for gen in x:
+ yield from flatten(gen)
+ return ''.join(flatten(args))
+ def plain(self, args):
+ yield args[0].value
+ def __default__(self, data, children, meta):
+ for child in children:
+ yield from child
+ return AtStep().transform(tree)
+
+ def get_schedule(prompt):
+ try:
+ tree = schedule_parser.parse(prompt)
+ except lark.exceptions.LarkError as e:
+ if 0:
+ import traceback
+ traceback.print_exc()
+ return [[steps, prompt]]
+ return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
+
+ promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}
+ return [promptdict[prompt] for prompt in prompts]
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
-ScheduledPromptBatch = namedtuple("ScheduledPromptBatch", ["shape", "schedules"])
-def get_learned_conditioning(prompts, steps):
+def get_learned_conditioning(model, prompts, steps):
+ """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
+ and the sampling step at which this condition is to be replaced by the next one.
+
+ Input:
+ (model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)
+ Output:
+ [
+ [
+ ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0'))
+ ],
+ [
+ ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')),
+ ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0'))
+ ]
+ ]
+ """
res = []
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
@@ -101,7 +135,7 @@ def get_learned_conditioning(prompts, steps):
continue
texts = [x[1] for x in prompt_schedule]
- conds = shared.sd_model.get_learned_conditioning(texts)
+ conds = model.get_learned_conditioning(texts)
cond_schedule = []
for i, (end_at_step, text) in enumerate(prompt_schedule):
@@ -110,22 +144,118 @@ def get_learned_conditioning(prompts, steps):
cache[prompt] = cond_schedule
res.append(cond_schedule)
- return ScheduledPromptBatch((len(prompts),) + res[0][0].cond.shape, res)
+ return res
+
+
+re_AND = re.compile(r"\bAND\b")
+re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
+
+def get_multicond_prompt_list(prompts):
+ res_indexes = []
+
+ prompt_flat_list = []
+ prompt_indexes = {}
+
+ for prompt in prompts:
+ subprompts = re_AND.split(prompt)
+
+ indexes = []
+ for subprompt in subprompts:
+ match = re_weight.search(subprompt)
+
+ text, weight = match.groups() if match is not None else (subprompt, 1.0)
+
+ weight = float(weight) if weight is not None else 1.0
+
+ index = prompt_indexes.get(text, None)
+ if index is None:
+ index = len(prompt_flat_list)
+ prompt_flat_list.append(text)
+ prompt_indexes[text] = index
+
+ indexes.append((index, weight))
+ res_indexes.append(indexes)
-def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step):
- res = torch.zeros(c.shape, device=shared.device, dtype=next(shared.sd_model.parameters()).dtype)
- for i, cond_schedule in enumerate(c.schedules):
+ return res_indexes, prompt_flat_list, prompt_indexes
+
+
+class ComposableScheduledPromptConditioning:
+ def __init__(self, schedules, weight=1.0):
+ self.schedules: List[ScheduledPromptConditioning] = schedules
+ self.weight: float = weight
+
+
+class MulticondLearnedConditioning:
+ def __init__(self, shape, batch):
+ self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
+ self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
+
+def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
+ """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
+ For each prompt, the list is obtained by splitting the prompt using the AND separator.
+
+ https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
+ """
+
+ res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
+
+ learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
+
+ res = []
+ for indexes in res_indexes:
+ res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])
+
+ return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
+
+
+def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
+ param = c[0][0].cond
+ res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
+ for i, cond_schedule in enumerate(c):
target_index = 0
- for curret_index, (end_at, cond) in enumerate(cond_schedule):
+ for current, (end_at, cond) in enumerate(cond_schedule):
if current_step <= end_at:
- target_index = curret_index
+ target_index = current
break
res[i] = cond_schedule[target_index].cond
return res
+def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
+ param = c.batch[0][0].schedules[0].cond
+
+ tensors = []
+ conds_list = []
+
+ for batch_no, composable_prompts in enumerate(c.batch):
+ conds_for_batch = []
+
+ for cond_index, composable_prompt in enumerate(composable_prompts):
+ target_index = 0
+ for current, (end_at, cond) in enumerate(composable_prompt.schedules):
+ if current_step <= end_at:
+ target_index = current
+ break
+
+ conds_for_batch.append((len(tensors), composable_prompt.weight))
+ tensors.append(composable_prompt.schedules[target_index].cond)
+
+ conds_list.append(conds_for_batch)
+
+ # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
+ # and won't be able to torch.stack them. So this fixes that.
+ token_count = max([x.shape[0] for x in tensors])
+ for i in range(len(tensors)):
+ if tensors[i].shape[0] != token_count:
+ last_vector = tensors[i][-1:]
+ last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
+ tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
+
+ return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
+
+
re_attention = re.compile(r"""
\\\(|
\\\)|
@@ -157,23 +287,26 @@ def parse_prompt_attention(text):
\\ - literal character '\'
anything else - just text
- Example:
-
- 'a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).'
-
- produces:
-
- [
- ['a ', 1.0],
- ['house', 1.5730000000000004],
- [' ', 1.1],
- ['on', 1.0],
- [' a ', 1.1],
- ['hill', 0.55],
- [', sun, ', 1.1],
- ['sky', 1.4641000000000006],
- ['.', 1.1]
- ]
+ >>> parse_prompt_attention('normal text')
+ [['normal text', 1.0]]
+ >>> parse_prompt_attention('an (important) word')
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
+ >>> parse_prompt_attention('(unbalanced')
+ [['unbalanced', 1.1]]
+ >>> parse_prompt_attention('\(literal\]')
+ [['(literal]', 1.0]]
+ >>> parse_prompt_attention('(unnecessary)(parens)')
+ [['unnecessaryparens', 1.1]]
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
+ [['a ', 1.0],
+ ['house', 1.5730000000000004],
+ [' ', 1.1],
+ ['on', 1.0],
+ [' a ', 1.1],
+ ['hill', 0.55],
+ [', sun, ', 1.1],
+ ['sky', 1.4641000000000006],
+ ['.', 1.1]]
"""
res = []
@@ -215,4 +348,19 @@ def parse_prompt_attention(text):
if len(res) == 0:
res = [["", 1.0]]
+ # merge runs of identical weights
+ i = 0
+ while i + 1 < len(res):
+ if res[i][1] == res[i + 1][1]:
+ res[i][0] += res[i + 1][0]
+ res.pop(i + 1)
+ else:
+ i += 1
+
return res
+
+if __name__ == "__main__":
+ import doctest
+ doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE)
+else:
+ import torch # doctest faster
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index dc0123e0..3ac0b97a 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -8,14 +8,12 @@ from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from modules.upscaler import Upscaler, UpscalerData
-from modules.paths import models_path
from modules.shared import cmd_opts, opts
class UpscalerRealESRGAN(Upscaler):
def __init__(self, path):
self.name = "RealESRGAN"
- self.model_path = os.path.join(models_path, self.name)
self.user_path = path
super().__init__()
try:
diff --git a/modules/safe.py b/modules/safe.py
new file mode 100644
index 00000000..399165a1
--- /dev/null
+++ b/modules/safe.py
@@ -0,0 +1,117 @@
+# this code is adapted from the script contributed by anon from /h/
+
+import io
+import pickle
+import collections
+import sys
+import traceback
+
+import torch
+import numpy
+import _codecs
+import zipfile
+import re
+
+
+# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
+TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
+
+
+def encode(*args):
+ out = _codecs.encode(*args)
+ return out
+
+
+class RestrictedUnpickler(pickle.Unpickler):
+ def persistent_load(self, saved_id):
+ assert saved_id[0] == 'storage'
+ return TypedStorage()
+
+ def find_class(self, module, name):
+ if module == 'collections' and name == 'OrderedDict':
+ return getattr(collections, name)
+ if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
+ return getattr(torch._utils, name)
+ if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage']:
+ return getattr(torch, name)
+ if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
+ return getattr(torch.nn.modules.container, name)
+ if module == 'numpy.core.multiarray' and name == 'scalar':
+ return numpy.core.multiarray.scalar
+ if module == 'numpy' and name == 'dtype':
+ return numpy.dtype
+ if module == '_codecs' and name == 'encode':
+ return encode
+ if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
+ import pytorch_lightning.callbacks
+ return pytorch_lightning.callbacks.model_checkpoint
+ if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
+ import pytorch_lightning.callbacks.model_checkpoint
+ return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
+ if module == "__builtin__" and name == 'set':
+ return set
+
+ # Forbid everything else.
+ raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
+
+
+allowed_zip_names = ["archive/data.pkl", "archive/version"]
+allowed_zip_names_re = re.compile(r"^archive/data/\d+$")
+
+
+def check_zip_filenames(filename, names):
+ for name in names:
+ if name in allowed_zip_names:
+ continue
+ if allowed_zip_names_re.match(name):
+ continue
+
+ raise Exception(f"bad file inside {filename}: {name}")
+
+
+def check_pt(filename):
+ try:
+
+ # new pytorch format is a zip file
+ with zipfile.ZipFile(filename) as z:
+ check_zip_filenames(filename, z.namelist())
+
+ with z.open('archive/data.pkl') as file:
+ unpickler = RestrictedUnpickler(file)
+ unpickler.load()
+
+ except zipfile.BadZipfile:
+
+ # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
+ with open(filename, "rb") as file:
+ unpickler = RestrictedUnpickler(file)
+ for i in range(5):
+ unpickler.load()
+
+
+def load(filename, *args, **kwargs):
+ from modules import shared
+
+ try:
+ if not shared.cmd_opts.disable_safe_unpickle:
+ check_pt(filename)
+
+ except pickle.UnpicklingError:
+ print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ print(f"-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
+ print(f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
+ return None
+
+ except Exception:
+ print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
+ print(f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
+ return None
+
+ return unsafe_torch_load(filename, *args, **kwargs)
+
+
+unsafe_torch_load = torch.load
+torch.load = load
diff --git a/modules/scripts.py b/modules/scripts.py
index 7c3bd5e7..45230f9a 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -162,6 +162,40 @@ class ScriptRunner:
return processed
+ def reload_sources(self):
+ for si, script in list(enumerate(self.scripts)):
+ with open(script.filename, "r", encoding="utf8") as file:
+ args_from = script.args_from
+ args_to = script.args_to
+ filename = script.filename
+ text = file.read()
+
+ from types import ModuleType
+
+ compiled = compile(text, filename, 'exec')
+ module = ModuleType(script.filename)
+ exec(compiled, module.__dict__)
+
+ for key, script_class in module.__dict__.items():
+ if type(script_class) == type and issubclass(script_class, Script):
+ self.scripts[si] = script_class()
+ self.scripts[si].filename = filename
+ self.scripts[si].args_from = args_from
+ self.scripts[si].args_to = args_to
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
+
+def reload_script_body_only():
+ scripts_txt2img.reload_sources()
+ scripts_img2img.reload_sources()
+
+
+def reload_scripts(basedir):
+ global scripts_txt2img, scripts_img2img
+
+ scripts_data.clear()
+ load_scripts(basedir)
+
+ scripts_txt2img = ScriptRunner()
+ scripts_img2img = ScriptRunner()
diff --git a/modules/scunet_model.py b/modules/scunet_model.py
new file mode 100644
index 00000000..36a996bf
--- /dev/null
+++ b/modules/scunet_model.py
@@ -0,0 +1,88 @@
+import os.path
+import sys
+import traceback
+
+import PIL.Image
+import numpy as np
+import torch
+from basicsr.utils.download_util import load_file_from_url
+
+import modules.upscaler
+from modules import devices, modelloader
+from modules.scunet_model_arch import SCUNet as net
+
+
+class UpscalerScuNET(modules.upscaler.Upscaler):
+ def __init__(self, dirname):
+ self.name = "ScuNET"
+ self.model_name = "ScuNET GAN"
+ self.model_name2 = "ScuNET PSNR"
+ self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth"
+ self.model_url2 = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth"
+ self.user_path = dirname
+ super().__init__()
+ model_paths = self.find_models(ext_filter=[".pth"])
+ scalers = []
+ add_model2 = True
+ for file in model_paths:
+ if "http" in file:
+ name = self.model_name
+ else:
+ name = modelloader.friendly_name(file)
+ if name == self.model_name2 or file == self.model_url2:
+ add_model2 = False
+ try:
+ scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
+ scalers.append(scaler_data)
+ except Exception:
+ print(f"Error loading ScuNET model: {file}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ if add_model2:
+ scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
+ scalers.append(scaler_data2)
+ self.scalers = scalers
+
+ def do_upscale(self, img: PIL.Image, selected_file):
+ torch.cuda.empty_cache()
+
+ model = self.load_model(selected_file)
+ if model is None:
+ return img
+
+ device = devices.device_scunet
+ img = np.array(img)
+ img = img[:, :, ::-1]
+ img = np.moveaxis(img, 2, 0) / 255
+ img = torch.from_numpy(img).float()
+ img = img.unsqueeze(0).to(device)
+
+ img = img.to(device)
+ with torch.no_grad():
+ output = model(img)
+ output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
+ output = 255. * np.moveaxis(output, 0, 2)
+ output = output.astype(np.uint8)
+ output = output[:, :, ::-1]
+ torch.cuda.empty_cache()
+ return PIL.Image.fromarray(output, 'RGB')
+
+ def load_model(self, path: str):
+ device = devices.device_scunet
+ if "http" in path:
+ filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
+ progress=True)
+ else:
+ filename = path
+ if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
+ print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
+ return None
+
+ model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
+ model.load_state_dict(torch.load(filename), strict=True)
+ model.eval()
+ for k, v in model.named_parameters():
+ v.requires_grad = False
+ model = model.to(device)
+
+ return model
+
diff --git a/modules/scunet_model_arch.py b/modules/scunet_model_arch.py
new file mode 100644
index 00000000..43ca8d36
--- /dev/null
+++ b/modules/scunet_model_arch.py
@@ -0,0 +1,265 @@
+# -*- coding: utf-8 -*-
+import numpy as np
+import torch
+import torch.nn as nn
+from einops import rearrange
+from einops.layers.torch import Rearrange
+from timm.models.layers import trunc_normal_, DropPath
+
+
+class WMSA(nn.Module):
+ """ Self-attention module in Swin Transformer
+ """
+
+ def __init__(self, input_dim, output_dim, head_dim, window_size, type):
+ super(WMSA, self).__init__()
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.head_dim = head_dim
+ self.scale = self.head_dim ** -0.5
+ self.n_heads = input_dim // head_dim
+ self.window_size = window_size
+ self.type = type
+ self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
+
+ self.relative_position_params = nn.Parameter(
+ torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads))
+
+ self.linear = nn.Linear(self.input_dim, self.output_dim)
+
+ trunc_normal_(self.relative_position_params, std=.02)
+ self.relative_position_params = torch.nn.Parameter(
+ self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1,
+ 2).transpose(
+ 0, 1))
+
+ def generate_mask(self, h, w, p, shift):
+ """ generating the mask of SW-MSA
+ Args:
+ shift: shift parameters in CyclicShift.
+ Returns:
+ attn_mask: should be (1 1 w p p),
+ """
+ # supporting square.
+ attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device)
+ if self.type == 'W':
+ return attn_mask
+
+ s = p - shift
+ attn_mask[-1, :, :s, :, s:, :] = True
+ attn_mask[-1, :, s:, :, :s, :] = True
+ attn_mask[:, -1, :, :s, :, s:] = True
+ attn_mask[:, -1, :, s:, :, :s] = True
+ attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)')
+ return attn_mask
+
+ def forward(self, x):
+ """ Forward pass of Window Multi-head Self-attention module.
+ Args:
+ x: input tensor with shape of [b h w c];
+ attn_mask: attention mask, fill -inf where the value is True;
+ Returns:
+ output: tensor shape [b h w c]
+ """
+ if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
+ x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
+ h_windows = x.size(1)
+ w_windows = x.size(2)
+ # square validation
+ # assert h_windows == w_windows
+
+ x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size)
+ qkv = self.embedding_layer(x)
+ q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0)
+ sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale
+ # Adding learnable relative embedding
+ sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q')
+ # Using Attn Mask to distinguish different subwindows.
+ if self.type != 'W':
+ attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2)
+ sim = sim.masked_fill_(attn_mask, float("-inf"))
+
+ probs = nn.functional.softmax(sim, dim=-1)
+ output = torch.einsum('hbwij,hbwjc->hbwic', probs, v)
+ output = rearrange(output, 'h b w p c -> b w p (h c)')
+ output = self.linear(output)
+ output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
+
+ if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2),
+ dims=(1, 2))
+ return output
+
+ def relative_embedding(self):
+ cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)]))
+ relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
+ # negative is allowed
+ return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()]
+
+
+class Block(nn.Module):
+ def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
+ """ SwinTransformer Block
+ """
+ super(Block, self).__init__()
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ assert type in ['W', 'SW']
+ self.type = type
+ if input_resolution <= window_size:
+ self.type = 'W'
+
+ self.ln1 = nn.LayerNorm(input_dim)
+ self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.ln2 = nn.LayerNorm(input_dim)
+ self.mlp = nn.Sequential(
+ nn.Linear(input_dim, 4 * input_dim),
+ nn.GELU(),
+ nn.Linear(4 * input_dim, output_dim),
+ )
+
+ def forward(self, x):
+ x = x + self.drop_path(self.msa(self.ln1(x)))
+ x = x + self.drop_path(self.mlp(self.ln2(x)))
+ return x
+
+
+class ConvTransBlock(nn.Module):
+ def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
+ """ SwinTransformer and Conv Block
+ """
+ super(ConvTransBlock, self).__init__()
+ self.conv_dim = conv_dim
+ self.trans_dim = trans_dim
+ self.head_dim = head_dim
+ self.window_size = window_size
+ self.drop_path = drop_path
+ self.type = type
+ self.input_resolution = input_resolution
+
+ assert self.type in ['W', 'SW']
+ if self.input_resolution <= self.window_size:
+ self.type = 'W'
+
+ self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path,
+ self.type, self.input_resolution)
+ self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
+ self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
+
+ self.conv_block = nn.Sequential(
+ nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
+ nn.ReLU(True),
+ nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False)
+ )
+
+ def forward(self, x):
+ conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
+ conv_x = self.conv_block(conv_x) + conv_x
+ trans_x = Rearrange('b c h w -> b h w c')(trans_x)
+ trans_x = self.trans_block(trans_x)
+ trans_x = Rearrange('b h w c -> b c h w')(trans_x)
+ res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
+ x = x + res
+
+ return x
+
+
+class SCUNet(nn.Module):
+ # def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256):
+ def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256):
+ super(SCUNet, self).__init__()
+ if config is None:
+ config = [2, 2, 2, 2, 2, 2, 2]
+ self.config = config
+ self.dim = dim
+ self.head_dim = 32
+ self.window_size = 8
+
+ # drop path rate for each layer
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
+
+ self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
+
+ begin = 0
+ self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
+ 'W' if not i % 2 else 'SW', input_resolution)
+ for i in range(config[0])] + \
+ [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
+
+ begin += config[0]
+ self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
+ 'W' if not i % 2 else 'SW', input_resolution // 2)
+ for i in range(config[1])] + \
+ [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
+
+ begin += config[1]
+ self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
+ 'W' if not i % 2 else 'SW', input_resolution // 4)
+ for i in range(config[2])] + \
+ [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
+
+ begin += config[2]
+ self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin],
+ 'W' if not i % 2 else 'SW', input_resolution // 8)
+ for i in range(config[3])]
+
+ begin += config[3]
+ self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \
+ [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
+ 'W' if not i % 2 else 'SW', input_resolution // 4)
+ for i in range(config[4])]
+
+ begin += config[4]
+ self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \
+ [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
+ 'W' if not i % 2 else 'SW', input_resolution // 2)
+ for i in range(config[5])]
+
+ begin += config[5]
+ self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \
+ [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
+ 'W' if not i % 2 else 'SW', input_resolution)
+ for i in range(config[6])]
+
+ self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
+
+ self.m_head = nn.Sequential(*self.m_head)
+ self.m_down1 = nn.Sequential(*self.m_down1)
+ self.m_down2 = nn.Sequential(*self.m_down2)
+ self.m_down3 = nn.Sequential(*self.m_down3)
+ self.m_body = nn.Sequential(*self.m_body)
+ self.m_up3 = nn.Sequential(*self.m_up3)
+ self.m_up2 = nn.Sequential(*self.m_up2)
+ self.m_up1 = nn.Sequential(*self.m_up1)
+ self.m_tail = nn.Sequential(*self.m_tail)
+ # self.apply(self._init_weights)
+
+ def forward(self, x0):
+
+ h, w = x0.size()[-2:]
+ paddingBottom = int(np.ceil(h / 64) * 64 - h)
+ paddingRight = int(np.ceil(w / 64) * 64 - w)
+ x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
+
+ x1 = self.m_head(x0)
+ x2 = self.m_down1(x1)
+ x3 = self.m_down2(x2)
+ x4 = self.m_down3(x3)
+ x = self.m_body(x4)
+ x = self.m_up3(x + x4)
+ x = self.m_up2(x + x3)
+ x = self.m_up1(x + x2)
+ x = self.m_tail(x + x1)
+
+ x = x[..., :h, :w]
+
+ return x
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0) \ No newline at end of file
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index fa7eaeb8..c81722a0 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -5,245 +5,66 @@ import traceback
import torch
import numpy as np
from torch import einsum
+from torch.nn.functional import silu
-from modules import prompt_parser
+import modules.textual_inversion.textual_inversion
+from modules import prompt_parser, devices, sd_hijack_optimizations, shared
from modules.shared import opts, device, cmd_opts
+from modules.sd_hijack_optimizations import invokeAI_mps_available
-from ldm.util import default
-from einops import rearrange
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
+attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
+diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
+diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
+
+def apply_optimizations():
+ undo_optimizations()
+
+ ldm.modules.diffusionmodules.model.nonlinearity = silu
+
+ if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)):
+ print("Applying xformers cross attention optimization.")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
+ elif cmd_opts.opt_split_attention_v1:
+ print("Applying v1 cross attention optimization.")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
+ elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
+ if not invokeAI_mps_available and shared.device.type == 'mps':
+ print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
+ print("Applying v1 cross attention optimization.")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
+ else:
+ print("Applying cross attention optimization (InvokeAI).")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
+ elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
+ print("Applying cross attention optimization (Doggettx).")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
-# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
-def split_cross_attention_forward_v1(self, x, context=None, mask=None):
- h = self.heads
-
- q = self.to_q(x)
- context = default(context, x)
- k = self.to_k(context)
- v = self.to_v(context)
- del context, x
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
-
- r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
- for i in range(0, q.shape[0], 2):
- end = i + 2
- s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
- s1 *= self.scale
-
- s2 = s1.softmax(dim=-1)
- del s1
-
- r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
- del s2
-
- r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
- del r1
-
- return self.to_out(r2)
-
-
-# taken from https://github.com/Doggettx/stable-diffusion
-def split_cross_attention_forward(self, x, context=None, mask=None):
- h = self.heads
-
- q_in = self.to_q(x)
- context = default(context, x)
- k_in = self.to_k(context) * self.scale
- v_in = self.to_v(context)
- del context, x
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
- del q_in, k_in, v_in
-
- r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
-
- stats = torch.cuda.memory_stats(q.device)
- mem_active = stats['active_bytes.all.current']
- mem_reserved = stats['reserved_bytes.all.current']
- mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
- mem_free_torch = mem_reserved - mem_active
- mem_free_total = mem_free_cuda + mem_free_torch
-
- gb = 1024 ** 3
- tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
- modifier = 3 if q.element_size() == 2 else 2.5
- mem_required = tensor_size * modifier
- steps = 1
-
- if mem_required > mem_free_total:
- steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
- # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
- # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
-
- if steps > 64:
- max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
- raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
- f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
-
- slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
- for i in range(0, q.shape[1], slice_size):
- end = i + slice_size
- s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
-
- s2 = s1.softmax(dim=-1, dtype=q.dtype)
- del s1
-
- r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
- del s2
-
- del q, k, v
-
- r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
- del r1
-
- return self.to_out(r2)
-
-def nonlinearity_hijack(x):
- # swish
- t = torch.sigmoid(x)
- x *= t
- del t
-
- return x
-
-def cross_attention_attnblock_forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q1 = self.q(h_)
- k1 = self.k(h_)
- v = self.v(h_)
-
- # compute attention
- b, c, h, w = q1.shape
-
- q2 = q1.reshape(b, c, h*w)
- del q1
-
- q = q2.permute(0, 2, 1) # b,hw,c
- del q2
-
- k = k1.reshape(b, c, h*w) # b,c,hw
- del k1
-
- h_ = torch.zeros_like(k, device=q.device)
-
- stats = torch.cuda.memory_stats(q.device)
- mem_active = stats['active_bytes.all.current']
- mem_reserved = stats['reserved_bytes.all.current']
- mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
- mem_free_torch = mem_reserved - mem_active
- mem_free_total = mem_free_cuda + mem_free_torch
-
- tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
- mem_required = tensor_size * 2.5
- steps = 1
-
- if mem_required > mem_free_total:
- steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
-
- slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
- for i in range(0, q.shape[1], slice_size):
- end = i + slice_size
-
- w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
- w2 = w1 * (int(c)**(-0.5))
- del w1
- w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
- del w2
-
- # attend to values
- v1 = v.reshape(b, c, h*w)
- w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
- del w3
- h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
- del v1, w4
+def undo_optimizations():
+ from modules.hypernetworks import hypernetwork
- h2 = h_.reshape(b, c, h, w)
- del h_
+ ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
+ ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
- h3 = self.proj_out(h2)
- del h2
- h3 += x
+def get_target_prompt_token_count(token_count):
+ return math.ceil(max(token_count, 1) / 75) * 75
- return h3
class StableDiffusionModelHijack:
- ids_lookup = {}
- word_embeddings = {}
- word_embeddings_checksums = {}
fixes = None
comments = []
- dir_mtime = None
layers = None
circular_enabled = False
clip = None
- def load_textual_inversion_embeddings(self, dirname, model):
- mt = os.path.getmtime(dirname)
- if self.dir_mtime is not None and mt <= self.dir_mtime:
- return
-
- self.dir_mtime = mt
- self.ids_lookup.clear()
- self.word_embeddings.clear()
-
- tokenizer = model.cond_stage_model.tokenizer
-
- def const_hash(a):
- r = 0
- for v in a:
- r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
- return r
-
- def process_file(path, filename):
- name = os.path.splitext(filename)[0]
-
- data = torch.load(path, map_location="cpu")
-
- # textual inversion embeddings
- if 'string_to_param' in data:
- param_dict = data['string_to_param']
- if hasattr(param_dict, '_parameters'):
- param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
- assert len(param_dict) == 1, 'embedding file has multiple terms in it'
- emb = next(iter(param_dict.items()))[1]
- # diffuser concepts
- elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
- assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
-
- emb = next(iter(data.values()))
- if len(emb.shape) == 1:
- emb = emb.unsqueeze(0)
-
- self.word_embeddings[name] = emb.detach().to(device)
- self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1)*100)&0xffff:04x}'
-
- ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
-
- first_id = ids[0]
- if first_id not in self.ids_lookup:
- self.ids_lookup[first_id] = []
- self.ids_lookup[first_id].append((ids, name))
-
- for fn in os.listdir(dirname):
- try:
- fullfn = os.path.join(dirname, fn)
-
- if os.stat(fullfn).st_size == 0:
- continue
-
- process_file(fullfn, fn)
- except Exception:
- print(f"Error loading emedding {fn}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- continue
-
- print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
+ embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
def hijack(self, m):
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
@@ -253,12 +74,7 @@ class StableDiffusionModelHijack:
self.clip = m.cond_stage_model
- if cmd_opts.opt_split_attention_v1:
- ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
- elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
- ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
- ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack
- ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
+ apply_optimizations()
def flatten(el):
flattened = [flatten(children) for children in el.children()]
@@ -286,21 +102,24 @@ class StableDiffusionModelHijack:
for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
layer.padding_mode = 'circular' if enable else 'zeros'
+ def clear_comments(self):
+ self.comments = []
+
def tokenize(self, text):
- max_length = self.clip.max_length - 2
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
- return remade_batch_tokens[0], token_count, max_length
+ return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count)
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack):
super().__init__()
self.wrapped = wrapped
- self.hijack = hijack
+ self.hijack: StableDiffusionModelHijack = hijack
self.tokenizer = wrapped.tokenizer
- self.max_length = wrapped.max_length
self.token_mults = {}
+ self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
+
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
@@ -317,11 +136,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if mult != 1.0:
self.token_mults[ident] = mult
-
def tokenize_line(self, line, used_custom_terms, hijack_comments):
- id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
- maxlen = self.wrapped.max_length
if opts.enable_emphasis:
parsed = prompt_parser.parse_prompt_attention(line)
@@ -333,48 +149,53 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
fixes = []
remade_tokens = []
multipliers = []
+ last_comma = -1
for tokens, (text, weight) in zip(tokenized, parsed):
i = 0
while i < len(tokens):
token = tokens[i]
- possible_matches = self.hijack.ids_lookup.get(token, None)
-
- if possible_matches is None:
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+
+ if token == self.comma_token:
+ last_comma = len(remade_tokens)
+ elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
+ last_comma += 1
+ reloc_tokens = remade_tokens[last_comma:]
+ reloc_mults = multipliers[last_comma:]
+
+ remade_tokens = remade_tokens[:last_comma]
+ length = len(remade_tokens)
+
+ rem = int(math.ceil(length / 75)) * 75 - length
+ remade_tokens += [id_end] * rem + reloc_tokens
+ multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
+
+ if embedding is None:
remade_tokens.append(token)
multipliers.append(weight)
+ i += 1
else:
- found = False
- for ids, word in possible_matches:
- if tokens[i:i + len(ids)] == ids:
- emb_len = int(self.hijack.word_embeddings[word].shape[0])
- fixes.append((len(remade_tokens), word))
- remade_tokens += [0] * emb_len
- multipliers += [weight] * emb_len
- i += len(ids) - 1
- found = True
- used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
- break
-
- if not found:
- remade_tokens.append(token)
- multipliers.append(weight)
- i += 1
-
- if len(remade_tokens) > maxlen - 2:
- vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
- ovf = remade_tokens[maxlen - 2:]
- overflowing_words = [vocab.get(int(x), "") for x in ovf]
- overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
- hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
+ emb_len = int(embedding.vec.shape[0])
+ iteration = len(remade_tokens) // 75
+ if (len(remade_tokens) + emb_len) // 75 != iteration:
+ rem = (75 * (iteration + 1) - len(remade_tokens))
+ remade_tokens += [id_end] * rem
+ multipliers += [1.0] * rem
+ iteration += 1
+ fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
+ remade_tokens += [0] * emb_len
+ multipliers += [weight] * emb_len
+ used_custom_terms.append((embedding.name, embedding.checksum()))
+ i += embedding_length_in_tokens
token_count = len(remade_tokens)
- remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
- remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
+ prompt_target_length = get_target_prompt_token_count(token_count)
+ tokens_to_add = prompt_target_length - len(remade_tokens)
- multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
- multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
+ remade_tokens = remade_tokens + [id_end] * tokens_to_add
+ multipliers = multipliers + [1.0] * tokens_to_add
return remade_tokens, fixes, multipliers, token_count
@@ -391,7 +212,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if line in cache:
remade_tokens, fixes, multipliers = cache[line]
else:
- remade_tokens, fixes, multipliers, token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
+ remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
+ token_count = max(current_token_count, token_count)
cache[line] = (remade_tokens, fixes, multipliers)
@@ -405,7 +227,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def process_text_old(self, text):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
- maxlen = self.wrapped.max_length
+ maxlen = self.wrapped.max_length # you get to stay at 77
used_custom_terms = []
remade_batch_tokens = []
overflowing_words = []
@@ -431,32 +253,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens):
token = tokens[i]
- possible_matches = self.hijack.ids_lookup.get(token, None)
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
if mult_change is not None:
mult *= mult_change
- elif possible_matches is None:
+ i += 1
+ elif embedding is None:
remade_tokens.append(token)
multipliers.append(mult)
+ i += 1
else:
- found = False
- for ids, word in possible_matches:
- if tokens[i:i+len(ids)] == ids:
- emb_len = int(self.hijack.word_embeddings[word].shape[0])
- fixes.append((len(remade_tokens), word))
- remade_tokens += [0] * emb_len
- multipliers += [mult] * emb_len
- i += len(ids) - 1
- found = True
- used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
- break
-
- if not found:
- remade_tokens.append(token)
- multipliers.append(mult)
-
- i += 1
+ emb_len = int(embedding.vec.shape[0])
+ fixes.append((len(remade_tokens), embedding))
+ remade_tokens += [0] * emb_len
+ multipliers += [mult] * emb_len
+ used_custom_terms.append((embedding.name, embedding.checksum()))
+ i += embedding_length_in_tokens
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
@@ -464,6 +277,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
+
token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
@@ -476,27 +290,74 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
-
+
def forward(self, text):
-
- if opts.use_old_emphasis_implementation:
+ use_old = opts.use_old_emphasis_implementation
+ if use_old:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
else:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
-
- self.hijack.fixes = hijack_fixes
- self.hijack.comments = hijack_comments
+ self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0:
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
-
+
+ if use_old:
+ self.hijack.fixes = hijack_fixes
+ return self.process_tokens(remade_batch_tokens, batch_multipliers)
+
+ z = None
+ i = 0
+ while max(map(len, remade_batch_tokens)) != 0:
+ rem_tokens = [x[75:] for x in remade_batch_tokens]
+ rem_multipliers = [x[75:] for x in batch_multipliers]
+
+ self.hijack.fixes = []
+ for unfiltered in hijack_fixes:
+ fixes = []
+ for fix in unfiltered:
+ if fix[0] == i:
+ fixes.append(fix[1])
+ self.hijack.fixes.append(fixes)
+
+ tokens = []
+ multipliers = []
+ for j in range(len(remade_batch_tokens)):
+ if len(remade_batch_tokens[j]) > 0:
+ tokens.append(remade_batch_tokens[j][:75])
+ multipliers.append(batch_multipliers[j][:75])
+ else:
+ tokens.append([self.wrapped.tokenizer.eos_token_id] * 75)
+ multipliers.append([1.0] * 75)
+
+ z1 = self.process_tokens(tokens, multipliers)
+ z = z1 if z is None else torch.cat((z, z1), axis=-2)
+
+ remade_batch_tokens = rem_tokens
+ batch_multipliers = rem_multipliers
+ i += 1
+
+ return z
+
+
+ def process_tokens(self, remade_batch_tokens, batch_multipliers):
+ if not opts.use_old_emphasis_implementation:
+ remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
+ batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
+
tokens = torch.asarray(remade_batch_tokens).to(device)
- outputs = self.wrapped.transformer(input_ids=tokens)
- z = outputs.last_hidden_state
+ outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
+
+ if opts.CLIP_stop_at_last_layers > 1:
+ z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
+ z = self.wrapped.transformer.text_model.final_layer_norm(z)
+ else:
+ z = outputs.last_hidden_state
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
- batch_multipliers = torch.asarray(batch_multipliers).to(device)
+ batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
+ batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device)
original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
@@ -517,14 +378,19 @@ class EmbeddingsWithFixes(torch.nn.Module):
inputs_embeds = self.wrapped(input_ids)
- if batch_fixes is not None:
- for fixes, tensor in zip(batch_fixes, inputs_embeds):
- for offset, word in fixes:
- emb = self.embeddings.word_embeddings[word]
- emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
- tensor[offset+1:offset+1+emb_len] = self.embeddings.word_embeddings[word][0:emb_len]
+ if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
+ return inputs_embeds
+
+ vecs = []
+ for fixes, tensor in zip(batch_fixes, inputs_embeds):
+ for offset, embedding in fixes:
+ emb = embedding.vec
+ emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
+ tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
+
+ vecs.append(tensor)
- return inputs_embeds
+ return torch.stack(vecs)
def add_circular_option_to_conv_2d():
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
new file mode 100644
index 00000000..79405525
--- /dev/null
+++ b/modules/sd_hijack_optimizations.py
@@ -0,0 +1,306 @@
+import math
+import sys
+import traceback
+import importlib
+
+import torch
+from torch import einsum
+
+from ldm.util import default
+from einops import rearrange
+
+from modules import shared
+from modules.hypernetworks import hypernetwork
+
+
+if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
+ try:
+ import xformers.ops
+ shared.xformers_available = True
+ except Exception:
+ print("Cannot import xformers", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+
+# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
+def split_cross_attention_forward_v1(self, x, context=None, mask=None):
+ h = self.heads
+
+ q_in = self.to_q(x)
+ context = default(context, x)
+
+ context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ k_in = self.to_k(context_k)
+ v_in = self.to_v(context_v)
+ del context, context_k, context_v, x
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
+ del q_in, k_in, v_in
+
+ r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
+ for i in range(0, q.shape[0], 2):
+ end = i + 2
+ s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
+ s1 *= self.scale
+
+ s2 = s1.softmax(dim=-1)
+ del s1
+
+ r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
+ del s2
+ del q, k, v
+
+ r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
+ del r1
+
+ return self.to_out(r2)
+
+
+# taken from https://github.com/Doggettx/stable-diffusion and modified
+def split_cross_attention_forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q_in = self.to_q(x)
+ context = default(context, x)
+
+ context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ k_in = self.to_k(context_k)
+ v_in = self.to_v(context_v)
+
+ k_in *= self.scale
+
+ del context, x
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
+ del q_in, k_in, v_in
+
+ r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
+
+ stats = torch.cuda.memory_stats(q.device)
+ mem_active = stats['active_bytes.all.current']
+ mem_reserved = stats['reserved_bytes.all.current']
+ mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
+ mem_free_torch = mem_reserved - mem_active
+ mem_free_total = mem_free_cuda + mem_free_torch
+
+ gb = 1024 ** 3
+ tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
+ modifier = 3 if q.element_size() == 2 else 2.5
+ mem_required = tensor_size * modifier
+ steps = 1
+
+ if mem_required > mem_free_total:
+ steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
+ # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
+ # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
+
+ if steps > 64:
+ max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
+ raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
+ f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
+
+ slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
+ for i in range(0, q.shape[1], slice_size):
+ end = i + slice_size
+ s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
+
+ s2 = s1.softmax(dim=-1, dtype=q.dtype)
+ del s1
+
+ r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
+ del s2
+
+ del q, k, v
+
+ r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
+ del r1
+
+ return self.to_out(r2)
+
+
+def check_for_psutil():
+ try:
+ spec = importlib.util.find_spec('psutil')
+ return spec is not None
+ except ModuleNotFoundError:
+ return False
+
+invokeAI_mps_available = check_for_psutil()
+
+# -- Taken from https://github.com/invoke-ai/InvokeAI --
+if invokeAI_mps_available:
+ import psutil
+ mem_total_gb = psutil.virtual_memory().total // (1 << 30)
+
+def einsum_op_compvis(q, k, v):
+ s = einsum('b i d, b j d -> b i j', q, k)
+ s = s.softmax(dim=-1, dtype=s.dtype)
+ return einsum('b i j, b j d -> b i d', s, v)
+
+def einsum_op_slice_0(q, k, v, slice_size):
+ r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
+ for i in range(0, q.shape[0], slice_size):
+ end = i + slice_size
+ r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
+ return r
+
+def einsum_op_slice_1(q, k, v, slice_size):
+ r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
+ for i in range(0, q.shape[1], slice_size):
+ end = i + slice_size
+ r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
+ return r
+
+def einsum_op_mps_v1(q, k, v):
+ if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
+ return einsum_op_compvis(q, k, v)
+ else:
+ slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
+ return einsum_op_slice_1(q, k, v, slice_size)
+
+def einsum_op_mps_v2(q, k, v):
+ if mem_total_gb > 8 and q.shape[1] <= 4096:
+ return einsum_op_compvis(q, k, v)
+ else:
+ return einsum_op_slice_0(q, k, v, 1)
+
+def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
+ size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
+ if size_mb <= max_tensor_mb:
+ return einsum_op_compvis(q, k, v)
+ div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
+ if div <= q.shape[0]:
+ return einsum_op_slice_0(q, k, v, q.shape[0] // div)
+ return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
+
+def einsum_op_cuda(q, k, v):
+ stats = torch.cuda.memory_stats(q.device)
+ mem_active = stats['active_bytes.all.current']
+ mem_reserved = stats['reserved_bytes.all.current']
+ mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
+ mem_free_torch = mem_reserved - mem_active
+ mem_free_total = mem_free_cuda + mem_free_torch
+ # Divide factor of safety as there's copying and fragmentation
+ return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
+
+def einsum_op(q, k, v):
+ if q.device.type == 'cuda':
+ return einsum_op_cuda(q, k, v)
+
+ if q.device.type == 'mps':
+ if mem_total_gb >= 32:
+ return einsum_op_mps_v1(q, k, v)
+ return einsum_op_mps_v2(q, k, v)
+
+ # Smaller slices are faster due to L2/L3/SLC caches.
+ # Tested on i7 with 8MB L3 cache.
+ return einsum_op_tensor_mem(q, k, v, 32)
+
+def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+
+ context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ k = self.to_k(context_k) * self.scale
+ v = self.to_v(context_v)
+ del context, context_k, context_v, x
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+ r = einsum_op(q, k, v)
+ return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
+
+# -- End of code from https://github.com/invoke-ai/InvokeAI --
+
+def xformers_attention_forward(self, x, context=None, mask=None):
+ h = self.heads
+ q_in = self.to_q(x)
+ context = default(context, x)
+
+ context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ k_in = self.to_k(context_k)
+ v_in = self.to_v(context_v)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
+ del q_in, k_in, v_in
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
+
+ out = rearrange(out, 'b n h d -> b n (h d)', h=h)
+ return self.to_out(out)
+
+def cross_attention_attnblock_forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q1 = self.q(h_)
+ k1 = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q1.shape
+
+ q2 = q1.reshape(b, c, h*w)
+ del q1
+
+ q = q2.permute(0, 2, 1) # b,hw,c
+ del q2
+
+ k = k1.reshape(b, c, h*w) # b,c,hw
+ del k1
+
+ h_ = torch.zeros_like(k, device=q.device)
+
+ stats = torch.cuda.memory_stats(q.device)
+ mem_active = stats['active_bytes.all.current']
+ mem_reserved = stats['reserved_bytes.all.current']
+ mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
+ mem_free_torch = mem_reserved - mem_active
+ mem_free_total = mem_free_cuda + mem_free_torch
+
+ tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
+ mem_required = tensor_size * 2.5
+ steps = 1
+
+ if mem_required > mem_free_total:
+ steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
+
+ slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
+ for i in range(0, q.shape[1], slice_size):
+ end = i + slice_size
+
+ w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w2 = w1 * (int(c)**(-0.5))
+ del w1
+ w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
+ del w2
+
+ # attend to values
+ v1 = v.reshape(b, c, h*w)
+ w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ del w3
+
+ h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ del v1, w4
+
+ h2 = h_.reshape(b, c, h, w)
+ del h_
+
+ h3 = self.proj_out(h2)
+ del h2
+
+ h3 += x
+
+ return h3
+
+def xformers_attnblock_forward(self, x):
+ try:
+ h_ = x
+ h_ = self.norm(h_)
+ q1 = self.q(h_).contiguous()
+ k1 = self.k(h_).contiguous()
+ v = self.v(h_).contiguous()
+ out = xformers.ops.memory_efficient_attention(q1, k1, v)
+ out = self.proj_out(out)
+ return x + out
+ except NotImplementedError:
+ return cross_attention_attnblock_forward(self, x)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 2539f14c..3aa21ec1 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -1,24 +1,21 @@
-import glob
+import collections
import os.path
import sys
from collections import namedtuple
import torch
from omegaconf import OmegaConf
-
from ldm.util import instantiate_from_config
-from modules import shared, modelloader
+from modules import shared, modelloader, devices
from modules.paths import models_path
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
-model_name = "sd-v1-4.ckpt"
-model_url = "https://drive.yerf.org/wl/?id=EBfTrmcCCUAGaQBXVIj5lJmEhjoP1tgl&mode=grid&download=1"
-user_dir = None
-CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
+CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config'])
checkpoints_list = {}
+checkpoints_loaded = collections.OrderedDict()
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
@@ -30,12 +27,10 @@ except Exception:
pass
-def setup_model(dirname):
- global user_dir
- user_dir = dirname
+def setup_model():
if not os.path.exists(model_path):
os.makedirs(model_path)
- checkpoints_list.clear()
+
list_models()
@@ -45,13 +40,13 @@ def checkpoint_tiles():
def list_models():
checkpoints_list.clear()
- model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=user_dir, ext_filter=[".ckpt"], download_name=model_name)
+ model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"])
def modeltitle(path, shorthash):
abspath = os.path.abspath(path)
- if user_dir is not None and abspath.startswith(user_dir):
- name = abspath.replace(user_dir, '')
+ if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
+ name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
elif abspath.startswith(model_path):
name = abspath.replace(model_path, '')
else:
@@ -68,14 +63,20 @@ def list_models():
if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt)
title, short_model_name = modeltitle(cmd_ckpt, h)
- checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
- shared.opts.sd_model_checkpoint = title
+ checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config)
+ shared.opts.data['sd_model_checkpoint'] = title
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
for filename in model_list:
h = model_hash(filename)
title, short_model_name = modeltitle(filename, h)
- checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
+
+ basename, _ = os.path.splitext(filename)
+ config = basename + ".yaml"
+ if not os.path.exists(config):
+ config = shared.cmd_opts.config
+
+ checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)
def get_closet_checkpoint_match(searchString):
@@ -106,8 +107,11 @@ def select_checkpoint():
if len(checkpoints_list) == 0:
print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
- print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
- print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
+ if shared.cmd_opts.ckpt is not None:
+ print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
+ print(f" - directory {model_path}", file=sys.stderr)
+ if shared.cmd_opts.ckpt_dir is not None:
+ print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
exit(1)
@@ -118,33 +122,72 @@ def select_checkpoint():
return checkpoint_info
-def load_model_weights(model, checkpoint_file, sd_model_hash):
- print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
+def get_state_dict_from_checkpoint(pl_sd):
+ if "state_dict" in pl_sd:
+ return pl_sd["state_dict"]
+
+ return pl_sd
+
+
+def load_model_weights(model, checkpoint_info):
+ checkpoint_file = checkpoint_info.filename
+ sd_model_hash = checkpoint_info.hash
+
+ if checkpoint_info not in checkpoints_loaded:
+ print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
- pl_sd = torch.load(checkpoint_file, map_location="cpu")
- if "global_step" in pl_sd:
- print(f"Global Step: {pl_sd['global_step']}")
- sd = pl_sd["state_dict"]
+ pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
- model.load_state_dict(sd, strict=False)
+ sd = get_state_dict_from_checkpoint(pl_sd)
+ model.load_state_dict(sd, strict=False)
- if shared.cmd_opts.opt_channelslast:
- model.to(memory_format=torch.channels_last)
+ if shared.cmd_opts.opt_channelslast:
+ model.to(memory_format=torch.channels_last)
- if not shared.cmd_opts.no_half:
- model.half()
+ if not shared.cmd_opts.no_half:
+ model.half()
+
+ devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
+ devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
+
+ vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
+
+ if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None:
+ vae_file = shared.cmd_opts.vae_path
+
+ if os.path.exists(vae_file):
+ print(f"Loading VAE weights from: {vae_file}")
+ vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
+ vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
+ model.first_stage_model.load_state_dict(vae_dict)
+
+ model.first_stage_model.to(devices.dtype_vae)
+
+ checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
+ while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
+ checkpoints_loaded.popitem(last=False) # LRU
+ else:
+ print(f"Loading weights [{sd_model_hash}] from cache")
+ checkpoints_loaded.move_to_end(checkpoint_info)
+ model.load_state_dict(checkpoints_loaded[checkpoint_info])
model.sd_model_hash = sd_model_hash
- model.sd_model_checkpint = checkpoint_file
+ model.sd_model_checkpoint = checkpoint_file
+ model.sd_checkpoint_info = checkpoint_info
def load_model():
from modules import lowvram, sd_hijack
checkpoint_info = select_checkpoint()
- sd_config = OmegaConf.load(shared.cmd_opts.config)
+ if checkpoint_info.config != shared.cmd_opts.config:
+ print(f"Loading config from: {checkpoint_info.config}")
+
+ sd_config = OmegaConf.load(checkpoint_info.config)
sd_model = instantiate_from_config(sd_config.model)
- load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
+ load_model_weights(sd_model, checkpoint_info)
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
@@ -163,9 +206,14 @@ def reload_model_weights(sd_model, info=None):
from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint()
- if sd_model.sd_model_checkpint == checkpoint_info.filename:
+ if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
+ if sd_model.sd_checkpoint_info.config != checkpoint_info.config:
+ checkpoints_loaded.clear()
+ shared.sd_model = load_model()
+ return shared.sd_model
+
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
else:
@@ -173,7 +221,7 @@ def reload_model_weights(sd_model, info=None):
sd_hijack.model_hijack.undo_hijack(sd_model)
- load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
+ load_model_weights(sd_model, checkpoint_info)
sd_hijack.model_hijack.hijack(sd_model)
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 92522214..20309e06 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -7,37 +7,63 @@ import inspect
import k_diffusion.sampling
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
-from modules import prompt_parser
+from modules import prompt_parser, devices, processing
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
-SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases'])
+SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
samplers_k_diffusion = [
- ('Euler a', 'sample_euler_ancestral', ['k_euler_a']),
- ('Euler', 'sample_euler', ['k_euler']),
- ('LMS', 'sample_lms', ['k_lms']),
- ('Heun', 'sample_heun', ['k_heun']),
- ('DPM2', 'sample_dpm_2', ['k_dpm_2']),
- ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a']),
- ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast']),
- ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad']),
+ ('Euler a', 'sample_euler_ancestral', ['k_euler_a'], {}),
+ ('Euler', 'sample_euler', ['k_euler'], {}),
+ ('LMS', 'sample_lms', ['k_lms'], {}),
+ ('Heun', 'sample_heun', ['k_heun'], {}),
+ ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}),
+ ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}),
+ ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
+ ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
+ ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
+ ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}),
+ ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}),
]
samplers_data_k_diffusion = [
- SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases)
- for label, funcname, aliases in samplers_k_diffusion
+ SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
+ for label, funcname, aliases, options in samplers_k_diffusion
if hasattr(k_diffusion.sampling, funcname)
]
-samplers = [
+all_samplers = [
*samplers_data_k_diffusion,
- SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []),
- SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []),
+ SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
+ SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
]
-samplers_for_img2img = [x for x in samplers if x.name not in ['PLMS', 'DPM fast', 'DPM adaptive']]
+
+samplers = []
+samplers_for_img2img = []
+
+
+def create_sampler_with_index(list_of_configs, index, model):
+ config = list_of_configs[index]
+ sampler = config.constructor(model)
+ sampler.config = config
+
+ return sampler
+
+
+def set_samplers():
+ global samplers, samplers_for_img2img
+
+ hidden = set(opts.hide_samplers)
+ hidden_img2img = set(opts.hide_samplers + ['PLMS'])
+
+ samplers = [x for x in all_samplers if x.name not in hidden]
+ samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
+
+
+set_samplers()
sampler_extra_params = {
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
@@ -57,7 +83,7 @@ def setup_img2img_steps(p, steps=None):
def sample_to_image(samples):
- x_sample = shared.sd_model.decode_first_stage(samples[0:1].type(shared.sd_model.dtype))[0]
+ x_sample = processing.decode_first_stage(shared.sd_model, samples[0:1])[0]
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
@@ -77,8 +103,10 @@ def extended_tdqm(sequence, *args, desc=None, **kwargs):
state.sampling_steps = len(sequence)
state.sampling_step = 0
- for x in tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs):
- if state.interrupted:
+ seq = sequence if cmd_opts.disable_console_progressbars else tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
+
+ for x in seq:
+ if state.interrupted or state.skipped:
break
yield x
@@ -102,14 +130,28 @@ class VanillaStableDiffusionSampler:
self.step = 0
self.eta = None
self.default_eta = 0.0
+ self.config = None
def number_of_needed_noises(self, p):
return 0
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
- cond = prompt_parser.reconstruct_cond_batch(cond, self.step)
+ conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
+ assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
+ cond = tensor
+
+ # for DDIM, shapes must match, we can't just process cond and uncond independently;
+ # filling unconditional_conditioning with repeats of the last vector to match length is
+ # not 100% correct but should work well enough
+ if unconditional_conditioning.shape[1] < cond.shape[1]:
+ last_vector = unconditional_conditioning[:, -1:]
+ last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
+ unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
+ elif unconditional_conditioning.shape[1] > cond.shape[1]:
+ unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
+
if self.mask is not None:
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
x_dec = img_orig * self.mask + self.nmask * x_dec
@@ -125,7 +167,7 @@ class VanillaStableDiffusionSampler:
return res
def initialize(self, p):
- self.eta = p.eta or opts.eta_ddim
+ self.eta = p.eta if p.eta is not None else opts.eta_ddim
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
if hasattr(self.sampler, fieldname):
@@ -139,7 +181,7 @@ class VanillaStableDiffusionSampler:
self.initialize(p)
- # existing code fails with cetain step counts, like 9
+ # existing code fails with certain step counts, like 9
try:
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
except Exception:
@@ -162,7 +204,7 @@ class VanillaStableDiffusionSampler:
steps = steps or p.steps
- # existing code fails with cetin step counts, like 9
+ # existing code fails with certain step counts, like 9
try:
samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)
except Exception:
@@ -181,19 +223,42 @@ class CFGDenoiser(torch.nn.Module):
self.step = 0
def forward(self, x, sigma, uncond, cond, cond_scale):
- cond = prompt_parser.reconstruct_cond_batch(cond, self.step)
+ conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
- if shared.batch_cond_uncond:
- x_in = torch.cat([x] * 2)
- sigma_in = torch.cat([sigma] * 2)
- cond_in = torch.cat([uncond, cond])
- uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
- denoised = uncond + (cond - uncond) * cond_scale
+ batch_size = len(conds_list)
+ repeats = [len(conds_list[i]) for i in range(batch_size)]
+
+ x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
+ sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
+
+ if tensor.shape[1] == uncond.shape[1]:
+ cond_in = torch.cat([tensor, uncond])
+
+ if shared.batch_cond_uncond:
+ x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
+ else:
+ x_out = torch.zeros_like(x_in)
+ for batch_offset in range(0, x_out.shape[0], batch_size):
+ a = batch_offset
+ b = a + batch_size
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])
else:
- uncond = self.inner_model(x, sigma, cond=uncond)
- cond = self.inner_model(x, sigma, cond=cond)
- denoised = uncond + (cond - uncond) * cond_scale
+ x_out = torch.zeros_like(x_in)
+ batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
+ for batch_offset in range(0, tensor.shape[0], batch_size):
+ a = batch_offset
+ b = min(a + batch_size, tensor.shape[0])
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=tensor[a:b])
+
+ x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=uncond)
+
+ denoised_uncond = x_out[-uncond.shape[0]:]
+ denoised = torch.clone(denoised_uncond)
+
+ for i, conds in enumerate(conds_list):
+ for cond_index, weight in conds:
+ denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
@@ -207,8 +272,10 @@ def extended_trange(sampler, count, *args, **kwargs):
state.sampling_steps = count
state.sampling_step = 0
- for x in tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs):
- if state.interrupted:
+ seq = range(count) if cmd_opts.disable_console_progressbars else tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
+
+ for x in seq:
+ if state.interrupted or state.skipped:
break
if sampler.stop_at is not None and x > sampler.stop_at:
@@ -246,6 +313,7 @@ class KDiffusionSampler:
self.stop_at = None
self.eta = None
self.default_eta = 1.0
+ self.config = None
def callback_state(self, d):
store_latent(d["denoised"])
@@ -291,28 +359,43 @@ class KDiffusionSampler:
steps, t_enc = setup_img2img_steps(p, steps)
if p.sampler_noise_scheduler_override:
- sigmas = p.sampler_noise_scheduler_override(steps)
+ sigmas = p.sampler_noise_scheduler_override(steps)
+ elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
+ sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
else:
- sigmas = self.model_wrap.get_sigmas(steps)
-
- noise = noise * sigmas[steps - t_enc - 1]
- xi = x + noise
-
- extra_params_kwargs = self.initialize(p)
+ sigmas = self.model_wrap.get_sigmas(steps)
sigma_sched = sigmas[steps - t_enc - 1:]
+ xi = x + noise * sigma_sched[0]
+
+ extra_params_kwargs = self.initialize(p)
+ if 'sigma_min' in inspect.signature(self.func).parameters:
+ ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
+ extra_params_kwargs['sigma_min'] = sigma_sched[-2]
+ if 'sigma_max' in inspect.signature(self.func).parameters:
+ extra_params_kwargs['sigma_max'] = sigma_sched[0]
+ if 'n' in inspect.signature(self.func).parameters:
+ extra_params_kwargs['n'] = len(sigma_sched) - 1
+ if 'sigma_sched' in inspect.signature(self.func).parameters:
+ extra_params_kwargs['sigma_sched'] = sigma_sched
+ if 'sigmas' in inspect.signature(self.func).parameters:
+ extra_params_kwargs['sigmas'] = sigma_sched
self.model_wrap_cfg.init_latent = x
- return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
+ return self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
+
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
steps = steps or p.steps
if p.sampler_noise_scheduler_override:
- sigmas = p.sampler_noise_scheduler_override(steps)
+ sigmas = p.sampler_noise_scheduler_override(steps)
+ elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
+ sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
else:
- sigmas = self.model_wrap.get_sigmas(steps)
+ sigmas = self.model_wrap.get_sigmas(steps)
+
x = x * sigmas[0]
extra_params_kwargs = self.initialize(p)
diff --git a/modules/shared.py b/modules/shared.py
index ac968b2d..aa69bedf 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -12,12 +12,13 @@ import modules.interrogate
import modules.memmon
import modules.sd_models
import modules.styles
-from modules.devices import get_optimal_device
-from modules.paths import script_path, sd_path
+import modules.devices as devices
+from modules import sd_samplers, sd_models
+from modules.hypernetworks import hypernetwork
+from modules.paths import models_path, script_path, sd_path
sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file
-model_path = os.path.join(script_path, 'models')
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
@@ -25,26 +26,36 @@ parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to director
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
+parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
+parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
+parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
-parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(model_path, 'Codeformer'))
-parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(model_path, 'GFPGAN'))
-parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(model_path, 'ESRGAN'))
-parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(model_path, 'BSRGAN'))
-parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(model_path, 'RealESRGAN'))
-parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(model_path, 'SwinIR'))
-parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(model_path, 'LDSR'))
-parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
-parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
+parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
+parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
+parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
+parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
+parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
+parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
+parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
+parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
+parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
+parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
+parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
+parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator")
+parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
+parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
+parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
+parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
@@ -53,21 +64,44 @@ parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide dire
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
+parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor")
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
+parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
+parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
+parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
+parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
+
cmd_opts = parser.parse_args()
-device = get_optimal_device()
+
+devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
+(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'])
+
+device = devices.device
+weight_load_location = None if cmd_opts.lowram else "cpu"
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
-
+xformers_available = False
config_filename = cmd_opts.ui_settings_file
+os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
+hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
+loaded_hypernetwork = None
+
+
+def reload_hypernetworks():
+ global hypernetworks
+
+ hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
+ hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
+
class State:
+ skipped = False
interrupted = False
job = ""
job_no = 0
@@ -78,6 +112,10 @@ class State:
current_latent = None
current_image = None
current_image_sampling_step = 0
+ textinfo = None
+
+ def skip(self):
+ self.skipped = True
def interrupt(self):
self.interrupted = True
@@ -88,7 +126,7 @@ class State:
self.current_image_sampling_step = 0
def get_job_timestamp(self):
- return datetime.datetime.now().strftime("%Y%m%d%H%M%S")
+ return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
state = State()
@@ -101,8 +139,6 @@ prompt_styles = modules.styles.StyleDatabase(styles_filename)
interrogator = modules.interrogate.InterrogateModels("interrogate")
face_restorers = []
-# This was moved to webui.py with the other model "setup" calls.
-# modules.sd_models.list_models()
def realesrgan_models_names():
@@ -111,18 +147,19 @@ def realesrgan_models_names():
class OptionInfo:
- def __init__(self, default=None, label="", component=None, component_args=None, onchange=None):
+ def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False, refresh=None):
self.default = default
self.label = label
self.component = component
self.component_args = component_args
self.onchange = onchange
self.section = None
+ self.refresh = refresh
-def options_section(section_identifer, options_dict):
+def options_section(section_identifier, options_dict):
for k, v in options_dict.items():
- v.section = section_identifer
+ v.section = section_identifier
return options_dict
@@ -140,6 +177,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"grid_format": OptionInfo('png', 'File format for grids'),
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
+ "grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
@@ -150,6 +188,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
+ "do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
}))
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
@@ -165,9 +204,10 @@ options_templates.update(options_section(('saving-paths', "Paths for saving"), {
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
"save_to_dirs": OptionInfo(False, "Save images to a subdirectory"),
- "grid_save_to_dirs": OptionInfo(False, "Save grids to subdirectory"),
+ "grid_save_to_dirs": OptionInfo(False, "Save grids to a subdirectory"),
+ "use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
"directories_filename_pattern": OptionInfo("", "Directory name pattern"),
- "directories_max_prompt_words": OptionInfo(8, "Max prompt words", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}),
+ "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}),
}))
options_templates.update(options_section(('upscaling', "Upscaling"), {
@@ -177,7 +217,7 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
"SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
- "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}),
+ "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
}))
options_templates.update(options_section(('face-restoration', "Face restoration"), {
@@ -189,50 +229,75 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
options_templates.update(options_section(('system', "System"), {
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
- "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job. Broken in PyCharm console."),
+ "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
+}))
+
+options_templates.update(options_section(('training', "Training"), {
+ "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"),
+ "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
+ "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
+ "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
+ "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
}))
options_templates.update(options_section(('sd', "Stable Diffusion"), {
- "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}),
+ "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
+ "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
+ "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
+ "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
- "enable_emphasis": OptionInfo(True, "Eemphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
+ "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
+ "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
+ 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
+ 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
}))
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
+ "interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."),
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
- "interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"),
+ "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"),
+ "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
+ "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
+ "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
+ "deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"),
}))
options_templates.update(options_section(('ui', "User interface"), {
"show_progressbar": OptionInfo(True, "Show progressbar"),
- "show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
+ "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
"return_grid": OptionInfo(True, "Show grid in results for web"),
+ "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
+ "add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
"font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
- "js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
+ "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
+ "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
}))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
- "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
- 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ "hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in sd_samplers.all_samplers]}),
+ "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
+ 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
}))
+
class Options:
data = None
data_labels = options_templates
@@ -289,6 +354,8 @@ class Options:
item = self.data_labels.get(key)
item.onchange = func
+ func()
+
def dumpjson(self):
d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
return json.dumps(d)
@@ -318,14 +385,14 @@ class TotalTQDM:
)
def update(self):
- if not opts.multiple_tqdm:
+ if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
return
if self._tqdm is None:
self.reset()
self._tqdm.update()
def updateTotal(self, new_total):
- if not opts.multiple_tqdm:
+ if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
return
if self._tqdm is None:
self.reset()
diff --git a/modules/swinir_model.py b/modules/swinir_model.py
index 41fda5a7..baa02e3d 100644
--- a/modules/swinir_model.py
+++ b/modules/swinir_model.py
@@ -5,11 +5,12 @@ import numpy as np
import torch
from PIL import Image
from basicsr.utils.download_util import load_file_from_url
+from tqdm import tqdm
from modules import modelloader
-from modules.paths import models_path
from modules.shared import cmd_opts, opts, device
from modules.swinir_model_arch import SwinIR as net
+from modules.swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData
precision_scope = (
@@ -24,7 +25,6 @@ class UpscalerSwinIR(Upscaler):
"/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
"-L_x4_GAN.pth "
self.model_name = "SwinIR 4x"
- self.model_path = os.path.join(models_path, self.name)
self.user_path = dirname
super().__init__()
scalers = []
@@ -58,22 +58,42 @@ class UpscalerSwinIR(Upscaler):
filename = path
if filename is None or not os.path.exists(filename):
return None
- model = net(
+ if filename.endswith(".v2.pth"):
+ model = net2(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
- depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
- embed_dim=240,
- num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
+ depths=[6, 6, 6, 6, 6, 6],
+ embed_dim=180,
+ num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2,
upsampler="nearest+conv",
- resi_connection="3conv",
- )
+ resi_connection="1conv",
+ )
+ params = None
+ else:
+ model = net(
+ upscale=scale,
+ in_chans=3,
+ img_size=64,
+ window_size=8,
+ img_range=1.0,
+ depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
+ embed_dim=240,
+ num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
+ mlp_ratio=2,
+ upsampler="nearest+conv",
+ resi_connection="3conv",
+ )
+ params = "params_ema"
pretrained_model = torch.load(filename)
- model.load_state_dict(pretrained_model["params_ema"], strict=True)
+ if params is not None:
+ model.load_state_dict(pretrained_model[params], strict=True)
+ else:
+ model.load_state_dict(pretrained_model, strict=True)
if not cmd_opts.no_half:
model = model.half()
return model
@@ -122,18 +142,20 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
W = torch.zeros_like(E, dtype=torch.half, device=device)
- for h_idx in h_idx_list:
- for w_idx in w_idx_list:
- in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
- out_patch = model(in_patch)
- out_patch_mask = torch.ones_like(out_patch)
-
- E[
- ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
- ].add_(out_patch)
- W[
- ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
- ].add_(out_patch_mask)
+ with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
+ for h_idx in h_idx_list:
+ for w_idx in w_idx_list:
+ in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
+ out_patch = model(in_patch)
+ out_patch_mask = torch.ones_like(out_patch)
+
+ E[
+ ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
+ ].add_(out_patch)
+ W[
+ ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
+ ].add_(out_patch_mask)
+ pbar.update(1)
output = E.div_(W)
return output
diff --git a/modules/swinir_model_arch.py b/modules/swinir_model_arch.py
index 461fb354..863f42db 100644
--- a/modules/swinir_model_arch.py
+++ b/modules/swinir_model_arch.py
@@ -166,7 +166,7 @@ class SwinTransformerBlock(nn.Module):
Args:
dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resulotion.
+ input_resolution (tuple[int]): Input resolution.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
diff --git a/modules/swinir_model_arch_v2.py b/modules/swinir_model_arch_v2.py
new file mode 100644
index 00000000..0e28ae6e
--- /dev/null
+++ b/modules/swinir_model_arch_v2.py
@@ -0,0 +1,1017 @@
+# -----------------------------------------------------------------------------------
+# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/
+# Written by Conde and Choi et al.
+# -----------------------------------------------------------------------------------
+
+import math
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+class WindowAttention(nn.Module):
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
+ """
+
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
+ pretrained_window_size=[0, 0]):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.pretrained_window_size = pretrained_window_size
+ self.num_heads = num_heads
+
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
+
+ # mlp to generate continuous relative position bias
+ self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
+ nn.ReLU(inplace=True),
+ nn.Linear(512, num_heads, bias=False))
+
+ # get relative_coords_table
+ relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
+ relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
+ relative_coords_table = torch.stack(
+ torch.meshgrid([relative_coords_h,
+ relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
+ if pretrained_window_size[0] > 0:
+ relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
+ relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
+ else:
+ relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
+ relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
+ relative_coords_table *= 8 # normalize to -8, 8
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8)
+
+ self.register_buffer("relative_coords_table", relative_coords_table)
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(dim))
+ self.v_bias = nn.Parameter(torch.zeros(dim))
+ else:
+ self.q_bias = None
+ self.v_bias = None
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ B_, N, C = x.shape
+ qkv_bias = None
+ if self.q_bias is not None:
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ # cosine attention
+ attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
+ logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp()
+ attn = attn * logit_scale
+
+ relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
+ relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f'dim={self.dim}, window_size={self.window_size}, ' \
+ f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
+
+ def flops(self, N):
+ # calculate flops for 1 window with token length of N
+ flops = 0
+ # qkv = self.qkv(x)
+ flops += N * self.dim * 3 * self.dim
+ # attn = (q @ k.transpose(-2, -1))
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
+ # x = (attn @ v)
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
+ # x = self.proj(x)
+ flops += N * self.dim * self.dim
+ return flops
+
+class SwinTransformerBlock(nn.Module):
+ r""" Swin Transformer Block.
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resulotion.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ pretrained_window_size (int): Window size in pre-training.
+ """
+
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ if min(self.input_resolution) <= self.window_size:
+ # if window size is larger than input resolution, we don't partition windows
+ self.shift_size = 0
+ self.window_size = min(self.input_resolution)
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
+ qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
+ pretrained_window_size=to_2tuple(pretrained_window_size))
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ if self.shift_size > 0:
+ attn_mask = self.calculate_mask(self.input_resolution)
+ else:
+ attn_mask = None
+
+ self.register_buffer("attn_mask", attn_mask)
+
+ def calculate_mask(self, x_size):
+ # calculate attention mask for SW-MSA
+ H, W = x_size
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+ return attn_mask
+
+ def forward(self, x, x_size):
+ H, W = x_size
+ B, L, C = x.shape
+ #assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = x.view(B, H, W, C)
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ else:
+ shifted_x = x
+
+ # partition windows
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
+ if self.input_resolution == x_size:
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
+ else:
+ attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+ x = x.view(B, H * W, C)
+ x = shortcut + self.drop_path(self.norm1(x))
+
+ # FFN
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
+
+ def flops(self):
+ flops = 0
+ H, W = self.input_resolution
+ # norm1
+ flops += self.dim * H * W
+ # W-MSA/SW-MSA
+ nW = H * W / self.window_size / self.window_size
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
+ # mlp
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
+ # norm2
+ flops += self.dim * H * W
+ return flops
+
+class PatchMerging(nn.Module):
+ r""" Patch Merging Layer.
+ Args:
+ input_resolution (tuple[int]): Resolution of input feature.
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(2 * dim)
+
+ def forward(self, x):
+ """
+ x: B, H*W, C
+ """
+ H, W = self.input_resolution
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
+
+ x = x.view(B, H, W, C)
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ x = self.reduction(x)
+ x = self.norm(x)
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
+
+ def flops(self):
+ H, W = self.input_resolution
+ flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
+ flops += H * W * self.dim // 2
+ return flops
+
+class BasicLayer(nn.Module):
+ """ A basic Swin Transformer layer for one stage.
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ pretrained_window_size (int): Local window size in pre-training.
+ """
+
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
+ pretrained_window_size=0):
+
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
+ num_heads=num_heads, window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop, attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer,
+ pretrained_window_size=pretrained_window_size)
+ for i in range(depth)])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x, x_size):
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x, x_size)
+ else:
+ x = blk(x, x_size)
+ if self.downsample is not None:
+ x = self.downsample(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+ def flops(self):
+ flops = 0
+ for blk in self.blocks:
+ flops += blk.flops()
+ if self.downsample is not None:
+ flops += self.downsample.flops()
+ return flops
+
+ def _init_respostnorm(self):
+ for blk in self.blocks:
+ nn.init.constant_(blk.norm1.bias, 0)
+ nn.init.constant_(blk.norm1.weight, 0)
+ nn.init.constant_(blk.norm2.bias, 0)
+ nn.init.constant_(blk.norm2.weight, 0)
+
+class PatchEmbed(nn.Module):
+ r""" Image to Patch Embedding
+ Args:
+ img_size (int): Image size. Default: 224.
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patches_resolution = patches_resolution
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ # FIXME look at relaxing size constraints
+ # assert H == self.img_size[0] and W == self.img_size[1],
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
+ if self.norm is not None:
+ x = self.norm(x)
+ return x
+
+ def flops(self):
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
+
+class RSTB(nn.Module):
+ """Residual Swin Transformer Block (RSTB).
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ img_size: Input image size.
+ patch_size: Patch size.
+ resi_connection: The convolutional block before residual connection.
+ """
+
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
+ img_size=224, patch_size=4, resi_connection='1conv'):
+ super(RSTB, self).__init__()
+
+ self.dim = dim
+ self.input_resolution = input_resolution
+
+ self.residual_group = BasicLayer(dim=dim,
+ input_resolution=input_resolution,
+ depth=depth,
+ num_heads=num_heads,
+ window_size=window_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop, attn_drop=attn_drop,
+ drop_path=drop_path,
+ norm_layer=norm_layer,
+ downsample=downsample,
+ use_checkpoint=use_checkpoint)
+
+ if resi_connection == '1conv':
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
+ elif resi_connection == '3conv':
+ # to save parameters and memory
+ self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
+ norm_layer=None)
+
+ self.patch_unembed = PatchUnEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
+ norm_layer=None)
+
+ def forward(self, x, x_size):
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
+
+ def flops(self):
+ flops = 0
+ flops += self.residual_group.flops()
+ H, W = self.input_resolution
+ flops += H * W * self.dim * self.dim * 9
+ flops += self.patch_embed.flops()
+ flops += self.patch_unembed.flops()
+
+ return flops
+
+class PatchUnEmbed(nn.Module):
+ r""" Image to Patch Unembedding
+
+ Args:
+ img_size (int): Image size. Default: 224.
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patches_resolution = patches_resolution
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ def forward(self, x, x_size):
+ B, HW, C = x.shape
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
+ return x
+
+ def flops(self):
+ flops = 0
+ return flops
+
+
+class Upsample(nn.Sequential):
+ """Upsample module.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+ """
+
+ def __init__(self, scale, num_feat):
+ m = []
+ if (scale & (scale - 1)) == 0: # scale = 2^n
+ for _ in range(int(math.log(scale, 2))):
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(2))
+ elif scale == 3:
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(3))
+ else:
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
+ super(Upsample, self).__init__(*m)
+
+class Upsample_hf(nn.Sequential):
+ """Upsample module.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+ """
+
+ def __init__(self, scale, num_feat):
+ m = []
+ if (scale & (scale - 1)) == 0: # scale = 2^n
+ for _ in range(int(math.log(scale, 2))):
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(2))
+ elif scale == 3:
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(3))
+ else:
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
+ super(Upsample_hf, self).__init__(*m)
+
+
+class UpsampleOneStep(nn.Sequential):
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
+ Used in lightweight SR to save parameters.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+
+ """
+
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
+ self.num_feat = num_feat
+ self.input_resolution = input_resolution
+ m = []
+ m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
+ m.append(nn.PixelShuffle(scale))
+ super(UpsampleOneStep, self).__init__(*m)
+
+ def flops(self):
+ H, W = self.input_resolution
+ flops = H * W * self.num_feat * 3 * 9
+ return flops
+
+
+
+class Swin2SR(nn.Module):
+ r""" Swin2SR
+ A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
+
+ Args:
+ img_size (int | tuple(int)): Input image size. Default 64
+ patch_size (int | tuple(int)): Patch size. Default: 1
+ in_chans (int): Number of input image channels. Default: 3
+ embed_dim (int): Patch embedding dimension. Default: 96
+ depths (tuple(int)): Depth of each Swin Transformer layer.
+ num_heads (tuple(int)): Number of attention heads in different layers.
+ window_size (int): Window size. Default: 7
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ drop_rate (float): Dropout rate. Default: 0
+ attn_drop_rate (float): Attention dropout rate. Default: 0
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
+ img_range: Image range. 1. or 255.
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
+ """
+
+ def __init__(self, img_size=64, patch_size=1, in_chans=3,
+ embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
+ window_size=7, mlp_ratio=4., qkv_bias=True,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
+ use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
+ **kwargs):
+ super(Swin2SR, self).__init__()
+ num_in_ch = in_chans
+ num_out_ch = in_chans
+ num_feat = 64
+ self.img_range = img_range
+ if in_chans == 3:
+ rgb_mean = (0.4488, 0.4371, 0.4040)
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
+ else:
+ self.mean = torch.zeros(1, 1, 1, 1)
+ self.upscale = upscale
+ self.upsampler = upsampler
+ self.window_size = window_size
+
+ #####################################################################################################
+ ################################### 1, shallow feature extraction ###################################
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
+
+ #####################################################################################################
+ ################################### 2, deep feature extraction ######################################
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.ape = ape
+ self.patch_norm = patch_norm
+ self.num_features = embed_dim
+ self.mlp_ratio = mlp_ratio
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+ num_patches = self.patch_embed.num_patches
+ patches_resolution = self.patch_embed.patches_resolution
+ self.patches_resolution = patches_resolution
+
+ # merge non-overlapping patches into image
+ self.patch_unembed = PatchUnEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+
+ # absolute position embedding
+ if self.ape:
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+ trunc_normal_(self.absolute_pos_embed, std=.02)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+
+ # build Residual Swin Transformer blocks (RSTB)
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = RSTB(dim=embed_dim,
+ input_resolution=(patches_resolution[0],
+ patches_resolution[1]),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
+ norm_layer=norm_layer,
+ downsample=None,
+ use_checkpoint=use_checkpoint,
+ img_size=img_size,
+ patch_size=patch_size,
+ resi_connection=resi_connection
+
+ )
+ self.layers.append(layer)
+
+ if self.upsampler == 'pixelshuffle_hf':
+ self.layers_hf = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = RSTB(dim=embed_dim,
+ input_resolution=(patches_resolution[0],
+ patches_resolution[1]),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
+ norm_layer=norm_layer,
+ downsample=None,
+ use_checkpoint=use_checkpoint,
+ img_size=img_size,
+ patch_size=patch_size,
+ resi_connection=resi_connection
+
+ )
+ self.layers_hf.append(layer)
+
+ self.norm = norm_layer(self.num_features)
+
+ # build the last conv layer in deep feature extraction
+ if resi_connection == '1conv':
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
+ elif resi_connection == '3conv':
+ # to save parameters and memory
+ self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
+
+ #####################################################################################################
+ ################################ 3, high quality image reconstruction ################################
+ if self.upsampler == 'pixelshuffle':
+ # for classical SR
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.upsample = Upsample(upscale, num_feat)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ elif self.upsampler == 'pixelshuffle_aux':
+ self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+ self.conv_before_upsample = nn.Sequential(
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ self.conv_after_aux = nn.Sequential(
+ nn.Conv2d(3, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.upsample = Upsample(upscale, num_feat)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+ elif self.upsampler == 'pixelshuffle_hf':
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.upsample = Upsample(upscale, num_feat)
+ self.upsample_hf = Upsample_hf(upscale, num_feat)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
+ self.conv_before_upsample_hf = nn.Sequential(
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+ elif self.upsampler == 'pixelshuffledirect':
+ # for lightweight SR (to save parameters)
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
+ (patches_resolution[0], patches_resolution[1]))
+ elif self.upsampler == 'nearest+conv':
+ # for real-world SR (less artifacts)
+ assert self.upscale == 4, 'only support x4 now.'
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+ else:
+ # for image denoising and JPEG compression artifact reduction
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'absolute_pos_embed'}
+
+ @torch.jit.ignore
+ def no_weight_decay_keywords(self):
+ return {'relative_position_bias_table'}
+
+ def check_image_size(self, x):
+ _, _, h, w = x.size()
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
+ return x
+
+ def forward_features(self, x):
+ x_size = (x.shape[2], x.shape[3])
+ x = self.patch_embed(x)
+ if self.ape:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+
+ for layer in self.layers:
+ x = layer(x, x_size)
+
+ x = self.norm(x) # B L C
+ x = self.patch_unembed(x, x_size)
+
+ return x
+
+ def forward_features_hf(self, x):
+ x_size = (x.shape[2], x.shape[3])
+ x = self.patch_embed(x)
+ if self.ape:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+
+ for layer in self.layers_hf:
+ x = layer(x, x_size)
+
+ x = self.norm(x) # B L C
+ x = self.patch_unembed(x, x_size)
+
+ return x
+
+ def forward(self, x):
+ H, W = x.shape[2:]
+ x = self.check_image_size(x)
+
+ self.mean = self.mean.type_as(x)
+ x = (x - self.mean) * self.img_range
+
+ if self.upsampler == 'pixelshuffle':
+ # for classical SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.conv_before_upsample(x)
+ x = self.conv_last(self.upsample(x))
+ elif self.upsampler == 'pixelshuffle_aux':
+ bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False)
+ bicubic = self.conv_bicubic(bicubic)
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.conv_before_upsample(x)
+ aux = self.conv_aux(x) # b, 3, LR_H, LR_W
+ x = self.conv_after_aux(aux)
+ x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale]
+ x = self.conv_last(x)
+ aux = aux / self.img_range + self.mean
+ elif self.upsampler == 'pixelshuffle_hf':
+ # for classical SR with HF
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x_before = self.conv_before_upsample(x)
+ x_out = self.conv_last(self.upsample(x_before))
+
+ x_hf = self.conv_first_hf(x_before)
+ x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
+ x_hf = self.conv_before_upsample_hf(x_hf)
+ x_hf = self.conv_last_hf(self.upsample_hf(x_hf))
+ x = x_out + x_hf
+ x_hf = x_hf / self.img_range + self.mean
+
+ elif self.upsampler == 'pixelshuffledirect':
+ # for lightweight SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.upsample(x)
+ elif self.upsampler == 'nearest+conv':
+ # for real-world SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.conv_before_upsample(x)
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
+ else:
+ # for image denoising and JPEG compression artifact reduction
+ x_first = self.conv_first(x)
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
+ x = x + self.conv_last(res)
+
+ x = x / self.img_range + self.mean
+ if self.upsampler == "pixelshuffle_aux":
+ return x[:, :, :H*self.upscale, :W*self.upscale], aux
+
+ elif self.upsampler == "pixelshuffle_hf":
+ x_out = x_out / self.img_range + self.mean
+ return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
+
+ else:
+ return x[:, :, :H*self.upscale, :W*self.upscale]
+
+ def flops(self):
+ flops = 0
+ H, W = self.patches_resolution
+ flops += H * W * 3 * self.embed_dim * 9
+ flops += self.patch_embed.flops()
+ for i, layer in enumerate(self.layers):
+ flops += layer.flops()
+ flops += H * W * 3 * self.embed_dim * self.embed_dim
+ flops += self.upsample.flops()
+ return flops
+
+
+if __name__ == '__main__':
+ upscale = 4
+ window_size = 8
+ height = (1024 // upscale // window_size + 1) * window_size
+ width = (720 // upscale // window_size + 1) * window_size
+ model = Swin2SR(upscale=2, img_size=(height, width),
+ window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
+ embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
+ print(model)
+ print(height, width, model.flops() / 1e9)
+
+ x = torch.randn((1, 3, height, width))
+ x = model(x)
+ print(x.shape) \ No newline at end of file
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
new file mode 100644
index 00000000..23bb4b6a
--- /dev/null
+++ b/modules/textual_inversion/dataset.py
@@ -0,0 +1,121 @@
+import os
+import numpy as np
+import PIL
+import torch
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+
+import random
+import tqdm
+from modules import devices, shared
+import re
+
+re_numbers_at_start = re.compile(r"^[-\d]+\s*")
+
+
+class DatasetEntry:
+ def __init__(self, filename=None, latent=None, filename_text=None):
+ self.filename = filename
+ self.latent = latent
+ self.filename_text = filename_text
+ self.cond = None
+ self.cond_text = None
+
+
+class PersonalizedBase(Dataset):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
+ re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
+
+ self.placeholder_token = placeholder_token
+
+ self.batch_size = batch_size
+ self.width = width
+ self.height = height
+ self.flip = transforms.RandomHorizontalFlip(p=flip_p)
+
+ self.dataset = []
+
+ with open(template_file, "r") as file:
+ lines = [x.strip() for x in file.readlines()]
+
+ self.lines = lines
+
+ assert data_root, 'dataset directory not specified'
+
+ cond_model = shared.sd_model.cond_stage_model
+
+ self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
+ print("Preparing dataset...")
+ for path in tqdm.tqdm(self.image_paths):
+ try:
+ image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
+ except Exception:
+ continue
+
+ text_filename = os.path.splitext(path)[0] + ".txt"
+ filename = os.path.basename(path)
+
+ if os.path.exists(text_filename):
+ with open(text_filename, "r", encoding="utf8") as file:
+ filename_text = file.read()
+ else:
+ filename_text = os.path.splitext(filename)[0]
+ filename_text = re.sub(re_numbers_at_start, '', filename_text)
+ if re_word:
+ tokens = re_word.findall(filename_text)
+ filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
+
+ npimage = np.array(image).astype(np.uint8)
+ npimage = (npimage / 127.5 - 1.0).astype(np.float32)
+
+ torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
+ torchdata = torch.moveaxis(torchdata, 2, 0)
+
+ init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
+ init_latent = init_latent.to(devices.cpu)
+
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent)
+
+ if include_cond:
+ entry.cond_text = self.create_text(filename_text)
+ entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
+
+ self.dataset.append(entry)
+
+ assert len(self.dataset) > 1, "No images have been found in the dataset."
+ self.length = len(self.dataset) * repeats // batch_size
+
+ self.initial_indexes = np.arange(len(self.dataset))
+ self.indexes = None
+ self.shuffle()
+
+ def shuffle(self):
+ self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
+
+ def create_text(self, filename_text):
+ text = random.choice(self.lines)
+ text = text.replace("[name]", self.placeholder_token)
+ text = text.replace("[filewords]", filename_text)
+ return text
+
+ def __len__(self):
+ return self.length
+
+ def __getitem__(self, i):
+ res = []
+
+ for j in range(self.batch_size):
+ position = i * self.batch_size + j
+ if position % len(self.indexes) == 0:
+ self.shuffle()
+
+ index = self.indexes[position % len(self.indexes)]
+ entry = self.dataset[index]
+
+ if entry.cond is None:
+ entry.cond_text = self.create_text(entry.filename_text)
+
+ res.append(entry)
+
+ return res
diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py
new file mode 100644
index 00000000..898ce3b3
--- /dev/null
+++ b/modules/textual_inversion/image_embedding.py
@@ -0,0 +1,219 @@
+import base64
+import json
+import numpy as np
+import zlib
+from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
+from fonts.ttf import Roboto
+import torch
+
+
+class EmbeddingEncoder(json.JSONEncoder):
+ def default(self, obj):
+ if isinstance(obj, torch.Tensor):
+ return {'TORCHTENSOR': obj.cpu().detach().numpy().tolist()}
+ return json.JSONEncoder.default(self, obj)
+
+
+class EmbeddingDecoder(json.JSONDecoder):
+ def __init__(self, *args, **kwargs):
+ json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
+
+ def object_hook(self, d):
+ if 'TORCHTENSOR' in d:
+ return torch.from_numpy(np.array(d['TORCHTENSOR']))
+ return d
+
+
+def embedding_to_b64(data):
+ d = json.dumps(data, cls=EmbeddingEncoder)
+ return base64.b64encode(d.encode())
+
+
+def embedding_from_b64(data):
+ d = base64.b64decode(data)
+ return json.loads(d, cls=EmbeddingDecoder)
+
+
+def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
+ while True:
+ seed = (a * seed + c) % m
+ yield seed % 255
+
+
+def xor_block(block):
+ g = lcg()
+ randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape)
+ return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)
+
+
+def style_block(block, sequence):
+ im = Image.new('RGB', (block.shape[1], block.shape[0]))
+ draw = ImageDraw.Draw(im)
+ i = 0
+ for x in range(-6, im.size[0], 8):
+ for yi, y in enumerate(range(-6, im.size[1], 8)):
+ offset = 0
+ if yi % 2 == 0:
+ offset = 4
+ shade = sequence[i % len(sequence)]
+ i += 1
+ draw.ellipse((x+offset, y, x+6+offset, y+6), fill=(shade, shade, shade))
+
+ fg = np.array(im).astype(np.uint8) & 0xF0
+
+ return block ^ fg
+
+
+def insert_image_data_embed(image, data):
+ d = 3
+ data_compressed = zlib.compress(json.dumps(data, cls=EmbeddingEncoder).encode(), level=9)
+ data_np_ = np.frombuffer(data_compressed, np.uint8).copy()
+ data_np_high = data_np_ >> 4
+ data_np_low = data_np_ & 0x0F
+
+ h = image.size[1]
+ next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
+ next_size = next_size + ((h*d)-(next_size % (h*d)))
+
+ data_np_low.resize(next_size)
+ data_np_low = data_np_low.reshape((h, -1, d))
+
+ data_np_high.resize(next_size)
+ data_np_high = data_np_high.reshape((h, -1, d))
+
+ edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
+ edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8)
+
+ data_np_low = style_block(data_np_low, sequence=edge_style)
+ data_np_low = xor_block(data_np_low)
+ data_np_high = style_block(data_np_high, sequence=edge_style[::-1])
+ data_np_high = xor_block(data_np_high)
+
+ im_low = Image.fromarray(data_np_low, mode='RGB')
+ im_high = Image.fromarray(data_np_high, mode='RGB')
+
+ background = Image.new('RGB', (image.size[0]+im_low.size[0]+im_high.size[0]+2, image.size[1]), (0, 0, 0))
+ background.paste(im_low, (0, 0))
+ background.paste(image, (im_low.size[0]+1, 0))
+ background.paste(im_high, (im_low.size[0]+1+image.size[0]+1, 0))
+
+ return background
+
+
+def crop_black(img, tol=0):
+ mask = (img > tol).all(2)
+ mask0, mask1 = mask.any(0), mask.any(1)
+ col_start, col_end = mask0.argmax(), mask.shape[1]-mask0[::-1].argmax()
+ row_start, row_end = mask1.argmax(), mask.shape[0]-mask1[::-1].argmax()
+ return img[row_start:row_end, col_start:col_end]
+
+
+def extract_image_data_embed(image):
+ d = 3
+ outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F
+ black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0)
+ if black_cols[0].shape[0] < 2:
+ print('No Image data blocks found.')
+ return None
+
+ data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8)
+ data_block_upper = outarr[:, black_cols[0].max()+1:, :].astype(np.uint8)
+
+ data_block_lower = xor_block(data_block_lower)
+ data_block_upper = xor_block(data_block_upper)
+
+ data_block = (data_block_upper << 4) | (data_block_lower)
+ data_block = data_block.flatten().tobytes()
+
+ data = zlib.decompress(data_block)
+ return json.loads(data, cls=EmbeddingDecoder)
+
+
+def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None):
+ from math import cos
+
+ image = srcimage.copy()
+
+ if textfont is None:
+ try:
+ textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
+ textfont = opts.font or Roboto
+ except Exception:
+ textfont = Roboto
+
+ factor = 1.5
+ gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
+ for y in range(image.size[1]):
+ mag = 1-cos(y/image.size[1]*factor)
+ mag = max(mag, 1-cos((image.size[1]-y)/image.size[1]*factor*1.1))
+ gradient.putpixel((0, y), (0, 0, 0, int(mag*255)))
+ image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
+
+ draw = ImageDraw.Draw(image)
+ fontsize = 32
+ font = ImageFont.truetype(textfont, fontsize)
+ padding = 10
+
+ _, _, w, h = draw.textbbox((0, 0), title, font=font)
+ fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72)
+ font = ImageFont.truetype(textfont, fontsize)
+ _, _, w, h = draw.textbbox((0, 0), title, font=font)
+ draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230))
+
+ _, _, w, h = draw.textbbox((0, 0), footerLeft, font=font)
+ fontsize_left = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
+ _, _, w, h = draw.textbbox((0, 0), footerMid, font=font)
+ fontsize_mid = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
+ _, _, w, h = draw.textbbox((0, 0), footerRight, font=font)
+ fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
+
+ font = ImageFont.truetype(textfont, min(fontsize_left, fontsize_mid, fontsize_right))
+
+ draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230))
+ draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230))
+ draw.text((image.size[0]-padding, image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255, 255, 255, 230))
+
+ return image
+
+
+if __name__ == '__main__':
+
+ testEmbed = Image.open('test_embedding.png')
+ data = extract_image_data_embed(testEmbed)
+ assert data is not None
+
+ data = embedding_from_b64(testEmbed.text['sd-ti-embedding'])
+ assert data is not None
+
+ image = Image.new('RGBA', (512, 512), (255, 255, 200, 255))
+ cap_image = caption_image_overlay(image, 'title', 'footerLeft', 'footerMid', 'footerRight')
+
+ test_embed = {'string_to_param': {'*': torch.from_numpy(np.random.random((2, 4096)))}}
+
+ embedded_image = insert_image_data_embed(cap_image, test_embed)
+
+ retrived_embed = extract_image_data_embed(embedded_image)
+
+ assert str(retrived_embed) == str(test_embed)
+
+ embedded_image2 = insert_image_data_embed(cap_image, retrived_embed)
+
+ assert embedded_image == embedded_image2
+
+ g = lcg()
+ shared_random = np.array([next(g) for _ in range(100)]).astype(np.uint8).tolist()
+
+ reference_random = [253, 242, 127, 44, 157, 27, 239, 133, 38, 79, 167, 4, 177,
+ 95, 130, 79, 78, 14, 52, 215, 220, 194, 126, 28, 240, 179,
+ 160, 153, 149, 50, 105, 14, 21, 218, 199, 18, 54, 198, 193,
+ 38, 128, 19, 53, 195, 124, 75, 205, 12, 6, 145, 0, 28,
+ 30, 148, 8, 45, 218, 171, 55, 249, 97, 166, 12, 35, 0,
+ 41, 221, 122, 215, 170, 31, 113, 186, 97, 119, 31, 23, 185,
+ 66, 140, 30, 41, 37, 63, 137, 109, 216, 55, 159, 145, 82,
+ 204, 86, 73, 222, 44, 198, 118, 240, 97]
+
+ assert shared_random == reference_random
+
+ hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist())
+
+ assert 12731374 == hunna_kay_random_sum
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
new file mode 100644
index 00000000..2062726a
--- /dev/null
+++ b/modules/textual_inversion/learn_schedule.py
@@ -0,0 +1,69 @@
+import tqdm
+
+
+class LearnScheduleIterator:
+ def __init__(self, learn_rate, max_steps, cur_step=0):
+ """
+ specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000
+ """
+
+ pairs = learn_rate.split(',')
+ self.rates = []
+ self.it = 0
+ self.maxit = 0
+ for i, pair in enumerate(pairs):
+ tmp = pair.split(':')
+ if len(tmp) == 2:
+ step = int(tmp[1])
+ if step > cur_step:
+ self.rates.append((float(tmp[0]), min(step, max_steps)))
+ self.maxit += 1
+ if step > max_steps:
+ return
+ elif step == -1:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
+ return
+ else:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
+ return
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.it < self.maxit:
+ self.it += 1
+ return self.rates[self.it - 1]
+ else:
+ raise StopIteration
+
+
+class LearnRateScheduler:
+ def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
+ self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
+ (self.learn_rate, self.end_step) = next(self.schedules)
+ self.verbose = verbose
+
+ if self.verbose:
+ print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
+
+ self.finished = False
+
+ def apply(self, optimizer, step_number):
+ if step_number <= self.end_step:
+ return
+
+ try:
+ (self.learn_rate, self.end_step) = next(self.schedules)
+ except Exception:
+ self.finished = True
+ return
+
+ if self.verbose:
+ tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
+
+ for pg in optimizer.param_groups:
+ pg['lr'] = self.learn_rate
+
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
new file mode 100644
index 00000000..886cf0c3
--- /dev/null
+++ b/modules/textual_inversion/preprocess.py
@@ -0,0 +1,116 @@
+import os
+from PIL import Image, ImageOps
+import platform
+import sys
+import tqdm
+import time
+
+from modules import shared, images
+from modules.shared import opts, cmd_opts
+if cmd_opts.deepdanbooru:
+ import modules.deepbooru as deepbooru
+
+
+def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
+ try:
+ if process_caption:
+ shared.interrogator.load()
+
+ if process_caption_deepbooru:
+ db_opts = deepbooru.create_deepbooru_opts()
+ db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
+ deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
+
+ preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru)
+
+ finally:
+
+ if process_caption:
+ shared.interrogator.send_blip_to_ram()
+
+ if process_caption_deepbooru:
+ deepbooru.release_process()
+
+
+
+def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
+ width = process_width
+ height = process_height
+ src = os.path.abspath(process_src)
+ dst = os.path.abspath(process_dst)
+
+ assert src != dst, 'same directory specified as source and destination'
+
+ os.makedirs(dst, exist_ok=True)
+
+ files = os.listdir(src)
+
+ shared.state.textinfo = "Preprocessing..."
+ shared.state.job_count = len(files)
+
+ def save_pic_with_caption(image, index):
+ caption = ""
+
+ if process_caption:
+ caption += shared.interrogator.generate_caption(image)
+
+ if process_caption_deepbooru:
+ if len(caption) > 0:
+ caption += ", "
+ caption += deepbooru.get_tags_from_process(image)
+
+ filename_part = filename
+ filename_part = os.path.splitext(filename_part)[0]
+ filename_part = os.path.basename(filename_part)
+
+ basename = f"{index:05}-{subindex[0]}-{filename_part}"
+ image.save(os.path.join(dst, f"{basename}.png"))
+
+ if len(caption) > 0:
+ with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
+ file.write(caption)
+
+ subindex[0] += 1
+
+ def save_pic(image, index):
+ save_pic_with_caption(image, index)
+
+ if process_flip:
+ save_pic_with_caption(ImageOps.mirror(image), index)
+
+ for index, imagefile in enumerate(tqdm.tqdm(files)):
+ subindex = [0]
+ filename = os.path.join(src, imagefile)
+ try:
+ img = Image.open(filename).convert("RGB")
+ except Exception:
+ continue
+
+ if shared.state.interrupted:
+ break
+
+ ratio = img.height / img.width
+ is_tall = ratio > 1.35
+ is_wide = ratio < 1 / 1.35
+
+ if process_split and is_tall:
+ img = img.resize((width, height * img.height // img.width))
+
+ top = img.crop((0, 0, width, height))
+ save_pic(top, index)
+
+ bot = img.crop((0, img.height - height, width, img.height))
+ save_pic(bot, index)
+ elif process_split and is_wide:
+ img = img.resize((width * img.width // img.height, height))
+
+ left = img.crop((0, 0, width, height))
+ save_pic(left, index)
+
+ right = img.crop((img.width - width, 0, img.width, height))
+ save_pic(right, index)
+ else:
+ img = images.resize_image(1, img, width, height)
+ save_pic(img, index)
+
+ shared.state.nextjob()
diff --git a/modules/textual_inversion/test_embedding.png b/modules/textual_inversion/test_embedding.png
new file mode 100644
index 00000000..07e2d9af
--- /dev/null
+++ b/modules/textual_inversion/test_embedding.png
Binary files differ
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
new file mode 100644
index 00000000..e754747e
--- /dev/null
+++ b/modules/textual_inversion/textual_inversion.py
@@ -0,0 +1,363 @@
+import os
+import sys
+import traceback
+
+import torch
+import tqdm
+import html
+import datetime
+import csv
+
+from PIL import Image, PngImagePlugin
+
+from modules import shared, devices, sd_hijack, processing, sd_models
+import modules.textual_inversion.dataset
+from modules.textual_inversion.learn_schedule import LearnRateScheduler
+
+from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
+ insert_image_data_embed, extract_image_data_embed,
+ caption_image_overlay)
+
+class Embedding:
+ def __init__(self, vec, name, step=None):
+ self.vec = vec
+ self.name = name
+ self.step = step
+ self.cached_checksum = None
+ self.sd_checkpoint = None
+ self.sd_checkpoint_name = None
+
+ def save(self, filename):
+ embedding_data = {
+ "string_to_token": {"*": 265},
+ "string_to_param": {"*": self.vec},
+ "name": self.name,
+ "step": self.step,
+ "sd_checkpoint": self.sd_checkpoint,
+ "sd_checkpoint_name": self.sd_checkpoint_name,
+ }
+
+ torch.save(embedding_data, filename)
+
+ def checksum(self):
+ if self.cached_checksum is not None:
+ return self.cached_checksum
+
+ def const_hash(a):
+ r = 0
+ for v in a:
+ r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
+ return r
+
+ self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
+ return self.cached_checksum
+
+
+class EmbeddingDatabase:
+ def __init__(self, embeddings_dir):
+ self.ids_lookup = {}
+ self.word_embeddings = {}
+ self.dir_mtime = None
+ self.embeddings_dir = embeddings_dir
+
+ def register_embedding(self, embedding, model):
+
+ self.word_embeddings[embedding.name] = embedding
+
+ ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]
+
+ first_id = ids[0]
+ if first_id not in self.ids_lookup:
+ self.ids_lookup[first_id] = []
+
+ self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
+
+ return embedding
+
+ def load_textual_inversion_embeddings(self):
+ mt = os.path.getmtime(self.embeddings_dir)
+ if self.dir_mtime is not None and mt <= self.dir_mtime:
+ return
+
+ self.dir_mtime = mt
+ self.ids_lookup.clear()
+ self.word_embeddings.clear()
+
+ def process_file(path, filename):
+ name = os.path.splitext(filename)[0]
+
+ data = []
+
+ if filename.upper().endswith('.PNG'):
+ embed_image = Image.open(path)
+ if 'sd-ti-embedding' in embed_image.text:
+ data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
+ name = data.get('name', name)
+ else:
+ data = extract_image_data_embed(embed_image)
+ name = data.get('name', name)
+ else:
+ data = torch.load(path, map_location="cpu")
+
+ # textual inversion embeddings
+ if 'string_to_param' in data:
+ param_dict = data['string_to_param']
+ if hasattr(param_dict, '_parameters'):
+ param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
+ assert len(param_dict) == 1, 'embedding file has multiple terms in it'
+ emb = next(iter(param_dict.items()))[1]
+ # diffuser concepts
+ elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
+ assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
+
+ emb = next(iter(data.values()))
+ if len(emb.shape) == 1:
+ emb = emb.unsqueeze(0)
+ else:
+ raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
+
+ vec = emb.detach().to(devices.device, dtype=torch.float32)
+ embedding = Embedding(vec, name)
+ embedding.step = data.get('step', None)
+ embedding.sd_checkpoint = data.get('hash', None)
+ embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
+ self.register_embedding(embedding, shared.sd_model)
+
+ for fn in os.listdir(self.embeddings_dir):
+ try:
+ fullfn = os.path.join(self.embeddings_dir, fn)
+
+ if os.stat(fullfn).st_size == 0:
+ continue
+
+ process_file(fullfn, fn)
+ except Exception:
+ print(f"Error loading emedding {fn}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ continue
+
+ print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
+
+ def find_embedding_at_position(self, tokens, offset):
+ token = tokens[offset]
+ possible_matches = self.ids_lookup.get(token, None)
+
+ if possible_matches is None:
+ return None, None
+
+ for ids, embedding in possible_matches:
+ if tokens[offset:offset + len(ids)] == ids:
+ return embedding, len(ids)
+
+ return None, None
+
+
+def create_embedding(name, num_vectors_per_token, init_text='*'):
+ cond_model = shared.sd_model.cond_stage_model
+ embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
+
+ ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
+ embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
+ vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
+
+ for i in range(num_vectors_per_token):
+ vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
+
+ fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
+ assert not os.path.exists(fn), f"file {fn} already exists"
+
+ embedding = Embedding(vec, name)
+ embedding.step = 0
+ embedding.save(fn)
+
+ return fn
+
+
+def write_loss(log_directory, filename, step, epoch_len, values):
+ if shared.opts.training_write_csv_every == 0:
+ return
+
+ if step % shared.opts.training_write_csv_every != 0:
+ return
+
+ write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
+
+ with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
+ csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())])
+
+ if write_csv_header:
+ csv_writer.writeheader()
+
+ epoch = step // epoch_len
+ epoch_step = step - epoch * epoch_len
+
+ csv_writer.writerow({
+ "step": step + 1,
+ "epoch": epoch + 1,
+ "epoch_step": epoch_step + 1,
+ **values,
+ })
+
+
+def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+ assert embedding_name, 'embedding not selected'
+
+ shared.state.textinfo = "Initializing textual inversion training..."
+ shared.state.job_count = steps
+
+ filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
+
+ log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
+
+ if save_embedding_every > 0:
+ embedding_dir = os.path.join(log_directory, "embeddings")
+ os.makedirs(embedding_dir, exist_ok=True)
+ else:
+ embedding_dir = None
+
+ if create_image_every > 0:
+ images_dir = os.path.join(log_directory, "images")
+ os.makedirs(images_dir, exist_ok=True)
+ else:
+ images_dir = None
+
+ if create_image_every > 0 and save_image_with_stored_embedding:
+ images_embeds_dir = os.path.join(log_directory, "image_embeddings")
+ os.makedirs(images_embeds_dir, exist_ok=True)
+ else:
+ images_embeds_dir = None
+
+ cond_model = shared.sd_model.cond_stage_model
+
+ shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
+ with torch.autocast("cuda"):
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
+
+ hijack = sd_hijack.model_hijack
+
+ embedding = hijack.embedding_db.word_embeddings[embedding_name]
+ embedding.vec.requires_grad = True
+
+ losses = torch.zeros((32,))
+
+ last_saved_file = "<none>"
+ last_saved_image = "<none>"
+
+ ititial_step = embedding.step or 0
+ if ititial_step > steps:
+ return embedding, filename
+
+ scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+ optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
+
+ pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
+ for i, entries in pbar:
+ embedding.step = i + ititial_step
+
+ scheduler.apply(optimizer, embedding.step)
+ if scheduler.finished:
+ break
+
+ if shared.state.interrupted:
+ break
+
+ with torch.autocast("cuda"):
+ c = cond_model([entry.cond_text for entry in entries])
+ x = torch.stack([entry.latent for entry in entries]).to(devices.device)
+ loss = shared.sd_model(x, c)[0]
+ del x
+
+ losses[embedding.step % losses.shape[0]] = loss.item()
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ epoch_num = embedding.step // len(ds)
+ epoch_step = embedding.step - (epoch_num * len(ds)) + 1
+
+ pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
+
+ if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
+ last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
+ embedding.save(last_saved_file)
+
+ write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
+ "loss": f"{losses.mean():.7f}",
+ "learn_rate": scheduler.learn_rate
+ })
+
+ if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
+ last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
+
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ )
+
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_index = preview_sampler_index
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = entries[0].cond_text
+ p.steps = 20
+ p.width = training_width
+ p.height = training_height
+
+ preview_text = p.prompt
+
+ processed = processing.process_images(p)
+ image = processed.images[0]
+
+ shared.state.current_image = image
+
+ if save_image_with_stored_embedding and os.path.exists(last_saved_file):
+
+ last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
+
+ info = PngImagePlugin.PngInfo()
+ data = torch.load(last_saved_file)
+ info.add_text("sd-ti-embedding", embedding_to_b64(data))
+
+ title = "<{}>".format(data.get('name', '???'))
+ checkpoint = sd_models.select_checkpoint()
+ footer_left = checkpoint.model_name
+ footer_mid = '[{}]'.format(checkpoint.hash)
+ footer_right = '{}'.format(embedding.step)
+
+ captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
+ captioned_image = insert_image_data_embed(captioned_image, data)
+
+ captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
+
+ image.save(last_saved_image)
+
+ last_saved_image += f", prompt: {preview_text}"
+
+ shared.state.job_no = embedding.step
+
+ shared.state.textinfo = f"""
+<p>
+Loss: {losses.mean():.7f}<br/>
+Step: {embedding.step}<br/>
+Last prompt: {html.escape(entries[0].cond_text)}<br/>
+Last saved embedding: {html.escape(last_saved_file)}<br/>
+Last saved image: {html.escape(last_saved_image)}<br/>
+</p>
+"""
+
+ checkpoint = sd_models.select_checkpoint()
+
+ embedding.sd_checkpoint = checkpoint.hash
+ embedding.sd_checkpoint_name = checkpoint.model_name
+ embedding.cached_checksum = None
+ embedding.save(filename)
+
+ return embedding, filename
diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
new file mode 100644
index 00000000..36881e7a
--- /dev/null
+++ b/modules/textual_inversion/ui.py
@@ -0,0 +1,42 @@
+import html
+
+import gradio as gr
+
+import modules.textual_inversion.textual_inversion
+import modules.textual_inversion.preprocess
+from modules import sd_hijack, shared
+
+
+def create_embedding(name, initialization_text, nvpt):
+ filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text)
+
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
+
+ return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
+
+
+def preprocess(*args):
+ modules.textual_inversion.preprocess.preprocess(*args)
+
+ return "Preprocessing finished.", ""
+
+
+def train_embedding(*args):
+
+ assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
+
+ try:
+ sd_hijack.undo_optimizations()
+
+ embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
+
+ res = f"""
+Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps.
+Embedding saved to {html.escape(filename)}
+"""
+ return res, ""
+ except Exception:
+ raise
+ finally:
+ sd_hijack.apply_optimizations()
+
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 5368e4d0..2381347f 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -6,7 +6,7 @@ import modules.processing as processing
from modules.ui import plaintext_to_html
-def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, scale_latent: bool, denoising_strength: float, *args):
+def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@@ -30,11 +30,14 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
restore_faces=restore_faces,
tiling=tiling,
enable_hr=enable_hr,
- scale_latent=scale_latent if enable_hr else None,
denoising_strength=denoising_strength if enable_hr else None,
+ firstphase_width=firstphase_width if enable_hr else None,
+ firstphase_height=firstphase_height if enable_hr else None,
)
- print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
+ if cmd_opts.enable_console_prompts:
+ print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
+
processed = modules.scripts.scripts_txt2img.run(p, *args)
if processed is None:
@@ -46,5 +49,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
if opts.samples_log_stdout:
print(generation_info_js)
+ if opts.do_not_show_images:
+ processed.images = []
+
return processed.images, generation_info_js, plaintext_to_html(processed.info)
diff --git a/modules/ui.py b/modules/ui.py
index 15572bb0..baf4c397 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -11,6 +11,7 @@ import time
import traceback
import platform
import subprocess as sp
+from functools import reduce
import numpy as np
import torch
@@ -21,8 +22,11 @@ import gradio as gr
import gradio.utils
import gradio.routes
+from modules import sd_hijack, sd_models
from modules.paths import script_path
from modules.shared import opts, cmd_opts
+if cmd_opts.deepdanbooru:
+ from modules.deepbooru import get_deepbooru_tags
import modules.shared as shared
from modules.sd_samplers import samplers, samplers_for_img2img
from modules.sd_hijack import model_hijack
@@ -32,8 +36,13 @@ import modules.gfpgan_model
import modules.codeformer_model
import modules.styles
import modules.generation_parameters_copypaste
+from modules import prompt_parser
+from modules.images import save_image
+import modules.textual_inversion.ui
+import modules.hypernetworks.ui
+import modules.images_history as img_his
-# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
+# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')
@@ -43,6 +52,11 @@ if not cmd_opts.share and not cmd_opts.listen:
gradio.utils.version_check = lambda: None
gradio.utils.get_local_ip_address = lambda: '127.0.0.1'
+if cmd_opts.ngrok != None:
+ import modules.ngrok as ngrok
+ print('ngrok authtoken detected, trying to connect...')
+ ngrok.connect(cmd_opts.ngrok, cmd_opts.port if cmd_opts.port != None else 7860)
+
def gr_show(visible=True):
return {"visible": visible, "__type__": "update"}
@@ -64,7 +78,9 @@ random_symbol = '\U0001f3b2\ufe0f' # 🎲️
reuse_symbol = '\u267b\ufe0f' # ♻️
art_symbol = '\U0001f3a8' # 🎨
paste_symbol = '\u2199\ufe0f' # ↙
-folder_symbol = '\uD83D\uDCC2'
+folder_symbol = '\U0001f4c2' # 📂
+refresh_symbol = '\U0001f504' # 🔄
+
def plaintext_to_html(text):
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
@@ -93,19 +109,32 @@ def send_gradio_gallery_to_image(x):
return image_from_url_text(x[0])
-def save_files(js_data, images, index):
+def save_files(js_data, images, do_make_zip, index):
import csv
-
- os.makedirs(opts.outdir_save, exist_ok=True)
-
filenames = []
+ fullfns = []
+
+ #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
+ class MyObject:
+ def __init__(self, d=None):
+ if d is not None:
+ for key, value in d.items():
+ setattr(self, key, value)
data = json.loads(js_data)
+
+ p = MyObject(data)
+ path = opts.outdir_save
+ save_to_dirs = opts.use_save_to_dirs_for_ui
+ extension: str = opts.samples_format
+ start_index = 0
+
if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
+
images = [images[index]]
- infotexts = [data["infotexts"][index]]
- else:
- infotexts = data["infotexts"]
+ start_index = index
+
+ os.makedirs(opts.outdir_save, exist_ok=True)
with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
at_start = file.tell() == 0
@@ -113,37 +142,42 @@ def save_files(js_data, images, index):
if at_start:
writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
- filename_base = str(int(time.time() * 1000))
- extension = opts.samples_format.lower()
- for i, filedata in enumerate(images):
- filename = filename_base + ("" if len(images) == 1 else "-" + str(i + 1)) + f".{extension}"
- filepath = os.path.join(opts.outdir_save, filename)
-
+ for image_index, filedata in enumerate(images, start_index):
if filedata.startswith("data:image/png;base64,"):
filedata = filedata[len("data:image/png;base64,"):]
image = Image.open(io.BytesIO(base64.decodebytes(filedata.encode('utf-8'))))
- if opts.enable_pnginfo and extension == 'png':
- pnginfo = PngImagePlugin.PngInfo()
- pnginfo.add_text('parameters', infotexts[i])
- image.save(filepath, pnginfo=pnginfo)
- else:
- image.save(filepath, quality=opts.jpeg_quality)
- if opts.enable_pnginfo and extension in ("jpg", "jpeg", "webp"):
- piexif.insert(piexif.dump({"Exif": {
- piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(infotexts[i], encoding="unicode")
- }}), filepath)
+ is_grid = image_index < p.index_of_first_image
+ i = 0 if is_grid else (image_index - p.index_of_first_image)
+
+ fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
+ filename = os.path.relpath(fullfn, path)
filenames.append(filename)
+ fullfns.append(fullfn)
+ if txt_fullfn:
+ filenames.append(os.path.basename(txt_fullfn))
+ fullfns.append(txt_fullfn)
writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
- return '', '', plaintext_to_html(f"Saved: {filenames[0]}")
+ # Make Zip
+ if do_make_zip:
+ zip_filepath = os.path.join(path, "images.zip")
+ from zipfile import ZipFile
+ with ZipFile(zip_filepath, "w") as zip_file:
+ for i in range(len(fullfns)):
+ with open(fullfns[i], mode="rb") as f:
+ zip_file.writestr(filenames[i], f.read())
+ fullfns.insert(0, zip_filepath)
-def wrap_gradio_call(func):
- def f(*args, **kwargs):
+ return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")
+
+
+def wrap_gradio_call(func, extra_outputs=None):
+ def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
if run_memmon:
shared.mem_mon.monitor()
@@ -152,16 +186,31 @@ def wrap_gradio_call(func):
try:
res = list(func(*args, **kwargs))
except Exception as e:
+ # When printing out our debug argument list, do not print out more than a MB of text
+ max_debug_str_len = 131072 # (1024*1024)/8
+
print("Error completing request", file=sys.stderr)
- print("Arguments:", args, kwargs, file=sys.stderr)
+ argStr = f"Arguments: {str(args)} {str(kwargs)}"
+ print(argStr[:max_debug_str_len], file=sys.stderr)
+ if len(argStr) > max_debug_str_len:
+ print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
+
print(traceback.format_exc(), file=sys.stderr)
shared.state.job = ""
shared.state.job_count = 0
- res = [None, '', f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
+ if extra_outputs_array is None:
+ extra_outputs_array = [None, '']
+
+ res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
elapsed = time.perf_counter() - t
+ elapsed_m = int(elapsed // 60)
+ elapsed_s = elapsed % 60
+ elapsed_text = f"{elapsed_s:.2f}s"
+ if (elapsed_m > 0):
+ elapsed_text = f"{elapsed_m}m "+elapsed_text
if run_memmon:
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
@@ -176,9 +225,11 @@ def wrap_gradio_call(func):
vram_html = ''
# last item is always HTML
- res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>{vram_html}</div>"
+ res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
+ shared.state.skipped = False
shared.state.interrupted = False
+ shared.state.job_count = 0
return tuple(res)
@@ -187,7 +238,7 @@ def wrap_gradio_call(func):
def check_progress_call(id_part):
if shared.state.job_count == 0:
- return "", gr_show(False), gr_show(False)
+ return "", gr_show(False), gr_show(False), gr_show(False)
progress = 0
@@ -219,13 +270,19 @@ def check_progress_call(id_part):
else:
preview_visibility = gr_show(True)
- return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image
+ if shared.state.textinfo is not None:
+ textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
+ else:
+ textinfo_result = gr_show(False)
+
+ return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image, textinfo_result
def check_progress_call_initial(id_part):
shared.state.job_count = -1
shared.state.current_latent = None
shared.state.current_image = None
+ shared.state.textinfo = None
return check_progress_call(id_part)
@@ -271,6 +328,11 @@ def interrogate(image):
return gr_show(True) if prompt is None else prompt
+def interrogate_deepbooru(image):
+ prompt = get_deepbooru_tags(image)
+ return gr_show(True) if prompt is None else prompt
+
+
def create_seed_inputs():
with gr.Row():
with gr.Box():
@@ -345,11 +407,24 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
outputs=[seed, dummy_component]
)
-def update_token_counter(text):
- tokens, token_count, max_length = model_hijack.tokenize(text)
+
+def update_token_counter(text, steps):
+ try:
+ _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
+ prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
+
+ except Exception:
+ # a parsing error can happen here during typing, and we don't want to bother the user with
+ # messages related to it in console
+ prompt_schedules = [[[steps, text]]]
+
+ flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
+ prompts = [prompt_text for step, prompt_text in flat_prompts]
+ tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1])
style_class = ' class="red"' if (token_count > max_length) else ""
return f"<span {style_class}>{token_count}/{max_length}</span>"
+
def create_toprow(is_img2img):
id_part = "img2img" if is_img2img else "txt2img"
@@ -358,54 +433,76 @@ def create_toprow(is_img2img):
with gr.Row():
with gr.Column(scale=80):
with gr.Row():
- prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, placeholder="Prompt", lines=2)
+ prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
+ placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
+ )
with gr.Column(scale=1, elem_id="roll_col"):
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
paste = gr.Button(value=paste_symbol, elem_id="paste")
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
- hidden_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
- hidden_button.click(fn=update_token_counter, inputs=[prompt], outputs=[token_counter])
+ token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
with gr.Column(scale=10, elem_id="style_pos_col"):
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
with gr.Row():
with gr.Column(scale=8):
- negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2)
+ with gr.Row():
+ negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
+ placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
+ )
+
+ with gr.Column(scale=1, elem_id="roll_col"):
+ sh = gr.Button(elem_id="sh", visible=True)
with gr.Column(scale=1, elem_id="style_neg_col"):
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
with gr.Column(scale=1):
with gr.Row():
+ skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
+ skip.click(
+ fn=lambda: shared.state.skip(),
+ inputs=[],
+ outputs=[],
+ )
+
interrupt.click(
fn=lambda: shared.state.interrupt(),
inputs=[],
outputs=[],
)
- with gr.Row():
+ with gr.Row(scale=1):
if is_img2img:
- interrogate = gr.Button('Interrogate', elem_id="interrogate")
+ interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
+ if cmd_opts.deepdanbooru:
+ deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
+ else:
+ deepbooru = None
else:
interrogate = None
+ deepbooru = None
prompt_style_apply = gr.Button('Apply style', elem_id="style_apply")
save_style = gr.Button('Create style', elem_id="style_create")
- return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste
+ return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
-def setup_progressbar(progressbar, preview, id_part):
+def setup_progressbar(progressbar, preview, id_part, textinfo=None):
+ if textinfo is None:
+ textinfo = gr.HTML(visible=False)
+
check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
check_progress.click(
fn=lambda: check_progress_call(id_part),
show_progress=False,
inputs=[],
- outputs=[progressbar, preview, preview],
+ outputs=[progressbar, preview, preview, textinfo],
)
check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
@@ -413,14 +510,44 @@ def setup_progressbar(progressbar, preview, id_part):
fn=lambda: check_progress_call_initial(id_part),
show_progress=False,
inputs=[],
- outputs=[progressbar, preview, preview],
+ outputs=[progressbar, preview, preview, textinfo],
)
-def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
+def apply_setting(key, value):
+ if value is None:
+ return gr.update()
+
+ if key == "sd_model_checkpoint":
+ ckpt_info = sd_models.get_closet_checkpoint_match(value)
+
+ if ckpt_info is not None:
+ value = ckpt_info.title
+ else:
+ return gr.update()
+
+ comp_args = opts.data_labels[key].component_args
+ if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
+ return
+
+ valtype = type(opts.data_labels[key].default)
+ oldval = opts.data[key]
+ opts.data[key] = valtype(value) if valtype != type(None) else value
+ if oldval != value and opts.data_labels[key].onchange is not None:
+ opts.data_labels[key].onchange()
+
+ opts.save(shared.config_filename)
+ return value
+
+
+def create_ui(wrap_gradio_gpu_call):
+ import modules.img2img
+ import modules.txt2img
+
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False)
+ txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
+ txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
with gr.Row(elem_id='txt2img_progress_row'):
with gr.Column(scale=1):
@@ -446,11 +573,12 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
enable_hr = gr.Checkbox(label='Highres. fix', value=False)
with gr.Row(visible=False) as hr_options:
- scale_latent = gr.Checkbox(label='Scale latent', value=False)
+ firstphase_width = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass width", value=0)
+ firstphase_height = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass height", value=0)
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)
- with gr.Row():
- batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
+ with gr.Row(equal_height=True):
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
@@ -475,15 +603,21 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id)
- with gr.Group():
- html_info = gr.HTML()
- generation_info = gr.Textbox(visible=False)
+ with gr.Row():
+ do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
+
+ with gr.Row():
+ download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
+
+ with gr.Group():
+ html_info = gr.HTML()
+ generation_info = gr.Textbox(visible=False)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
txt2img_args = dict(
- fn=txt2img,
+ fn=wrap_gradio_gpu_call(modules.txt2img.txt2img),
_js="submit",
inputs=[
txt2img_prompt,
@@ -502,8 +636,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
height,
width,
enable_hr,
- scale_latent,
denoising_strength,
+ firstphase_width,
+ firstphase_height,
] + custom_inputs,
outputs=[
txt2img_gallery,
@@ -516,6 +651,17 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
txt2img_prompt.submit(**txt2img_args)
submit.click(**txt2img_args)
+ txt_prompt_img.change(
+ fn=modules.images.image_data,
+ inputs=[
+ txt_prompt_img
+ ],
+ outputs=[
+ txt2img_prompt,
+ txt_prompt_img
+ ]
+ )
+
enable_hr.change(
fn=lambda x: gr_show(x),
inputs=[enable_hr],
@@ -524,13 +670,15 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
save.click(
fn=wrap_gradio_call(save_files),
- _js="(x, y, z) => [x, y, selected_gallery_index()]",
+ _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]",
inputs=[
generation_info,
txt2img_gallery,
+ do_make_zip,
html_info,
],
outputs=[
+ download_files,
html_info,
html_info,
html_info,
@@ -539,6 +687,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
roll.click(
fn=roll_artist,
+ _js="update_txt2img_tokens",
inputs=[
txt2img_prompt,
],
@@ -565,13 +714,29 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
(denoising_strength, "Denoising strength"),
(enable_hr, lambda d: "Denoising strength" in d),
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
+ (firstphase_width, "First pass size-1"),
+ (firstphase_height, "First pass size-2"),
+ ]
+
+ txt2img_preview_params = [
+ txt2img_prompt,
+ txt2img_negative_prompt,
+ steps,
+ sampler_index,
+ cfg_scale,
+ seed,
+ width,
+ height,
]
- modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt)
+
+ token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
with gr.Blocks(analytics_enabled=False) as img2img_interface:
- img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste = create_toprow(is_img2img=True)
+ img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button = create_toprow(is_img2img=True)
with gr.Row(elem_id='img2img_progress_row'):
+ img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
+
with gr.Column(scale=1):
pass
@@ -585,7 +750,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
with gr.TabItem('img2img', id='img2img'):
- init_img = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil")
+ init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool)
with gr.TabItem('Inpaint', id='inpaint'):
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA")
@@ -607,7 +772,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
with gr.TabItem('Batch img2img', id='batch'):
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
- gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.{hidden}</p>")
+ gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>")
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs)
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs)
@@ -626,7 +791,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
tiling = gr.Checkbox(label='Tiling', value=False)
with gr.Row():
- batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
with gr.Group():
@@ -653,13 +818,30 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id)
- with gr.Group():
- html_info = gr.HTML()
- generation_info = gr.Textbox(visible=False)
+ with gr.Row():
+ do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
+
+ with gr.Row():
+ download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
+
+ with gr.Group():
+ html_info = gr.HTML()
+ generation_info = gr.Textbox(visible=False)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
+ img2img_prompt_img.change(
+ fn=modules.images.image_data,
+ inputs=[
+ img2img_prompt_img
+ ],
+ outputs=[
+ img2img_prompt,
+ img2img_prompt_img
+ ]
+ )
+
mask_mode.change(
lambda mode, img: {
init_img_with_mask: gr_show(mode == 0),
@@ -675,7 +857,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
)
img2img_args = dict(
- fn=img2img,
+ fn=wrap_gradio_gpu_call(modules.img2img.img2img),
_js="submit_img2img",
inputs=[
dummy_component,
@@ -726,15 +908,24 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
outputs=[img2img_prompt],
)
+ if cmd_opts.deepdanbooru:
+ img2img_deepbooru.click(
+ fn=interrogate_deepbooru,
+ inputs=[init_img],
+ outputs=[img2img_prompt],
+ )
+
save.click(
fn=wrap_gradio_call(save_files),
- _js="(x, y, z) => [x, y, selected_gallery_index()]",
+ _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]",
inputs=[
generation_info,
img2img_gallery,
- html_info
+ do_make_zip,
+ html_info,
],
outputs=[
+ download_files,
html_info,
html_info,
html_info,
@@ -743,6 +934,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
roll.click(
fn=roll_artist,
+ _js="update_img2img_tokens",
inputs=[
img2img_prompt,
],
@@ -753,6 +945,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)]
+ style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
button.click(
@@ -764,9 +957,10 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2],
)
- for button, (prompt, negative_prompt), (style1, style2) in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns):
+ for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
button.click(
fn=apply_styles,
+ _js=js_func,
inputs=[prompt, negative_prompt, style1, style2],
outputs=[prompt, negative_prompt, style1, style2],
)
@@ -788,7 +982,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
(seed_resize_from_h, "Seed resize from-2"),
(denoising_strength, "Denoising strength"),
]
- modules.generation_parameters_copypaste.connect_paste(paste, img2img_paste_fields, img2img_prompt)
+ token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
with gr.Blocks(analytics_enabled=False) as extras_interface:
with gr.Row().style(equal_height=False):
@@ -800,7 +994,15 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
with gr.TabItem('Batch Process'):
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file")
- upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2)
+ with gr.Tabs(elem_id="extras_resize_mode"):
+ with gr.TabItem('Scale by'):
+ upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2)
+ with gr.TabItem('Scale to'):
+ with gr.Group():
+ with gr.Row():
+ upscaling_resize_w = gr.Number(label="Width", value=512, precision=0)
+ upscaling_resize_h = gr.Number(label="Height", value=512, precision=0)
+ upscaling_crop = gr.Checkbox(label='Crop to fit', value=True)
with gr.Group():
extras_upscaler_1 = gr.Radio(label='Upscaler 1', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
@@ -827,17 +1029,22 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else ''
open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
+
submit.click(
- fn=run_extras,
+ fn=wrap_gradio_gpu_call(modules.extras.run_extras),
_js="get_extras_tab_index",
inputs=[
dummy_component,
+ dummy_component,
extras_image,
image_batch,
gfpgan_visibility,
codeformer_visibility,
codeformer_weight,
upscaling_resize,
+ upscaling_resize_w,
+ upscaling_resize_h,
+ upscaling_crop,
extras_upscaler_1,
extras_upscaler_2,
extras_upscaler_2_visibility,
@@ -848,17 +1055,17 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
html_info,
]
)
-
+
extras_send_to_img2img.click(
fn=lambda x: image_from_url_text(x),
_js="extract_image_from_gallery_img2img",
inputs=[result_images],
outputs=[init_img],
)
-
+
extras_send_to_inpaint.click(
fn=lambda x: image_from_url_text(x),
- _js="extract_image_from_gallery_img2img",
+ _js="extract_image_from_gallery_inpaint",
inputs=[result_images],
outputs=[init_img_with_mask],
)
@@ -878,29 +1085,220 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
pnginfo_send_to_img2img = gr.Button('Send to img2img')
image.change(
- fn=wrap_gradio_call(run_pnginfo),
+ fn=wrap_gradio_call(modules.extras.run_pnginfo),
inputs=[image],
outputs=[html, generation_info, html2],
)
+ #images history
+ images_history_switch_dict = {
+ "fn":modules.generation_parameters_copypaste.connect_paste,
+ "t2i":txt2img_paste_fields,
+ "i2i":img2img_paste_fields
+ }
+
+ images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
with gr.Blocks() as modelmerger_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
-
+
with gr.Row():
- primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name")
- secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name")
+ primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
+ secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
+ tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
custom_name = gr.Textbox(label="Custom Name (Optional)")
- interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3)
- interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method")
- save_as_half = gr.Checkbox(value=False, label="Safe as float16")
+ interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3)
+ interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method")
+ save_as_half = gr.Checkbox(value=False, label="Save as float16")
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
-
+
with gr.Column(variant='panel'):
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
- def create_setting_component(key):
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
+
+ with gr.Blocks() as train_interface:
+ with gr.Row().style(equal_height=False):
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
+
+ with gr.Row().style(equal_height=False):
+ with gr.Tabs(elem_id="train_tabs"):
+
+ with gr.Tab(label="Create embedding"):
+ new_embedding_name = gr.Textbox(label="Name")
+ initialization_text = gr.Textbox(label="Initialization text", value="*")
+ nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ gr.HTML(value="")
+
+ with gr.Column():
+ create_embedding = gr.Button(value="Create embedding", variant='primary')
+
+ with gr.Tab(label="Create hypernetwork"):
+ new_hypernetwork_name = gr.Textbox(label="Name")
+ new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ gr.HTML(value="")
+
+ with gr.Column():
+ create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary')
+
+ with gr.Tab(label="Preprocess images"):
+ process_src = gr.Textbox(label='Source directory')
+ process_dst = gr.Textbox(label='Destination directory')
+ process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
+ process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+
+ with gr.Row():
+ process_flip = gr.Checkbox(label='Create flipped copies')
+ process_split = gr.Checkbox(label='Split oversized images into two')
+ process_caption = gr.Checkbox(label='Use BLIP for caption')
+ process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False)
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ gr.HTML(value="")
+
+ with gr.Column():
+ run_preprocess = gr.Button(value="Preprocess", variant='primary')
+
+ with gr.Tab(label="Train"):
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
+ train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
+ train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
+ learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
+ batch_size = gr.Number(label='Batch size', value=1, precision=0)
+ dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
+ log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
+ template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
+ training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
+ training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+ steps = gr.Number(label='Max steps', value=100000, precision=0)
+ create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
+ save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
+ save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
+ preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
+
+ with gr.Row():
+ interrupt_training = gr.Button(value="Interrupt")
+ train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
+ train_embedding = gr.Button(value="Train Embedding", variant='primary')
+
+ with gr.Column():
+ progressbar = gr.HTML(elem_id="ti_progressbar")
+ ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
+
+ ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
+ ti_preview = gr.Image(elem_id='ti_preview', visible=False)
+ ti_progress = gr.HTML(elem_id="ti_progress", value="")
+ ti_outcome = gr.HTML(elem_id="ti_error", value="")
+ setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress)
+
+ create_embedding.click(
+ fn=modules.textual_inversion.ui.create_embedding,
+ inputs=[
+ new_embedding_name,
+ initialization_text,
+ nvpt,
+ ],
+ outputs=[
+ train_embedding_name,
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
+ create_hypernetwork.click(
+ fn=modules.hypernetworks.ui.create_hypernetwork,
+ inputs=[
+ new_hypernetwork_name,
+ new_hypernetwork_sizes,
+ ],
+ outputs=[
+ train_hypernetwork_name,
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
+ run_preprocess.click(
+ fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
+ _js="start_training_textual_inversion",
+ inputs=[
+ process_src,
+ process_dst,
+ process_width,
+ process_height,
+ process_flip,
+ process_split,
+ process_caption,
+ process_caption_deepbooru
+ ],
+ outputs=[
+ ti_output,
+ ti_outcome,
+ ],
+ )
+
+ train_embedding.click(
+ fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
+ _js="start_training_textual_inversion",
+ inputs=[
+ train_embedding_name,
+ learn_rate,
+ batch_size,
+ dataset_directory,
+ log_directory,
+ training_width,
+ training_height,
+ steps,
+ create_image_every,
+ save_embedding_every,
+ template_file,
+ save_image_with_stored_embedding,
+ preview_from_txt2img,
+ *txt2img_preview_params,
+ ],
+ outputs=[
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
+ train_hypernetwork.click(
+ fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]),
+ _js="start_training_textual_inversion",
+ inputs=[
+ train_hypernetwork_name,
+ learn_rate,
+ batch_size,
+ dataset_directory,
+ log_directory,
+ steps,
+ create_image_every,
+ save_embedding_every,
+ template_file,
+ preview_from_txt2img,
+ *txt2img_preview_params,
+ ],
+ outputs=[
+ ti_output,
+ ti_outcome,
+ ]
+ )
+
+ interrupt_training.click(
+ fn=lambda: shared.state.interrupt(),
+ inputs=[],
+ outputs=[],
+ )
+
+ def create_setting_component(key, is_quicksettings=False):
def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default
@@ -920,12 +1318,48 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
else:
raise Exception(f'bad options item type: {str(t)} for key {key}')
- return comp(label=info.label, value=fun, **(args or {}))
+ if info.refresh is not None:
+ if is_quicksettings:
+ res = comp(label=info.label, value=fun, **(args or {}))
+ refresh_button = gr.Button(value=refresh_symbol, elem_id="refresh_"+key)
+ else:
+ with gr.Row(variant="compact"):
+ res = comp(label=info.label, value=fun, **(args or {}))
+ refresh_button = gr.Button(value=refresh_symbol, elem_id="refresh_" + key)
+
+ def refresh():
+ info.refresh()
+ refreshed_args = info.component_args() if callable(info.component_args) else info.component_args
+
+ for k, v in refreshed_args.items():
+ setattr(res, k, v)
+
+ return gr.update(**(refreshed_args or {}))
+
+ refresh_button.click(
+ fn=refresh,
+ inputs=[],
+ outputs=[res],
+ )
+ else:
+ res = comp(label=info.label, value=fun, **(args or {}))
+
+
+ return res
components = []
component_dict = {}
def open_folder(f):
+ if not os.path.isdir(f):
+ print(f"""
+WARNING
+An open_folder request was made with an argument that is not a folder.
+This could be an error or a malicious attempt to run code on your computer.
+Requested path was: {f}
+""", file=sys.stderr)
+ return
+
if not shared.cmd_opts.hide_ui_dir_config:
path = os.path.normpath(f)
if platform.system() == "Windows":
@@ -939,10 +1373,13 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
changed = 0
for key, value, comp in zip(opts.data_labels.keys(), args, components):
- if not opts.same_type(value, opts.data_labels[key].default):
- return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
+ if comp != dummy_component and not opts.same_type(value, opts.data_labels[key].default):
+ return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}", opts.dumpjson()
for key, value, comp in zip(opts.data_labels.keys(), args, components):
+ if comp == dummy_component:
+ continue
+
comp_args = opts.data_labels[key].component_args
if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
continue
@@ -960,6 +1397,21 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
return f'{changed} settings changed.', opts.dumpjson()
+ def run_settings_single(value, key):
+ if not opts.same_type(value, opts.data_labels[key].default):
+ return gr.update(visible=True), opts.dumpjson()
+
+ oldval = opts.data.get(key, None)
+ opts.data[key] = value
+
+ if oldval != value:
+ if opts.data_labels[key].onchange is not None:
+ opts.data_labels[key].onchange()
+
+ opts.save(shared.config_filename)
+
+ return gr.update(value=value), opts.dumpjson()
+
with gr.Blocks(analytics_enabled=False) as settings_interface:
settings_submit = gr.Button(value="Apply settings", variant='primary')
result = gr.HTML()
@@ -967,6 +1419,11 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
settings_cols = 3
items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols)
+ quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")]
+ quicksettings_names = set(x for x in quicksettings_names if x != 'quicksettings')
+
+ quicksettings_list = []
+
cols_displayed = 0
items_displayed = 0
previous_section = None
@@ -989,12 +1446,20 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='<h1 class="gr-button-lg">{}</h1>'.format(item.section[1]))
- component = create_setting_component(k)
- component_dict[k] = component
- components.append(component)
- items_displayed += 1
+ if k in quicksettings_names:
+ quicksettings_list.append((i, k, item))
+ components.append(dummy_component)
+ else:
+ component = create_setting_component(k)
+ component_dict[k] = component
+ components.append(component)
+ items_displayed += 1
+
+ with gr.Row():
+ request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
+ reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
+ restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
- request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
request_notifications.click(
fn=lambda: None,
inputs=[],
@@ -1002,6 +1467,27 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
_js='function(){}'
)
+ def reload_scripts():
+ modules.scripts.reload_script_body_only()
+
+ reload_script_bodies.click(
+ fn=reload_scripts,
+ inputs=[],
+ outputs=[],
+ _js='function(){}'
+ )
+
+ def request_restart():
+ shared.state.interrupt()
+ settings_interface.gradio_ref.do_restart = True
+
+ restart_gradio.click(
+ fn=request_restart,
+ inputs=[],
+ outputs=[],
+ _js='function(){restart_reload()}'
+ )
+
if column is not None:
column.__exit__()
@@ -1010,7 +1496,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
(img2img_interface, "img2img", "img2img"),
(extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
+ (images_history, "History", "images_history"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
+ (train_interface, "Train", "ti"),
(settings_interface, "Settings", "settings"),
]
@@ -1026,12 +1514,18 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
css += css_hide_progressbar
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
+ with gr.Row(elem_id="quicksettings"):
+ for i, k, item in quicksettings_list:
+ component = create_setting_component(k, is_quicksettings=True)
+ component_dict[k] = component
+
+ settings_interface.gradio_ref = demo
- with gr.Tabs() as tabs:
+ with gr.Tabs(elem_id="tabs") as tabs:
for interface, label, ifid in interfaces:
- with gr.TabItem(label, id=ifid):
+ with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
interface.render()
-
+
if os.path.exists(os.path.join(script_path, "notification.mp3")):
audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
@@ -1041,14 +1535,23 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
inputs=components,
outputs=[result, text_settings],
)
-
+
+ for i, k, item in quicksettings_list:
+ component = component_dict[k]
+
+ component.change(
+ fn=lambda value, k=k: run_settings_single(value, key=k),
+ inputs=[component],
+ outputs=[component, text_settings],
+ )
+
def modelmerger(*args):
try:
- results = run_modelmerger(*args)
+ results = modules.extras.run_modelmerger(*args)
except Exception as e:
print("Error loading/saving model file:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
- modules.sd_models.list_models() #To remove the potentially missing models from the list
+ modules.sd_models.list_models() # to remove the potentially missing models from the list
return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
return results
@@ -1057,6 +1560,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
inputs=[
primary_model_name,
secondary_model_name,
+ tertiary_model_name,
interp_method,
interp_amount,
save_as_half,
@@ -1066,6 +1570,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
submit_result,
primary_model_name,
secondary_model_name,
+ tertiary_model_name,
component_dict['sd_model_checkpoint'],
]
)
@@ -1132,8 +1637,22 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
outputs=[extras_image],
)
- modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_txt2img, txt2img_paste_fields, generation_info, 'switch_to_txt2img')
- modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_img2img, img2img_paste_fields, generation_info, 'switch_to_img2img_img2img')
+ settings_map = {
+ 'sd_hypernetwork': 'Hypernet',
+ 'CLIP_stop_at_last_layers': 'Clip skip',
+ 'sd_model_checkpoint': 'Model hash',
+ }
+
+ settings_paste_fields = [
+ (component_dict[k], lambda d, k=k, v=v: apply_setting(k, d.get(v, None)))
+ for k, v in settings_map.items()
+ ]
+
+ modules.generation_parameters_copypaste.connect_paste(txt2img_paste, txt2img_paste_fields + settings_paste_fields, txt2img_prompt)
+ modules.generation_parameters_copypaste.connect_paste(img2img_paste, img2img_paste_fields + settings_paste_fields, img2img_prompt)
+
+ modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_txt2img, txt2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_txt2img')
+ modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_img2img, img2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_img2img_img2img')
ui_config_file = cmd_opts.ui_config_file
ui_settings = {}
@@ -1155,10 +1674,10 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
if getattr(obj,'custom_script_source',None) is not None:
key = 'customscript/' + obj.custom_script_source + '/' + key
-
+
if getattr(obj, 'do_not_save_to_config', False):
return
-
+
saved_value = ui_settings.get(key, None)
if saved_value is None:
ui_settings[key] = getattr(obj, field)
@@ -1182,10 +1701,10 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
if type(x) == gr.Textbox:
apply_field(x, 'value')
-
+
if type(x) == gr.Number:
apply_field(x, 'value')
-
+
visit(txt2img_interface, loadsave, "txt2img")
visit(img2img_interface, loadsave, "img2img")
visit(extras_interface, loadsave, "extras")
@@ -1206,12 +1725,13 @@ for filename in sorted(os.listdir(jsdir)):
javascript += f"\n<script>{jsfile.read()}</script>"
-def template_response(*args, **kwargs):
- res = gradio_routes_templates_response(*args, **kwargs)
- res.body = res.body.replace(b'</head>', f'{javascript}</head>'.encode("utf8"))
- res.init_headers()
- return res
+if 'gradio_routes_templates_response' not in globals():
+ def template_response(*args, **kwargs):
+ res = gradio_routes_templates_response(*args, **kwargs)
+ res.body = res.body.replace(b'</head>', f'{javascript}</head>'.encode("utf8"))
+ res.init_headers()
+ return res
+ gradio_routes_templates_response = gradio.routes.templates.TemplateResponse
+ gradio.routes.templates.TemplateResponse = template_response
-gradio_routes_templates_response = gradio.routes.templates.TemplateResponse
-gradio.routes.templates.TemplateResponse = template_response
diff --git a/modules/upscaler.py b/modules/upscaler.py
index d9d7c5e2..6ab2fb40 100644
--- a/modules/upscaler.py
+++ b/modules/upscaler.py
@@ -36,10 +36,11 @@ class Upscaler:
self.half = not modules.shared.cmd_opts.no_half
self.pre_pad = 0
self.mod_scale = None
- if self.name is not None and create_dirs:
+
+ if self.model_path is None and self.name:
self.model_path = os.path.join(models_path, self.name)
- if not os.path.exists(self.model_path):
- os.makedirs(self.model_path)
+ if self.model_path and create_dirs:
+ os.makedirs(self.model_path, exist_ok=True)
try:
import cv2