aboutsummaryrefslogtreecommitdiff
path: root/modules/deepbooru.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/deepbooru.py')
-rw-r--r--modules/deepbooru.py60
1 files changed, 60 insertions, 0 deletions
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
new file mode 100644
index 00000000..958b1c3d
--- /dev/null
+++ b/modules/deepbooru.py
@@ -0,0 +1,60 @@
+import os.path
+from concurrent.futures import ProcessPoolExecutor
+
+import numpy as np
+import deepdanbooru as dd
+import tensorflow as tf
+
+
+def _load_tf_and_return_tags(pil_image, threshold):
+ this_folder = os.path.dirname(__file__)
+ model_path = os.path.join(this_folder, '..', 'models', 'deepbooru', 'deepdanbooru-v3-20211112-sgd-e28')
+ if not os.path.exists(model_path):
+ return "Download https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip unpack and put into models/deepbooru"
+
+ tags = dd.project.load_tags_from_project(model_path)
+ model = dd.project.load_model_from_project(
+ model_path, compile_model=True
+ )
+
+ width = model.input_shape[2]
+ height = model.input_shape[1]
+ image = np.array(pil_image)
+ image = tf.image.resize(
+ image,
+ size=(height, width),
+ method=tf.image.ResizeMethod.AREA,
+ preserve_aspect_ratio=True,
+ )
+ image = image.numpy() # EagerTensor to np.array
+ image = dd.image.transform_and_pad_image(image, width, height)
+ image = image / 255.0
+ image_shape = image.shape
+ image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2]))
+
+ y = model.predict(image)[0]
+
+ result_dict = {}
+
+ for i, tag in enumerate(tags):
+ result_dict[tag] = y[i]
+
+
+
+ result_tags_out = []
+ result_tags_print = []
+ for tag in tags:
+ if result_dict[tag] >= threshold:
+ result_tags_out.append(tag)
+ result_tags_print.append(f'{result_dict[tag]} {tag}')
+
+ print('\n'.join(sorted(result_tags_print, reverse=True)))
+
+ return ', '.join(result_tags_out)
+
+
+def get_deepbooru_tags(pil_image, threshold=0.5):
+ with ProcessPoolExecutor() as executor:
+ f = executor.submit(_load_tf_and_return_tags, pil_image, threshold)
+ ret = f.result() # will rethrow any exceptions
+ return ret \ No newline at end of file