aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorw-e-w <40751091+w-e-w@users.noreply.github.com>2023-11-27 17:26:16 +0900
committerw-e-w <40751091+w-e-w@users.noreply.github.com>2023-11-28 12:09:51 +0900
commit03ee297aa22296ea12b965fc1cb11aa46375d372 (patch)
tree83edb374cfb078ea535cacf2919579b274ac0012
parentf0f100e67b78f686dc73cf3c8cad422e45cc9b8a (diff)
fix Auto focal point crop for opencv >= 4.8.x
autocrop.download_and_cache_models in opencv >= 4.8 the face detection model was updated download the base on opencv version returns the model path or raise exception
-rw-r--r--modules/textual_inversion/autocrop.py29
-rw-r--r--modules/textual_inversion/preprocess.py4
2 files changed, 18 insertions, 15 deletions
diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py
index 1675e39a..051be118 100644
--- a/modules/textual_inversion/autocrop.py
+++ b/modules/textual_inversion/autocrop.py
@@ -3,6 +3,8 @@ import requests
import os
import numpy as np
from PIL import ImageDraw
+from modules import paths_internal
+from pkg_resources import parse_version
GREEN = "#0F0"
BLUE = "#00F"
@@ -294,22 +296,23 @@ def is_square(w, h):
return w == h
-def download_and_cache_models(dirname):
- download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
- model_file_name = 'face_detection_yunet.onnx'
+model_dir_opencv = os.path.join(paths_internal.models_path, 'opencv')
+if parse_version(cv2.__version__) >= parse_version('4.8'):
+ model_file_path = os.path.join(model_dir_opencv, 'face_detection_yunet_2023mar.onnx')
+ model_url = 'https://github.com/opencv/opencv_zoo/blob/b6e370b10f641879a87890d44e42173077154a05/models/face_detection_yunet/face_detection_yunet_2023mar.onnx?raw=true'
+else:
+ model_file_path = os.path.join(model_dir_opencv, 'face_detection_yunet.onnx')
+ model_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
- os.makedirs(dirname, exist_ok=True)
- cache_file = os.path.join(dirname, model_file_name)
- if not os.path.exists(cache_file):
- print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
- response = requests.get(download_url)
- with open(cache_file, "wb") as f:
+def download_and_cache_models():
+ if not os.path.exists(model_file_path):
+ os.makedirs(model_dir_opencv, exist_ok=True)
+ print(f"downloading face detection model from '{model_url}' to '{model_file_path}'")
+ response = requests.get(model_url)
+ with open(model_file_path, "wb") as f:
f.write(response.content)
-
- if os.path.exists(cache_file):
- return cache_file
- return None
+ return model_file_path
class PointOfInterest:
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index dbd856bd..789fa083 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -3,7 +3,7 @@ from PIL import Image, ImageOps
import math
import tqdm
-from modules import paths, shared, images, deepbooru
+from modules import shared, images, deepbooru
from modules.textual_inversion import autocrop
@@ -196,7 +196,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
dnn_model_path = None
try:
- dnn_model_path = autocrop.download_and_cache_models(os.path.join(paths.models_path, "opencv"))
+ dnn_model_path = autocrop.download_and_cache_models()
except Exception as e:
print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)