aboutsummaryrefslogtreecommitdiff
path: root/modules/deepbooru.py
diff options
context:
space:
mode:
authorGreendayle <Greendayle>2022-10-08 18:02:56 +0200
committerGreendayle <Greendayle>2022-10-08 18:02:56 +0200
commit01f8cb44474e454903c11718e6a4f33dbde34bb8 (patch)
tree3f0b30e2f356733b5d610d1fb4c4913c305c3af4 /modules/deepbooru.py
parent5329d0aba0296f2fde4b5e6256dd27d46028a429 (diff)
made deepdanbooru optional, added to readme, automatic download of deepbooru model
Diffstat (limited to 'modules/deepbooru.py')
-rw-r--r--modules/deepbooru.py20
1 files changed, 10 insertions, 10 deletions
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
index 781b2249..7e3c0618 100644
--- a/modules/deepbooru.py
+++ b/modules/deepbooru.py
@@ -9,16 +9,16 @@ def _load_tf_and_return_tags(pil_image, threshold):
import numpy as np
this_folder = os.path.dirname(__file__)
- model_path = os.path.join(this_folder, '..', 'models', 'deepbooru', 'deepdanbooru-v3-20211112-sgd-e28')
-
- model_good = False
- for path_candidate in [model_path, os.path.dirname(model_path)]:
- if os.path.exists(os.path.join(path_candidate, 'project.json')):
- model_path = path_candidate
- model_good = True
- if not model_good:
- return ("Download https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/"
- "deepdanbooru-v3-20211112-sgd-e28.zip unpack and put into models/deepbooru")
+ model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru'))
+ if not os.path.exists(os.path.join(model_path, 'project.json')):
+ # there is no point importing these every time
+ import zipfile
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
+ model_path)
+ with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref:
+ zip_ref.extractall(model_path)
+ os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"))
tags = dd.project.load_tags_from_project(model_path)
model = dd.project.load_model_from_project(