aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.github/workflows/run_tests.yaml2
-rw-r--r--modules/api/api.py28
-rw-r--r--modules/face_restoration_utils.py39
-rw-r--r--modules/initialize.py3
-rw-r--r--modules/modelloader.py54
-rw-r--r--modules/realesrgan_model.py2
-rw-r--r--modules/textual_inversion/textual_inversion.py10
-rw-r--r--modules/upscaler_utils.py6
-rw-r--r--requirements.txt1
-rw-r--r--requirements_versions.txt1
10 files changed, 63 insertions, 83 deletions
diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml
index cd5c3f86..f42e4758 100644
--- a/.github/workflows/run_tests.yaml
+++ b/.github/workflows/run_tests.yaml
@@ -57,7 +57,7 @@ jobs:
2>&1 | tee output.txt &
- name: Run tests
run: |
- wait-for-it --service 127.0.0.1:7860 -t 600
+ wait-for-it --service 127.0.0.1:7860 -t 20
python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
- name: Kill test server
if: always()
diff --git a/modules/api/api.py b/modules/api/api.py
index 2e18c6b9..843c59b0 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -251,6 +251,24 @@ class Api:
self.default_script_arg_txt2img = []
self.default_script_arg_img2img = []
+ txt2img_script_runner = scripts.scripts_txt2img
+ img2img_script_runner = scripts.scripts_img2img
+
+ if not txt2img_script_runner.scripts or not img2img_script_runner.scripts:
+ ui.create_ui()
+
+ if not txt2img_script_runner.scripts:
+ txt2img_script_runner.initialize_scripts(False)
+ if not self.default_script_arg_txt2img:
+ self.default_script_arg_txt2img = self.init_default_script_args(txt2img_script_runner)
+
+ if not img2img_script_runner.scripts:
+ img2img_script_runner.initialize_scripts(True)
+ if not self.default_script_arg_img2img:
+ self.default_script_arg_img2img = self.init_default_script_args(img2img_script_runner)
+
+
+
def add_api_route(self, path: str, endpoint, **kwargs):
if shared.cmd_opts.api_auth:
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
@@ -413,15 +431,10 @@ class Api:
task_id = txt2imgreq.force_task_id or create_task_id("txt2img")
script_runner = scripts.scripts_txt2img
- if not script_runner.scripts:
- script_runner.initialize_scripts(False)
- ui.create_ui()
infotext_script_args = {}
self.apply_infotext(txt2imgreq, "txt2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
- if not self.default_script_arg_txt2img:
- self.default_script_arg_txt2img = self.init_default_script_args(script_runner)
selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)
populate = txt2imgreq.copy(update={ # Override __init__ params
@@ -482,15 +495,10 @@ class Api:
mask = decode_base64_to_image(mask)
script_runner = scripts.scripts_img2img
- if not script_runner.scripts:
- script_runner.initialize_scripts(True)
- ui.create_ui()
infotext_script_args = {}
self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
- if not self.default_script_arg_img2img:
- self.default_script_arg_img2img = self.init_default_script_args(script_runner)
selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
populate = img2imgreq.copy(update={ # Override __init__ params
diff --git a/modules/face_restoration_utils.py b/modules/face_restoration_utils.py
index c65c85ef..1cbac236 100644
--- a/modules/face_restoration_utils.py
+++ b/modules/face_restoration_utils.py
@@ -17,6 +17,28 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+def bgr_image_to_rgb_tensor(img: np.ndarray) -> torch.Tensor:
+ """Convert a BGR NumPy image in [0..1] range to a PyTorch RGB float32 tensor."""
+ assert img.shape[2] == 3, "image must be RGB"
+ if img.dtype == "float64":
+ img = img.astype("float32")
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ return torch.from_numpy(img.transpose(2, 0, 1)).float()
+
+
+def rgb_tensor_to_bgr_image(tensor: torch.Tensor, *, min_max=(0.0, 1.0)) -> np.ndarray:
+ """
+ Convert a PyTorch RGB tensor in range `min_max` to a BGR NumPy image in [0..1] range.
+ """
+ tensor = tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
+ tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])
+ assert tensor.dim() == 3, "tensor must be RGB"
+ img_np = tensor.numpy().transpose(1, 2, 0)
+ if img_np.shape[2] == 1: # gray image, no RGB/BGR required
+ return np.squeeze(img_np, axis=2)
+ return cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
+
+
def create_face_helper(device) -> FaceRestoreHelper:
from facexlib.detection import retinaface
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
@@ -36,14 +58,13 @@ def create_face_helper(device) -> FaceRestoreHelper:
def restore_with_face_helper(
np_image: np.ndarray,
face_helper: FaceRestoreHelper,
- restore_face: Callable[[np.ndarray], np.ndarray],
+ restore_face: Callable[[torch.Tensor], torch.Tensor],
) -> np.ndarray:
"""
Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image.
`restore_face` should take a cropped face image and return a restored face image.
"""
- from basicsr.utils import img2tensor, tensor2img
from torchvision.transforms.functional import normalize
np_image = np_image[:, :, ::-1]
original_resolution = np_image.shape[0:2]
@@ -56,23 +77,19 @@ def restore_with_face_helper(
face_helper.align_warp_face()
logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces))
for cropped_face in face_helper.cropped_faces:
- cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
+ cropped_face_t = bgr_image_to_rgb_tensor(cropped_face / 255.0)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
try:
with torch.no_grad():
- restored_face = tensor2img(
- restore_face(cropped_face_t),
- rgb2bgr=True,
- min_max=(-1, 1),
- )
+ cropped_face_t = restore_face(cropped_face_t)
devices.torch_gc()
except Exception:
errors.report('Failed face-restoration inference', exc_info=True)
- restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
- restored_face = restored_face.astype('uint8')
+ restored_face = rgb_tensor_to_bgr_image(cropped_face_t, min_max=(-1, 1))
+ restored_face = (restored_face * 255.0).astype('uint8')
face_helper.add_restored_face(restored_face)
logger.debug("Merging restored faces into image")
@@ -126,7 +143,7 @@ class CommonFaceRestoration(face_restoration.FaceRestoration):
def restore_with_helper(
self,
np_image: np.ndarray,
- restore_face: Callable[[np.ndarray], np.ndarray],
+ restore_face: Callable[[torch.Tensor], torch.Tensor],
) -> np.ndarray:
try:
if self.net is None:
diff --git a/modules/initialize.py b/modules/initialize.py
index ac95fc6f..4a3cd98c 100644
--- a/modules/initialize.py
+++ b/modules/initialize.py
@@ -54,9 +54,6 @@ def initialize():
initialize_util.configure_sigint_handler()
initialize_util.configure_opts_onchange()
- from modules import modelloader
- modelloader.cleanup_models()
-
from modules import sd_models
sd_models.setup_model()
startup_timer.record("setup SD model")
diff --git a/modules/modelloader.py b/modules/modelloader.py
index f4182559..0b89d682 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -2,7 +2,6 @@ from __future__ import annotations
import logging
import os
-import shutil
import importlib
from urllib.parse import urlparse
@@ -10,7 +9,6 @@ import torch
from modules import shared
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
-from modules.paths import script_path, models_path
logger = logging.getLogger(__name__)
@@ -96,54 +94,6 @@ def friendly_name(file: str):
return model_name
-def cleanup_models():
- # This code could probably be more efficient if we used a tuple list or something to store the src/destinations
- # and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
- # somehow auto-register and just do these things...
- root_path = script_path
- src_path = models_path
- dest_path = os.path.join(models_path, "Stable-diffusion")
- move_files(src_path, dest_path, ".ckpt")
- move_files(src_path, dest_path, ".safetensors")
- src_path = os.path.join(root_path, "ESRGAN")
- dest_path = os.path.join(models_path, "ESRGAN")
- move_files(src_path, dest_path)
- src_path = os.path.join(models_path, "BSRGAN")
- dest_path = os.path.join(models_path, "ESRGAN")
- move_files(src_path, dest_path, ".pth")
- src_path = os.path.join(root_path, "gfpgan")
- dest_path = os.path.join(models_path, "GFPGAN")
- move_files(src_path, dest_path)
- src_path = os.path.join(root_path, "SwinIR")
- dest_path = os.path.join(models_path, "SwinIR")
- move_files(src_path, dest_path)
- src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
- dest_path = os.path.join(models_path, "LDSR")
- move_files(src_path, dest_path)
-
-
-def move_files(src_path: str, dest_path: str, ext_filter: str = None):
- try:
- os.makedirs(dest_path, exist_ok=True)
- if os.path.exists(src_path):
- for file in os.listdir(src_path):
- fullpath = os.path.join(src_path, file)
- if os.path.isfile(fullpath):
- if ext_filter is not None:
- if ext_filter not in file:
- continue
- print(f"Moving {file} from {src_path} to {dest_path}.")
- try:
- shutil.move(fullpath, dest_path)
- except Exception:
- pass
- if len(os.listdir(src_path)) == 0:
- print(f"Removing empty folder: {src_path}")
- shutil.rmtree(src_path, True)
- except Exception:
- pass
-
-
def load_upscalers():
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
# so we'll try to import any _model.py files before looking in __subclasses__
@@ -196,7 +146,9 @@ def load_spandrel_model(
import spandrel
model = spandrel.ModelLoader(device=device).load_from_file(path)
if expected_architecture and model.architecture != expected_architecture:
- raise TypeError(f"Model {path} is not a {expected_architecture} model")
+ logger.warning(
+ f"Model {path!r} is not a {expected_architecture!r} model (got {model.architecture!r})",
+ )
if half:
model = model.model.half()
if dtype:
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index 2a2be5ad..65f2e880 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -40,7 +40,7 @@ class UpscalerRealESRGAN(Upscaler):
info.local_data_path,
device=self.device,
half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
- expected_architecture="RealESRGAN",
+ expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel
)
return upscale_with_model(
mod,
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 04dda585..c6bcab15 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -11,7 +11,6 @@ import safetensors.torch
import numpy as np
from PIL import Image, PngImagePlugin
-from torch.utils.tensorboard import SummaryWriter
from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
import modules.textual_inversion.dataset
@@ -344,6 +343,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
})
def tensorboard_setup(log_directory):
+ from torch.utils.tensorboard import SummaryWriter
os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
return SummaryWriter(
log_dir=os.path.join(log_directory, "tensorboard"),
@@ -448,8 +448,12 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
old_parallel_processing_allowed = shared.parallel_processing_allowed
+ tensorboard_writer = None
if shared.opts.training_enable_tensorboard:
- tensorboard_writer = tensorboard_setup(log_directory)
+ try:
+ tensorboard_writer = tensorboard_setup(log_directory)
+ except ImportError:
+ errors.report("Error initializing tensorboard", exc_info=True)
pin_memory = shared.opts.pin_memory
@@ -622,7 +626,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
- if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
+ if tensorboard_writer and shared.opts.training_tensorboard_save_images:
tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py
index 8bdda51c..dde5d7ad 100644
--- a/modules/upscaler_utils.py
+++ b/modules/upscaler_utils.py
@@ -16,9 +16,13 @@ def upscale_without_tiling(model, img: Image.Image):
img = img[:, :, ::-1]
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(devices.device_esrgan)
+
+ model_weight = next(iter(model.model.parameters()))
+ img = img.unsqueeze(0).to(device=model_weight.device, dtype=model_weight.dtype)
+
with torch.no_grad():
output = model(img)
+
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
output = 255. * np.moveaxis(output, 0, 2)
output = output.astype(np.uint8)
diff --git a/requirements.txt b/requirements.txt
index b1329c9e..731a1be7 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,7 +2,6 @@ GitPython
Pillow
accelerate
-basicsr
blendmodes
clean-fid
einops
diff --git a/requirements_versions.txt b/requirements_versions.txt
index 7ec7abe2..2a922f28 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -1,7 +1,6 @@
GitPython==3.1.32
Pillow==9.5.0
accelerate==0.21.0
-basicsr==1.4.2
blendmodes==2022
clean-fid==0.1.35
einops==0.4.1