aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-09-08 15:36:50 +0300
committerAUTOMATIC <16777216c@gmail.com>2022-09-08 15:36:50 +0300
commit62ce77e24568113f9a19836bf90741dba4166db5 (patch)
tree612f26e6feebc21d7a2af20e86074c1ab690270d
parentf5001246e27e78422bb11187160702bcaba7daca (diff)
support for sd-concepts as alternatives for textual inversion #151
-rw-r--r--.gitignore3
-rw-r--r--modules/sd_hijack.py20
2 files changed, 17 insertions, 6 deletions
diff --git a/.gitignore b/.gitignore
index 5381c515..78cf719e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -9,4 +9,5 @@ __pycache__
/outputs
/config.json
/log
-webui.settings.bat \ No newline at end of file
+/webui.settings.bat
+/embeddings
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 1084e248..db9952a5 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -73,11 +73,21 @@ class StableDiffusionModelHijack:
name = os.path.splitext(filename)[0]
data = torch.load(path)
- param_dict = data['string_to_param']
- if hasattr(param_dict, '_parameters'):
- param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
- assert len(param_dict) == 1, 'embedding file has multiple terms in it'
- emb = next(iter(param_dict.items()))[1]
+
+ # textual inversion embeddings
+ if 'string_to_param' in data:
+ param_dict = data['string_to_param']
+ if hasattr(param_dict, '_parameters'):
+ param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
+ assert len(param_dict) == 1, 'embedding file has multiple terms in it'
+ emb = next(iter(param_dict.items()))[1]
+ elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
+ assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
+
+ emb = next(iter(data.values()))
+ if len(emb.shape) == 1:
+ emb = emb.unsqueeze(0)
+
self.word_embeddings[name] = emb.detach()
self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1))&0xffff:04x}'