aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorDepFA <35278260+dfaker@users.noreply.github.com>2022-10-09 22:05:09 +0100
committerGitHub <noreply@github.com>2022-10-09 22:05:09 +0100
commit5d12ec82d3e13f5ff4c55db2930e4e10aed7015a (patch)
tree5704c568b650078442dfa2a4273ff97203578979 /modules
parent969bd8256e5b4f1007d3cc653723d4ad50a92528 (diff)
add encoder and decoder classes
Diffstat (limited to 'modules')
-rw-r--r--modules/textual_inversion/textual_inversion.py21
1 files changed, 21 insertions, 0 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index d7813084..44d4e08b 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -16,6 +16,27 @@ import json
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
+class EmbeddingEncoder(json.JSONEncoder):
+ def default(self, obj):
+ if isinstance(obj, torch.Tensor):
+ return {'EMBEDDINGTENSOR':obj.cpu().detach().numpy().tolist()}
+ return json.JSONEncoder.default(self, o)
+
+class EmbeddingDecoder(json.JSONDecoder):
+ def __init__(self, *args, **kwargs):
+ json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
+ def object_hook(self, d):
+ if 'EMBEDDINGTENSOR' in d:
+ return torch.from_numpy(np.array(d['EMBEDDINGTENSOR']))
+ return d
+
+def embeddingToB64(data):
+ d = json.dumps(data,cls=EmbeddingEncoder)
+ return base64.b64encode(d.encode())
+
+def EmbeddingFromB64(data):
+ d = base64.b64decode(data)
+ return json.loads(d,cls=EmbeddingDecoder)
class Embedding:
def __init__(self, vec, name, step=None):