aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.github/workflows/run_tests.yaml8
-rw-r--r--.gitignore1
-rw-r--r--modules/codeformer_model.py158
-rw-r--r--modules/face_restoration_utils.py163
-rw-r--r--modules/gfpgan_model.py166
-rw-r--r--requirements.txt1
-rw-r--r--requirements_versions.txt1
-rw-r--r--test/conftest.py15
-rw-r--r--test/test_face_restorers.py29
-rw-r--r--test/test_files/two-faces.jpgbin0 -> 14768 bytes
-rw-r--r--test/test_outputs/.gitkeep0
11 files changed, 308 insertions, 234 deletions
diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml
index 3dafaf8d..cd5c3f86 100644
--- a/.github/workflows/run_tests.yaml
+++ b/.github/workflows/run_tests.yaml
@@ -20,6 +20,12 @@ jobs:
cache-dependency-path: |
**/requirements*txt
launch.py
+ - name: Cache models
+ id: cache-models
+ uses: actions/cache@v3
+ with:
+ path: models
+ key: "2023-12-30"
- name: Install test dependencies
run: pip install wait-for-it -r requirements-test.txt
env:
@@ -33,6 +39,8 @@ jobs:
TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu
WEBUI_LAUNCH_LIVE_OUTPUT: "1"
PYTHONUNBUFFERED: "1"
+ - name: Print installed packages
+ run: pip freeze
- name: Start test server
run: >
python -m coverage run
diff --git a/.gitignore b/.gitignore
index 09734267..6790e9ee 100644
--- a/.gitignore
+++ b/.gitignore
@@ -37,3 +37,4 @@ notification.mp3
/node_modules
/package-lock.json
/.coverage*
+/test/test_outputs
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index 517eadfd..ceda4bab 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -1,140 +1,62 @@
-import os
+from __future__ import annotations
+
+import logging
-import cv2
import torch
-import modules.face_restoration
-import modules.shared
-from modules import shared, devices, modelloader, errors
-from modules.paths import models_path
+from modules import (
+ devices,
+ errors,
+ face_restoration,
+ face_restoration_utils,
+ modelloader,
+ shared,
+)
+
+logger = logging.getLogger(__name__)
-model_dir = "Codeformer"
-model_path = os.path.join(models_path, model_dir)
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
+model_download_name = 'codeformer-v0.1.0.pth'
-codeformer = None
+# used by e.g. postprocessing_codeformer.py
+codeformer: face_restoration.FaceRestoration | None = None
-class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
+class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
def name(self):
return "CodeFormer"
- def __init__(self, dirname):
- self.net = None
- self.face_helper = None
- self.cmd_dir = dirname
-
- def create_models(self):
- from facexlib.detection import retinaface
- from facexlib.utils.face_restoration_helper import FaceRestoreHelper
-
- if self.net is not None and self.face_helper is not None:
- self.net.to(devices.device_codeformer)
- return self.net, self.face_helper
- model_paths = modelloader.load_models(
- model_path,
- model_url,
- self.cmd_dir,
- download_name='codeformer-v0.1.0.pth',
+ def load_net(self) -> torch.Module:
+ for model_path in modelloader.load_models(
+ model_path=self.model_path,
+ model_url=model_url,
+ command_path=self.model_path,
+ download_name=model_download_name,
ext_filter=['.pth'],
- )
-
- if len(model_paths) != 0:
- ckpt_path = model_paths[0]
- else:
- print("Unable to load codeformer model.")
- return None, None
- net = modelloader.load_spandrel_model(ckpt_path, device=devices.device_codeformer)
-
- if hasattr(retinaface, 'device'):
- retinaface.device = devices.device_codeformer
-
- face_helper = FaceRestoreHelper(
- upscale_factor=1,
- face_size=512,
- crop_ratio=(1, 1),
- det_model='retinaface_resnet50',
- save_ext='png',
- use_parse=True,
- device=devices.device_codeformer,
- )
-
- self.net = net
- self.face_helper = face_helper
-
- def send_model_to(self, device):
- self.net.to(device)
- self.face_helper.face_det.to(device)
- self.face_helper.face_parse.to(device)
-
- def restore(self, np_image, w=None):
- from torchvision.transforms.functional import normalize
- from basicsr.utils import img2tensor, tensor2img
- np_image = np_image[:, :, ::-1]
-
- original_resolution = np_image.shape[0:2]
-
- self.create_models()
- if self.net is None or self.face_helper is None:
- return np_image
-
- self.send_model_to(devices.device_codeformer)
-
- self.face_helper.clean_all()
- self.face_helper.read_image(np_image)
- self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
- self.face_helper.align_warp_face()
-
- for cropped_face in self.face_helper.cropped_faces:
- cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
- normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
- cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
-
- try:
- with torch.no_grad():
- res = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)
- if isinstance(res, tuple):
- output = res[0]
- else:
- output = res
- if not isinstance(res, torch.Tensor):
- raise TypeError(f"Expected torch.Tensor, got {type(res)}")
- restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
- del output
- devices.torch_gc()
- except Exception:
- errors.report('Failed inference for CodeFormer', exc_info=True)
- restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
-
- restored_face = restored_face.astype('uint8')
- self.face_helper.add_restored_face(restored_face)
-
- self.face_helper.get_inverse_affine(None)
-
- restored_img = self.face_helper.paste_faces_to_input_image()
- restored_img = restored_img[:, :, ::-1]
+ ):
+ return modelloader.load_spandrel_model(
+ model_path,
+ device=devices.device_codeformer,
+ ).model
+ raise ValueError("No codeformer model found")
- if original_resolution != restored_img.shape[0:2]:
- restored_img = cv2.resize(
- restored_img,
- (0, 0),
- fx=original_resolution[1]/restored_img.shape[1],
- fy=original_resolution[0]/restored_img.shape[0],
- interpolation=cv2.INTER_LINEAR,
- )
+ def get_device(self):
+ return devices.device_codeformer
- self.face_helper.clean_all()
+ def restore(self, np_image, w: float | None = None):
+ if w is None:
+ w = getattr(shared.opts, "code_former_weight", 0.5)
- if shared.opts.face_restoration_unload:
- self.send_model_to(devices.cpu)
+ def restore_face(cropped_face_t):
+ assert self.net is not None
+ return self.net(cropped_face_t, w=w, adain=True)[0]
- return restored_img
+ return self.restore_with_helper(np_image, restore_face)
-def setup_model(dirname):
- os.makedirs(model_path, exist_ok=True)
+def setup_model(dirname: str) -> None:
+ global codeformer
try:
- global codeformer
codeformer = FaceRestorerCodeFormer(dirname)
shared.face_restorers.append(codeformer)
except Exception:
diff --git a/modules/face_restoration_utils.py b/modules/face_restoration_utils.py
new file mode 100644
index 00000000..c65c85ef
--- /dev/null
+++ b/modules/face_restoration_utils.py
@@ -0,0 +1,163 @@
+from __future__ import annotations
+
+import logging
+import os
+from functools import cached_property
+from typing import TYPE_CHECKING, Callable
+
+import cv2
+import numpy as np
+import torch
+
+from modules import devices, errors, face_restoration, shared
+
+if TYPE_CHECKING:
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
+
+logger = logging.getLogger(__name__)
+
+
+def create_face_helper(device) -> FaceRestoreHelper:
+ from facexlib.detection import retinaface
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
+ if hasattr(retinaface, 'device'):
+ retinaface.device = device
+ return FaceRestoreHelper(
+ upscale_factor=1,
+ face_size=512,
+ crop_ratio=(1, 1),
+ det_model='retinaface_resnet50',
+ save_ext='png',
+ use_parse=True,
+ device=device,
+ )
+
+
+def restore_with_face_helper(
+ np_image: np.ndarray,
+ face_helper: FaceRestoreHelper,
+ restore_face: Callable[[np.ndarray], np.ndarray],
+) -> np.ndarray:
+ """
+ Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image.
+
+ `restore_face` should take a cropped face image and return a restored face image.
+ """
+ from basicsr.utils import img2tensor, tensor2img
+ from torchvision.transforms.functional import normalize
+ np_image = np_image[:, :, ::-1]
+ original_resolution = np_image.shape[0:2]
+
+ try:
+ logger.debug("Detecting faces...")
+ face_helper.clean_all()
+ face_helper.read_image(np_image)
+ face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
+ face_helper.align_warp_face()
+ logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces))
+ for cropped_face in face_helper.cropped_faces:
+ cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
+
+ try:
+ with torch.no_grad():
+ restored_face = tensor2img(
+ restore_face(cropped_face_t),
+ rgb2bgr=True,
+ min_max=(-1, 1),
+ )
+ devices.torch_gc()
+ except Exception:
+ errors.report('Failed face-restoration inference', exc_info=True)
+ restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
+
+ restored_face = restored_face.astype('uint8')
+ face_helper.add_restored_face(restored_face)
+
+ logger.debug("Merging restored faces into image")
+ face_helper.get_inverse_affine(None)
+ img = face_helper.paste_faces_to_input_image()
+ img = img[:, :, ::-1]
+ if original_resolution != img.shape[0:2]:
+ img = cv2.resize(
+ img,
+ (0, 0),
+ fx=original_resolution[1] / img.shape[1],
+ fy=original_resolution[0] / img.shape[0],
+ interpolation=cv2.INTER_LINEAR,
+ )
+ logger.debug("Face restoration complete")
+ finally:
+ face_helper.clean_all()
+ return img
+
+
+class CommonFaceRestoration(face_restoration.FaceRestoration):
+ net: torch.Module | None
+ model_url: str
+ model_download_name: str
+
+ def __init__(self, model_path: str):
+ super().__init__()
+ self.net = None
+ self.model_path = model_path
+ os.makedirs(model_path, exist_ok=True)
+
+ @cached_property
+ def face_helper(self) -> FaceRestoreHelper:
+ return create_face_helper(self.get_device())
+
+ def send_model_to(self, device):
+ if self.net:
+ logger.debug("Sending %s to %s", self.net, device)
+ self.net.to(device)
+ if self.face_helper:
+ logger.debug("Sending face helper to %s", device)
+ self.face_helper.face_det.to(device)
+ self.face_helper.face_parse.to(device)
+
+ def get_device(self):
+ raise NotImplementedError("get_device must be implemented by subclasses")
+
+ def load_net(self) -> torch.Module:
+ raise NotImplementedError("load_net must be implemented by subclasses")
+
+ def restore_with_helper(
+ self,
+ np_image: np.ndarray,
+ restore_face: Callable[[np.ndarray], np.ndarray],
+ ) -> np.ndarray:
+ try:
+ if self.net is None:
+ self.net = self.load_net()
+ except Exception:
+ logger.warning("Unable to load face-restoration model", exc_info=True)
+ return np_image
+
+ try:
+ self.send_model_to(self.get_device())
+ return restore_with_face_helper(np_image, self.face_helper, restore_face)
+ finally:
+ if shared.opts.face_restoration_unload:
+ self.send_model_to(devices.cpu)
+
+
+def patch_facexlib(dirname: str) -> None:
+ import facexlib.detection
+ import facexlib.parsing
+
+ det_facex_load_file_from_url = facexlib.detection.load_file_from_url
+ par_facex_load_file_from_url = facexlib.parsing.load_file_from_url
+
+ def update_kwargs(kwargs):
+ return dict(kwargs, save_dir=dirname, model_dir=None)
+
+ def facex_load_file_from_url(**kwargs):
+ return det_facex_load_file_from_url(**update_kwargs(kwargs))
+
+ def facex_load_file_from_url2(**kwargs):
+ return par_facex_load_file_from_url(**update_kwargs(kwargs))
+
+ facexlib.detection.load_file_from_url = facex_load_file_from_url
+ facexlib.parsing.load_file_from_url = facex_load_file_from_url2
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index 6b6f17c4..a356b56f 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -1,126 +1,68 @@
+from __future__ import annotations
+
+import logging
import os
-import modules.face_restoration
-from modules import paths, shared, devices, modelloader, errors
+from modules import (
+ devices,
+ errors,
+ face_restoration,
+ face_restoration_utils,
+ modelloader,
+ shared,
+)
-model_dir = "GFPGAN"
-user_path = None
-model_path = os.path.join(paths.models_path, model_dir)
-model_file_path = None
+logger = logging.getLogger(__name__)
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
-have_gfpgan = False
-loaded_gfpgan_model = None
-
-
-def gfpgann():
- global loaded_gfpgan_model
- global model_path
- global model_file_path
- if loaded_gfpgan_model is not None:
- loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
- return loaded_gfpgan_model
-
- if gfpgan_constructor is None:
- return None
-
- models = modelloader.load_models(model_path, model_url, user_path, ext_filter=['.pth'])
-
- if len(models) == 1 and models[0].startswith("http"):
- model_file = models[0]
- elif len(models) != 0:
- gfp_models = []
- for item in models:
- if 'GFPGAN' in os.path.basename(item):
- gfp_models.append(item)
- latest_file = max(gfp_models, key=os.path.getctime)
- model_file = latest_file
- else:
- print("Unable to load gfpgan model!")
- return None
-
- import facexlib.detection.retinaface
-
- if hasattr(facexlib.detection.retinaface, 'device'):
- facexlib.detection.retinaface.device = devices.device_gfpgan
- model_file_path = model_file
- model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
- loaded_gfpgan_model = model
-
- return model
-
-
-def send_model_to(model, device):
- model.gfpgan.to(device)
- model.face_helper.face_det.to(device)
- model.face_helper.face_parse.to(device)
+model_download_name = "GFPGANv1.4.pth"
+gfpgan_face_restorer: face_restoration.FaceRestoration | None = None
+
+
+class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
+ def name(self):
+ return "GFPGAN"
+
+ def get_device(self):
+ return devices.device_gfpgan
+
+ def load_net(self) -> None:
+ for model_path in modelloader.load_models(
+ model_path=self.model_path,
+ model_url=model_url,
+ command_path=self.model_path,
+ download_name=model_download_name,
+ ext_filter=['.pth'],
+ ):
+ if 'GFPGAN' in os.path.basename(model_path):
+ net = modelloader.load_spandrel_model(
+ model_path,
+ device=self.get_device(),
+ ).model
+ net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81
+ return net
+ raise ValueError("No GFPGAN model found")
+
+ def restore(self, np_image):
+ def restore_face(cropped_face_t):
+ assert self.net is not None
+ return self.net(cropped_face_t, return_rgb=False)[0]
+
+ return self.restore_with_helper(np_image, restore_face)
def gfpgan_fix_faces(np_image):
- model = gfpgann()
- if model is None:
- return np_image
-
- send_model_to(model, devices.device_gfpgan)
-
- np_image_bgr = np_image[:, :, ::-1]
- cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
- np_image = gfpgan_output_bgr[:, :, ::-1]
-
- model.face_helper.clean_all()
-
- if shared.opts.face_restoration_unload:
- send_model_to(model, devices.cpu)
-
+ if gfpgan_face_restorer:
+ return gfpgan_face_restorer.restore(np_image)
+ logger.warning("GFPGAN face restorer not set up")
return np_image
-gfpgan_constructor = None
+def setup_model(dirname: str) -> None:
+ global gfpgan_face_restorer
-
-def setup_model(dirname):
try:
- os.makedirs(model_path, exist_ok=True)
- import gfpgan
- import facexlib.detection
- import facexlib.parsing
-
- global user_path
- global have_gfpgan
- global gfpgan_constructor
- global model_file_path
-
- facexlib_path = model_path
-
- if dirname is not None:
- facexlib_path = dirname
-
- load_file_from_url_orig = gfpgan.utils.load_file_from_url
- facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
- facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
-
- def my_load_file_from_url(**kwargs):
- return load_file_from_url_orig(**dict(kwargs, model_dir=model_file_path))
-
- def facex_load_file_from_url(**kwargs):
- return facex_load_file_from_url_orig(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
-
- def facex_load_file_from_url2(**kwargs):
- return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
-
- gfpgan.utils.load_file_from_url = my_load_file_from_url
- facexlib.detection.load_file_from_url = facex_load_file_from_url
- facexlib.parsing.load_file_from_url = facex_load_file_from_url2
- user_path = dirname
- have_gfpgan = True
- gfpgan_constructor = gfpgan.GFPGANer
-
- class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
- def name(self):
- return "GFPGAN"
-
- def restore(self, np_image):
- return gfpgan_fix_faces(np_image)
-
- shared.face_restorers.append(FaceRestorerGFPGAN())
+ face_restoration_utils.patch_facexlib(dirname)
+ gfpgan_face_restorer = FaceRestorerGFPGAN(model_path=dirname)
+ shared.face_restorers.append(gfpgan_face_restorer)
except Exception:
errors.report("Error setting up GFPGAN", exc_info=True)
diff --git a/requirements.txt b/requirements.txt
index 36f5674a..b1329c9e 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -8,7 +8,6 @@ clean-fid
einops
facexlib
fastapi>=0.90.1
-gfpgan
gradio==3.41.2
inflection
jsonmerge
diff --git a/requirements_versions.txt b/requirements_versions.txt
index 042fa708..edbb6db9 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -7,7 +7,6 @@ clean-fid==0.1.35
einops==0.4.1
facexlib==0.3.0
fastapi==0.94.0
-gfpgan==1.3.8
gradio==3.41.2
httpcore==0.15
inflection==0.5.1
diff --git a/test/conftest.py b/test/conftest.py
index 31a5d9ea..e4fc5678 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -1,10 +1,16 @@
+import base64
import os
import pytest
-import base64
-
test_files_path = os.path.dirname(__file__) + "/test_files"
+test_outputs_path = os.path.dirname(__file__) + "/test_outputs"
+
+
+def pytest_configure(config):
+ # We don't want to fail on Py.test command line arguments being
+ # parsed by webui:
+ os.environ.setdefault("IGNORE_CMD_ARGS_ERRORS", "1")
def file_to_base64(filename):
@@ -23,3 +29,8 @@ def img2img_basic_image_base64() -> str:
@pytest.fixture(scope="session") # session so we don't read this over and over
def mask_basic_image_base64() -> str:
return file_to_base64(os.path.join(test_files_path, "mask_basic.png"))
+
+
+@pytest.fixture(scope="session")
+def initialize() -> None:
+ import webui # noqa: F401
diff --git a/test/test_face_restorers.py b/test/test_face_restorers.py
new file mode 100644
index 00000000..7760d51b
--- /dev/null
+++ b/test/test_face_restorers.py
@@ -0,0 +1,29 @@
+import os
+from test.conftest import test_files_path, test_outputs_path
+
+import numpy as np
+import pytest
+from PIL import Image
+
+
+@pytest.mark.usefixtures("initialize")
+@pytest.mark.parametrize("restorer_name", ["gfpgan", "codeformer"])
+def test_face_restorers(restorer_name):
+ from modules import shared
+
+ if restorer_name == "gfpgan":
+ from modules import gfpgan_model
+ gfpgan_model.setup_model(shared.cmd_opts.gfpgan_models_path)
+ restorer = gfpgan_model.gfpgan_fix_faces
+ elif restorer_name == "codeformer":
+ from modules import codeformer_model
+ codeformer_model.setup_model(shared.cmd_opts.codeformer_models_path)
+ restorer = codeformer_model.codeformer.restore
+ else:
+ raise NotImplementedError("...")
+ img = Image.open(os.path.join(test_files_path, "two-faces.jpg"))
+ np_img = np.array(img, dtype=np.uint8)
+ fixed_image = restorer(np_img)
+ assert fixed_image.shape == np_img.shape
+ assert not np.allclose(fixed_image, np_img) # should have visibly changed
+ Image.fromarray(fixed_image).save(os.path.join(test_outputs_path, f"{restorer_name}.png"))
diff --git a/test/test_files/two-faces.jpg b/test/test_files/two-faces.jpg
new file mode 100644
index 00000000..c9d1b010
--- /dev/null
+++ b/test/test_files/two-faces.jpg
Binary files differ
diff --git a/test/test_outputs/.gitkeep b/test/test_outputs/.gitkeep
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/test/test_outputs/.gitkeep