aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorcaptin411 <captindave@gmail.com>2022-10-25 13:22:27 -0700
committercaptin411 <captindave@gmail.com>2022-10-25 13:22:27 -0700
commit6629446a2f9bb3ade1c271854aae1530ba1a8cc3 (patch)
treead7cfd2b3f0208c24da64c7f08e0550e783228ec /modules
parent3e6c2420c1177e9e79f2b566a5a7795b7416e34a (diff)
parent3e15f8e0f5cc87507f77546d92435670644dbd18 (diff)
Merge branch 'master' into focal-point-cropping
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py66
-rw-r--r--modules/api/models.py (renamed from modules/api/processing.py)11
-rw-r--r--modules/bsrgan_model.py76
-rw-r--r--modules/bsrgan_model_arch.py102
-rw-r--r--modules/deepbooru.py5
-rw-r--r--modules/devices.py23
-rw-r--r--modules/esrgan_model.py192
-rw-r--r--modules/esrgan_model_arch.py487
-rw-r--r--modules/extras.py19
-rw-r--r--modules/generation_parameters_copypaste.py25
-rw-r--r--modules/hypernetworks/hypernetwork.py217
-rw-r--r--modules/hypernetworks/ui.py28
-rw-r--r--modules/images.py201
-rw-r--r--modules/images_history.py183
-rw-r--r--modules/img2img.py3
-rw-r--r--modules/interrogate.py12
-rw-r--r--modules/lowvram.py9
-rw-r--r--modules/processing.py146
-rw-r--r--modules/script_callbacks.py100
-rw-r--r--modules/scripts.py246
-rw-r--r--modules/scunet_model.py3
-rw-r--r--modules/sd_hijack.py29
-rw-r--r--modules/sd_hijack_inpainting.py331
-rw-r--r--modules/sd_models.py34
-rw-r--r--modules/sd_samplers.py81
-rw-r--r--modules/shared.py45
-rw-r--r--modules/swinir_model.py12
-rw-r--r--modules/textual_inversion/dataset.py4
-rw-r--r--modules/textual_inversion/image_embedding.py5
-rw-r--r--modules/textual_inversion/preprocess.py91
-rw-r--r--modules/textual_inversion/textual_inversion.py6
-rw-r--r--modules/textual_inversion/ui.py4
-rw-r--r--modules/txt2img.py7
-rw-r--r--modules/ui.py201
34 files changed, 2120 insertions, 884 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 5b0c934e..a860a964 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,5 +1,5 @@
-from modules.api.processing import StableDiffusionProcessingAPI
-from modules.processing import StableDiffusionProcessingTxt2Img, process_images
+from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
+from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
from modules.extras import run_pnginfo
import modules.shared as shared
@@ -10,6 +10,7 @@ from pydantic import BaseModel, Field, Json
import json
import io
import base64
+from PIL import Image
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
@@ -18,6 +19,11 @@ class TextToImageResponse(BaseModel):
parameters: Json
info: Json
+class ImageToImageResponse(BaseModel):
+ images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+ parameters: Json
+ info: Json
+
class Api:
def __init__(self, app, queue_lock):
@@ -25,8 +31,17 @@ class Api:
self.app = app
self.queue_lock = queue_lock
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
+ self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"])
- def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
+ def __base64_to_image(self, base64_string):
+ # if has a comma, deal with prefix
+ if "," in base64_string:
+ base64_string = base64_string.split(",")[1]
+ imgdata = base64.b64decode(base64_string)
+ # convert base64 to PIL image
+ return Image.open(io.BytesIO(imgdata))
+
+ def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
if sampler_index is None:
@@ -54,8 +69,49 @@ class Api:
- def img2imgapi(self):
- raise NotImplementedError
+ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
+ sampler_index = sampler_to_index(img2imgreq.sampler_index)
+
+ if sampler_index is None:
+ raise HTTPException(status_code=404, detail="Sampler not found")
+
+
+ init_images = img2imgreq.init_images
+ if init_images is None:
+ raise HTTPException(status_code=404, detail="Init image not found")
+
+ mask = img2imgreq.mask
+ if mask:
+ mask = self.__base64_to_image(mask)
+
+
+ populate = img2imgreq.copy(update={ # Override __init__ params
+ "sd_model": shared.sd_model,
+ "sampler_index": sampler_index[0],
+ "do_not_save_samples": True,
+ "do_not_save_grid": True,
+ "mask": mask
+ }
+ )
+ p = StableDiffusionProcessingImg2Img(**vars(populate))
+
+ imgs = []
+ for img in init_images:
+ img = self.__base64_to_image(img)
+ imgs = [img] * p.batch_size
+
+ p.init_images = imgs
+ # Override object param
+ with self.queue_lock:
+ processed = process_images(p)
+
+ b64images = []
+ for i in processed.images:
+ buffer = io.BytesIO()
+ i.save(buffer, format="png")
+ b64images.append(base64.b64encode(buffer.getvalue()))
+
+ return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=json.dumps(processed.info))
def extrasapi(self):
raise NotImplementedError
diff --git a/modules/api/processing.py b/modules/api/models.py
index 4c541241..f551fa35 100644
--- a/modules/api/processing.py
+++ b/modules/api/models.py
@@ -1,7 +1,8 @@
+from array import array
from inflection import underscore
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, create_model
-from modules.processing import StableDiffusionProcessingTxt2Img
+from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
import inspect
@@ -92,8 +93,14 @@ class PydanticModelGenerator:
DynamicModel.__config__.allow_mutation = True
return DynamicModel
-StableDiffusionProcessingAPI = PydanticModelGenerator(
+StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingTxt2Img",
StableDiffusionProcessingTxt2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}]
+).generate_model()
+
+StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
+ "StableDiffusionProcessingImg2Img",
+ StableDiffusionProcessingImg2Img,
+ [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}]
).generate_model() \ No newline at end of file
diff --git a/modules/bsrgan_model.py b/modules/bsrgan_model.py
deleted file mode 100644
index 737e1a76..00000000
--- a/modules/bsrgan_model.py
+++ /dev/null
@@ -1,76 +0,0 @@
-import os.path
-import sys
-import traceback
-
-import PIL.Image
-import numpy as np
-import torch
-from basicsr.utils.download_util import load_file_from_url
-
-import modules.upscaler
-from modules import devices, modelloader
-from modules.bsrgan_model_arch import RRDBNet
-
-
-class UpscalerBSRGAN(modules.upscaler.Upscaler):
- def __init__(self, dirname):
- self.name = "BSRGAN"
- self.model_name = "BSRGAN 4x"
- self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
- self.user_path = dirname
- super().__init__()
- model_paths = self.find_models(ext_filter=[".pt", ".pth"])
- scalers = []
- if len(model_paths) == 0:
- scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4)
- scalers.append(scaler_data)
- for file in model_paths:
- if "http" in file:
- name = self.model_name
- else:
- name = modelloader.friendly_name(file)
- try:
- scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
- scalers.append(scaler_data)
- except Exception:
- print(f"Error loading BSRGAN model: {file}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- self.scalers = scalers
-
- def do_upscale(self, img: PIL.Image, selected_file):
- torch.cuda.empty_cache()
- model = self.load_model(selected_file)
- if model is None:
- return img
- model.to(devices.device_bsrgan)
- torch.cuda.empty_cache()
- img = np.array(img)
- img = img[:, :, ::-1]
- img = np.moveaxis(img, 2, 0) / 255
- img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(devices.device_bsrgan)
- with torch.no_grad():
- output = model(img)
- output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
- output = 255. * np.moveaxis(output, 0, 2)
- output = output.astype(np.uint8)
- output = output[:, :, ::-1]
- torch.cuda.empty_cache()
- return PIL.Image.fromarray(output, 'RGB')
-
- def load_model(self, path: str):
- if "http" in path:
- filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
- progress=True)
- else:
- filename = path
- if not os.path.exists(filename) or filename is None:
- print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr)
- return None
- model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4) # define network
- model.load_state_dict(torch.load(filename), strict=True)
- model.eval()
- for k, v in model.named_parameters():
- v.requires_grad = False
- return model
-
diff --git a/modules/bsrgan_model_arch.py b/modules/bsrgan_model_arch.py
deleted file mode 100644
index cb4d1c13..00000000
--- a/modules/bsrgan_model_arch.py
+++ /dev/null
@@ -1,102 +0,0 @@
-import functools
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.nn.init as init
-
-
-def initialize_weights(net_l, scale=1):
- if not isinstance(net_l, list):
- net_l = [net_l]
- for net in net_l:
- for m in net.modules():
- if isinstance(m, nn.Conv2d):
- init.kaiming_normal_(m.weight, a=0, mode='fan_in')
- m.weight.data *= scale # for residual block
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m, nn.Linear):
- init.kaiming_normal_(m.weight, a=0, mode='fan_in')
- m.weight.data *= scale
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m, nn.BatchNorm2d):
- init.constant_(m.weight, 1)
- init.constant_(m.bias.data, 0.0)
-
-
-def make_layer(block, n_layers):
- layers = []
- for _ in range(n_layers):
- layers.append(block())
- return nn.Sequential(*layers)
-
-
-class ResidualDenseBlock_5C(nn.Module):
- def __init__(self, nf=64, gc=32, bias=True):
- super(ResidualDenseBlock_5C, self).__init__()
- # gc: growth channel, i.e. intermediate channels
- self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
- self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
- self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
- self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
- self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
-
- # initialization
- initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
-
- def forward(self, x):
- x1 = self.lrelu(self.conv1(x))
- x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
- x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
- x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
- x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
- return x5 * 0.2 + x
-
-
-class RRDB(nn.Module):
- '''Residual in Residual Dense Block'''
-
- def __init__(self, nf, gc=32):
- super(RRDB, self).__init__()
- self.RDB1 = ResidualDenseBlock_5C(nf, gc)
- self.RDB2 = ResidualDenseBlock_5C(nf, gc)
- self.RDB3 = ResidualDenseBlock_5C(nf, gc)
-
- def forward(self, x):
- out = self.RDB1(x)
- out = self.RDB2(out)
- out = self.RDB3(out)
- return out * 0.2 + x
-
-
-class RRDBNet(nn.Module):
- def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
- super(RRDBNet, self).__init__()
- RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
- self.sf = sf
-
- self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
- self.RRDB_trunk = make_layer(RRDB_block_f, nb)
- self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- #### upsampling
- self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- if self.sf==4:
- self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
-
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
-
- def forward(self, x):
- fea = self.conv_first(x)
- trunk = self.trunk_conv(self.RRDB_trunk(fea))
- fea = fea + trunk
-
- fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
- if self.sf==4:
- fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
- out = self.conv_last(self.lrelu(self.HRconv(fea)))
-
- return out \ No newline at end of file
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
index 8914662d..8bbc90a4 100644
--- a/modules/deepbooru.py
+++ b/modules/deepbooru.py
@@ -50,11 +50,12 @@ def create_deepbooru_process(threshold, deepbooru_opts):
the tags.
"""
from modules import shared # prevents circular reference
- shared.deepbooru_process_manager = multiprocessing.Manager()
+ context = multiprocessing.get_context("spawn")
+ shared.deepbooru_process_manager = context.Manager()
shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
shared.deepbooru_process_return["value"] = -1
- shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts))
+ shared.deepbooru_process = context.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts))
shared.deepbooru_process.start()
diff --git a/modules/devices.py b/modules/devices.py
index eb422583..7511e1dc 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -1,7 +1,6 @@
+import sys, os, shlex
import contextlib
-
import torch
-
from modules import errors
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
@@ -9,10 +8,22 @@ has_mps = getattr(torch, 'has_mps', False)
cpu = torch.device("cpu")
+def extract_device_id(args, name):
+ for x in range(len(args)):
+ if name in args[x]: return args[x+1]
+ return None
def get_optimal_device():
if torch.cuda.is_available():
- return torch.device("cuda")
+ from modules import shared
+
+ device_id = shared.cmd_opts.device_id
+
+ if device_id is not None:
+ cuda_device = f"cuda:{device_id}"
+ return torch.device(cuda_device)
+ else:
+ return torch.device("cuda")
if has_mps:
return torch.device("mps")
@@ -34,7 +45,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
-device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
+device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None
dtype = torch.float16
dtype_vae = torch.float16
@@ -70,3 +81,7 @@ def autocast(disable=False):
return contextlib.nullcontext()
return torch.autocast("cuda")
+
+# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
+def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor
+def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device)
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index 46ad0da3..a13cf6ac 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -11,62 +11,109 @@ from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts
-def fix_model_layers(crt_model, pretrained_net):
- # this code is adapted from https://github.com/xinntao/ESRGAN
- if 'conv_first.weight' in pretrained_net:
- return pretrained_net
-
- if 'model.0.weight' not in pretrained_net:
- is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
- if is_realesrgan:
- raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
- else:
- raise Exception("The file is not a ESRGAN model.")
- crt_net = crt_model.state_dict()
- load_net_clean = {}
- for k, v in pretrained_net.items():
- if k.startswith('module.'):
- load_net_clean[k[7:]] = v
- else:
- load_net_clean[k] = v
- pretrained_net = load_net_clean
-
- tbd = []
- for k, v in crt_net.items():
- tbd.append(k)
-
- # directly copy
- for k, v in crt_net.items():
- if k in pretrained_net and pretrained_net[k].size() == v.size():
- crt_net[k] = pretrained_net[k]
- tbd.remove(k)
-
- crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
- crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
-
- for k in tbd.copy():
- if 'RDB' in k:
- ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
- if '.weight' in k:
- ori_k = ori_k.replace('.weight', '.0.weight')
- elif '.bias' in k:
- ori_k = ori_k.replace('.bias', '.0.bias')
- crt_net[k] = pretrained_net[ori_k]
- tbd.remove(k)
-
- crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
- crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
- crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
- crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
- crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
- crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
- crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
- crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
- crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
- crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
-
- return crt_net
+def mod2normal(state_dict):
+ # this code is copied from https://github.com/victorca25/iNNfer
+ if 'conv_first.weight' in state_dict:
+ crt_net = {}
+ items = []
+ for k, v in state_dict.items():
+ items.append(k)
+
+ crt_net['model.0.weight'] = state_dict['conv_first.weight']
+ crt_net['model.0.bias'] = state_dict['conv_first.bias']
+
+ for k in items.copy():
+ if 'RDB' in k:
+ ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
+ if '.weight' in k:
+ ori_k = ori_k.replace('.weight', '.0.weight')
+ elif '.bias' in k:
+ ori_k = ori_k.replace('.bias', '.0.bias')
+ crt_net[ori_k] = state_dict[k]
+ items.remove(k)
+
+ crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
+ crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
+ crt_net['model.3.weight'] = state_dict['upconv1.weight']
+ crt_net['model.3.bias'] = state_dict['upconv1.bias']
+ crt_net['model.6.weight'] = state_dict['upconv2.weight']
+ crt_net['model.6.bias'] = state_dict['upconv2.bias']
+ crt_net['model.8.weight'] = state_dict['HRconv.weight']
+ crt_net['model.8.bias'] = state_dict['HRconv.bias']
+ crt_net['model.10.weight'] = state_dict['conv_last.weight']
+ crt_net['model.10.bias'] = state_dict['conv_last.bias']
+ state_dict = crt_net
+ return state_dict
+
+
+def resrgan2normal(state_dict, nb=23):
+ # this code is copied from https://github.com/victorca25/iNNfer
+ if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
+ crt_net = {}
+ items = []
+ for k, v in state_dict.items():
+ items.append(k)
+
+ crt_net['model.0.weight'] = state_dict['conv_first.weight']
+ crt_net['model.0.bias'] = state_dict['conv_first.bias']
+
+ for k in items.copy():
+ if "rdb" in k:
+ ori_k = k.replace('body.', 'model.1.sub.')
+ ori_k = ori_k.replace('.rdb', '.RDB')
+ if '.weight' in k:
+ ori_k = ori_k.replace('.weight', '.0.weight')
+ elif '.bias' in k:
+ ori_k = ori_k.replace('.bias', '.0.bias')
+ crt_net[ori_k] = state_dict[k]
+ items.remove(k)
+
+ crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
+ crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
+ crt_net['model.3.weight'] = state_dict['conv_up1.weight']
+ crt_net['model.3.bias'] = state_dict['conv_up1.bias']
+ crt_net['model.6.weight'] = state_dict['conv_up2.weight']
+ crt_net['model.6.bias'] = state_dict['conv_up2.bias']
+ crt_net['model.8.weight'] = state_dict['conv_hr.weight']
+ crt_net['model.8.bias'] = state_dict['conv_hr.bias']
+ crt_net['model.10.weight'] = state_dict['conv_last.weight']
+ crt_net['model.10.bias'] = state_dict['conv_last.bias']
+ state_dict = crt_net
+ return state_dict
+
+
+def infer_params(state_dict):
+ # this code is copied from https://github.com/victorca25/iNNfer
+ scale2x = 0
+ scalemin = 6
+ n_uplayer = 0
+ plus = False
+
+ for block in list(state_dict):
+ parts = block.split(".")
+ n_parts = len(parts)
+ if n_parts == 5 and parts[2] == "sub":
+ nb = int(parts[3])
+ elif n_parts == 3:
+ part_num = int(parts[1])
+ if (part_num > scalemin
+ and parts[0] == "model"
+ and parts[2] == "weight"):
+ scale2x += 1
+ if part_num > n_uplayer:
+ n_uplayer = part_num
+ out_nc = state_dict[block].shape[0]
+ if not plus and "conv1x1" in block:
+ plus = True
+
+ nf = state_dict["model.0.weight"].shape[0]
+ in_nc = state_dict["model.0.weight"].shape[1]
+ out_nc = out_nc
+ scale = 2 ** scale2x
+
+ return in_nc, out_nc, nf, nb, plus, scale
+
class UpscalerESRGAN(Upscaler):
def __init__(self, dirname):
@@ -109,22 +156,41 @@ class UpscalerESRGAN(Upscaler):
print("Unable to load %s from %s" % (self.model_path, filename))
return None
- pretrained_net = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
- crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
+ state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
+
+ if "params_ema" in state_dict:
+ state_dict = state_dict["params_ema"]
+ elif "params" in state_dict:
+ state_dict = state_dict["params"]
+ num_conv = 16 if "realesr-animevideov3" in filename else 32
+ model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
+ model.load_state_dict(state_dict)
+ model.eval()
+ return model
+
+ if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
+ nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
+ state_dict = resrgan2normal(state_dict, nb)
+ elif "conv_first.weight" in state_dict:
+ state_dict = mod2normal(state_dict)
+ elif "model.0.weight" not in state_dict:
+ raise Exception("The file is not a recognized ESRGAN model.")
+
+ in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
- pretrained_net = fix_model_layers(crt_model, pretrained_net)
- crt_model.load_state_dict(pretrained_net)
- crt_model.eval()
+ model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
+ model.load_state_dict(state_dict)
+ model.eval()
- return crt_model
+ return model
def upscale_without_tiling(model, img):
img = np.array(img)
img = img[:, :, ::-1]
- img = np.moveaxis(img, 2, 0) / 255
+ img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(devices.device_esrgan)
+ img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_esrgan)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py
index e413d36e..bc9ceb2a 100644
--- a/modules/esrgan_model_arch.py
+++ b/modules/esrgan_model_arch.py
@@ -1,80 +1,463 @@
-# this file is taken from https://github.com/xinntao/ESRGAN
+# this file is adapted from https://github.com/victorca25/iNNfer
+import math
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
-def make_layer(block, n_layers):
- layers = []
- for _ in range(n_layers):
- layers.append(block())
- return nn.Sequential(*layers)
+####################
+# RRDBNet Generator
+####################
+class RRDBNet(nn.Module):
+ def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
+ act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
+ finalact=None, gaussian_noise=False, plus=False):
+ super(RRDBNet, self).__init__()
+ n_upscale = int(math.log(upscale, 2))
+ if upscale == 3:
+ n_upscale = 1
-class ResidualDenseBlock_5C(nn.Module):
- def __init__(self, nf=64, gc=32, bias=True):
- super(ResidualDenseBlock_5C, self).__init__()
- # gc: growth channel, i.e. intermediate channels
- self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
- self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
- self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
- self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
- self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+ self.resrgan_scale = 0
+ if in_nc % 16 == 0:
+ self.resrgan_scale = 1
+ elif in_nc != 4 and in_nc % 4 == 0:
+ self.resrgan_scale = 2
- # initialization
- # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
+ fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
+ rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
+ norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
+ gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
+ LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
- def forward(self, x):
- x1 = self.lrelu(self.conv1(x))
- x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
- x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
- x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
- x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
- return x5 * 0.2 + x
+ if upsample_mode == 'upconv':
+ upsample_block = upconv_block
+ elif upsample_mode == 'pixelshuffle':
+ upsample_block = pixelshuffle_block
+ else:
+ raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
+ if upscale == 3:
+ upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
+ else:
+ upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
+ HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
+ HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
+
+ outact = act(finalact) if finalact else None
+
+ self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
+ *upsampler, HR_conv0, HR_conv1, outact)
+
+ def forward(self, x, outm=None):
+ if self.resrgan_scale == 1:
+ feat = pixel_unshuffle(x, scale=4)
+ elif self.resrgan_scale == 2:
+ feat = pixel_unshuffle(x, scale=2)
+ else:
+ feat = x
+
+ return self.model(feat)
class RRDB(nn.Module):
- '''Residual in Residual Dense Block'''
+ """
+ Residual in Residual Dense Block
+ (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
+ """
- def __init__(self, nf, gc=32):
+ def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
+ norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
+ spectral_norm=False, gaussian_noise=False, plus=False):
super(RRDB, self).__init__()
- self.RDB1 = ResidualDenseBlock_5C(nf, gc)
- self.RDB2 = ResidualDenseBlock_5C(nf, gc)
- self.RDB3 = ResidualDenseBlock_5C(nf, gc)
+ # This is for backwards compatibility with existing models
+ if nr == 3:
+ self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
+ gaussian_noise=gaussian_noise, plus=plus)
+ self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
+ gaussian_noise=gaussian_noise, plus=plus)
+ self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
+ gaussian_noise=gaussian_noise, plus=plus)
+ else:
+ RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
+ norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
+ gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
+ self.RDBs = nn.Sequential(*RDB_list)
def forward(self, x):
- out = self.RDB1(x)
- out = self.RDB2(out)
- out = self.RDB3(out)
+ if hasattr(self, 'RDB1'):
+ out = self.RDB1(x)
+ out = self.RDB2(out)
+ out = self.RDB3(out)
+ else:
+ out = self.RDBs(x)
return out * 0.2 + x
-class RRDBNet(nn.Module):
- def __init__(self, in_nc, out_nc, nf, nb, gc=32):
- super(RRDBNet, self).__init__()
- RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
+class ResidualDenseBlock_5C(nn.Module):
+ """
+ Residual Dense Block
+ The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
+ Modified options that can be used:
+ - "Partial Convolution based Padding" arXiv:1811.11718
+ - "Spectral normalization" arXiv:1802.05957
+ - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
+ {Rakotonirina} and A. {Rasoanaivo}
+ """
+
+ def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
+ norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
+ spectral_norm=False, gaussian_noise=False, plus=False):
+ super(ResidualDenseBlock_5C, self).__init__()
- self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
- self.RRDB_trunk = make_layer(RRDB_block_f, nb)
- self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- #### upsampling
- self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
- self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
+ self.noise = GaussianNoise() if gaussian_noise else None
+ self.conv1x1 = conv1x1(nf, gc) if plus else None
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+ self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
+ self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
+ self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
+ self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
+ if mode == 'CNA':
+ last_act = None
+ else:
+ last_act = act_type
+ self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
+ norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
+ spectral_norm=spectral_norm)
def forward(self, x):
- fea = self.conv_first(x)
- trunk = self.trunk_conv(self.RRDB_trunk(fea))
- fea = fea + trunk
+ x1 = self.conv1(x)
+ x2 = self.conv2(torch.cat((x, x1), 1))
+ if self.conv1x1:
+ x2 = x2 + self.conv1x1(x)
+ x3 = self.conv3(torch.cat((x, x1, x2), 1))
+ x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
+ if self.conv1x1:
+ x4 = x4 + x2
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
+ if self.noise:
+ return self.noise(x5.mul(0.2) + x)
+ else:
+ return x5 * 0.2 + x
+
- fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
- fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
- out = self.conv_last(self.lrelu(self.HRconv(fea)))
+####################
+# ESRGANplus
+####################
+class GaussianNoise(nn.Module):
+ def __init__(self, sigma=0.1, is_relative_detach=False):
+ super().__init__()
+ self.sigma = sigma
+ self.is_relative_detach = is_relative_detach
+ self.noise = torch.tensor(0, dtype=torch.float)
+
+ def forward(self, x):
+ if self.training and self.sigma != 0:
+ self.noise = self.noise.to(x.device)
+ scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
+ sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
+ x = x + sampled_noise
+ return x
+
+def conv1x1(in_planes, out_planes, stride=1):
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+####################
+# SRVGGNetCompact
+####################
+
+class SRVGGNetCompact(nn.Module):
+ """A compact VGG-style network structure for super-resolution.
+ This class is copied from https://github.com/xinntao/Real-ESRGAN
+ """
+
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
+ super(SRVGGNetCompact, self).__init__()
+ self.num_in_ch = num_in_ch
+ self.num_out_ch = num_out_ch
+ self.num_feat = num_feat
+ self.num_conv = num_conv
+ self.upscale = upscale
+ self.act_type = act_type
+
+ self.body = nn.ModuleList()
+ # the first conv
+ self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
+ # the first activation
+ if act_type == 'relu':
+ activation = nn.ReLU(inplace=True)
+ elif act_type == 'prelu':
+ activation = nn.PReLU(num_parameters=num_feat)
+ elif act_type == 'leakyrelu':
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+ self.body.append(activation)
+
+ # the body structure
+ for _ in range(num_conv):
+ self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
+ # activation
+ if act_type == 'relu':
+ activation = nn.ReLU(inplace=True)
+ elif act_type == 'prelu':
+ activation = nn.PReLU(num_parameters=num_feat)
+ elif act_type == 'leakyrelu':
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+ self.body.append(activation)
+
+ # the last conv
+ self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
+ # upsample
+ self.upsampler = nn.PixelShuffle(upscale)
+
+ def forward(self, x):
+ out = x
+ for i in range(0, len(self.body)):
+ out = self.body[i](out)
+
+ out = self.upsampler(out)
+ # add the nearest upsampled image, so that the network learns the residual
+ base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
+ out += base
return out
+
+
+####################
+# Upsampler
+####################
+
+class Upsample(nn.Module):
+ r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
+ The input data is assumed to be of the form
+ `minibatch x channels x [optional depth] x [optional height] x width`.
+ """
+
+ def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
+ super(Upsample, self).__init__()
+ if isinstance(scale_factor, tuple):
+ self.scale_factor = tuple(float(factor) for factor in scale_factor)
+ else:
+ self.scale_factor = float(scale_factor) if scale_factor else None
+ self.mode = mode
+ self.size = size
+ self.align_corners = align_corners
+
+ def forward(self, x):
+ return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
+
+ def extra_repr(self):
+ if self.scale_factor is not None:
+ info = 'scale_factor=' + str(self.scale_factor)
+ else:
+ info = 'size=' + str(self.size)
+ info += ', mode=' + self.mode
+ return info
+
+
+def pixel_unshuffle(x, scale):
+ """ Pixel unshuffle.
+ Args:
+ x (Tensor): Input feature with shape (b, c, hh, hw).
+ scale (int): Downsample ratio.
+ Returns:
+ Tensor: the pixel unshuffled feature.
+ """
+ b, c, hh, hw = x.size()
+ out_channel = c * (scale**2)
+ assert hh % scale == 0 and hw % scale == 0
+ h = hh // scale
+ w = hw // scale
+ x_view = x.view(b, c, h, scale, w, scale)
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
+
+
+def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
+ pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
+ """
+ Pixel shuffle layer
+ (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
+ Neural Network, CVPR17)
+ """
+ conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
+ pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
+ pixel_shuffle = nn.PixelShuffle(upscale_factor)
+
+ n = norm(norm_type, out_nc) if norm_type else None
+ a = act(act_type) if act_type else None
+ return sequential(conv, pixel_shuffle, n, a)
+
+
+def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
+ pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
+ """ Upconv layer """
+ upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
+ upsample = Upsample(scale_factor=upscale_factor, mode=mode)
+ conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
+ pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
+ return sequential(upsample, conv)
+
+
+
+
+
+
+
+
+####################
+# Basic blocks
+####################
+
+
+def make_layer(basic_block, num_basic_block, **kwarg):
+ """Make layers by stacking the same blocks.
+ Args:
+ basic_block (nn.module): nn.module class for basic block. (block)
+ num_basic_block (int): number of blocks. (n_layers)
+ Returns:
+ nn.Sequential: Stacked blocks in nn.Sequential.
+ """
+ layers = []
+ for _ in range(num_basic_block):
+ layers.append(basic_block(**kwarg))
+ return nn.Sequential(*layers)
+
+
+def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
+ """ activation helper """
+ act_type = act_type.lower()
+ if act_type == 'relu':
+ layer = nn.ReLU(inplace)
+ elif act_type in ('leakyrelu', 'lrelu'):
+ layer = nn.LeakyReLU(neg_slope, inplace)
+ elif act_type == 'prelu':
+ layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
+ elif act_type == 'tanh': # [-1, 1] range output
+ layer = nn.Tanh()
+ elif act_type == 'sigmoid': # [0, 1] range output
+ layer = nn.Sigmoid()
+ else:
+ raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
+ return layer
+
+
+class Identity(nn.Module):
+ def __init__(self, *kwargs):
+ super(Identity, self).__init__()
+
+ def forward(self, x, *kwargs):
+ return x
+
+
+def norm(norm_type, nc):
+ """ Return a normalization layer """
+ norm_type = norm_type.lower()
+ if norm_type == 'batch':
+ layer = nn.BatchNorm2d(nc, affine=True)
+ elif norm_type == 'instance':
+ layer = nn.InstanceNorm2d(nc, affine=False)
+ elif norm_type == 'none':
+ def norm_layer(x): return Identity()
+ else:
+ raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
+ return layer
+
+
+def pad(pad_type, padding):
+ """ padding layer helper """
+ pad_type = pad_type.lower()
+ if padding == 0:
+ return None
+ if pad_type == 'reflect':
+ layer = nn.ReflectionPad2d(padding)
+ elif pad_type == 'replicate':
+ layer = nn.ReplicationPad2d(padding)
+ elif pad_type == 'zero':
+ layer = nn.ZeroPad2d(padding)
+ else:
+ raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
+ return layer
+
+
+def get_valid_padding(kernel_size, dilation):
+ kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
+ padding = (kernel_size - 1) // 2
+ return padding
+
+
+class ShortcutBlock(nn.Module):
+ """ Elementwise sum the output of a submodule to its input """
+ def __init__(self, submodule):
+ super(ShortcutBlock, self).__init__()
+ self.sub = submodule
+
+ def forward(self, x):
+ output = x + self.sub(x)
+ return output
+
+ def __repr__(self):
+ return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
+
+
+def sequential(*args):
+ """ Flatten Sequential. It unwraps nn.Sequential. """
+ if len(args) == 1:
+ if isinstance(args[0], OrderedDict):
+ raise NotImplementedError('sequential does not support OrderedDict input.')
+ return args[0] # No sequential is needed.
+ modules = []
+ for module in args:
+ if isinstance(module, nn.Sequential):
+ for submodule in module.children():
+ modules.append(submodule)
+ elif isinstance(module, nn.Module):
+ modules.append(module)
+ return nn.Sequential(*modules)
+
+
+def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
+ pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
+ spectral_norm=False):
+ """ Conv layer with padding, normalization, activation """
+ assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
+ padding = get_valid_padding(kernel_size, dilation)
+ p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
+ padding = padding if pad_type == 'zero' else 0
+
+ if convtype=='PartialConv2D':
+ c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
+ dilation=dilation, bias=bias, groups=groups)
+ elif convtype=='DeformConv2D':
+ c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
+ dilation=dilation, bias=bias, groups=groups)
+ elif convtype=='Conv3D':
+ c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
+ dilation=dilation, bias=bias, groups=groups)
+ else:
+ c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
+ dilation=dilation, bias=bias, groups=groups)
+
+ if spectral_norm:
+ c = nn.utils.spectral_norm(c)
+
+ a = act(act_type) if act_type else None
+ if 'CNA' in mode:
+ n = norm(norm_type, out_nc) if norm_type else None
+ return sequential(p, c, n, a)
+ elif mode == 'NAC':
+ if norm_type is None and act_type is not None:
+ a = act(act_type, inplace=False)
+ n = norm(norm_type, in_nc) if norm_type else None
+ return sequential(n, a, p, c)
diff --git a/modules/extras.py b/modules/extras.py
index b853fa5b..22c5a1c1 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -39,9 +39,12 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
if input_dir == '':
return outputs, "Please select an input directory.", ''
- image_list = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
+ image_list = [file for file in [os.path.join(input_dir, x) for x in sorted(os.listdir(input_dir))] if os.path.isfile(file)]
for img in image_list:
- image = Image.open(img)
+ try:
+ image = Image.open(img)
+ except Exception:
+ continue
imageArr.append(image)
imageNameArr.append(img)
else:
@@ -118,10 +121,14 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
while len(cached_images) > 2:
del cached_images[next(iter(cached_images.keys()))]
-
- images.save_image(image, path=outpath, basename="", seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
- no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo,
- forced_filename=image_name if opts.use_original_name_batch else None)
+
+ if opts.use_original_name_batch and image_name != None:
+ basename = os.path.splitext(os.path.basename(image_name))[0]
+ else:
+ basename = ''
+
+ images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
+ no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
if opts.enable_pnginfo:
image.info = existing_pnginfo
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index c27826b6..f73647da 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -4,13 +4,22 @@ import gradio as gr
from modules.shared import script_path
from modules import shared
-re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)"
+re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
type_of_gr_update = type(gr.update())
+def quote(text):
+ if ',' not in str(text):
+ return text
+
+ text = str(text)
+ text = text.replace('\\', '\\\\')
+ text = text.replace('"', '\\"')
+ return f'"{text}"'
+
def parse_generation_parameters(x: str):
"""parses generation parameters string, the one you see in text field under the picture in UI:
```
@@ -45,11 +54,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
else:
prompt += ("" if prompt == "" else "\n") + line
- if len(prompt) > 0:
- res["Prompt"] = prompt
-
- if len(negative_prompt) > 0:
- res["Negative prompt"] = negative_prompt
+ res["Prompt"] = prompt
+ res["Negative prompt"] = negative_prompt
for k, v in re_param.findall(lastline):
m = re_imagesize.match(v)
@@ -86,7 +92,12 @@ def connect_paste(button, paste_fields, input_comp, js=None):
else:
try:
valtype = type(output.value)
- val = valtype(v)
+
+ if valtype == bool and v == "False":
+ val = False
+ else:
+ val = valtype(v)
+
res.append(gr.update(value=val))
except Exception:
res.append(gr.update())
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index b8695fc1..d647ea55 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -1,46 +1,100 @@
+import csv
import datetime
import glob
import html
import os
import sys
import traceback
-import tqdm
-import csv
+import modules.textual_inversion.dataset
import torch
-
-from ldm.util import default
-from modules import devices, shared, processing, sd_models
-import torch
-from torch import einsum
+import tqdm
from einops import rearrange, repeat
-import modules.textual_inversion.dataset
+from ldm.util import default
+from modules import devices, processing, sd_models, shared
from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
+from torch import einsum
+from collections import defaultdict, deque
+from statistics import stdev, mean
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
-
- def __init__(self, dim, state_dict=None):
+ activation_dict = {
+ "relu": torch.nn.ReLU,
+ "leakyrelu": torch.nn.LeakyReLU,
+ "elu": torch.nn.ELU,
+ "swish": torch.nn.Hardswish,
+ }
+
+ def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
super().__init__()
- self.linear1 = torch.nn.Linear(dim, dim * 2)
- self.linear2 = torch.nn.Linear(dim * 2, dim)
+ assert layer_structure is not None, "layer_structure must not be None"
+ assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
+ assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
+
+ linears = []
+ for i in range(len(layer_structure) - 1):
+
+ # Add a fully-connected layer
+ linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
+
+ # Add an activation func
+ if activation_func == "linear" or activation_func is None:
+ pass
+ elif activation_func in self.activation_dict:
+ linears.append(self.activation_dict[activation_func]())
+ else:
+ raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
+
+ # Add layer normalization
+ if add_layer_norm:
+ linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
+
+ # Add dropout expect last layer
+ if use_dropout and i < len(layer_structure) - 3:
+ linears.append(torch.nn.Dropout(p=0.3))
+
+ self.linear = torch.nn.Sequential(*linears)
if state_dict is not None:
- self.load_state_dict(state_dict, strict=True)
+ self.fix_old_state_dict(state_dict)
+ self.load_state_dict(state_dict)
else:
-
- self.linear1.weight.data.normal_(mean=0.0, std=0.01)
- self.linear1.bias.data.zero_()
- self.linear2.weight.data.normal_(mean=0.0, std=0.01)
- self.linear2.bias.data.zero_()
+ for layer in self.linear:
+ if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
+ layer.weight.data.normal_(mean=0.0, std=0.01)
+ layer.bias.data.zero_()
self.to(devices.device)
+ def fix_old_state_dict(self, state_dict):
+ changes = {
+ 'linear1.bias': 'linear.0.bias',
+ 'linear1.weight': 'linear.0.weight',
+ 'linear2.bias': 'linear.1.bias',
+ 'linear2.weight': 'linear.1.weight',
+ }
+
+ for fr, to in changes.items():
+ x = state_dict.get(fr, None)
+ if x is None:
+ continue
+
+ del state_dict[fr]
+ state_dict[to] = x
+
def forward(self, x):
- return x + (self.linear2(self.linear1(x))) * self.multiplier
+ return x + self.linear(x) * self.multiplier
+
+ def trainables(self):
+ layer_structure = []
+ for layer in self.linear:
+ if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
+ layer_structure += [layer.weight, layer.bias]
+ return layer_structure
def apply_strength(value=None):
@@ -51,16 +105,23 @@ class Hypernetwork:
filename = None
name = None
- def __init__(self, name=None, enable_sizes=None):
+ def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
self.filename = None
self.name = name
self.layers = {}
self.step = 0
self.sd_checkpoint = None
self.sd_checkpoint_name = None
+ self.layer_structure = layer_structure
+ self.activation_func = activation_func
+ self.add_layer_norm = add_layer_norm
+ self.use_dropout = use_dropout
for size in enable_sizes or []:
- self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))
+ self.layers[size] = (
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
+ )
def weights(self):
res = []
@@ -68,7 +129,7 @@ class Hypernetwork:
for k, layers in self.layers.items():
for layer in layers:
layer.train()
- res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias]
+ res += layer.trainables()
return res
@@ -80,6 +141,10 @@ class Hypernetwork:
state_dict['step'] = self.step
state_dict['name'] = self.name
+ state_dict['layer_structure'] = self.layer_structure
+ state_dict['activation_func'] = self.activation_func
+ state_dict['is_layer_norm'] = self.add_layer_norm
+ state_dict['use_dropout'] = self.use_dropout
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
@@ -92,9 +157,17 @@ class Hypernetwork:
state_dict = torch.load(filename, map_location='cpu')
+ self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
+ self.activation_func = state_dict.get('activation_func', None)
+ self.add_layer_norm = state_dict.get('is_layer_norm', False)
+ self.use_dropout = state_dict.get('use_dropout', False)
+
for size, sd in state_dict.items():
if type(size) == int:
- self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
+ self.layers[size] = (
+ HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
+ HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
+ )
self.name = state_dict.get('name', self.name)
self.step = state_dict.get('step', 0)
@@ -196,7 +269,39 @@ def stack_conds(conds):
return torch.stack(conds)
+
+def statistics(data):
+ if len(data) < 2:
+ std = 0
+ else:
+ std = stdev(data)
+ total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
+ recent_data = data[-32:]
+ if len(recent_data) < 2:
+ std = 0
+ else:
+ std = stdev(recent_data)
+ recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + u"\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})"
+ return total_information, recent_information
+
+
+def report_statistics(loss_info:dict):
+ keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
+ for key in keys:
+ try:
+ print("Loss statistics for file " + key)
+ info, recent = statistics(list(loss_info[key]))
+ print(info)
+ print(recent)
+ except Exception as e:
+ print(e)
+
+
+
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+ # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
+ from modules import images
+
assert hypernetwork_name, 'hypernetwork not selected'
path = shared.hypernetworks.get(hypernetwork_name, None)
@@ -226,7 +331,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
-
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
@@ -236,22 +340,34 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
for weight in weights:
weight.requires_grad = True
- losses = torch.zeros((32,))
+ size = len(ds.indexes)
+ loss_dict = defaultdict(lambda : deque(maxlen = 1024))
+ losses = torch.zeros((size,))
+ previous_mean_losses = [0]
+ previous_mean_loss = 0
+ print("Mean loss of {} elements".format(size))
last_saved_file = "<none>"
last_saved_image = "<none>"
+ forced_filename = "<none>"
ititial_step = hypernetwork.step or 0
if ititial_step > steps:
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+ # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
+ steps_without_grad = 0
+
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entries in pbar:
hypernetwork.step = i + ititial_step
-
+ if len(loss_dict) > 0:
+ previous_mean_losses = [i[-1] for i in loss_dict.values()]
+ previous_mean_loss = mean(previous_mean_losses)
+
scheduler.apply(optimizer, hypernetwork.step)
if scheduler.finished:
break
@@ -261,33 +377,52 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
with torch.autocast("cuda"):
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
-# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
+ # c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
loss = shared.sd_model(x, c)[0]
del x
del c
losses[hypernetwork.step % losses.shape[0]] = loss.item()
-
+ for entry in entries:
+ loss_dict[entry.filename].append(loss.item())
+
optimizer.zero_grad()
+ weights[0].grad = None
loss.backward()
+
+ if weights[0].grad is None:
+ steps_without_grad += 1
+ else:
+ steps_without_grad = 0
+ assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
+
optimizer.step()
- mean_loss = losses.mean()
- if torch.isnan(mean_loss):
+
+ if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
raise RuntimeError("Loss diverged.")
- pbar.set_description(f"loss: {mean_loss:.7f}")
+
+ if len(previous_mean_losses) > 1:
+ std = stdev(previous_mean_losses)
+ else:
+ std = 0
+ dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
+ pbar.set_description(dataset_loss_info)
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
- last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
+ # Before saving, change name to match current checkpoint.
+ hypernetwork.name = f'{hypernetwork_name}-{hypernetwork.step}'
+ last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
hypernetwork.save(last_saved_file)
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
- "loss": f"{mean_loss:.7f}",
+ "loss": f"{previous_mean_loss:.7f}",
"learn_rate": scheduler.learn_rate
})
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
- last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
+ forced_filename = f'{hypernetwork_name}-{hypernetwork.step}'
+ last_saved_image = os.path.join(images_dir, forced_filename)
optimizer.zero_grad()
shared.sd_model.cond_stage_model.to(devices.device)
@@ -323,27 +458,29 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if image is not None:
shared.state.current_image = image
- image.save(last_saved_image)
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step
shared.state.textinfo = f"""
<p>
-Loss: {mean_loss:.7f}<br/>
+Loss: {previous_mean_loss:.7f}<br/>
Step: {hypernetwork.step}<br/>
Last prompt: {html.escape(entries[0].cond_text)}<br/>
-Last saved embedding: {html.escape(last_saved_file)}<br/>
+Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
-
+
+ report_statistics(loss_dict)
checkpoint = sd_models.select_checkpoint()
hypernetwork.sd_checkpoint = checkpoint.hash
hypernetwork.sd_checkpoint_name = checkpoint.model_name
+ # Before saving for the last time, change name back to the base name (as opposed to the save_hypernetwork_every step-suffixed naming convention).
+ hypernetwork.name = hypernetwork_name
+ filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork.name}.pt')
hypernetwork.save(filename)
return hypernetwork, filename
-
-
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index dfa599af..2b472d87 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -1,19 +1,33 @@
import html
import os
+import re
import gradio as gr
-
-import modules.textual_inversion.textual_inversion
import modules.textual_inversion.preprocess
-from modules import sd_hijack, shared, devices
+import modules.textual_inversion.textual_inversion
+from modules import devices, sd_hijack, shared
from modules.hypernetworks import hypernetwork
-def create_hypernetwork(name, enable_sizes):
- fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
- assert not os.path.exists(fn), f"file {fn} already exists"
+def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
+ # Remove illegal characters from name.
+ name = "".join( x for x in name if (x.isalnum() or x in "._- "))
- hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name, enable_sizes=[int(x) for x in enable_sizes])
+ fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
+ if not overwrite_old:
+ assert not os.path.exists(fn), f"file {fn} already exists"
+
+ if type(layer_structure) == str:
+ layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
+
+ hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
+ name=name,
+ enable_sizes=[int(x) for x in enable_sizes],
+ layer_structure=layer_structure,
+ activation_func=activation_func,
+ add_layer_norm=add_layer_norm,
+ use_dropout=use_dropout,
+ )
hypernet.save(fn)
shared.reload_hypernetworks()
diff --git a/modules/images.py b/modules/images.py
index b9589563..286de2ae 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -1,4 +1,8 @@
import datetime
+import sys
+import traceback
+
+import pytz
import io
import math
import os
@@ -12,7 +16,7 @@ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
from fonts.ttf import Roboto
import string
-from modules import sd_samplers, shared
+from modules import sd_samplers, shared, script_callbacks
from modules.shared import opts, cmd_opts
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
@@ -273,10 +277,15 @@ invalid_filename_chars = '<>:"/\\|?*\n'
invalid_filename_prefix = ' '
invalid_filename_postfix = ' .'
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
+re_pattern = re.compile(r"([^\[\]]+|\[([^]]+)]|[\[\]]*)")
+re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
max_filename_part_length = 128
def sanitize_filename_part(text, replace_spaces=True):
+ if text is None:
+ return None
+
if replace_spaces:
text = text.replace(' ', '_')
@@ -286,49 +295,106 @@ def sanitize_filename_part(text, replace_spaces=True):
return text
-def apply_filename_pattern(x, p, seed, prompt):
- max_prompt_words = opts.directories_max_prompt_words
-
- if seed is not None:
- x = x.replace("[seed]", str(seed))
-
- if p is not None:
- x = x.replace("[steps]", str(p.steps))
- x = x.replace("[cfg]", str(p.cfg_scale))
- x = x.replace("[width]", str(p.width))
- x = x.replace("[height]", str(p.height))
- x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]) or "None", replace_spaces=False))
- x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
-
- x = x.replace("[model_hash]", getattr(p, "sd_model_hash", shared.sd_model.sd_model_hash))
- x = x.replace("[date]", datetime.date.today().isoformat())
- x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
- x = x.replace("[job_timestamp]", getattr(p, "job_timestamp", shared.state.job_timestamp))
+class FilenameGenerator:
+ replacements = {
+ 'seed': lambda self: self.seed if self.seed is not None else '',
+ 'steps': lambda self: self.p and self.p.steps,
+ 'cfg': lambda self: self.p and self.p.cfg_scale,
+ 'width': lambda self: self.p and self.p.width,
+ 'height': lambda self: self.p and self.p.height,
+ 'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
+ 'sampler': lambda self: self.p and sanitize_filename_part(sd_samplers.samplers[self.p.sampler_index].name, replace_spaces=False),
+ 'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
+ 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
+ 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
+ 'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
+ 'prompt': lambda self: sanitize_filename_part(self.prompt),
+ 'prompt_no_styles': lambda self: self.prompt_no_style(),
+ 'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
+ 'prompt_words': lambda self: self.prompt_words(),
+ }
+ default_time_format = '%Y%m%d%H%M%S'
+
+ def __init__(self, p, seed, prompt):
+ self.p = p
+ self.seed = seed
+ self.prompt = prompt
+
+ def prompt_no_style(self):
+ if self.p is None or self.prompt is None:
+ return None
+
+ prompt_no_style = self.prompt
+ for style in shared.prompt_styles.get_style_prompts(self.p.styles):
+ if len(style) > 0:
+ for part in style.split("{prompt}"):
+ prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
+
+ prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
+
+ return sanitize_filename_part(prompt_no_style, replace_spaces=False)
+
+ def prompt_words(self):
+ words = [x for x in re_nonletters.split(self.prompt or "") if len(x) > 0]
+ if len(words) == 0:
+ words = ["empty"]
+ return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False)
+
+ def datetime(self, *args):
+ time_datetime = datetime.datetime.now()
+
+ time_format = args[0] if len(args) > 0 else self.default_time_format
+ try:
+ time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
+ except pytz.exceptions.UnknownTimeZoneError as _:
+ time_zone = None
+
+ time_zone_time = time_datetime.astimezone(time_zone)
+ try:
+ formatted_time = time_zone_time.strftime(time_format)
+ except (ValueError, TypeError) as _:
+ formatted_time = time_zone_time.strftime(self.default_time_format)
+
+ return sanitize_filename_part(formatted_time, replace_spaces=False)
+
+ def apply(self, x):
+ res = ''
+
+ for m in re_pattern.finditer(x):
+ text, pattern = m.groups()
+
+ if pattern is None:
+ res += text
+ continue
- # Apply [prompt] at last. Because it may contain any replacement word.^M
- if prompt is not None:
- x = x.replace("[prompt]", sanitize_filename_part(prompt))
- if "[prompt_no_styles]" in x:
- prompt_no_style = prompt
- for style in shared.prompt_styles.get_style_prompts(p.styles):
- if len(style) > 0:
- style_parts = [y for y in style.split("{prompt}")]
- for part in style_parts:
- prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
- prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
- x = x.replace("[prompt_no_styles]", sanitize_filename_part(prompt_no_style, replace_spaces=False))
+ pattern_args = []
+ while True:
+ m = re_pattern_arg.match(pattern)
+ if m is None:
+ break
+
+ pattern, arg = m.groups()
+ pattern_args.insert(0, arg)
+
+ fun = self.replacements.get(pattern.lower())
+ if fun is not None:
+ try:
+ replacement = fun(self, *pattern_args)
+ except Exception:
+ replacement = None
+ print(f"Error adding [{pattern}] to filename", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ if replacement is None:
+ res += f'[{pattern}]'
+ else:
+ res += str(replacement)
- x = x.replace("[prompt_spaces]", sanitize_filename_part(prompt, replace_spaces=False))
- if "[prompt_words]" in x:
- words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0]
- if len(words) == 0:
- words = ["empty"]
- x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
+ continue
- if cmd_opts.hide_ui_dir_config:
- x = re.sub(r'^[\\/]+|\.{2,}[\\/]+|[\\/]+\.{2,}', '', x)
+ res += f'[{pattern}]'
- return x
+ return res
def get_next_sequence_number(path, basename):
@@ -354,7 +420,7 @@ def get_next_sequence_number(path, basename):
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
- '''Save an image.
+ """Save an image.
Args:
image (`PIL.Image`):
@@ -385,18 +451,8 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
The full path of the saved imaged.
txt_fullfn (`str` or None):
If a text file is saved for this image, this will be its full path. Otherwise None.
- '''
- if short_filename or prompt is None or seed is None:
- file_decoration = ""
- elif opts.save_to_dirs:
- file_decoration = opts.samples_filename_pattern or "[seed]"
- else:
- file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
-
- if file_decoration != "":
- file_decoration = "-" + file_decoration.lower()
-
- file_decoration = apply_filename_pattern(file_decoration, p, seed, prompt) + suffix
+ """
+ namegen = FilenameGenerator(p, seed, prompt)
if extension == 'png' and opts.enable_pnginfo and info is not None:
pnginfo = PngImagePlugin.PngInfo()
@@ -413,21 +469,39 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
if save_to_dirs:
- dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ /')
+ dirname = namegen.apply(opts.directories_filename_pattern or "[prompt_words]").lstrip(' ').rstrip('\\ /')
path = os.path.join(path, dirname)
os.makedirs(path, exist_ok=True)
if forced_filename is None:
- basecount = get_next_sequence_number(path, basename)
- fullfn = "a.png"
- fullfn_without_extension = "a"
- for i in range(500):
- fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
- fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
- fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
- if not os.path.exists(fullfn):
- break
+ if short_filename or seed is None:
+ file_decoration = ""
+ elif opts.save_to_dirs:
+ file_decoration = opts.samples_filename_pattern or "[seed]"
+ else:
+ file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
+
+ add_number = opts.save_images_add_number or file_decoration == ''
+
+ if file_decoration != "" and add_number:
+ file_decoration = "-" + file_decoration
+
+ file_decoration = namegen.apply(file_decoration) + suffix
+
+ if add_number:
+ basecount = get_next_sequence_number(path, basename)
+ fullfn = None
+ fullfn_without_extension = None
+ for i in range(500):
+ fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
+ fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
+ fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
+ if not os.path.exists(fullfn):
+ break
+ else:
+ fullfn = os.path.join(path, f"{file_decoration}.{extension}")
+ fullfn_without_extension = os.path.join(path, file_decoration)
else:
fullfn = os.path.join(path, f"{forced_filename}.{extension}")
fullfn_without_extension = os.path.join(path, forced_filename)
@@ -467,6 +541,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
else:
txt_fullfn = None
+ script_callbacks.image_saved_callback(image, p, fullfn, txt_fullfn)
return fullfn, txt_fullfn
diff --git a/modules/images_history.py b/modules/images_history.py
deleted file mode 100644
index 46b23e56..00000000
--- a/modules/images_history.py
+++ /dev/null
@@ -1,183 +0,0 @@
-import os
-import shutil
-import sys
-
-def traverse_all_files(output_dir, image_list, curr_dir=None):
- curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir)
- try:
- f_list = os.listdir(curr_path)
- except:
- if curr_dir[-10:].rfind(".") > 0 and curr_dir[-4:] != ".txt":
- image_list.append(curr_dir)
- return image_list
- for file in f_list:
- file = file if curr_dir is None else os.path.join(curr_dir, file)
- file_path = os.path.join(curr_path, file)
- if file[-4:] == ".txt":
- pass
- elif os.path.isfile(file_path) and file[-10:].rfind(".") > 0:
- image_list.append(file)
- else:
- image_list = traverse_all_files(output_dir, image_list, file)
- return image_list
-
-
-def get_recent_images(dir_name, page_index, step, image_index, tabname):
- page_index = int(page_index)
- image_list = []
- if not os.path.exists(dir_name):
- pass
- elif os.path.isdir(dir_name):
- image_list = traverse_all_files(dir_name, image_list)
- image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file)))
- else:
- print(f'ERROR: "{dir_name}" is not a directory. Check the path in the settings.', file=sys.stderr)
- num = 48 if tabname != "extras" else 12
- max_page_index = len(image_list) // num + 1
- page_index = max_page_index if page_index == -1 else page_index + step
- page_index = 1 if page_index < 1 else page_index
- page_index = max_page_index if page_index > max_page_index else page_index
- idx_frm = (page_index - 1) * num
- image_list = image_list[idx_frm:idx_frm + num]
- image_index = int(image_index)
- if image_index < 0 or image_index > len(image_list) - 1:
- current_file = None
- hidden = None
- else:
- current_file = image_list[int(image_index)]
- hidden = os.path.join(dir_name, current_file)
- return [os.path.join(dir_name, file) for file in image_list], page_index, image_list, current_file, hidden, ""
-
-
-def first_page_click(dir_name, page_index, image_index, tabname):
- return get_recent_images(dir_name, 1, 0, image_index, tabname)
-
-
-def end_page_click(dir_name, page_index, image_index, tabname):
- return get_recent_images(dir_name, -1, 0, image_index, tabname)
-
-
-def prev_page_click(dir_name, page_index, image_index, tabname):
- return get_recent_images(dir_name, page_index, -1, image_index, tabname)
-
-
-def next_page_click(dir_name, page_index, image_index, tabname):
- return get_recent_images(dir_name, page_index, 1, image_index, tabname)
-
-
-def page_index_change(dir_name, page_index, image_index, tabname):
- return get_recent_images(dir_name, page_index, 0, image_index, tabname)
-
-
-def show_image_info(num, image_path, filenames):
- # print(f"select image {num}")
- file = filenames[int(num)]
- return file, num, os.path.join(image_path, file)
-
-
-def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, image_index):
- if name == "":
- return filenames, delete_num
- else:
- delete_num = int(delete_num)
- index = list(filenames).index(name)
- i = 0
- new_file_list = []
- for name in filenames:
- if i >= index and i < index + delete_num:
- path = os.path.join(dir_name, name)
- if os.path.exists(path):
- print(f"Delete file {path}")
- os.remove(path)
- txt_file = os.path.splitext(path)[0] + ".txt"
- if os.path.exists(txt_file):
- os.remove(txt_file)
- else:
- print(f"Not exists file {path}")
- else:
- new_file_list.append(name)
- i += 1
- return new_file_list, 1
-
-
-def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
- if opts.outdir_samples != "":
- dir_name = opts.outdir_samples
- elif tabname == "txt2img":
- dir_name = opts.outdir_txt2img_samples
- elif tabname == "img2img":
- dir_name = opts.outdir_img2img_samples
- elif tabname == "extras":
- dir_name = opts.outdir_extras_samples
- else:
- return
- with gr.Row():
- renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page")
- first_page = gr.Button('First Page')
- prev_page = gr.Button('Prev Page')
- page_index = gr.Number(value=1, label="Page Index")
- next_page = gr.Button('Next Page')
- end_page = gr.Button('End Page')
- with gr.Row(elem_id=tabname + "_images_history"):
- with gr.Row():
- with gr.Column(scale=2):
- history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6)
- with gr.Row():
- delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
- delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button")
- with gr.Column():
- with gr.Row():
- pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
- pnginfo_send_to_img2img = gr.Button('Send to img2img')
- with gr.Row():
- with gr.Column():
- img_file_info = gr.Textbox(label="Generate Info", interactive=False)
- img_file_name = gr.Textbox(label="File Name", interactive=False)
- with gr.Row():
- # hiden items
-
- img_path = gr.Textbox(dir_name.rstrip("/"), visible=False)
- tabname_box = gr.Textbox(tabname, visible=False)
- image_index = gr.Textbox(value=-1, visible=False)
- set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False)
- filenames = gr.State()
- hidden = gr.Image(type="pil", visible=False)
- info1 = gr.Textbox(visible=False)
- info2 = gr.Textbox(visible=False)
-
- # turn pages
- gallery_inputs = [img_path, page_index, image_index, tabname_box]
- gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hidden, img_file_name]
-
- first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
- next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
- prev_page.click(prev_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
- end_page.click(end_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
- page_index.submit(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
- renew_page.click(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
- # page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index])
-
- # other funcitons
- set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hidden])
- img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
- delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num])
- hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
-
- # pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
- switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img')
- switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img')
-
-
-def create_history_tabs(gr, opts, run_pnginfo, switch_dict):
- with gr.Blocks(analytics_enabled=False) as images_history:
- with gr.Tabs() as tabs:
- with gr.Tab("txt2img history"):
- with gr.Blocks(analytics_enabled=False) as images_history_txt2img:
- show_images_history(gr, opts, "txt2img", run_pnginfo, switch_dict)
- with gr.Tab("img2img history"):
- with gr.Blocks(analytics_enabled=False) as images_history_img2img:
- show_images_history(gr, opts, "img2img", run_pnginfo, switch_dict)
- with gr.Tab("extras history"):
- with gr.Blocks(analytics_enabled=False) as images_history_img2img:
- show_images_history(gr, opts, "extras", run_pnginfo, switch_dict)
- return images_history
diff --git a/modules/img2img.py b/modules/img2img.py
index 24126774..8d9f7cf9 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -109,6 +109,9 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
inpainting_mask_invert=inpainting_mask_invert,
)
+ p.scripts = modules.scripts.scripts_txt2img
+ p.script_args = args
+
if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 64b91eb4..65b05d34 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -28,9 +28,11 @@ class InterrogateModels:
clip_preprocess = None
categories = None
dtype = None
+ running_on_cpu = None
def __init__(self, content_dir):
self.categories = []
+ self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
if os.path.exists(content_dir):
for filename in os.listdir(content_dir):
@@ -53,7 +55,11 @@ class InterrogateModels:
def load_clip_model(self):
import clip
- model, preprocess = clip.load(clip_model_name)
+ if self.running_on_cpu:
+ model, preprocess = clip.load(clip_model_name, device="cpu")
+ else:
+ model, preprocess = clip.load(clip_model_name)
+
model.eval()
model = model.to(devices.device_interrogate)
@@ -62,14 +68,14 @@ class InterrogateModels:
def load(self):
if self.blip_model is None:
self.blip_model = self.load_blip_model()
- if not shared.cmd_opts.no_half:
+ if not shared.cmd_opts.no_half and not self.running_on_cpu:
self.blip_model = self.blip_model.half()
self.blip_model = self.blip_model.to(devices.device_interrogate)
if self.clip_model is None:
self.clip_model, self.clip_preprocess = self.load_clip_model()
- if not shared.cmd_opts.no_half:
+ if not shared.cmd_opts.no_half and not self.running_on_cpu:
self.clip_model = self.clip_model.half()
self.clip_model = self.clip_model.to(devices.device_interrogate)
diff --git a/modules/lowvram.py b/modules/lowvram.py
index 7eba1349..f327c3df 100644
--- a/modules/lowvram.py
+++ b/modules/lowvram.py
@@ -1,9 +1,8 @@
import torch
-from modules.devices import get_optimal_device
+from modules import devices
module_in_gpu = None
cpu = torch.device("cpu")
-device = gpu = get_optimal_device()
def send_everything_to_cpu():
@@ -33,7 +32,7 @@ def setup_for_low_vram(sd_model, use_medvram):
if module_in_gpu is not None:
module_in_gpu.to(cpu)
- module.to(gpu)
+ module.to(devices.device)
module_in_gpu = module
# see below for register_forward_pre_hook;
@@ -51,7 +50,7 @@ def setup_for_low_vram(sd_model, use_medvram):
# send the model to GPU. Then put modules back. the modules will be in CPU.
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
- sd_model.to(device)
+ sd_model.to(devices.device)
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
# register hooks for those the first two models
@@ -70,7 +69,7 @@ def setup_for_low_vram(sd_model, use_medvram):
# so that only one of them is in GPU at a time
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
- sd_model.model.to(device)
+ sd_model.model.to(devices.device)
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
# install hooks for bits of third model
diff --git a/modules/processing.py b/modules/processing.py
index ea926fc3..c61bbfbd 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -12,7 +12,7 @@ from skimage import exposure
from typing import Any, Dict, List, Optional
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -47,6 +47,25 @@ def apply_color_correction(correction, image):
return image
+def apply_overlay(image, paste_loc, index, overlays):
+ if overlays is None or index >= len(overlays):
+ return image
+
+ overlay = overlays[index]
+
+ if paste_loc is not None:
+ x, y, w, h = paste_loc
+ base_image = Image.new('RGBA', (overlay.width, overlay.height))
+ image = images.resize_image(1, image, w, h)
+ base_image.paste(image, (x, y))
+ image = base_image
+
+ image = image.convert('RGBA')
+ image.alpha_composite(overlay)
+ image = image.convert('RGB')
+
+ return image
+
def get_correct_sampler(p):
if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img):
return sd_samplers.samplers
@@ -104,6 +123,12 @@ class StableDiffusionProcessing():
self.seed_resize_from_h = 0
self.seed_resize_from_w = 0
+ self.scripts = None
+ self.script_args = None
+ self.all_prompts = None
+ self.all_seeds = None
+ self.all_subseeds = None
+
def init(self, all_prompts, all_seeds, all_subseeds):
pass
@@ -304,7 +329,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
- "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(':', '')),
+ "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
@@ -318,7 +343,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
generation_params.update(p.extra_generation_params)
- generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
+ generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
@@ -350,32 +375,35 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
shared.prompt_styles.apply_styles(p)
if type(p.prompt) == list:
- all_prompts = p.prompt
+ p.all_prompts = p.prompt
else:
- all_prompts = p.batch_size * p.n_iter * [p.prompt]
+ p.all_prompts = p.batch_size * p.n_iter * [p.prompt]
if type(seed) == list:
- all_seeds = seed
+ p.all_seeds = seed
else:
- all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
+ p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
if type(subseed) == list:
- all_subseeds = subseed
+ p.all_subseeds = subseed
else:
- all_subseeds = [int(subseed) + x for x in range(len(all_prompts))]
+ p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
def infotext(iteration=0, position_in_batch=0):
- return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
+ return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()
+ if p.scripts is not None:
+ p.scripts.run_alwayson_scripts(p)
+
infotexts = []
output_images = []
with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast():
- p.init(all_prompts, all_seeds, all_subseeds)
+ p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
if state.job_count == -1:
state.job_count = p.n_iter
@@ -387,15 +415,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if state.interrupted:
break
- prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
- seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
- subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
+ prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+ seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
+ subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
if (len(prompts) == 0):
break
- #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
- #c = p.sd_model.get_learned_conditioning(prompts)
with devices.autocast():
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
@@ -442,22 +468,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if p.color_corrections is not None and i < len(p.color_corrections):
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
- images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
+ image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
+ images.save_image(image_without_cc, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)
- if p.overlay_images is not None and i < len(p.overlay_images):
- overlay = p.overlay_images[i]
-
- if p.paste_to is not None:
- x, y, w, h = p.paste_to
- base_image = Image.new('RGBA', (overlay.width, overlay.height))
- image = images.resize_image(1, image, w, h)
- base_image.paste(image, (x, y))
- image = base_image
-
- image = image.convert('RGBA')
- image.alpha_composite(overlay)
- image = image.convert('RGB')
+ image = apply_overlay(image, p.paste_to, i, p.overlay_images)
if opts.samples_save and not p.do_not_save_samples:
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
@@ -490,10 +505,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
index_of_first_image = 1
if opts.grid_save:
- images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
+ images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
devices.torch_gc()
- return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
+ return Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
@@ -515,6 +530,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else:
state.job_count = state.job_count * 2
+ self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
+
if self.firstphase_width == 0 or self.firstphase_height == 0:
desired_pixel_count = 512 * 512
actual_pixel_count = self.width * self.height
@@ -536,21 +553,40 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
firstphase_width_truncated = self.firstphase_height * self.width / self.height
firstphase_height_truncated = self.firstphase_height
- self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
+ def create_dummy_mask(self, x, width=None, height=None):
+ if self.sampler.conditioning_key in {'hybrid', 'concat'}:
+ height = height or self.height
+ width = width or self.width
+
+ # The "masked-image" in this case will just be all zeros since the entire image is masked.
+ image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
+ image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
+
+ # Add the fake full 1s mask to the first dimension.
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
+ image_conditioning = image_conditioning.to(x.dtype)
+
+ else:
+ # Dummy zero conditioning if we're not using inpainting model.
+ # Still takes up a bit of memory, but no encoder call.
+ # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
+ image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
+
+ return image_conditioning
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x))
return samples
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, self.firstphase_width, self.firstphase_height))
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
@@ -587,7 +623,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = None
devices.torch_gc()
- samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
+ samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=self.create_dummy_mask(samples))
return samples
@@ -595,7 +631,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None
- def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0, **kwargs):
+ def __init__(self, init_images: list=None, resize_mode: int=0, denoising_strength: float=0.75, mask: Any=None, mask_blur: int=4, inpainting_fill: int=0, inpaint_full_res: bool=True, inpaint_full_res_padding: int=0, inpainting_mask_invert: int=0, **kwargs):
super().__init__(**kwargs)
self.init_images = init_images
@@ -613,6 +649,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.inpainting_mask_invert = inpainting_mask_invert
self.mask = None
self.nmask = None
+ self.image_conditioning = None
def init(self, all_prompts, all_seeds, all_subseeds):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
@@ -685,6 +722,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
if self.overlay_images is not None:
self.overlay_images = self.overlay_images * self.batch_size
+
+ if self.color_corrections is not None and len(self.color_corrections) == 1:
+ self.color_corrections = self.color_corrections * self.batch_size
+
elif len(imgs) <= self.batch_size:
self.batch_size = len(imgs)
batch_images = np.array(imgs)
@@ -714,10 +755,39 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
+ if self.sampler.conditioning_key in {'hybrid', 'concat'}:
+ if self.image_mask is not None:
+ conditioning_mask = np.array(self.image_mask.convert("L"))
+ conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
+ conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
+
+ # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
+ conditioning_mask = torch.round(conditioning_mask)
+ else:
+ conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
+
+ # Create another latent image, this time with a masked version of the original input.
+ conditioning_mask = conditioning_mask.to(image.device)
+ conditioning_image = image * (1.0 - conditioning_mask)
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
+
+ # Create the concatenated conditioning tensor to be fed to `c_concat`
+ conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
+ conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
+ self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
+ self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
+ else:
+ self.image_conditioning = torch.zeros(
+ self.init_latent.shape[0], 5, 1, 1,
+ dtype=self.init_latent.dtype,
+ device=self.init_latent.device
+ )
+
+
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
+ samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
if self.mask is not None:
samples = samples * self.nmask + self.init_latent * self.mask
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
new file mode 100644
index 00000000..dc520abc
--- /dev/null
+++ b/modules/script_callbacks.py
@@ -0,0 +1,100 @@
+import sys
+import traceback
+from collections import namedtuple
+import inspect
+
+
+def report_exception(c, job):
+ print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+
+ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
+callbacks_model_loaded = []
+callbacks_ui_tabs = []
+callbacks_ui_settings = []
+callbacks_image_saved = []
+
+def clear_callbacks():
+ callbacks_model_loaded.clear()
+ callbacks_ui_tabs.clear()
+ callbacks_image_saved.clear()
+
+
+def model_loaded_callback(sd_model):
+ for c in callbacks_model_loaded:
+ try:
+ c.callback(sd_model)
+ except Exception:
+ report_exception(c, 'model_loaded_callback')
+
+
+def ui_tabs_callback():
+ res = []
+
+ for c in callbacks_ui_tabs:
+ try:
+ res += c.callback() or []
+ except Exception:
+ report_exception(c, 'ui_tabs_callback')
+
+ return res
+
+
+def ui_settings_callback():
+ for c in callbacks_ui_settings:
+ try:
+ c.callback()
+ except Exception:
+ report_exception(c, 'ui_settings_callback')
+
+
+def image_saved_callback(image, p, fullfn, txt_fullfn):
+ for c in callbacks_image_saved:
+ try:
+ c.callback(image, p, fullfn, txt_fullfn)
+ except Exception:
+ report_exception(c, 'image_saved_callback')
+
+
+def add_callback(callbacks, fun):
+ stack = [x for x in inspect.stack() if x.filename != __file__]
+ filename = stack[0].filename if len(stack) > 0 else 'unknown file'
+
+ callbacks.append(ScriptCallback(filename, fun))
+
+
+
+def on_model_loaded(callback):
+ """register a function to be called when the stable diffusion model is created; the model is
+ passed as an argument"""
+ add_callback(callbacks_model_loaded, callback)
+
+
+def on_ui_tabs(callback):
+ """register a function to be called when the UI is creating new tabs.
+ The function must either return a None, which means no new tabs to be added, or a list, where
+ each element is a tuple:
+ (gradio_component, title, elem_id)
+
+ gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
+ title is tab text displayed to user in the UI
+ elem_id is HTML id for the tab
+ """
+ add_callback(callbacks_ui_tabs, callback)
+
+
+def on_ui_settings(callback):
+ """register a function to be called before UI settings are populated; add your settings
+ by using shared.opts.add_option(shared.OptionInfo(...)) """
+ add_callback(callbacks_ui_settings, callback)
+
+
+def on_save_imaged(callback):
+ """register a function to be called after modules.images.save_image is called.
+ The callback is called with three arguments:
+ - p - procesing object (or a dummy object with same fields if the image is saved using save button)
+ - fullfn - image filename
+ - txt_fullfn - text file with parameters; may be None
+ """
+ add_callback(callbacks_image_saved, callback)
diff --git a/modules/scripts.py b/modules/scripts.py
index 1039fa9c..9323af3e 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -1,86 +1,175 @@
import os
import sys
import traceback
+from collections import namedtuple
import modules.ui as ui
import gradio as gr
from modules.processing import StableDiffusionProcessing
-from modules import shared
+from modules import shared, paths, script_callbacks
+
+AlwaysVisible = object()
+
class Script:
filename = None
args_from = None
args_to = None
+ alwayson = False
+
+ infotext_fields = None
+ """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
+ parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
+ """
- # The title of the script. This is what will be displayed in the dropdown menu.
def title(self):
+ """this function should return the title of the script. This is what will be displayed in the dropdown menu."""
+
raise NotImplementedError()
- # How the script is displayed in the UI. See https://gradio.app/docs/#components
- # for the different UI components you can use and how to create them.
- # Most UI components can return a value, such as a boolean for a checkbox.
- # The returned values are passed to the run method as parameters.
def ui(self, is_img2img):
+ """this function should create gradio UI elements. See https://gradio.app/docs/#components
+ The return value should be an array of all components that are used in processing.
+ Values of those returned componenbts will be passed to run() and process() functions.
+ """
+
pass
- # Determines when the script should be shown in the dropdown menu via the
- # returned value. As an example:
- # is_img2img is True if the current tab is img2img, and False if it is txt2img.
- # Thus, return is_img2img to only show the script on the img2img tab.
def show(self, is_img2img):
+ """
+ is_img2img is True if this function is called for the img2img interface, and Fasle otherwise
+
+ This function should return:
+ - False if the script should not be shown in UI at all
+ - True if the script should be shown in UI if it's scelected in the scripts drowpdown
+ - script.AlwaysVisible if the script should be shown in UI at all times
+ """
+
return True
- # This is where the additional processing is implemented. The parameters include
- # self, the model object "p" (a StableDiffusionProcessing class, see
- # processing.py), and the parameters returned by the ui method.
- # Custom functions can be defined here, and additional libraries can be imported
- # to be used in processing. The return value should be a Processed object, which is
- # what is returned by the process_images method.
- def run(self, *args):
+ def run(self, p, *args):
+ """
+ This function is called if the script has been selected in the script dropdown.
+ It must do all processing and return the Processed object with results, same as
+ one returned by processing.process_images.
+
+ Usually the processing is done by calling the processing.process_images function.
+
+ args contains all values returned by components from ui()
+ """
+
raise NotImplementedError()
- # The description method is currently unused.
- # To add a description that appears when hovering over the title, amend the "titles"
- # dict in script.js to include the script title (returned by title) as a key, and
- # your description as the value.
+ def process(self, p, *args):
+ """
+ This function is called before processing begins for AlwaysVisible scripts.
+ scripts. You can modify the processing object (p) here, inject hooks, etc.
+ """
+
+ pass
+
def describe(self):
+ """unused"""
return ""
+current_basedir = paths.script_path
+
+
+def basedir():
+ """returns the base directory for the current script. For scripts in the main scripts directory,
+ this is the main directory (where webui.py resides), and for scripts in extensions directory
+ (ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic)
+ """
+ return current_basedir
+
+
scripts_data = []
+ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
+ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir"])
-def load_scripts(basedir):
- if not os.path.exists(basedir):
- return
+def list_scripts(scriptdirname, extension):
+ scripts_list = []
- for filename in sorted(os.listdir(basedir)):
- path = os.path.join(basedir, filename)
+ basedir = os.path.join(paths.script_path, scriptdirname)
+ if os.path.exists(basedir):
+ for filename in sorted(os.listdir(basedir)):
+ scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
- if os.path.splitext(path)[1].lower() != '.py':
- continue
+ extdir = os.path.join(paths.script_path, "extensions")
+ if os.path.exists(extdir):
+ for dirname in sorted(os.listdir(extdir)):
+ dirpath = os.path.join(extdir, dirname)
+ scriptdirpath = os.path.join(dirpath, scriptdirname)
+
+ if not os.path.isdir(scriptdirpath):
+ continue
+
+ for filename in sorted(os.listdir(scriptdirpath)):
+ scripts_list.append(ScriptFile(dirpath, filename, os.path.join(scriptdirpath, filename)))
+
+ scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
+
+ return scripts_list
+
+
+def list_files_with_name(filename):
+ res = []
+
+ dirs = [paths.script_path]
+
+ extdir = os.path.join(paths.script_path, "extensions")
+ if os.path.exists(extdir):
+ dirs += [os.path.join(extdir, d) for d in sorted(os.listdir(extdir))]
- if not os.path.isfile(path):
+ for dirpath in dirs:
+ if not os.path.isdir(dirpath):
continue
+ path = os.path.join(dirpath, filename)
+ if os.path.isfile(filename):
+ res.append(path)
+
+ return res
+
+
+def load_scripts():
+ global current_basedir
+ scripts_data.clear()
+ script_callbacks.clear_callbacks()
+
+ scripts_list = list_scripts("scripts", ".py")
+
+ syspath = sys.path
+
+ for scriptfile in sorted(scripts_list):
try:
- with open(path, "r", encoding="utf8") as file:
+ if scriptfile.basedir != paths.script_path:
+ sys.path = [scriptfile.basedir] + sys.path
+ current_basedir = scriptfile.basedir
+
+ with open(scriptfile.path, "r", encoding="utf8") as file:
text = file.read()
from types import ModuleType
- compiled = compile(text, path, 'exec')
- module = ModuleType(filename)
+ compiled = compile(text, scriptfile.path, 'exec')
+ module = ModuleType(scriptfile.filename)
exec(compiled, module.__dict__)
for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script):
- scripts_data.append((script_class, path))
+ scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir))
except Exception:
- print(f"Error loading script: {filename}", file=sys.stderr)
+ print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
+ finally:
+ sys.path = syspath
+ current_basedir = paths.script_path
+
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
try:
@@ -96,56 +185,80 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
class ScriptRunner:
def __init__(self):
self.scripts = []
+ self.selectable_scripts = []
+ self.alwayson_scripts = []
self.titles = []
+ self.infotext_fields = []
def setup_ui(self, is_img2img):
- for script_class, path in scripts_data:
+ for script_class, path, basedir in scripts_data:
script = script_class()
script.filename = path
- if not script.show(is_img2img):
- continue
+ visibility = script.show(is_img2img)
- self.scripts.append(script)
+ if visibility == AlwaysVisible:
+ self.scripts.append(script)
+ self.alwayson_scripts.append(script)
+ script.alwayson = True
- self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts]
+ elif visibility:
+ self.scripts.append(script)
+ self.selectable_scripts.append(script)
- dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index")
- dropdown.save_to_config = True
- inputs = [dropdown]
+ self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
+
+ inputs = [None]
+ inputs_alwayson = [True]
- for script in self.scripts:
+ def create_script_ui(script, inputs, inputs_alwayson):
script.args_from = len(inputs)
script.args_to = len(inputs)
controls = wrap_call(script.ui, script.filename, "ui", is_img2img)
if controls is None:
- continue
+ return
for control in controls:
control.custom_script_source = os.path.basename(script.filename)
- control.visible = False
+ if not script.alwayson:
+ control.visible = False
+
+ if script.infotext_fields is not None:
+ self.infotext_fields += script.infotext_fields
inputs += controls
+ inputs_alwayson += [script.alwayson for _ in controls]
script.args_to = len(inputs)
+ for script in self.alwayson_scripts:
+ with gr.Group():
+ create_script_ui(script, inputs, inputs_alwayson)
+
+ dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index")
+ dropdown.save_to_config = True
+ inputs[0] = dropdown
+
+ for script in self.selectable_scripts:
+ create_script_ui(script, inputs, inputs_alwayson)
+
def select_script(script_index):
- if 0 < script_index <= len(self.scripts):
- script = self.scripts[script_index-1]
+ if 0 < script_index <= len(self.selectable_scripts):
+ script = self.selectable_scripts[script_index-1]
args_from = script.args_from
args_to = script.args_to
else:
args_from = 0
args_to = 0
- return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))]
+ return [ui.gr_show(True if i == 0 else args_from <= i < args_to or is_alwayson) for i, is_alwayson in enumerate(inputs_alwayson)]
def init_field(title):
if title == 'None':
return
script_index = self.titles.index(title)
- script = self.scripts[script_index]
+ script = self.selectable_scripts[script_index]
for i in range(script.args_from, script.args_to):
inputs[i].visible = True
@@ -164,7 +277,7 @@ class ScriptRunner:
if script_index == 0:
return None
- script = self.scripts[script_index-1]
+ script = self.selectable_scripts[script_index-1]
if script is None:
return None
@@ -176,7 +289,16 @@ class ScriptRunner:
return processed
- def reload_sources(self):
+ def run_alwayson_scripts(self, p):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.process(p, *script_args)
+ except Exception:
+ print(f"Error running alwayson script: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ def reload_sources(self, cache):
for si, script in list(enumerate(self.scripts)):
with open(script.filename, "r", encoding="utf8") as file:
args_from = script.args_from
@@ -186,9 +308,12 @@ class ScriptRunner:
from types import ModuleType
- compiled = compile(text, filename, 'exec')
- module = ModuleType(script.filename)
- exec(compiled, module.__dict__)
+ module = cache.get(filename, None)
+ if module is None:
+ compiled = compile(text, filename, 'exec')
+ module = ModuleType(script.filename)
+ exec(compiled, module.__dict__)
+ cache[filename] = module
for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script):
@@ -197,19 +322,22 @@ class ScriptRunner:
self.scripts[si].args_from = args_from
self.scripts[si].args_to = args_to
+
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
+
def reload_script_body_only():
- scripts_txt2img.reload_sources()
- scripts_img2img.reload_sources()
+ cache = {}
+ scripts_txt2img.reload_sources(cache)
+ scripts_img2img.reload_sources(cache)
-def reload_scripts(basedir):
+def reload_scripts():
global scripts_txt2img, scripts_img2img
- scripts_data.clear()
- load_scripts(basedir)
+ load_scripts()
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
+
diff --git a/modules/scunet_model.py b/modules/scunet_model.py
index 36a996bf..59532274 100644
--- a/modules/scunet_model.py
+++ b/modules/scunet_model.py
@@ -54,9 +54,8 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(device)
+ img = devices.mps_contiguous_to(img.unsqueeze(0), device)
- img = img.to(device)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 984b35c4..0f10828e 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -19,6 +19,7 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
+
def apply_optimizations():
undo_optimizations()
@@ -167,11 +168,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
remade_tokens = remade_tokens[:last_comma]
length = len(remade_tokens)
-
+
rem = int(math.ceil(length / 75)) * 75 - length
remade_tokens += [id_end] * rem + reloc_tokens
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
-
+
if embedding is None:
remade_tokens.append(token)
multipliers.append(weight)
@@ -223,7 +224,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
-
def process_text_old(self, text):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
@@ -280,7 +280,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
- remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
+ remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
@@ -290,7 +290,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
-
+
def forward(self, text):
use_old = opts.use_old_emphasis_implementation
if use_old:
@@ -302,11 +302,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if len(used_custom_terms) > 0:
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
-
+
if use_old:
self.hijack.fixes = hijack_fixes
return self.process_tokens(remade_batch_tokens, batch_multipliers)
-
+
z = None
i = 0
while max(map(len, remade_batch_tokens)) != 0:
@@ -320,7 +320,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if fix[0] == i:
fixes.append(fix[1])
self.hijack.fixes.append(fixes)
-
+
tokens = []
multipliers = []
for j in range(len(remade_batch_tokens)):
@@ -333,19 +333,18 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
z1 = self.process_tokens(tokens, multipliers)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
-
+
remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers
i += 1
-
+
return z
-
-
+
def process_tokens(self, remade_batch_tokens, batch_multipliers):
if not opts.use_old_emphasis_implementation:
remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
-
+
tokens = torch.asarray(remade_batch_tokens).to(device)
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
@@ -385,8 +384,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes:
emb = embedding.vec
- emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
- tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
+ emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
+ tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
vecs.append(tensor)
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
new file mode 100644
index 00000000..fd92a335
--- /dev/null
+++ b/modules/sd_hijack_inpainting.py
@@ -0,0 +1,331 @@
+import torch
+
+from einops import repeat
+from omegaconf import ListConfig
+
+import ldm.models.diffusion.ddpm
+import ldm.models.diffusion.ddim
+import ldm.models.diffusion.plms
+
+from ldm.models.diffusion.ddpm import LatentDiffusion
+from ldm.models.diffusion.plms import PLMSSampler
+from ldm.models.diffusion.ddim import DDIMSampler, noise_like
+
+# =================================================================================================
+# Monkey patch DDIMSampler methods from RunwayML repo directly.
+# Adapted from:
+# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py
+# =================================================================================================
+@torch.no_grad()
+def sample_ddim(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ ctmp = conditioning[list(conditioning.keys())[0]]
+ while isinstance(ctmp, list):
+ ctmp = ctmp[0]
+ cbs = ctmp.shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
+
+ samples, intermediates = self.ddim_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ )
+ return samples, intermediates
+
+@torch.no_grad()
+def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None):
+ b, *_, device = *x.shape, x.device
+
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ e_t = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ if isinstance(c, dict):
+ assert isinstance(unconditional_conditioning, dict)
+ c_in = dict()
+ for k in c:
+ if isinstance(c[k], list):
+ c_in[k] = [
+ torch.cat([unconditional_conditioning[k][i], c[k][i]])
+ for i in range(len(c[k]))
+ ]
+ else:
+ c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
+ else:
+ c_in = torch.cat([unconditional_conditioning, c])
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+
+# =================================================================================================
+# Monkey patch PLMSSampler methods.
+# This one was not actually patched correctly in the RunwayML repo, but we can replicate the changes.
+# Adapted from:
+# https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/plms.py
+# =================================================================================================
+@torch.no_grad()
+def sample_plms(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ ctmp = conditioning[list(conditioning.keys())[0]]
+ while isinstance(ctmp, list):
+ ctmp = ctmp[0]
+ cbs = ctmp.shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for PLMS sampling is {size}')
+
+ samples, intermediates = self.plms_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ )
+ return samples, intermediates
+
+
+@torch.no_grad()
+def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
+ b, *_, device = *x.shape, x.device
+
+ def get_model_output(x, t):
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ e_t = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+
+ if isinstance(c, dict):
+ assert isinstance(unconditional_conditioning, dict)
+ c_in = dict()
+ for k in c:
+ if isinstance(c[k], list):
+ c_in[k] = [
+ torch.cat([unconditional_conditioning[k][i], c[k][i]])
+ for i in range(len(c[k]))
+ ]
+ else:
+ c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
+ else:
+ c_in = torch.cat([unconditional_conditioning, c])
+
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ return e_t
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+
+ def get_x_prev_and_pred_x0(e_t, index):
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ e_t = get_model_output(x, t)
+ if len(old_eps) == 0:
+ # Pseudo Improved Euler (2nd order)
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
+ e_t_next = get_model_output(x_prev, t_next)
+ e_t_prime = (e_t + e_t_next) / 2
+ elif len(old_eps) == 1:
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
+ elif len(old_eps) == 2:
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
+ elif len(old_eps) >= 3:
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
+
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
+
+ return x_prev, pred_x0, e_t
+
+# =================================================================================================
+# Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config.
+# Adapted from:
+# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py
+# =================================================================================================
+
+@torch.no_grad()
+def get_unconditional_conditioning(self, batch_size, null_label=None):
+ if null_label is not None:
+ xc = null_label
+ if isinstance(xc, ListConfig):
+ xc = list(xc)
+ if isinstance(xc, dict) or isinstance(xc, list):
+ c = self.get_learned_conditioning(xc)
+ else:
+ if hasattr(xc, "to"):
+ xc = xc.to(self.device)
+ c = self.get_learned_conditioning(xc)
+ else:
+ # todo: get null label from cond_stage_model
+ raise NotImplementedError()
+ c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
+ return c
+
+
+class LatentInpaintDiffusion(LatentDiffusion):
+ def __init__(
+ self,
+ concat_keys=("mask", "masked_image"),
+ masked_image_key="masked_image",
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ self.masked_image_key = masked_image_key
+ assert self.masked_image_key in concat_keys
+ self.concat_keys = concat_keys
+
+
+def should_hijack_inpainting(checkpoint_info):
+ return str(checkpoint_info.filename).endswith("inpainting.ckpt") and not checkpoint_info.config.endswith("inpainting.yaml")
+
+
+def do_inpainting_hijack():
+ ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
+ ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
+
+ ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
+ ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
+
+ ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
+ ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms \ No newline at end of file
diff --git a/modules/sd_models.py b/modules/sd_models.py
index eae22e87..e697bb72 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -7,8 +7,9 @@ from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
-from modules import shared, modelloader, devices
+from modules import shared, modelloader, devices, script_callbacks
from modules.paths import models_path
+from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
@@ -20,7 +21,7 @@ checkpoints_loaded = collections.OrderedDict()
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
- from transformers import logging
+ from transformers import logging, CLIPModel
logging.set_verbosity_error()
except Exception:
@@ -154,6 +155,9 @@ def get_state_dict_from_checkpoint(pl_sd):
return pl_sd
+vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
+
+
def load_model_weights(model, checkpoint_info):
checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash
@@ -185,7 +189,7 @@ def load_model_weights(model, checkpoint_info):
if os.path.exists(vae_file):
print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
- vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
+ vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
model.first_stage_model.load_state_dict(vae_dict)
model.first_stage_model.to(devices.dtype_vae)
@@ -203,14 +207,26 @@ def load_model_weights(model, checkpoint_info):
model.sd_checkpoint_info = checkpoint_info
-def load_model():
+def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack
- checkpoint_info = select_checkpoint()
+ checkpoint_info = checkpoint_info or select_checkpoint()
if checkpoint_info.config != shared.cmd_opts.config:
print(f"Loading config from: {checkpoint_info.config}")
sd_config = OmegaConf.load(checkpoint_info.config)
+
+ if should_hijack_inpainting(checkpoint_info):
+ # Hardcoded config for now...
+ sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
+ sd_config.model.params.use_ema = False
+ sd_config.model.params.conditioning_key = "hybrid"
+ sd_config.model.params.unet_config.params.in_channels = 9
+
+ # Create a "fake" config with a different name so that we know to unload it when switching models.
+ checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
+
+ do_inpainting_hijack()
sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info)
@@ -222,6 +238,9 @@ def load_model():
sd_hijack.model_hijack.hijack(sd_model)
sd_model.eval()
+ shared.sd_model = sd_model
+
+ script_callbacks.model_loaded_callback(sd_model)
print(f"Model loaded.")
return sd_model
@@ -234,9 +253,9 @@ def reload_model_weights(sd_model, info=None):
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
- if sd_model.sd_checkpoint_info.config != checkpoint_info.config:
+ if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
checkpoints_loaded.clear()
- shared.sd_model = load_model()
+ load_model(checkpoint_info)
return shared.sd_model
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
@@ -249,6 +268,7 @@ def reload_model_weights(sd_model, info=None):
load_model_weights(sd_model, checkpoint_info)
sd_hijack.model_hijack.hijack(sd_model)
+ script_callbacks.model_loaded_callback(sd_model)
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device)
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index b58e810b..3670b57d 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -7,7 +7,7 @@ import inspect
import k_diffusion.sampling
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
-from modules import prompt_parser, devices, processing
+from modules import prompt_parser, devices, processing, images
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -71,6 +71,7 @@ sampler_extra_params = {
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
}
+
def setup_img2img_steps(p, steps=None):
if opts.img2img_fix_steps or steps is not None:
steps = int((steps or p.steps) / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
@@ -82,14 +83,22 @@ def setup_img2img_steps(p, steps=None):
return steps, t_enc
-def sample_to_image(samples):
- x_sample = processing.decode_first_stage(shared.sd_model, samples[0:1])[0]
+def single_sample_to_image(sample):
+ x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
return Image.fromarray(x_sample)
+def sample_to_image(samples):
+ return single_sample_to_image(samples[0])
+
+
+def samples_to_image_grid(samples):
+ return images.image_grid([single_sample_to_image(sample) for sample in samples])
+
+
def store_latent(decoded):
state.current_latent = decoded
@@ -117,6 +126,8 @@ class VanillaStableDiffusionSampler:
self.config = None
self.last_latent = None
+ self.conditioning_key = sd_model.model.conditioning_key
+
def number_of_needed_noises(self, p):
return 0
@@ -136,6 +147,12 @@ class VanillaStableDiffusionSampler:
if self.stop_at is not None and self.step > self.stop_at:
raise InterruptedException
+ # Have to unwrap the inpainting conditioning here to perform pre-processing
+ image_conditioning = None
+ if isinstance(cond, dict):
+ image_conditioning = cond["c_concat"][0]
+ cond = cond["c_crossattn"][0]
+ unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
@@ -157,6 +174,12 @@ class VanillaStableDiffusionSampler:
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
x_dec = img_orig * self.mask + self.nmask * x_dec
+ # Wrap the image conditioning back up since the DDIM code can accept the dict directly.
+ # Note that they need to be lists because it just concatenates them later.
+ if image_conditioning is not None:
+ cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
+ unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
+
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
if self.mask is not None:
@@ -182,7 +205,7 @@ class VanillaStableDiffusionSampler:
self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None
- def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps)
self.initialize(p)
@@ -196,20 +219,33 @@ class VanillaStableDiffusionSampler:
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
self.init_latent = x
+ self.last_latent = x
self.step = 0
- samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
+ # Wrap the conditioning models with additional image conditioning for inpainting model
+ if image_conditioning is not None:
+ conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
+ unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
+
+
+ samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
return samples
- def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
self.initialize(p)
self.init_latent = None
+ self.last_latent = x
self.step = 0
steps = steps or p.steps
+ # Wrap the conditioning models with additional image conditioning for inpainting model
+ if image_conditioning is not None:
+ conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
+ unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
+
# existing code fails with certain step counts, like 9
try:
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
@@ -228,7 +264,7 @@ class CFGDenoiser(torch.nn.Module):
self.init_latent = None
self.step = 0
- def forward(self, x, sigma, uncond, cond, cond_scale):
+ def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
if state.interrupted or state.skipped:
raise InterruptedException
@@ -239,28 +275,29 @@ class CFGDenoiser(torch.nn.Module):
repeats = [len(conds_list[i]) for i in range(batch_size)]
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
+ image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
if tensor.shape[1] == uncond.shape[1]:
cond_in = torch.cat([tensor, uncond])
if shared.batch_cond_uncond:
- x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
+ x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
else:
x_out = torch.zeros_like(x_in)
for batch_offset in range(0, x_out.shape[0], batch_size):
a = batch_offset
b = a + batch_size
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
else:
x_out = torch.zeros_like(x_in)
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
for batch_offset in range(0, tensor.shape[0], batch_size):
a = batch_offset
b = min(a + batch_size, tensor.shape[0])
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=tensor[a:b])
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
- x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=uncond)
+ x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)
@@ -306,6 +343,8 @@ class KDiffusionSampler:
self.config = None
self.last_latent = None
+ self.conditioning_key = sd_model.model.conditioning_key
+
def callback_state(self, d):
step = d['i']
latent = d["denoised"]
@@ -361,7 +400,7 @@ class KDiffusionSampler:
return extra_params_kwargs
- def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps)
if p.sampler_noise_scheduler_override:
@@ -388,12 +427,18 @@ class KDiffusionSampler:
extra_params_kwargs['sigmas'] = sigma_sched
self.model_wrap_cfg.init_latent = x
+ self.last_latent = x
- samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
+ samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
+ 'cond_scale': p.cfg_scale
+ }, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples
- def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
steps = steps or p.steps
if p.sampler_noise_scheduler_override:
@@ -414,7 +459,13 @@ class KDiffusionSampler:
else:
extra_params_kwargs['sigmas'] = sigmas
- samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
+ self.last_latent = x
+ samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
+ 'cond_scale': p.cfg_scale
+ }, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples
diff --git a/modules/shared.py b/modules/shared.py
index f7d66870..308fccce 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -3,6 +3,7 @@ import datetime
import json
import os
import sys
+from collections import OrderedDict
import gradio as gr
import tqdm
@@ -57,12 +58,13 @@ parser.add_argument("--opt-split-attention", action='store_true', help="force-en
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
-parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower)
+parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json'))
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
+parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
@@ -78,10 +80,13 @@ parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencode
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui")
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
+parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
+parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
cmd_opts = parser.parse_args()
restricted_opts = [
"samples_filename_pattern",
+ "directories_filename_pattern",
"outdir_samples",
"outdir_txt2img_samples",
"outdir_img2img_samples",
@@ -91,8 +96,8 @@ restricted_opts = [
"outdir_save",
]
-devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
-(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'])
+devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
+(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'])
device = devices.device
weight_load_location = None if cmd_opts.lowram else "cpu"
@@ -137,7 +142,7 @@ class State:
self.job_no += 1
self.sampling_step = 0
self.current_image_sampling_step = 0
-
+
def get_job_timestamp(self):
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
@@ -162,13 +167,13 @@ def realesrgan_models_names():
class OptionInfo:
- def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False, refresh=None):
+ def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None):
self.default = default
self.label = label
self.component = component
self.component_args = component_args
self.onchange = onchange
- self.section = None
+ self.section = section
self.refresh = refresh
@@ -186,7 +191,8 @@ options_templates = {}
options_templates.update(options_section(('saving-images', "Saving images/grids"), {
"samples_save": OptionInfo(True, "Always save all generated images"),
"samples_format": OptionInfo('png', 'File format for images'),
- "samples_filename_pattern": OptionInfo("", "Images filename pattern"),
+ "samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs),
+ "save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
"grid_save": OptionInfo(True, "Always save all generated image grids"),
"grid_format": OptionInfo('png', 'File format for grids'),
@@ -221,8 +227,8 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
"save_to_dirs": OptionInfo(False, "Save images to a subdirectory"),
"grid_save_to_dirs": OptionInfo(False, "Save grids to a subdirectory"),
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
- "directories_filename_pattern": OptionInfo("", "Directory name pattern"),
- "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}),
+ "directories_filename_pattern": OptionInfo("", "Directory name pattern", component_args=hide_dirs),
+ "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
}))
options_templates.update(options_section(('upscaling', "Upscaling"), {
@@ -249,7 +255,7 @@ options_templates.update(options_section(('system', "System"), {
}))
options_templates.update(options_section(('training', "Training"), {
- "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"),
+ "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
@@ -291,6 +297,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
options_templates.update(options_section(('ui', "User interface"), {
"show_progressbar": OptionInfo(True, "Show progressbar"),
"show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
+ "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
@@ -343,7 +350,7 @@ class Options:
def save(self, filename):
with open(filename, "w", encoding="utf8") as file:
- json.dump(self.data, file)
+ json.dump(self.data, file, indent=4)
def same_type(self, x, y):
if x is None or y is None:
@@ -378,6 +385,20 @@ class Options:
d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
return json.dumps(d)
+ def add_option(self, key, info):
+ self.data_labels[key] = info
+
+ def reorder(self):
+ """reorder settings so that all items related to section always go together"""
+
+ section_ids = {}
+ settings_items = self.data_labels.items()
+ for k, item in settings_items:
+ if item.section not in section_ids:
+ section_ids[item.section] = len(section_ids)
+
+ self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])}
+
opts = Options()
if os.path.exists(config_filename):
@@ -387,6 +408,8 @@ sd_upscalers = []
sd_model = None
+clip_model = None
+
progress_print_out = sys.stdout
diff --git a/modules/swinir_model.py b/modules/swinir_model.py
index baa02e3d..4253b66d 100644
--- a/modules/swinir_model.py
+++ b/modules/swinir_model.py
@@ -7,8 +7,8 @@ from PIL import Image
from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm
-from modules import modelloader
-from modules.shared import cmd_opts, opts, device
+from modules import modelloader, devices
+from modules.shared import cmd_opts, opts
from modules.swinir_model_arch import SwinIR as net
from modules.swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData
@@ -42,7 +42,7 @@ class UpscalerSwinIR(Upscaler):
model = self.load_model(model_file)
if model is None:
return img
- model = model.to(device)
+ model = model.to(devices.device_swinir)
img = upscale(img, model)
try:
torch.cuda.empty_cache()
@@ -111,7 +111,7 @@ def upscale(
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(device)
+ img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_swinir)
with torch.no_grad(), precision_scope("cuda"):
_, _, h_old, w_old = img.size()
h_pad = (h_old // window_size + 1) * window_size - h_old
@@ -139,8 +139,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
- E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
- W = torch.zeros_like(E, dtype=torch.half, device=device)
+ E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=devices.device_swinir).type_as(img)
+ W = torch.zeros_like(E, dtype=torch.half, device=devices.device_swinir)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list:
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 23bb4b6a..5b1c5002 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -83,7 +83,7 @@ class PersonalizedBase(Dataset):
self.dataset.append(entry)
- assert len(self.dataset) > 1, "No images have been found in the dataset."
+ assert len(self.dataset) > 0, "No images have been found in the dataset."
self.length = len(self.dataset) * repeats // batch_size
self.initial_indexes = np.arange(len(self.dataset))
@@ -91,7 +91,7 @@ class PersonalizedBase(Dataset):
self.shuffle()
def shuffle(self):
- self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
+ self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()]
def create_text(self, filename_text):
text = random.choice(self.lines)
diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py
index 898ce3b3..ea653806 100644
--- a/modules/textual_inversion/image_embedding.py
+++ b/modules/textual_inversion/image_embedding.py
@@ -5,6 +5,7 @@ import zlib
from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
from fonts.ttf import Roboto
import torch
+from modules.shared import opts
class EmbeddingEncoder(json.JSONEncoder):
@@ -133,7 +134,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
from math import cos
image = srcimage.copy()
-
+ fontsize = 32
if textfont is None:
try:
textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
@@ -150,7 +151,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
draw = ImageDraw.Draw(image)
- fontsize = 32
+
font = ImageFont.truetype(textfont, fontsize)
padding = 10
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 0c79f012..a8c17c6f 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -1,5 +1,6 @@
import os
from PIL import Image, ImageOps
+import math
import platform
import sys
import tqdm
@@ -12,7 +13,7 @@ if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
-def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False, process_entropy_focus=False):
+def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_entropy_focus=False):
try:
if process_caption:
shared.interrogator.load()
@@ -22,7 +23,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
- preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru, process_entropy_focus)
+ preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_entropy_focus)
finally:
@@ -34,11 +35,13 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
-def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False, process_entropy_focus=False):
+def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_entropy_focus=False):
width = process_width
height = process_height
src = os.path.abspath(process_src)
dst = os.path.abspath(process_dst)
+ split_threshold = max(0.0, min(1.0, split_threshold))
+ overlap_ratio = max(0.0, min(0.9, overlap_ratio))
assert src != dst, 'same directory specified as source and destination'
@@ -49,7 +52,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)
- def save_pic_with_caption(image, index):
+ def save_pic_with_caption(image, index, existing_caption=None):
caption = ""
if process_caption:
@@ -67,17 +70,49 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
basename = f"{index:05}-{subindex[0]}-{filename_part}"
image.save(os.path.join(dst, f"{basename}.png"))
+ if preprocess_txt_action == 'prepend' and existing_caption:
+ caption = existing_caption + ' ' + caption
+ elif preprocess_txt_action == 'append' and existing_caption:
+ caption = caption + ' ' + existing_caption
+ elif preprocess_txt_action == 'copy' and existing_caption:
+ caption = existing_caption
+
+ caption = caption.strip()
+
if len(caption) > 0:
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
file.write(caption)
subindex[0] += 1
- def save_pic(image, index):
- save_pic_with_caption(image, index)
+ def save_pic(image, index, existing_caption=None):
+ save_pic_with_caption(image, index, existing_caption=existing_caption)
if process_flip:
- save_pic_with_caption(ImageOps.mirror(image), index)
+ save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption)
+
+ def split_pic(image, inverse_xy):
+ if inverse_xy:
+ from_w, from_h = image.height, image.width
+ to_w, to_h = height, width
+ else:
+ from_w, from_h = image.width, image.height
+ to_w, to_h = width, height
+ h = from_h * to_w // from_w
+ if inverse_xy:
+ image = image.resize((h, to_w))
+ else:
+ image = image.resize((to_w, h))
+
+ split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
+ y_step = (h - to_h) / (split_count - 1)
+ for i in range(split_count):
+ y = int(y_step * i)
+ if inverse_xy:
+ splitted = image.crop((y, 0, y + to_h, to_w))
+ else:
+ splitted = image.crop((0, y, to_w, y + to_h))
+ yield splitted
for index, imagefile in enumerate(tqdm.tqdm(files)):
@@ -88,34 +123,27 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
except Exception:
continue
+ existing_caption = None
+ existing_caption_filename = os.path.splitext(filename)[0] + '.txt'
+ if os.path.exists(existing_caption_filename):
+ with open(existing_caption_filename, 'r', encoding="utf8") as file:
+ existing_caption = file.read()
+
if shared.state.interrupted:
break
- ratio = img.height / img.width
- is_tall = ratio > 1.35
- is_wide = ratio < 1 / 1.35
+ if img.height > img.width:
+ ratio = (img.width * height) / (img.height * width)
+ inverse_xy = False
+ else:
+ ratio = (img.height * width) / (img.width * height)
+ inverse_xy = True
processing_option_ran = False
- if process_split and is_tall:
- img = img.resize((width, height * img.height // img.width))
-
- top = img.crop((0, 0, width, height))
- save_pic(top, index)
-
- bot = img.crop((0, img.height - height, width, img.height))
- save_pic(bot, index)
-
- processing_option_ran = True
- elif process_split and is_wide:
- img = img.resize((width * img.width // img.height, height))
-
- left = img.crop((0, 0, width, height))
- save_pic(left, index)
-
- right = img.crop((img.width - width, 0, img.width, height))
- save_pic(right, index)
-
+ if process_split and ratio < 1.0 and ratio <= split_threshold:
+ for splitted in split_pic(img, inverse_xy):
+ save_pic(splitted, index, existing_caption=existing_caption)
processing_option_ran = True
if process_entropy_focus and img.height != img.width:
@@ -128,12 +156,11 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
annotate_image = False
)
focal = autocrop.crop_image(img, autocrop_settings)
- save_pic(focal, index)
-
+ save_pic(focal, index, existing_caption=existing_caption)
processing_option_ran = True
if not processing_option_ran:
img = images.resize_image(1, img, width, height)
- save_pic(img, index)
+ save_pic(img, index, existing_caption=existing_caption)
shared.state.nextjob() \ No newline at end of file
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 3be69562..529ed3e2 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -153,7 +153,7 @@ class EmbeddingDatabase:
return None, None
-def create_embedding(name, num_vectors_per_token, init_text='*'):
+def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
@@ -165,7 +165,8 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
- assert not os.path.exists(fn), f"file {fn} already exists"
+ if not overwrite_old:
+ assert not os.path.exists(fn), f"file {fn} already exists"
embedding = Embedding(vec, name)
embedding.step = 0
@@ -275,6 +276,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
loss.backward()
optimizer.step()
+
epoch_num = embedding.step // len(ds)
epoch_step = embedding.step - (epoch_num * len(ds)) + 1
diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
index 36881e7a..e712284d 100644
--- a/modules/textual_inversion/ui.py
+++ b/modules/textual_inversion/ui.py
@@ -7,8 +7,8 @@ import modules.textual_inversion.preprocess
from modules import sd_hijack, shared
-def create_embedding(name, initialization_text, nvpt):
- filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text)
+def create_embedding(name, initialization_text, nvpt, overwrite_old):
+ filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 2381347f..c9d5a090 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -1,5 +1,6 @@
import modules.scripts
-from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
+from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
+ StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, cmd_opts
import modules.shared as shared
import modules.processing as processing
@@ -35,6 +36,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
firstphase_height=firstphase_height if enable_hr else None,
)
+ p.scripts = modules.scripts.scripts_txt2img
+ p.script_args = args
+
if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
@@ -53,4 +57,3 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
processed.images = []
return processed.images, generation_info_js, plaintext_to_html(processed.info)
-
diff --git a/modules/ui.py b/modules/ui.py
index b6be713b..028eb4e5 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -5,47 +5,56 @@ import json
import math
import mimetypes
import os
+import platform
import random
+import subprocess as sp
import sys
import tempfile
import time
import traceback
-import platform
-import subprocess as sp
from functools import partial, reduce
+import gradio as gr
+import gradio.routes
+import gradio.utils
import numpy as np
+import piexif
import torch
from PIL import Image, PngImagePlugin
-import piexif
import gradio as gr
import gradio.utils
import gradio.routes
-from modules import sd_hijack, sd_models, localization
+from modules import sd_hijack, sd_models, localization, script_callbacks
from modules.paths import script_path
+
from modules.shared import opts, cmd_opts, restricted_opts
+
if cmd_opts.deepdanbooru:
from modules.deepbooru import get_deepbooru_tags
-import modules.shared as shared
-from modules.sd_samplers import samplers, samplers_for_img2img
-from modules.sd_hijack import model_hijack
+
+import modules.codeformer_model
+import modules.generation_parameters_copypaste
+import modules.gfpgan_model
+import modules.hypernetworks.ui
import modules.ldsr_model
import modules.scripts
-import modules.gfpgan_model
-import modules.codeformer_model
+import modules.shared as shared
import modules.styles
-import modules.generation_parameters_copypaste
+import modules.textual_inversion.ui
from modules import prompt_parser
from modules.images import save_image
+from modules.sd_hijack import model_hijack
+from modules.sd_samplers import samplers, samplers_for_img2img
import modules.textual_inversion.ui
import modules.hypernetworks.ui
-import modules.images_history as img_his
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')
+txt2img_paste_fields = []
+img2img_paste_fields = []
if not cmd_opts.share and not cmd_opts.listen:
@@ -268,8 +277,13 @@ def calc_time_left(progress, threshold, label, force_display):
time_since_start = time.time() - shared.state.time_start
eta = (time_since_start/progress)
eta_relative = eta-time_since_start
- if (eta_relative > threshold and progress > 0.02) or force_display:
- return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
+ if (eta_relative > threshold and progress > 0.02) or force_display:
+ if eta_relative > 3600:
+ return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
+ elif eta_relative > 60:
+ return label + time.strftime('%M:%S', time.gmtime(eta_relative))
+ else:
+ return label + time.strftime('%Ss', time.gmtime(eta_relative))
else:
return ""
@@ -285,7 +299,7 @@ def check_progress_call(id_part):
if shared.state.sampling_steps > 0:
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
- time_left = calc_time_left( progress, 60, " ETA:", shared.state.time_left_force_display )
+ time_left = calc_time_left( progress, 1, " ETA: ", shared.state.time_left_force_display )
if time_left != "":
shared.state.time_left_force_display = True
@@ -293,7 +307,7 @@ def check_progress_call(id_part):
progressbar = ""
if opts.show_progressbar:
- progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:hidden;width:{progress * 100}%">{str(int(progress*100))+"%"+time_left if progress > 0.01 else ""}</div></div>"""
+ progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:visible;width:{progress * 100}%;white-space:nowrap;">{"&nbsp;" * 2 + str(int(progress*100))+"%" + time_left if progress > 0.01 else ""}</div></div>"""
image = gr_show(False)
preview_visibility = gr_show(False)
@@ -302,7 +316,10 @@ def check_progress_call(id_part):
if shared.parallel_processing_allowed:
if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None:
- shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent)
+ if opts.show_progress_grid:
+ shared.state.current_image = modules.sd_samplers.samples_to_image_grid(shared.state.current_latent)
+ else:
+ shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent)
shared.state.current_image_sampling_step = shared.state.sampling_step
image = shared.state.current_image
@@ -477,14 +494,14 @@ def create_toprow(is_img2img):
with gr.Row():
with gr.Column(scale=80):
with gr.Row():
- prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
+ prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
)
with gr.Row():
with gr.Column(scale=80):
with gr.Row():
- negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
+ negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
)
@@ -561,6 +578,9 @@ def apply_setting(key, value):
if value is None:
return gr.update()
+ if shared.cmd_opts.freeze_settings:
+ return gr.update()
+
# dont allow model to be swapped when model hash exists in prompt
if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
return gr.update()
@@ -587,27 +607,29 @@ def apply_setting(key, value):
return value
-def create_ui(wrap_gradio_gpu_call):
- import modules.img2img
- import modules.txt2img
+def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
+ def refresh():
+ refresh_method()
+ args = refreshed_args() if callable(refreshed_args) else refreshed_args
- def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
- def refresh():
- refresh_method()
- args = refreshed_args() if callable(refreshed_args) else refreshed_args
+ for k, v in args.items():
+ setattr(refresh_component, k, v)
- for k, v in args.items():
- setattr(refresh_component, k, v)
+ return gr.update(**(args or {}))
- return gr.update(**(args or {}))
+ refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
+ refresh_button.click(
+ fn=refresh,
+ inputs=[],
+ outputs=[refresh_component]
+ )
+ return refresh_button
+
+
+def create_ui(wrap_gradio_gpu_call):
+ import modules.img2img
+ import modules.txt2img
- refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
- refresh_button.click(
- fn = refresh,
- inputs = [],
- outputs = [refresh_component]
- )
- return refresh_button
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
@@ -705,6 +727,7 @@ def create_ui(wrap_gradio_gpu_call):
firstphase_width,
firstphase_height,
] + custom_inputs,
+
outputs=[
txt2img_gallery,
generation_info,
@@ -761,6 +784,7 @@ def create_ui(wrap_gradio_gpu_call):
]
)
+ global txt2img_paste_fields
txt2img_paste_fields = [
(txt2img_prompt, "Prompt"),
(txt2img_negative_prompt, "Negative prompt"),
@@ -781,6 +805,7 @@ def create_ui(wrap_gradio_gpu_call):
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
(firstphase_width, "First pass size-1"),
(firstphase_height, "First pass size-2"),
+ *modules.scripts.scripts_txt2img.infotext_fields
]
txt2img_preview_params = [
@@ -848,8 +873,8 @@ def create_ui(wrap_gradio_gpu_call):
sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
with gr.Group():
- width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
- height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+ width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="img2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="img2img_height")
with gr.Row():
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
@@ -1030,6 +1055,7 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[prompt, negative_prompt, style1, style2],
)
+ global img2img_paste_fields
img2img_paste_fields = [
(img2img_prompt, "Prompt"),
(img2img_negative_prompt, "Negative prompt"),
@@ -1046,6 +1072,7 @@ def create_ui(wrap_gradio_gpu_call):
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
(denoising_strength, "Denoising strength"),
+ *modules.scripts.scripts_img2img.infotext_fields
]
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
@@ -1077,9 +1104,9 @@ def create_ui(wrap_gradio_gpu_call):
upscaling_resize_w = gr.Number(label="Width", value=512, precision=0)
upscaling_resize_h = gr.Number(label="Height", value=512, precision=0)
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True)
-
+
with gr.Group():
- extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
+ extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
with gr.Group():
extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
@@ -1166,15 +1193,7 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[image],
outputs=[html, generation_info, html2],
)
- #images history
- images_history_switch_dict = {
- "fn":modules.generation_parameters_copypaste.connect_paste,
- "t2i":txt2img_paste_fields,
- "i2i":img2img_paste_fields
- }
-
- images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
-
+
with gr.Blocks() as modelmerger_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
@@ -1206,6 +1225,7 @@ def create_ui(wrap_gradio_gpu_call):
new_embedding_name = gr.Textbox(label="Name")
initialization_text = gr.Textbox(label="Initialization text", value="*")
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
+ overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding")
with gr.Row():
with gr.Column(scale=3):
@@ -1217,6 +1237,11 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Tab(label="Create hypernetwork"):
new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
+ new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
+ new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu", "elu", "swish"])
+ new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
+ new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
+ overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
with gr.Row():
with gr.Column(scale=3):
@@ -1230,14 +1255,19 @@ def create_ui(wrap_gradio_gpu_call):
process_dst = gr.Textbox(label='Destination directory')
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
+ preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"])
with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies')
- process_split = gr.Checkbox(label='Split oversized images into two')
+ process_split = gr.Checkbox(label='Split oversized images')
process_entropy_focus = gr.Checkbox(label='Create auto focal point crop')
process_caption = gr.Checkbox(label='Use BLIP for caption')
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False)
+ with gr.Row(visible=False) as process_split_extra_row:
+ process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
+ process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05)
+
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
@@ -1245,15 +1275,24 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column():
run_preprocess = gr.Button(value="Preprocess", variant='primary')
+ process_split.change(
+ fn=lambda show: gr_show(show),
+ inputs=[process_split],
+ outputs=[process_split_extra_row],
+ )
+
with gr.Tab(label="Train"):
- gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
with gr.Row():
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
with gr.Row():
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
- learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
+ with gr.Row():
+ embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005")
+ hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
+
batch_size = gr.Number(label='Batch size', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
@@ -1287,6 +1326,7 @@ def create_ui(wrap_gradio_gpu_call):
new_embedding_name,
initialization_text,
nvpt,
+ overwrite_old_embedding,
],
outputs=[
train_embedding_name,
@@ -1300,6 +1340,11 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[
new_hypernetwork_name,
new_hypernetwork_sizes,
+ overwrite_old_hypernetwork,
+ new_hypernetwork_layer_structure,
+ new_hypernetwork_activation_func,
+ new_hypernetwork_add_layer_norm,
+ new_hypernetwork_use_dropout
],
outputs=[
train_hypernetwork_name,
@@ -1316,11 +1361,14 @@ def create_ui(wrap_gradio_gpu_call):
process_dst,
process_width,
process_height,
+ preprocess_txt_action,
process_flip,
process_split,
process_caption,
process_caption_deepbooru,
- process_entropy_focus
+ process_split_threshold,
+ process_overlap_ratio,
+ process_entropy_focus,
],
outputs=[
ti_output,
@@ -1333,7 +1381,7 @@ def create_ui(wrap_gradio_gpu_call):
_js="start_training_textual_inversion",
inputs=[
train_embedding_name,
- learn_rate,
+ embedding_learn_rate,
batch_size,
dataset_directory,
log_directory,
@@ -1358,7 +1406,7 @@ def create_ui(wrap_gradio_gpu_call):
_js="start_training_textual_inversion",
inputs=[
train_hypernetwork_name,
- learn_rate,
+ hypernetwork_learn_rate,
batch_size,
dataset_directory,
log_directory,
@@ -1422,6 +1470,9 @@ def create_ui(wrap_gradio_gpu_call):
components = []
component_dict = {}
+ script_callbacks.ui_settings_callback()
+ opts.reorder()
+
def open_folder(f):
if not os.path.exists(f):
print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
@@ -1447,6 +1498,8 @@ Requested path was: {f}
def run_settings(*args):
changed = 0
+ assert not shared.cmd_opts.freeze_settings, "changing settings is disabled"
+
for key, value, comp in zip(opts.data_labels.keys(), args, components):
if comp != dummy_component and not opts.same_type(value, opts.data_labels[key].default):
return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}", opts.dumpjson()
@@ -1476,13 +1529,15 @@ Requested path was: {f}
return f'{changed} settings changed.', opts.dumpjson()
def run_settings_single(value, key):
+ assert not shared.cmd_opts.freeze_settings, "changing settings is disabled"
+
if not opts.same_type(value, opts.data_labels[key].default):
return gr.update(visible=True), opts.dumpjson()
+ oldval = opts.data.get(key, None)
if cmd_opts.hide_ui_dir_config and key in restricted_opts:
return gr.update(value=oldval), opts.dumpjson()
- oldval = opts.data.get(key, None)
opts.data[key] = value
if oldval != value:
@@ -1525,9 +1580,10 @@ Requested path was: {f}
previous_section = item.section
- gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='<h1 class="gr-button-lg">{}</h1>'.format(item.section[1]))
+ elem_id, text = item.section
+ gr.HTML(elem_id="settings_header_text_{}".format(elem_id), value='<h1 class="gr-button-lg">{}</h1>'.format(text))
- if k in quicksettings_names:
+ if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
quicksettings_list.append((i, k, item))
components.append(dummy_component)
else:
@@ -1560,7 +1616,7 @@ Requested path was: {f}
def reload_scripts():
modules.scripts.reload_script_body_only()
- reload_javascript() # need to refresh the html page
+ reload_javascript() # need to refresh the html page
reload_script_bodies.click(
fn=reload_scripts,
@@ -1588,19 +1644,26 @@ Requested path was: {f}
(img2img_interface, "img2img", "img2img"),
(extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
- (images_history, "History", "images_history"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
(train_interface, "Train", "ti"),
- (settings_interface, "Settings", "settings"),
]
- with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file:
- css = file.read()
+ interfaces += script_callbacks.ui_tabs_callback()
+
+ interfaces += [(settings_interface, "Settings", "settings")]
+
+ css = ""
+
+ for cssfile in modules.scripts.list_files_with_name("style.css"):
+ if not os.path.isfile(cssfile):
+ continue
+
+ with open(cssfile, "r", encoding="utf8") as file:
+ css += file.read() + "\n"
if os.path.exists(os.path.join(script_path, "user.css")):
with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file:
- usercss = file.read()
- css += usercss
+ css += file.read() + "\n"
if not cmd_opts.no_progressbar_hiding:
css += css_hide_progressbar
@@ -1823,9 +1886,10 @@ def load_javascript(raw_response):
with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
javascript = f'<script>{jsfile.read()}</script>'
- jsdir = os.path.join(script_path, "javascript")
- for filename in sorted(os.listdir(jsdir)):
- with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile:
+ scripts_list = modules.scripts.list_scripts("javascript", ".js")
+
+ for basedir, filename, path in scripts_list:
+ with open(path, "r", encoding="utf8") as jsfile:
javascript += f"\n<!-- {filename} --><script>{jsfile.read()}</script>"
if cmd_opts.theme is not None:
@@ -1843,6 +1907,5 @@ def load_javascript(raw_response):
gradio.routes.templates.TemplateResponse = template_response
-reload_javascript = partial(load_javascript,
- gradio.routes.templates.TemplateResponse)
+reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse)
reload_javascript()