aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/face_restoration_utils.py35
-rw-r--r--requirements.txt1
-rw-r--r--requirements_versions.txt1
3 files changed, 26 insertions, 11 deletions
diff --git a/modules/face_restoration_utils.py b/modules/face_restoration_utils.py
index 85cb3057..1cbac236 100644
--- a/modules/face_restoration_utils.py
+++ b/modules/face_restoration_utils.py
@@ -17,6 +17,28 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+def bgr_image_to_rgb_tensor(img: np.ndarray) -> torch.Tensor:
+ """Convert a BGR NumPy image in [0..1] range to a PyTorch RGB float32 tensor."""
+ assert img.shape[2] == 3, "image must be RGB"
+ if img.dtype == "float64":
+ img = img.astype("float32")
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ return torch.from_numpy(img.transpose(2, 0, 1)).float()
+
+
+def rgb_tensor_to_bgr_image(tensor: torch.Tensor, *, min_max=(0.0, 1.0)) -> np.ndarray:
+ """
+ Convert a PyTorch RGB tensor in range `min_max` to a BGR NumPy image in [0..1] range.
+ """
+ tensor = tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
+ tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])
+ assert tensor.dim() == 3, "tensor must be RGB"
+ img_np = tensor.numpy().transpose(1, 2, 0)
+ if img_np.shape[2] == 1: # gray image, no RGB/BGR required
+ return np.squeeze(img_np, axis=2)
+ return cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
+
+
def create_face_helper(device) -> FaceRestoreHelper:
from facexlib.detection import retinaface
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
@@ -43,7 +65,6 @@ def restore_with_face_helper(
`restore_face` should take a cropped face image and return a restored face image.
"""
- from basicsr.utils import img2tensor, tensor2img
from torchvision.transforms.functional import normalize
np_image = np_image[:, :, ::-1]
original_resolution = np_image.shape[0:2]
@@ -56,23 +77,19 @@ def restore_with_face_helper(
face_helper.align_warp_face()
logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces))
for cropped_face in face_helper.cropped_faces:
- cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
+ cropped_face_t = bgr_image_to_rgb_tensor(cropped_face / 255.0)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
try:
with torch.no_grad():
- restored_face = tensor2img(
- restore_face(cropped_face_t),
- rgb2bgr=True,
- min_max=(-1, 1),
- )
+ cropped_face_t = restore_face(cropped_face_t)
devices.torch_gc()
except Exception:
errors.report('Failed face-restoration inference', exc_info=True)
- restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
- restored_face = restored_face.astype('uint8')
+ restored_face = rgb_tensor_to_bgr_image(cropped_face_t, min_max=(-1, 1))
+ restored_face = (restored_face * 255.0).astype('uint8')
face_helper.add_restored_face(restored_face)
logger.debug("Merging restored faces into image")
diff --git a/requirements.txt b/requirements.txt
index b1329c9e..731a1be7 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,7 +2,6 @@ GitPython
Pillow
accelerate
-basicsr
blendmodes
clean-fid
einops
diff --git a/requirements_versions.txt b/requirements_versions.txt
index edbb6db9..1e0ccafa 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -1,7 +1,6 @@
GitPython==3.1.32
Pillow==9.5.0
accelerate==0.21.0
-basicsr==1.4.2
blendmodes==2022
clean-fid==0.1.35
einops==0.4.1