aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLeonard Kugis <leonard@kug.is>2023-03-22 04:22:29 +0100
committerLeonard Kugis <leonard@kug.is>2023-03-22 04:22:29 +0100
commita747254ece9da9af1517dd7dfdb4d16949b0eabb (patch)
tree9999a22523c2dc9a4473efb200b823943e3ca155
parent7c635fed24bba638f10a73fe39134519ecd675bf (diff)
Implemented detail scanner with multithreading
-rw-r--r--file-tagger.py60
-rw-r--r--gui.py2
-rw-r--r--util.py61
3 files changed, 109 insertions, 14 deletions
diff --git a/file-tagger.py b/file-tagger.py
index 82aacf7..92eb023 100644
--- a/file-tagger.py
+++ b/file-tagger.py
@@ -8,6 +8,23 @@ import magic
from tmsu import *
from util import *
+MODEL_DIMENSIONS = 224
+
+def predict_image(model, img, top):
+ from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
+ logger = logging.getLogger(__name__)
+ array = np.expand_dims(img, axis=0)
+ array = preprocess_input(array)
+ predictions = model.predict(array)
+ classes = decode_predictions(predictions, top=top)
+ logger.debug("Predicted image classes: {}".format(classes[0]))
+ return set([(name, prob) for _, name, prob in classes[0]])
+
+def predict_partial(tags, model, img, x, y, top):
+ #cv2.imshow("test", img[x:(x+MODEL_DIMENSIONS), y:(y+MODEL_DIMENSIONS)])
+ #cv2.waitKey(0)
+ tags.update(predict_image(model, img[x:(x+MODEL_DIMENSIONS), y:(y+MODEL_DIMENSIONS)], top))
+
'''
Walk over all files for the given base directory and all subdirectories recursively.
@@ -50,24 +67,37 @@ def walk(args):
if mime_type.split("/")[0] == "image":
logger.debug("File is image")
img = cv2.imread(file_path)
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
- img = cv2.resize(img, dsize=(800, 800), interpolation=cv2.INTER_CUBIC)
+ #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if args["predict_images"]:
logger.info("Predicting image tags ...")
- array_pre = cv2.resize(img, dsize=(224, 224), interpolation=cv2.INTER_CUBIC)
+ tags_predict = set()
for _ in range(4):
- array = np.expand_dims(array_pre, axis=0)
- array = preprocess_input(array)
- predictions = model.predict(array)
- classes = decode_predictions(predictions, top=args["predict_images_top"])
- logger.debug("Predicted image classes: {}".format(classes[0]))
- tags.update([name for _, name, _ in classes[0]])
- array_pre = cv2.rotate(array_pre, cv2.ROTATE_90_CLOCKWISE)
- logger.info("Predicted tags: {}".format(tags))
+ logger.debug("Raw scan")
+ raw = cv2.resize(img.copy(), dsize=(MODEL_DIMENSIONS, MODEL_DIMENSIONS), interpolation=cv2.INTER_CUBIC)
+ tags_predict.update(predict_image(model, raw, args["predict_images_top"]))
+ img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
+ if not args["predict_images_skip_detail"]:
+ pool = ThreadPool(max(1, os.cpu_count() - 2), 1000)
+ for _ in range(4):
+ if img.shape[0] > img.shape[1]:
+ detail = image_resize(img.copy(), height=(args["predict_images_detail_factor"] * MODEL_DIMENSIONS))
+ else:
+ detail = image_resize(img.copy(), width=(args["predict_images_detail_factor"] * MODEL_DIMENSIONS))
+ for x in range(0, detail.shape[0], int(MODEL_DIMENSIONS/2)):
+ for y in range(0, detail.shape[1], int(MODEL_DIMENSIONS/2)):
+ pool.add_task(predict_partial, tags_predict, model, detail.copy(), x, y, args["predict_images_top"])
+ img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
+ pool.wait_completion()
+ tags_sorted = [tag[0] for tag in sorted(tags_predict, key=lambda tag: tag[1], reverse=True)]
+ tags_predict = set(list(dict.fromkeys(tags_sorted))[0:args["predict_images_top"]])
+ logger.info("Predicted tags: {}".format(tags_predict))
+ tags.update(tags_predict)
if args["gui_tag"]:
while(True): # For GUI inputs (rotate, ...)
logger.debug("Showing image GUI ...")
- ret = GuiImage(i, file_path, img, tags).loop()
+ img_show = image_resize(img, width=args["gui_image_length"]) if img.shape[1] > img.shape[0] else image_resize(img, height=args["gui_image_length"])
+ img_show = cv2.cvtColor(img_show, cv2.COLOR_BGR2RGB)
+ ret = GuiImage(i, file_path, img_show, tags).loop()
tags = set(ret[1]).difference({''})
if ret[0] == GuiImage.RETURN_ROTATE_90_CLOCKWISE:
img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
@@ -102,7 +132,10 @@ if __name__ == "__main__":
parser.add_argument('-g', '--gui', nargs='?', const=1, default=False, type=bool, help='Show main GUI (default: %(default)s)')
parser.add_argument('--predict-images', nargs='?', const=1, default=False, type=bool, help='Use prediction for image tagging (default: %(default)s)')
parser.add_argument('--predict-images-top', nargs='?', const=1, default=10, type=int, help='Defines how many top prediction keywords should be used (default: %(default)s)')
+ parser.add_argument('--predict-images-detail-factor', nargs='?', const=1, default=2, type=int, help='Width factor for detail scan, multiplied by 224 for ResNet50 (default: %(default)s)')
+ parser.add_argument('--predict-images-skip-detail', nargs='?', const=1, default=False, type=bool, help='Skip detail scan in image prediction (default: %(default)s)')
parser.add_argument('--gui-tag', nargs='?', const=1, default=False, type=bool, help='Show GUI for tagging (default: %(default)s)')
+ parser.add_argument('--gui-image-length', nargs='?', const=1, default=800, type=int, help='Length of longest side for preview (default: %(default)s)')
parser.add_argument('--open-system', nargs='?', const=1, default=False, type=bool, help='Open all files with system default (default: %(default)s)')
parser.add_argument('-s', '--skip-prompt', nargs='?', const=1, default=False, type=bool, help='Skip prompt for file tags (default: %(default)s)')
parser.add_argument('-i', '--index', nargs='?', const=1, default=0, type=int, help='Start tagging at the given file index (default: %(default)s)')
@@ -125,7 +158,10 @@ if __name__ == "__main__":
"gui": args.gui,
"predict_images": args.predict_images,
"predict_images_top": args.predict_images_top,
+ "predict_images_detail_factor": args.predict_images_detail_factor,
+ "predict_images_skip_detail": args.predict_images_skip_detail,
"gui_tag": args.gui_tag,
+ "gui_image_length": args.gui_image_length,
"open_system": args.open_system,
"skip_prompt": args.skip_prompt,
"index": args.index,
diff --git a/gui.py b/gui.py
index 8ab062d..78f3926 100644
--- a/gui.py
+++ b/gui.py
@@ -102,7 +102,7 @@ class GuiImage(object):
self.__image = ImageTk.PhotoImage(image=self.__image_pil)
Label(self.__master, text="Index: {}".format(index)).grid(row=0, column=0, columnspan=4)
Label(self.__master, text="File: {}".format(file)).grid(row=1, column=0, columnspan=4)
- self.__label = Label(self.__master, width=800, height=800, image=self.__image)
+ self.__label = Label(self.__master, width=img.shape[1], height=img.shape[0], image=self.__image)
self.__label.grid(row=2, column=0, columnspan=4)
Entry(self.__master, textvariable=self.__tags).grid(row=3, column=0, columnspan=4, sticky="we")
Button(self.__master, text="↺", command=self.__handle_rotate_90_counterclockwise).grid(row=4, column=0)
diff --git a/util.py b/util.py
index 2bf4ff8..9fca80c 100644
--- a/util.py
+++ b/util.py
@@ -4,6 +4,9 @@ import cv2
import platform
import readline
import os
+import numpy as np
+from queue import Queue
+from threading import Thread, Lock
def image_resize(image, width = None, height = None, inter = cv2.INTER_AREA):
# initialize the dimensions of the image to be resized and
@@ -36,6 +39,11 @@ def image_resize(image, width = None, height = None, inter = cv2.INTER_AREA):
# return the resized image
return resized
+def image_embed(img, dimensions):
+ ret = np.zeros((dimensions[0], dimensions[1], 3), np.uint8)
+ ret[0:img.shape[0], 0:img.shape[1]] = img
+ return ret
+
'''
Fetch input prompt with prefilled text.
@@ -76,4 +84,55 @@ def open_system(file):
elif platform.system() == 'Windows': # Windows
os.startfile(file)
else: # linux variants
- subprocess.call(('xdg-open', file)) \ No newline at end of file
+ subprocess.call(('xdg-open', file))
+
+class Worker(Thread):
+ def __init__(self, tasks):
+ Thread.__init__(self)
+ self.tasks = tasks
+ self.daemon = True
+ self.lock = Lock()
+ self.start()
+
+ def run(self):
+ while True:
+ func, args, kargs = self.tasks.get()
+ try:
+ if func.lower() == "terminate":
+ break
+ except:
+ try:
+ with self.lock:
+ func(*args, **kargs)
+ except Exception as exception:
+ print(exception)
+ self.tasks.task_done()
+
+class ThreadPool:
+ def __init__(self, num_threads, num_queue=None):
+ if num_queue is None or num_queue < num_threads:
+ num_queue = num_threads
+ self.tasks = Queue(num_queue)
+ self.threads = num_threads
+ for _ in range(num_threads): Worker(self.tasks)
+
+ # This function can be called to terminate all the worker threads of the queue
+ def terminate(self):
+ self.wait_completion()
+ for _ in range(self.threads): self.add_task("terminate")
+ return None
+
+ # This function can be called to add new work to the queue
+ def add_task(self, func, *args, **kargs):
+ self.tasks.put((func, args, kargs))
+
+ # This function can be called to wait till all the workers are done processing the pending works. If this function is called, the main will not process any new lines unless all the workers are done with the pending works.
+ def wait_completion(self):
+ self.tasks.join()
+
+ # This function can be called to check if there are any pending/running works in the queue. If there are any works pending, the call will return Boolean True or else it will return Boolean False
+ def is_alive(self):
+ if self.tasks.unfinished_tasks == 0:
+ return False
+ else:
+ return True \ No newline at end of file