From c9a0c6dc4981e1e6eb933b32dc2a66244ea77384 Mon Sep 17 00:00:00 2001 From: Leonard Kugis Date: Wed, 22 Mar 2023 05:13:28 +0100 Subject: Optimized partial prediction --- file-tagger.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/file-tagger.py b/file-tagger.py index a8bf3ef..ca4b8f5 100644 --- a/file-tagger.py +++ b/file-tagger.py @@ -13,6 +13,9 @@ MODEL_DIMENSIONS = 224 def predict_image(model, img, top): from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions logger = logging.getLogger(__name__) + #cv2.imshow("test", img) + #cv2.waitKey(0) + #cv2.destroyAllWindows() array = np.expand_dims(img, axis=0) array = preprocess_input(array) predictions = model.predict(array) @@ -20,10 +23,14 @@ def predict_image(model, img, top): logger.debug("Predicted image classes: {}".format(classes[0])) return set([(name, prob) for _, name, prob in classes[0]]) -def predict_partial(tags, model, img, x, y, top): +def predict_partial(tags, model, img, x, y, rot, top): #cv2.imshow("test", img[x:(x+MODEL_DIMENSIONS), y:(y+MODEL_DIMENSIONS)]) #cv2.waitKey(0) - tags.update(predict_image(model, img[x:(x+MODEL_DIMENSIONS), y:(y+MODEL_DIMENSIONS)], top)) + if rot is None: + tmp = img[x:(x+MODEL_DIMENSIONS), y:(y+MODEL_DIMENSIONS)] + else: + tmp = cv2.rotate(img[x:(x+MODEL_DIMENSIONS), y:(y+MODEL_DIMENSIONS)], rot) + tags.update(predict_image(model, tmp, top)) ''' Walk over all files for the given base directory and all subdirectories recursively. @@ -78,15 +85,16 @@ def walk(args): img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) if not args["predict_images_skip_detail"]: pool = ThreadPool(max(1, os.cpu_count() - 2), 10000) - for _ in range(4): - if img.shape[0] > img.shape[1]: - detail = image_resize(img.copy(), height=(args["predict_images_detail_factor"] * MODEL_DIMENSIONS)) - else: - detail = image_resize(img.copy(), width=(args["predict_images_detail_factor"] * MODEL_DIMENSIONS)) - for x in range(0, detail.shape[0], int(MODEL_DIMENSIONS/2)): - for y in range(0, detail.shape[1], int(MODEL_DIMENSIONS/2)): - pool.add_task(predict_partial, tags_predict, model, detail, x, y, args["predict_images_top"]) - img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) + if img.shape[0] > img.shape[1]: + detail = image_resize(img.copy(), height=(args["predict_images_detail_factor"] * MODEL_DIMENSIONS)) + else: + detail = image_resize(img.copy(), width=(args["predict_images_detail_factor"] * MODEL_DIMENSIONS)) + for x in range(0, detail.shape[0], int(MODEL_DIMENSIONS/2)): + for y in range(0, detail.shape[1], int(MODEL_DIMENSIONS/2)): + pool.add_task(predict_partial, tags_predict, model, detail, x, y, None, args["predict_images_top"]) + pool.add_task(predict_partial, tags_predict, model, detail, x, y, cv2.ROTATE_90_CLOCKWISE, args["predict_images_top"]) + pool.add_task(predict_partial, tags_predict, model, detail, x, y, cv2.ROTATE_180, args["predict_images_top"]) + pool.add_task(predict_partial, tags_predict, model, detail, x, y, cv2.ROTATE_90_COUNTERCLOCKWISE, args["predict_images_top"]) pool.wait_completion() tags_sorted = [tag[0] for tag in sorted(tags_predict, key=lambda tag: tag[1], reverse=True)] tags_predict = set(list(dict.fromkeys(tags_sorted))[0:args["predict_images_top"]]) -- cgit v1.2.1