aboutsummaryrefslogtreecommitdiff
path: root/modules/textual_inversion/textual_inversion.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/textual_inversion/textual_inversion.py')
-rw-r--r--modules/textual_inversion/textual_inversion.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 4eb75cb5..e28c357a 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -269,6 +269,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
+ old_parallel_processing_allowed = shared.parallel_processing_allowed
pin_memory = shared.opts.pin_memory
@@ -279,6 +280,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
if unload:
+ shared.parallel_processing_allowed = False
shared.sd_model.first_stage_model.to(devices.cpu)
embedding.vec.requires_grad = True
@@ -316,7 +318,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
if shared.state.interrupted:
break
- with torch.autocast("cuda"):
+ with devices.autocast():
# c = stack_conds(batch.cond).to(devices.device)
# mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
# print(mask)
@@ -450,6 +452,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
pbar.leave = False
pbar.close()
shared.sd_model.first_stage_model.to(devices.device)
+ shared.parallel_processing_allowed = old_parallel_processing_allowed
return embedding, filename