aboutsummaryrefslogtreecommitdiff
path: root/modules/deepbooru.py
diff options
context:
space:
mode:
authorGreendayle <Greendayle>2022-10-05 22:05:24 +0200
committerGreendayle <Greendayle>2022-10-05 22:07:28 +0200
commit17a99baf0c929e5df4dfc4b2a96aa3890a141112 (patch)
treefdebfa0bacecc24904fbd5dfb6f29dd4cf6764d2 /modules/deepbooru.py
parent1506fab29ad54beb9f52236912abc432209c8089 (diff)
better model search
Diffstat (limited to 'modules/deepbooru.py')
-rw-r--r--modules/deepbooru.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
index 841cb9c5..a64fd9cd 100644
--- a/modules/deepbooru.py
+++ b/modules/deepbooru.py
@@ -9,8 +9,15 @@ import tensorflow as tf
def _load_tf_and_return_tags(pil_image, threshold):
this_folder = os.path.dirname(__file__)
model_path = os.path.join(this_folder, '..', 'models', 'deepbooru', 'deepdanbooru-v3-20211112-sgd-e28')
- if not os.path.exists(model_path):
- return "Download https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip unpack and put into models/deepbooru"
+
+ model_good = False
+ for path_candidate in [model_path, os.path.dirname(model_path)]:
+ if os.path.exists(os.path.join(path_candidate, 'project.json')):
+ model_path = path_candidate
+ model_good = True
+ if not model_good:
+ return ("Download https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/"
+ "deepdanbooru-v3-20211112-sgd-e28.zip unpack and put into models/deepbooru")
tags = dd.project.load_tags_from_project(model_path)
model = dd.project.load_model_from_project(