aboutsummaryrefslogtreecommitdiff
path: root/modules/hypernetworks/hypernetwork.py
diff options
context:
space:
mode:
authorZac Liu <liuguang@baai.ac.cn>2022-12-06 09:16:15 +0800
committerGitHub <noreply@github.com>2022-12-06 09:16:15 +0800
commit3ebf977a6e4f478ab918e44506974beee32da276 (patch)
treef68456207e5cd78718ec1e9c588ecdc22d568d81 /modules/hypernetworks/hypernetwork.py
parent231fb72872191ffa8c446af1577c9003b3d19d4f (diff)
parent44c46f0ed395967cd3830dd481a2db759fda5b3b (diff)
Merge branch 'AUTOMATIC1111:master' into master
Diffstat (limited to 'modules/hypernetworks/hypernetwork.py')
-rw-r--r--modules/hypernetworks/hypernetwork.py7
1 files changed, 6 insertions, 1 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 8466887f..c406ffb3 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -433,7 +433,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
+ old_parallel_processing_allowed = shared.parallel_processing_allowed
+
if unload:
+ shared.parallel_processing_allowed = False
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
@@ -495,7 +498,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
if shared.state.interrupted:
break
- with torch.autocast("cuda"):
+ with devices.autocast():
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
if tag_drop_out != 0 or shuffle_tags:
shared.sd_model.cond_stage_model.to(devices.device)
@@ -612,10 +615,12 @@ Last saved image: {html.escape(last_saved_image)}<br/>
if shared.opts.save_optimizer_state:
hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
+
del optimizer
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
+ shared.parallel_processing_allowed = old_parallel_processing_allowed
return hypernetwork, filename