From 8d026f6a9eba0bc3905b945543f0e88a19c5f5cc Mon Sep 17 00:00:00 2001 From: Leonard Kugis Date: Fri, 31 Mar 2023 02:50:27 +0200 Subject: Predictor: Moved out of main script --- file-tagger.py | 54 +++--------------------------- predictor.py | 104 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+), 49 deletions(-) create mode 100644 predictor.py diff --git a/file-tagger.py b/file-tagger.py index ca4b8f5..70909c8 100644 --- a/file-tagger.py +++ b/file-tagger.py @@ -7,30 +7,8 @@ import logging import magic from tmsu import * from util import * - -MODEL_DIMENSIONS = 224 - -def predict_image(model, img, top): - from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions - logger = logging.getLogger(__name__) - #cv2.imshow("test", img) - #cv2.waitKey(0) - #cv2.destroyAllWindows() - array = np.expand_dims(img, axis=0) - array = preprocess_input(array) - predictions = model.predict(array) - classes = decode_predictions(predictions, top=top) - logger.debug("Predicted image classes: {}".format(classes[0])) - return set([(name, prob) for _, name, prob in classes[0]]) - -def predict_partial(tags, model, img, x, y, rot, top): - #cv2.imshow("test", img[x:(x+MODEL_DIMENSIONS), y:(y+MODEL_DIMENSIONS)]) - #cv2.waitKey(0) - if rot is None: - tmp = img[x:(x+MODEL_DIMENSIONS), y:(y+MODEL_DIMENSIONS)] - else: - tmp = cv2.rotate(img[x:(x+MODEL_DIMENSIONS), y:(y+MODEL_DIMENSIONS)], rot) - tags.update(predict_image(model, tmp, top)) +from predictor import * +from PIL import Image ''' Walk over all files for the given base directory and all subdirectories recursively. @@ -52,10 +30,8 @@ def walk(args): return if args["predict_images"]: - from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input, decode_predictions - from tensorflow.keras.preprocessing import image - from tensorflow.keras.models import Model - model = ResNet50(weights="imagenet") + #predictor = Predictor(Predictor.BackendTorch(top=args["predict_images_top"])) + predictor = Predictor(Predictor.BackendTensorflow(top=args["predict_images_top"], detail=(not args["predict_images_skip_detail"]), detail_factor=args["predict_images_detail_factor"])) for i in range(args["index"], len(files)): file_path = files[i] @@ -77,27 +53,7 @@ def walk(args): img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if args["predict_images"]: logger.info("Predicting image tags ...") - tags_predict = set() - for _ in range(4): - logger.debug("Raw scan") - raw = cv2.resize(img.copy(), dsize=(MODEL_DIMENSIONS, MODEL_DIMENSIONS), interpolation=cv2.INTER_CUBIC) - tags_predict.update(predict_image(model, raw, args["predict_images_top"])) - img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) - if not args["predict_images_skip_detail"]: - pool = ThreadPool(max(1, os.cpu_count() - 2), 10000) - if img.shape[0] > img.shape[1]: - detail = image_resize(img.copy(), height=(args["predict_images_detail_factor"] * MODEL_DIMENSIONS)) - else: - detail = image_resize(img.copy(), width=(args["predict_images_detail_factor"] * MODEL_DIMENSIONS)) - for x in range(0, detail.shape[0], int(MODEL_DIMENSIONS/2)): - for y in range(0, detail.shape[1], int(MODEL_DIMENSIONS/2)): - pool.add_task(predict_partial, tags_predict, model, detail, x, y, None, args["predict_images_top"]) - pool.add_task(predict_partial, tags_predict, model, detail, x, y, cv2.ROTATE_90_CLOCKWISE, args["predict_images_top"]) - pool.add_task(predict_partial, tags_predict, model, detail, x, y, cv2.ROTATE_180, args["predict_images_top"]) - pool.add_task(predict_partial, tags_predict, model, detail, x, y, cv2.ROTATE_90_COUNTERCLOCKWISE, args["predict_images_top"]) - pool.wait_completion() - tags_sorted = [tag[0] for tag in sorted(tags_predict, key=lambda tag: tag[1], reverse=True)] - tags_predict = set(list(dict.fromkeys(tags_sorted))[0:args["predict_images_top"]]) + tags_predict = predictor.predict(img) logger.info("Predicted tags: {}".format(tags_predict)) tags.update(tags_predict) if args["gui_tag"]: diff --git a/predictor.py b/predictor.py new file mode 100644 index 0000000..8a886a7 --- /dev/null +++ b/predictor.py @@ -0,0 +1,104 @@ +import logging +import os +import cv2 +import numpy as np +from util import * + +class Predictor(object): + + class Backend(object): + + def __init__(self): + raise NotImplementedError() + + def predict(self, img, top=10): + raise NotImplementedError() + + class BackendTensorflow(Backend): + + MODEL_DIMENSIONS = 224 + + def __init__(self, top=10, detail=True, detail_factor=4): + logger = logging.getLogger(__name__) + logger.debug("Initializing Tensorflow/Keras backend ...") + from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input, decode_predictions + from tensorflow.keras.preprocessing import image + from tensorflow.keras.models import Model + self.__model = ResNet50(weights="imagenet") + self.__top = top + self.__detail = detail + self.__detail_factor = detail_factor + + def __predict(self, img): + logger = logging.getLogger(__name__) + logger.debug("Predicting image part ...") + from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions + array = np.expand_dims(img, axis=0) + array = preprocess_input(array) + predictions = self.__model.predict(array) + classes = decode_predictions(predictions, top=self.__top) + logger.debug("Predicted raw image classes: {}".format(classes[0])) + return set([(name, prob) for _, name, prob in classes[0]]) + + def __predict_partial(self, tags, img, x, y, rot): + logger = logging.getLogger(__name__) + logger.debug("Predicting detail image at x={}, y={}, rot={}".format(x, y, rot)) + if rot is None: + tmp = img[x:(x+self.MODEL_DIMENSIONS), y:(y+self.MODEL_DIMENSIONS)] + else: + tmp = cv2.rotate(img[x:(x+self.MODEL_DIMENSIONS), y:(y+self.MODEL_DIMENSIONS)], rot) + tags.update(self.__predict(tmp)) + + def predict(self, img): + logger = logging.getLogger(__name__) + logger.debug("Predicting raw image ...") + ret = self.__predict(cv2.resize(img.copy(), dsize=(self.MODEL_DIMENSIONS, self.MODEL_DIMENSIONS), interpolation=cv2.INTER_AREA)) + + if self.__detail: + logger.debug("Predicting detail image ...") + tmp = set() + pool = ThreadPool(max(1, os.cpu_count() - 2), 10000) + if img.shape[0] > img.shape[1]: + detail = image_resize(img.copy(), height=(self.__detail_factor * self.MODEL_DIMENSIONS)) + else: + detail = image_resize(img.copy(), width=(self.__detail_factor * self.MODEL_DIMENSIONS)) + for x in range(0, detail.shape[0], int(self.MODEL_DIMENSIONS/2)): + for y in range(0, detail.shape[1], int(self.MODEL_DIMENSIONS/2)): + pool.add_task(self.__predict_partial, ret, detail, x, y, None) + pool.add_task(self.__predict_partial, ret, detail, x, y, cv2.ROTATE_90_CLOCKWISE) + pool.add_task(self.__predict_partial, ret, detail, x, y, cv2.ROTATE_180) + pool.add_task(self.__predict_partial, ret, detail, x, y, cv2.ROTATE_90_COUNTERCLOCKWISE) + pool.wait_completion() + + ret = [tag[0] for tag in sorted(ret, key=lambda tag: tag[1], reverse=True)] + ret = set(list(dict.fromkeys(ret))[0:self.__top]) + return ret + + class BackendTorch(Backend): + + def __init__(self, top=10): + logger = logging.getLogger(__name__) + logger.debug("Initializing Torch backend ...") + import torch + from torchvision.models import resnet50, ResNet50_Weights + self.__weights = ResNet50_Weights.DEFAULT + self.__model = resnet50(weights=self.__weights) + self.__model.eval() + self.__preprocess = self.__weights.transforms() + self.__top = top + + def predict(self, img): + import torch + from PIL import Image + batch = self.__preprocess(Image.fromarray(img)).unsqueeze(0) + prediction = self.__model(batch).squeeze(0).softmax(0) + classes = torch.topk(prediction.flatten(), self.__top).indices + #return set([(weights.meta["categories"][clazz], prediction[clazz].item()) for clazz in classes]) + return set([self.__weights.meta["categories"][clazz] for clazz in classes]) + + def __init__(self, backend): + self.__backend = backend + + def predict(self, img): + return self.__backend.predict(img) + -- cgit v1.2.1