aboutsummaryrefslogtreecommitdiff
path: root/modules/interrogate.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/interrogate.py')
-rw-r--r--modules/interrogate.py55
1 files changed, 41 insertions, 14 deletions
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 738d8ff7..19938cbb 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -5,12 +5,13 @@ from collections import namedtuple
import re
import torch
+import torch.hub
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import modules.shared as shared
-from modules import devices, paths, lowvram, modelloader
+from modules import devices, paths, lowvram, modelloader, errors
blip_image_eval_size = 384
clip_model_name = 'ViT-L/14'
@@ -20,27 +21,59 @@ Category = namedtuple("Category", ["name", "topn", "items"])
re_topn = re.compile(r"\.top(\d+)\.")
+def download_default_clip_interrogate_categories(content_dir):
+ print("Downloading CLIP categories...")
+
+ tmpdir = content_dir + "_tmp"
+ try:
+ os.makedirs(tmpdir)
+
+ torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/artists.txt", os.path.join(tmpdir, "artists.txt"))
+ torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/flavors.txt", os.path.join(tmpdir, "flavors.top3.txt"))
+ torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/mediums.txt", os.path.join(tmpdir, "mediums.txt"))
+ torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/movements.txt", os.path.join(tmpdir, "movements.txt"))
+
+ os.rename(tmpdir, content_dir)
+
+ except Exception as e:
+ errors.display(e, "downloading default CLIP interrogate categories")
+ finally:
+ if os.path.exists(tmpdir):
+ os.remove(tmpdir)
+
+
class InterrogateModels:
blip_model = None
clip_model = None
clip_preprocess = None
- categories = None
dtype = None
running_on_cpu = None
def __init__(self, content_dir):
- self.categories = []
+ self.loaded_categories = None
+ self.content_dir = content_dir
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
- if os.path.exists(content_dir):
- for filename in os.listdir(content_dir):
+ def categories(self):
+ if self.loaded_categories is not None:
+ return self.loaded_categories
+
+ self.loaded_categories = []
+
+ if not os.path.exists(self.content_dir):
+ download_default_clip_interrogate_categories(self.content_dir)
+
+ if os.path.exists(self.content_dir):
+ for filename in os.listdir(self.content_dir):
m = re_topn.search(filename)
topn = 1 if m is None else int(m.group(1))
- with open(os.path.join(content_dir, filename), "r", encoding="utf8") as file:
+ with open(os.path.join(self.content_dir, filename), "r", encoding="utf8") as file:
lines = [x.strip() for x in file.readlines()]
- self.categories.append(Category(name=filename, topn=topn, items=lines))
+ self.loaded_categories.append(Category(name=filename, topn=topn, items=lines))
+
+ return self.loaded_categories
def load_blip_model(self):
import models.blip
@@ -139,7 +172,6 @@ class InterrogateModels:
shared.state.begin()
shared.state.job = 'interrogate'
try:
-
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
devices.torch_gc()
@@ -159,12 +191,7 @@ class InterrogateModels:
image_features /= image_features.norm(dim=-1, keepdim=True)
- if shared.opts.interrogate_use_builtin_artists:
- artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0]
-
- res += ", " + artist[0]
-
- for name, topn, items in self.categories:
+ for name, topn, items in self.categories():
matches = self.rank(image_features, items, top_count=topn)
for match, score in matches:
if shared.opts.interrogate_return_ranks: