aboutsummaryrefslogtreecommitdiff
path: root/modules/interrogate.py
diff options
context:
space:
mode:
authorVladimir Mandic <mandic00@live.com>2023-01-23 12:29:23 -0500
committerGitHub <noreply@github.com>2023-01-23 12:29:23 -0500
commit04a561c11c9bf9a00d7f9b50ca3f7962aa59ba6e (patch)
treee43ba382c7237c0a034b6e052a0d261a4e46b1ed /modules/interrogate.py
parentefa7287be0a018dcb92e362460cbe19d42d70b03 (diff)
add option to skip interrogate categories
Diffstat (limited to 'modules/interrogate.py')
-rw-r--r--modules/interrogate.py32
1 files changed, 18 insertions, 14 deletions
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 1d1ac572..c252b148 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -2,6 +2,7 @@ import os
import sys
import traceback
from collections import namedtuple
+from pathlib import Path
import re
import torch
@@ -20,12 +21,16 @@ Category = namedtuple("Category", ["name", "topn", "items"])
re_topn = re.compile(r"\.top(\d+)\.")
-category_types = ["artists", "flavors", "mediums", "movements"]
+def category_types():
+ return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
+
def download_default_clip_interrogate_categories(content_dir):
print("Downloading CLIP categories...")
tmpdir = content_dir + "_tmp"
+ category_types = ["artists", "flavors", "mediums", "movements"]
+
try:
os.makedirs(tmpdir)
for category_type in category_types:
@@ -48,33 +53,32 @@ class InterrogateModels:
def __init__(self, content_dir):
self.loaded_categories = None
- self.selected_categories = []
+ self.skip_categories = []
self.content_dir = content_dir
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
def categories(self):
- if self.loaded_categories is not None and self.selected_categories == shared.opts.interrogate_clip_categories:
+ if not os.path.exists(self.content_dir):
+ download_default_clip_interrogate_categories(self.content_dir)
+
+ if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories:
return self.loaded_categories
self.loaded_categories = []
- if not os.path.exists(self.content_dir):
- download_default_clip_interrogate_categories(self.content_dir)
-
if os.path.exists(self.content_dir):
- self.selected_categories = shared.opts.interrogate_clip_categories
- for category_type in category_types:
- if 'all' not in self.selected_categories and category_type not in self.selected_categories:
- continue
- filename = os.path.join(self.content_dir, f"{category_type}.txt")
- if not os.path.isfile(filename):
+ self.skip_categories = shared.opts.interrogate_clip_skip_categories
+ category_types = []
+ for filename in Path(self.content_dir).glob('*.txt'):
+ category_types.append(filename.stem)
+ if filename.stem in self.skip_categories:
continue
- m = re_topn.search(filename)
+ m = re_topn.search(filename.stem)
topn = 1 if m is None else int(m.group(1))
with open(filename, "r", encoding="utf8") as file:
lines = [x.strip() for x in file.readlines()]
- self.loaded_categories.append(Category(name=category_type, topn=topn, items=lines))
+ self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines))
return self.loaded_categories