aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-10-26 09:44:02 +0300
committerAUTOMATIC <16777216c@gmail.com>2022-10-26 09:44:02 +0300
commitcbb857b675cf0f169b21515c29da492b513cc8c4 (patch)
tree18ada7f614e3bb16049938fefae073bf91cbfa4d /modules/textual_inversion
parentee73341f04128cf81fb6a55c1942d35d20c9016e (diff)
enable creating embedding with --medvram
Diffstat (limited to 'modules/textual_inversion')
-rw-r--r--modules/textual_inversion/textual_inversion.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 529ed3e2..647ffe3e 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -157,6 +157,9 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
+ with devices.autocast():
+ cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
+
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)