aboutsummaryrefslogtreecommitdiff
path: root/modules/interrogate.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2022-12-03 18:17:56 +0300
committerGitHub <noreply@github.com>2022-12-03 18:17:56 +0300
commit2a649154ec994063b27a6723afa40e52be219771 (patch)
treee42bd12d4b46c86a4da04394342fba9594181b06 /modules/interrogate.py
parent0d21624ceef52b843c731ddc7fdcd7b8d108a42e (diff)
parenta2ae5a655518b150a34b95d7afecc87a43280406 (diff)
Merge pull request #4956 from TiagoSantos81/offline_BLIP
[CLIP interrogator] use local file, if available
Diffstat (limited to 'modules/interrogate.py')
-rw-r--r--modules/interrogate.py13
1 files changed, 12 insertions, 1 deletions
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 40c6b082..3a09b366 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -14,6 +14,8 @@ import modules.shared as shared
from modules import devices, paths, lowvram
blip_image_eval_size = 384
+blip_local_dir = os.path.join('models', 'Interrogator')
+blip_local_file = os.path.join(blip_local_dir, 'model_base_caption_capfilt_large.pth')
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
clip_model_name = 'ViT-L/14'
@@ -47,7 +49,16 @@ class InterrogateModels:
def load_blip_model(self):
import models.blip
- blip_model = models.blip.blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
+ if not os.path.isfile(blip_local_file):
+ if not os.path.isdir(blip_local_dir):
+ os.mkdir(blip_local_dir)
+
+ print("Downloading BLIP...")
+ from requests import get as reqget
+ open(blip_local_file, 'wb').write(reqget(blip_model_url, allow_redirects=True).content)
+ print("BLIP downloaded to", blip_local_file + '.')
+
+ blip_model = models.blip.blip_decoder(pretrained=blip_local_file, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
blip_model.eval()
return blip_model