aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/deepbooru.py14
-rw-r--r--modules/interrogate.py7
-rw-r--r--modules/shared.py1
-rw-r--r--modules/ui.py5
4 files changed, 17 insertions, 10 deletions
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
index 7e3c0618..32d741e2 100644
--- a/modules/deepbooru.py
+++ b/modules/deepbooru.py
@@ -3,7 +3,7 @@ from concurrent.futures import ProcessPoolExecutor
from multiprocessing import get_context
-def _load_tf_and_return_tags(pil_image, threshold):
+def _load_tf_and_return_tags(pil_image, threshold, include_ranks):
import deepdanbooru as dd
import tensorflow as tf
import numpy as np
@@ -52,12 +52,16 @@ def _load_tf_and_return_tags(pil_image, threshold):
if result_dict[tag] >= threshold:
if tag.startswith("rating:"):
continue
- result_tags_out.append(tag)
+ tag_formatted = tag.replace('_', ' ').replace(':', ' ')
+ if include_ranks:
+ result_tags_out.append(f'({tag_formatted}:{result_dict[tag]})')
+ else:
+ result_tags_out.append(tag_formatted)
result_tags_print.append(f'{result_dict[tag]} {tag}')
print('\n'.join(sorted(result_tags_print, reverse=True)))
- return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ')
+ return ', '.join(result_tags_out)
def subprocess_init_no_cuda():
@@ -65,9 +69,9 @@ def subprocess_init_no_cuda():
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
-def get_deepbooru_tags(pil_image, threshold=0.5):
+def get_deepbooru_tags(pil_image, threshold=0.5, include_ranks=False):
context = get_context('spawn')
with ProcessPoolExecutor(initializer=subprocess_init_no_cuda, mp_context=context) as executor:
- f = executor.submit(_load_tf_and_return_tags, pil_image, threshold, )
+ f = executor.submit(_load_tf_and_return_tags, pil_image, threshold, include_ranks)
ret = f.result() # will rethrow any exceptions
return ret \ No newline at end of file
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 635e266e..af858cc0 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -123,7 +123,7 @@ class InterrogateModels:
return caption[0]
- def interrogate(self, pil_image):
+ def interrogate(self, pil_image, include_ranks=False):
res = None
try:
@@ -156,7 +156,10 @@ class InterrogateModels:
for name, topn, items in self.categories:
matches = self.rank(image_features, items, top_count=topn)
for match, score in matches:
- res += ", " + match
+ if include_ranks:
+ res += ", " + match
+ else:
+ res += f", ({match}:{score})"
except Exception:
print(f"Error interrogating", file=sys.stderr)
diff --git a/modules/shared.py b/modules/shared.py
index c1092ff7..3e0bfd72 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -251,6 +251,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
+ "interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."),
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
diff --git a/modules/ui.py b/modules/ui.py
index 1204eef7..f4dbe247 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -311,13 +311,12 @@ def apply_styles(prompt, prompt_neg, style1_name, style2_name):
def interrogate(image):
- prompt = shared.interrogator.interrogate(image)
-
+ prompt = shared.interrogator.interrogate(image, include_ranks=opts.interrogate_return_ranks)
return gr_show(True) if prompt is None else prompt
def interrogate_deepbooru(image):
- prompt = get_deepbooru_tags(image, opts.interrogate_deepbooru_score_threshold)
+ prompt = get_deepbooru_tags(image, opts.interrogate_deepbooru_score_threshold, opts.interrogate_return_ranks)
return gr_show(True) if prompt is None else prompt