aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorDynamic <bradje@naver.com>2022-10-25 18:27:32 +0900
committerGitHub <noreply@github.com>2022-10-25 18:27:32 +0900
commit563fb0aa39faca32187e78c07bec695531f21f39 (patch)
treee8ba5b699b256ce90a07c52c52051e504f601659 /modules
parente595b41c9d8a596b9b29d9505324e9afca2f12b5 (diff)
parent3e15f8e0f5cc87507f77546d92435670644dbd18 (diff)
Merge branch 'AUTOMATIC1111:master' into kr-localization
Diffstat (limited to 'modules')
-rw-r--r--modules/devices.py6
-rw-r--r--modules/esrgan_model.py2
-rw-r--r--modules/images.py7
-rw-r--r--modules/script_callbacks.py67
-rw-r--r--modules/scunet_model.py3
-rw-r--r--modules/sd_samplers.py4
-rw-r--r--modules/shared.py6
-rw-r--r--modules/swinir_model.py12
8 files changed, 80 insertions, 27 deletions
diff --git a/modules/devices.py b/modules/devices.py
index dc1f3cdd..7511e1dc 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -45,7 +45,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
-device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = None
+device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None
dtype = torch.float16
dtype_vae = torch.float16
@@ -81,3 +81,7 @@ def autocast(disable=False):
return contextlib.nullcontext()
return torch.autocast("cuda")
+
+# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
+def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor
+def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device)
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index a49e2258..a13cf6ac 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -190,7 +190,7 @@ def upscale_without_tiling(model, img):
img = img[:, :, ::-1]
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(devices.device_esrgan)
+ img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_esrgan)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
diff --git a/modules/images.py b/modules/images.py
index 9a8fe3ed..286de2ae 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -16,7 +16,7 @@ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
from fonts.ttf import Roboto
import string
-from modules import sd_samplers, shared
+from modules import sd_samplers, shared, script_callbacks
from modules.shared import opts, cmd_opts
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
@@ -477,8 +477,10 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
if forced_filename is None:
if short_filename or seed is None:
file_decoration = ""
- else:
+ elif opts.save_to_dirs:
file_decoration = opts.samples_filename_pattern or "[seed]"
+ else:
+ file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
add_number = opts.save_images_add_number or file_decoration == ''
@@ -539,6 +541,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
else:
txt_fullfn = None
+ script_callbacks.image_saved_callback(image, p, fullfn, txt_fullfn)
return fullfn, txt_fullfn
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index f46d3d9a..dc520abc 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -1,37 +1,74 @@
+import sys
+import traceback
+from collections import namedtuple
+import inspect
+
+def report_exception(c, job):
+ print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+
+ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callbacks_model_loaded = []
callbacks_ui_tabs = []
callbacks_ui_settings = []
-
+callbacks_image_saved = []
def clear_callbacks():
callbacks_model_loaded.clear()
callbacks_ui_tabs.clear()
+ callbacks_image_saved.clear()
def model_loaded_callback(sd_model):
- for callback in callbacks_model_loaded:
- callback(sd_model)
+ for c in callbacks_model_loaded:
+ try:
+ c.callback(sd_model)
+ except Exception:
+ report_exception(c, 'model_loaded_callback')
def ui_tabs_callback():
res = []
- for callback in callbacks_ui_tabs:
- res += callback() or []
+ for c in callbacks_ui_tabs:
+ try:
+ res += c.callback() or []
+ except Exception:
+ report_exception(c, 'ui_tabs_callback')
return res
def ui_settings_callback():
- for callback in callbacks_ui_settings:
- callback()
+ for c in callbacks_ui_settings:
+ try:
+ c.callback()
+ except Exception:
+ report_exception(c, 'ui_settings_callback')
+
+
+def image_saved_callback(image, p, fullfn, txt_fullfn):
+ for c in callbacks_image_saved:
+ try:
+ c.callback(image, p, fullfn, txt_fullfn)
+ except Exception:
+ report_exception(c, 'image_saved_callback')
+
+
+def add_callback(callbacks, fun):
+ stack = [x for x in inspect.stack() if x.filename != __file__]
+ filename = stack[0].filename if len(stack) > 0 else 'unknown file'
+
+ callbacks.append(ScriptCallback(filename, fun))
+
def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is
passed as an argument"""
- callbacks_model_loaded.append(callback)
+ add_callback(callbacks_model_loaded, callback)
def on_ui_tabs(callback):
@@ -44,10 +81,20 @@ def on_ui_tabs(callback):
title is tab text displayed to user in the UI
elem_id is HTML id for the tab
"""
- callbacks_ui_tabs.append(callback)
+ add_callback(callbacks_ui_tabs, callback)
def on_ui_settings(callback):
"""register a function to be called before UI settings are populated; add your settings
by using shared.opts.add_option(shared.OptionInfo(...)) """
- callbacks_ui_settings.append(callback)
+ add_callback(callbacks_ui_settings, callback)
+
+
+def on_save_imaged(callback):
+ """register a function to be called after modules.images.save_image is called.
+ The callback is called with three arguments:
+ - p - procesing object (or a dummy object with same fields if the image is saved using save button)
+ - fullfn - image filename
+ - txt_fullfn - text file with parameters; may be None
+ """
+ add_callback(callbacks_image_saved, callback)
diff --git a/modules/scunet_model.py b/modules/scunet_model.py
index 36a996bf..59532274 100644
--- a/modules/scunet_model.py
+++ b/modules/scunet_model.py
@@ -54,9 +54,8 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(device)
+ img = devices.mps_contiguous_to(img.unsqueeze(0), device)
- img = img.to(device)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 0b408a70..3670b57d 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -228,7 +228,7 @@ class VanillaStableDiffusionSampler:
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
- samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
+ samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
return samples
@@ -429,7 +429,7 @@ class KDiffusionSampler:
self.model_wrap_cfg.init_latent = x
self.last_latent = x
- samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
+ samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
diff --git a/modules/shared.py b/modules/shared.py
index 76cbb1bd..308fccce 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -58,7 +58,7 @@ parser.add_argument("--opt-split-attention", action='store_true', help="force-en
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
-parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower)
+parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
@@ -96,8 +96,8 @@ restricted_opts = [
"outdir_save",
]
-devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
-(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'])
+devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
+(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'])
device = devices.device
weight_load_location = None if cmd_opts.lowram else "cpu"
diff --git a/modules/swinir_model.py b/modules/swinir_model.py
index baa02e3d..4253b66d 100644
--- a/modules/swinir_model.py
+++ b/modules/swinir_model.py
@@ -7,8 +7,8 @@ from PIL import Image
from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm
-from modules import modelloader
-from modules.shared import cmd_opts, opts, device
+from modules import modelloader, devices
+from modules.shared import cmd_opts, opts
from modules.swinir_model_arch import SwinIR as net
from modules.swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData
@@ -42,7 +42,7 @@ class UpscalerSwinIR(Upscaler):
model = self.load_model(model_file)
if model is None:
return img
- model = model.to(device)
+ model = model.to(devices.device_swinir)
img = upscale(img, model)
try:
torch.cuda.empty_cache()
@@ -111,7 +111,7 @@ def upscale(
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(device)
+ img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_swinir)
with torch.no_grad(), precision_scope("cuda"):
_, _, h_old, w_old = img.size()
h_pad = (h_old // window_size + 1) * window_size - h_old
@@ -139,8 +139,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
- E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
- W = torch.zeros_like(E, dtype=torch.half, device=device)
+ E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=devices.device_swinir).type_as(img)
+ W = torch.zeros_like(E, dtype=torch.half, device=devices.device_swinir)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list: