aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorLee Bousfield <ljbousfield@gmail.com>2023-01-10 18:40:34 -0700
committerLee Bousfield <ljbousfield@gmail.com>2023-01-10 18:40:34 -0700
commitf9706acf431f77e0ce9e4270e5be7299922ee963 (patch)
tree22ea75014ae75fdfe327d0d136db50ad064f8076 /modules
parent9cfd10cdefc7b2966b8e42fbb0e05735967cf87b (diff)
Support loading textual inversion embeddings from safetensors files
Diffstat (limited to 'modules')
-rw-r--r--modules/textual_inversion/textual_inversion.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 5420903f..3866c154 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -9,6 +9,7 @@ import tqdm
import html
import datetime
import csv
+import safetensors.torch
from PIL import Image, PngImagePlugin
@@ -150,6 +151,8 @@ class EmbeddingDatabase:
name = data.get('name', name)
elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu")
+ elif ext in ['.SAFETENSORS']:
+ data = safetensors.torch.load_file(path, device="cpu")
else:
return