aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/deepbooru.py73
-rw-r--r--modules/ui.py24
2 files changed, 92 insertions, 5 deletions
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
new file mode 100644
index 00000000..781b2249
--- /dev/null
+++ b/modules/deepbooru.py
@@ -0,0 +1,73 @@
+import os.path
+from concurrent.futures import ProcessPoolExecutor
+from multiprocessing import get_context
+
+
+def _load_tf_and_return_tags(pil_image, threshold):
+ import deepdanbooru as dd
+ import tensorflow as tf
+ import numpy as np
+
+ this_folder = os.path.dirname(__file__)
+ model_path = os.path.join(this_folder, '..', 'models', 'deepbooru', 'deepdanbooru-v3-20211112-sgd-e28')
+
+ model_good = False
+ for path_candidate in [model_path, os.path.dirname(model_path)]:
+ if os.path.exists(os.path.join(path_candidate, 'project.json')):
+ model_path = path_candidate
+ model_good = True
+ if not model_good:
+ return ("Download https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/"
+ "deepdanbooru-v3-20211112-sgd-e28.zip unpack and put into models/deepbooru")
+
+ tags = dd.project.load_tags_from_project(model_path)
+ model = dd.project.load_model_from_project(
+ model_path, compile_model=True
+ )
+
+ width = model.input_shape[2]
+ height = model.input_shape[1]
+ image = np.array(pil_image)
+ image = tf.image.resize(
+ image,
+ size=(height, width),
+ method=tf.image.ResizeMethod.AREA,
+ preserve_aspect_ratio=True,
+ )
+ image = image.numpy() # EagerTensor to np.array
+ image = dd.image.transform_and_pad_image(image, width, height)
+ image = image / 255.0
+ image_shape = image.shape
+ image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2]))
+
+ y = model.predict(image)[0]
+
+ result_dict = {}
+
+ for i, tag in enumerate(tags):
+ result_dict[tag] = y[i]
+ result_tags_out = []
+ result_tags_print = []
+ for tag in tags:
+ if result_dict[tag] >= threshold:
+ if tag.startswith("rating:"):
+ continue
+ result_tags_out.append(tag)
+ result_tags_print.append(f'{result_dict[tag]} {tag}')
+
+ print('\n'.join(sorted(result_tags_print, reverse=True)))
+
+ return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ')
+
+
+def subprocess_init_no_cuda():
+ import os
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
+
+
+def get_deepbooru_tags(pil_image, threshold=0.5):
+ context = get_context('spawn')
+ with ProcessPoolExecutor(initializer=subprocess_init_no_cuda, mp_context=context) as executor:
+ f = executor.submit(_load_tf_and_return_tags, pil_image, threshold, )
+ ret = f.result() # will rethrow any exceptions
+ return ret \ No newline at end of file
diff --git a/modules/ui.py b/modules/ui.py
index ffd75f6a..30583fe9 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -23,6 +23,7 @@ import gradio.utils
import gradio.routes
from modules import sd_hijack
+from modules.deepbooru import get_deepbooru_tags
from modules.paths import script_path
from modules.shared import opts, cmd_opts
import modules.shared as shared
@@ -292,6 +293,11 @@ def interrogate(image):
return gr_show(True) if prompt is None else prompt
+def interrogate_deepbooru(image):
+ prompt = get_deepbooru_tags(image)
+ return gr_show(True) if prompt is None else prompt
+
+
def create_seed_inputs():
with gr.Row():
with gr.Box():
@@ -428,15 +434,17 @@ def create_toprow(is_img2img):
outputs=[],
)
- with gr.Row():
+ with gr.Row(scale=1):
if is_img2img:
- interrogate = gr.Button('Interrogate', elem_id="interrogate")
+ interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
+ deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
else:
interrogate = None
+ deepbooru = None
prompt_style_apply = gr.Button('Apply style', elem_id="style_apply")
save_style = gr.Button('Create style', elem_id="style_create")
- return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste, token_counter, token_button
+ return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
@@ -465,7 +473,7 @@ def create_ui(wrap_gradio_gpu_call):
import modules.txt2img
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False)
+ txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
with gr.Row(elem_id='txt2img_progress_row'):
@@ -617,7 +625,7 @@ def create_ui(wrap_gradio_gpu_call):
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
with gr.Blocks(analytics_enabled=False) as img2img_interface:
- img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True)
+ img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True)
with gr.Row(elem_id='img2img_progress_row'):
with gr.Column(scale=1):
@@ -774,6 +782,12 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[img2img_prompt],
)
+ img2img_deepbooru.click(
+ fn=interrogate_deepbooru,
+ inputs=[init_img],
+ outputs=[img2img_prompt],
+ )
+
save.click(
fn=wrap_gradio_call(save_files),
_js="(x, y, z) => [x, y, selected_gallery_index()]",