From 6b5a426e1750a57d5b291fa6e13ef5f36c9a44ad Mon Sep 17 00:00:00 2001 From: Leonard Kugis Date: Mon, 10 Apr 2023 00:28:40 +0200 Subject: Implemented video prediction --- file-tagger.py | 91 ++++++++++++++++++++++++++++++++++++++++++---------------- gui.py | 25 ++++++++++------ 2 files changed, 83 insertions(+), 33 deletions(-) diff --git a/file-tagger.py b/file-tagger.py index 67e1102..d7d572b 100644 --- a/file-tagger.py +++ b/file-tagger.py @@ -30,7 +30,7 @@ def walk(tmsu, args): logger.error("Invalid start index. index = {}, number of files = {}".format(args["index"], len(files))) return - if args["predict_images"]: + if args["predict_images"] or args["predict_videos"]: backend = { "torch": Predictor.BackendTorch, "tensorflow": Predictor.BackendTensorflow, @@ -78,33 +78,72 @@ def walk(tmsu, args): time_m.strftime("%Hh")}) # Detect MIME-type for file - mime_type = mime.from_file(file_path) + mime_type = mime.from_file(file_path).split("/") + + tags.update(mime_type) # Handle images - if mime_type.split("/")[0] == "image": + if mime_type[0] == "image": logger.debug("File is image") - img = cv2.imread(file_path) - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - if args["predict_images"]: - logger.info("Predicting image tags ...") - tags_predict = predictor.predict(img) - logger.info("Predicted tags: {}".format(tags_predict)) - tags.update(tags_predict) - if args["gui_tag"]: - while(True): # For GUI inputs (rotate, ...) - logger.debug("Showing image GUI ...") - img_show = image_resize(img, width=args["gui_image_length"]) if img.shape[1] > img.shape[0] else image_resize(img, height=args["gui_image_length"]) - #img_show = cv2.cvtColor(img_show, cv2.COLOR_BGR2RGB) - ret = GuiImage(i, file_path, img_show, tags).loop() - tags = set(ret[1]).difference({''}) - if ret[0] == GuiImage.RETURN_ROTATE_90_CLOCKWISE: - img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) - elif ret[0] == GuiImage.RETURN_ROTATE_90_COUNTERCLOCKWISE: - img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) - elif ret[0] == GuiImage.RETURN_NEXT: + if args["predict_images"] or args["gui_tag"]: + img = cv2.imread(file_path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if args["predict_images"]: + logger.info("Predicting image tags ...") + tags_predict = predictor.predict(img) + logger.info("Predicted tags: {}".format(tags_predict)) + tags.update(tags_predict) + if args["gui_tag"]: + while(True): # For GUI inputs (rotate, ...) + logger.debug("Showing image GUI ...") + img_show = image_resize(img, width=args["gui_image_length"]) if img.shape[1] > img.shape[0] else image_resize(img, height=args["gui_image_length"]) + #img_show = cv2.cvtColor(img_show, cv2.COLOR_BGR2RGB) + ret = GuiImage(i, file_path, img_show, tags).loop() + tags = set(ret[1]).difference({''}) + if ret[0] == GuiImage.RETURN_ROTATE_90_CLOCKWISE: + img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) + elif ret[0] == GuiImage.RETURN_ROTATE_90_COUNTERCLOCKWISE: + img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) + elif ret[0] == GuiImage.RETURN_NEXT: + break + elif ret[0] == GuiImage.RETURN_ABORT: + return + elif mime_type[0] == "video": + logger.debug("File is video") + if args["predict_videos"] or args["gui_tag"]: + cap = cv2.VideoCapture(file_path) + n_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT) + step = n_frames / args["predict_videos_key_frames"] + print(step) + preview = None + for frame in np.arange(0, n_frames, step): + cap.set(cv2.CAP_PROP_POS_FRAMES, max(-1, round(frame - 1))) + _, f = cap.read() + f = cv2.cvtColor(f, cv2.COLOR_BGR2RGB) + if frame == 0: + preview = f + if args["predict_videos"]: + logger.info("Predictig video frame {} of {}".format(frame, n_frames)) + tags_predict = predictor.predict(f) + logger.info("Predicted tags: {}".format(tags_predict)) + tags.update(tags_predict) + else: break - elif ret[0] == GuiImage.RETURN_ABORT: - return + if args["gui_tag"]: + while(True): # For GUI inputs (rotate, ...) + logger.debug("Showing image GUI ...") + img_show = image_resize(preview, width=args["gui_image_length"]) if preview.shape[1] > preview.shape[0] else image_resize(preview, height=args["gui_image_length"]) + #img_show = cv2.cvtColor(img_show, cv2.COLOR_BGR2RGB) + ret = GuiImage(i, file_path, img_show, tags).loop() + tags = set(ret[1]).difference({''}) + if ret[0] == GuiImage.RETURN_ROTATE_90_CLOCKWISE: + preview = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) + elif ret[0] == GuiImage.RETURN_ROTATE_90_COUNTERCLOCKWISE: + preview = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) + elif ret[0] == GuiImage.RETURN_NEXT: + break + elif ret[0] == GuiImage.RETURN_ABORT: + return else: if args["gui_tag"]: while(True): @@ -135,6 +174,8 @@ if __name__ == "__main__": parser.add_argument('--predict-images-top', nargs='?', const=1, default=10, type=int, help='Defines how many top prediction keywords should be used (default: %(default)s)') parser.add_argument('--predict-images-detail-factor', nargs='?', const=1, default=2, type=int, help='Width factor for detail scan, multiplied by 224 for ResNet50 (default: %(default)s)') parser.add_argument('--predict-images-skip-detail', nargs='?', const=1, default=False, type=bool, help='Skip detail scan in image prediction (default: %(default)s)') + parser.add_argument('--predict-videos', nargs='?', const=1, default=False, type=bool, help='Use prediction for video tagging (default: %(default)s)') + parser.add_argument('--predict-videos-key-frames', nargs='?', const=1, default=5, type=int, help='Defines how many key frames are used to predict videos (default: %(default)s)') parser.add_argument('--gui-tag', nargs='?', const=1, default=False, type=bool, help='Show GUI for tagging (default: %(default)s)') parser.add_argument('--gui-image-length', nargs='?', const=1, default=800, type=int, help='Length of longest side for preview (default: %(default)s)') parser.add_argument('--open-system', nargs='?', const=1, default=False, type=bool, help='Open all files with system default (default: %(default)s)') @@ -165,6 +206,8 @@ if __name__ == "__main__": "predict_images_top": args.predict_images_top, "predict_images_detail_factor": args.predict_images_detail_factor, "predict_images_skip_detail": args.predict_images_skip_detail, + "predict_videos": args.predict_videos, + "predict_videos_key_frames": args.predict_videos_key_frames, "gui_tag": args.gui_tag, "gui_image_length": args.gui_image_length, "open_system": args.open_system, diff --git a/gui.py b/gui.py index 8f2ad4c..f99b62a 100644 --- a/gui.py +++ b/gui.py @@ -19,6 +19,8 @@ class GuiMain(object): self.__predict_images_top = StringVar(self.__master, value=str(args["predict_images_top"])) self.__predict_images_skip_detail = BooleanVar(self.__master, value=args["predict_images_skip_detail"]) self.__predict_images_detail_factor = StringVar(self.__master, value=str(args["predict_images_detail_factor"])) + self.__predict_videos = BooleanVar(self.__master, value=args["predict_videos"]) + self.__predict_videos_key_frames = StringVar(self.__master, value=str(args["predict_videos_key_frames"])) self.__gui_tag = BooleanVar(self.__master, value=args["gui_tag"]) self.__gui_image_length = StringVar(self.__master, value=str(args["gui_image_length"])) self.__open_system = BooleanVar(self.__master, value=args["open_system"]) @@ -43,15 +45,18 @@ class GuiMain(object): Checkbutton(self.__master, text="Skip detail scan in image prediction", variable=self.__predict_images_skip_detail).grid(row=6, column=0, columnspan=4, sticky=W) Label(self.__master, text="Width factor for detail scan:").grid(row=7, column=0) Entry(self.__master, textvariable=self.__predict_images_detail_factor, validate='all', validatecommand=(validate_number, '%P')).grid(row=7, column=1, columnspan=1) - Checkbutton(self.__master, text="Show GUI for tagging", variable=self.__gui_tag).grid(row=8, column=0, columnspan=4, sticky=W) - Label(self.__master, text="Image GUI preview size:").grid(row=9, column=0) - Entry(self.__master, textvariable=self.__gui_image_length, validate='all', validatecommand=(validate_number, '%P')).grid(row=9, column=1, columnspan=1) - Checkbutton(self.__master, text="Open all files with system default", variable=self.__open_system).grid(row=10, column=0, columnspan=4, sticky=W) - Checkbutton(self.__master, text="Skip prompt for file tags", variable=self.__skip_prompt).grid(row=11, column=0, columnspan=4, sticky=W) - Checkbutton(self.__master, text="Skip already tagged files", variable=self.__skip_tagged).grid(row=12, column=0, columnspan=4, sticky=W) - Label(self.__master, text="Start at index:").grid(row=13, column=0) - Entry(self.__master, textvariable=self.__index, validate='all', validatecommand=(validate_number, '%P')).grid(row=13, column=1, columnspan=1) - Button(self.__master, text="Start", command=self.__master.destroy).grid(row=14, column=0, columnspan=4) + Checkbutton(self.__master, text="Use prediction for video tagging", variable=self.__predict_videos).grid(row=8, column=0, columnspan=4, sticky=W) + Label(self.__master, text="Number of key frames:").grid(row=9, column=0) + Entry(self.__master, textvariable=self.__predict_videos_key_frames, validate='all', validatecommand=(validate_number, '%P')).grid(row=9, column=1, columnspan=1) + Checkbutton(self.__master, text="Show GUI for tagging", variable=self.__gui_tag).grid(row=10, column=0, columnspan=4, sticky=W) + Label(self.__master, text="Image GUI preview size:").grid(row=11, column=0) + Entry(self.__master, textvariable=self.__gui_image_length, validate='all', validatecommand=(validate_number, '%P')).grid(row=11, column=1, columnspan=1) + Checkbutton(self.__master, text="Open all files with system default", variable=self.__open_system).grid(row=12, column=0, columnspan=4, sticky=W) + Checkbutton(self.__master, text="Skip prompt for file tags", variable=self.__skip_prompt).grid(row=13, column=0, columnspan=4, sticky=W) + Checkbutton(self.__master, text="Skip already tagged files", variable=self.__skip_tagged).grid(row=14, column=0, columnspan=4, sticky=W) + Label(self.__master, text="Start at index:").grid(row=15, column=0) + Entry(self.__master, textvariable=self.__index, validate='all', validatecommand=(validate_number, '%P')).grid(row=15, column=1, columnspan=1) + Button(self.__master, text="Start", command=self.__master.destroy).grid(row=16, column=0, columnspan=4) def loop(self): self.__master.mainloop() @@ -64,6 +69,8 @@ class GuiMain(object): self.__args["predict_images_top"] = int(self.__predict_images_top.get()) self.__args["predict_images_skip_detail"] = self.__predict_images_skip_detail.get() self.__args["predict_images_detail_factor"] = int(self.__predict_images_detail_factor.get()) + self.__args["predict_videos"] = self.__predict_videos.get() + self.__args["predict_videos_key_frames"] = int(self.__predict_videos_key_frames.get()) self.__args["gui_tag"] = self.__gui_tag.get() self.__args["gui_image_length"] = int(self.__gui_image_length.get()) self.__args["open_system"] = self.__open_system.get() -- cgit v1.2.1