aboutsummaryrefslogtreecommitdiff
path: root/modules/face_restoration_utils.py
diff options
context:
space:
mode:
authorAarni Koskela <akx@iki.fi>2023-12-30 17:41:19 +0200
committerAarni Koskela <akx@iki.fi>2023-12-30 17:41:29 +0200
commitf476649c02cf3547d891fa08c50a92f92c4d73bd (patch)
treeb42f511d0ebe8b1eb93df064bc7c7563608fa23d /modules/face_restoration_utils.py
parentcd12c0e15c4dc1545cac18ba902ca17488812953 (diff)
Correct arg type for restore_face
Diffstat (limited to 'modules/face_restoration_utils.py')
-rw-r--r--modules/face_restoration_utils.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/modules/face_restoration_utils.py b/modules/face_restoration_utils.py
index c65c85ef..85cb3057 100644
--- a/modules/face_restoration_utils.py
+++ b/modules/face_restoration_utils.py
@@ -36,7 +36,7 @@ def create_face_helper(device) -> FaceRestoreHelper:
def restore_with_face_helper(
np_image: np.ndarray,
face_helper: FaceRestoreHelper,
- restore_face: Callable[[np.ndarray], np.ndarray],
+ restore_face: Callable[[torch.Tensor], torch.Tensor],
) -> np.ndarray:
"""
Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image.
@@ -126,7 +126,7 @@ class CommonFaceRestoration(face_restoration.FaceRestoration):
def restore_with_helper(
self,
np_image: np.ndarray,
- restore_face: Callable[[np.ndarray], np.ndarray],
+ restore_face: Callable[[torch.Tensor], torch.Tensor],
) -> np.ndarray:
try:
if self.net is None: