aboutsummaryrefslogtreecommitdiff
path: root/modules/interrogate.py
diff options
context:
space:
mode:
authorZac Liu <liuguang@baai.ac.cn>2022-12-06 09:16:15 +0800
committerGitHub <noreply@github.com>2022-12-06 09:16:15 +0800
commit3ebf977a6e4f478ab918e44506974beee32da276 (patch)
treef68456207e5cd78718ec1e9c588ecdc22d568d81 /modules/interrogate.py
parent231fb72872191ffa8c446af1577c9003b3d19d4f (diff)
parent44c46f0ed395967cd3830dd481a2db759fda5b3b (diff)
Merge branch 'AUTOMATIC1111:master' into master
Diffstat (limited to 'modules/interrogate.py')
-rw-r--r--modules/interrogate.py16
1 files changed, 10 insertions, 6 deletions
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 9769aa34..0068b81c 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -1,4 +1,3 @@
-import contextlib
import os
import sys
import traceback
@@ -11,10 +10,9 @@ from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import modules.shared as shared
-from modules import devices, paths, lowvram
+from modules import devices, paths, lowvram, modelloader
blip_image_eval_size = 384
-blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
clip_model_name = 'ViT-L/14'
Category = namedtuple("Category", ["name", "topn", "items"])
@@ -47,7 +45,14 @@ class InterrogateModels:
def load_blip_model(self):
import models.blip
- blip_model = models.blip.blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
+ files = modelloader.load_models(
+ model_path=os.path.join(paths.models_path, "BLIP"),
+ model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
+ ext_filter=[".pth"],
+ download_name='model_base_caption_capfilt_large.pth',
+ )
+
+ blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
blip_model.eval()
return blip_model
@@ -148,8 +153,7 @@ class InterrogateModels:
clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
- precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
- with torch.no_grad(), precision_scope("cuda"):
+ with torch.no_grad(), devices.autocast():
image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
image_features /= image_features.norm(dim=-1, keepdim=True)