aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_models.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-03-25 12:03:26 +0300
committerGitHub <noreply@github.com>2023-03-25 12:03:26 +0300
commit956ed9a737e9f548336fb274901e5f43683736f8 (patch)
treeec48b6da539a97d2974db0774999c2c89d6a94f7 /modules/sd_models.py
parent8d2c582e3ea99e107df57a4e142acc28a6318d55 (diff)
parent4cbbb881ee530d9b9ba18027e2b0057e6a2c4ee1 (diff)
Merge pull request #8780 from Brawlence/master
Unload and re-load checkpoint to VRAM on request (API & Manual)
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py22
1 files changed, 21 insertions, 1 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 6410b09a..86218c08 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -494,7 +494,7 @@ def reload_model_weights(sd_model=None, info=None):
if sd_model is None or checkpoint_config != sd_model.used_config:
del sd_model
checkpoints_loaded.clear()
- load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"])
+ load_model(checkpoint_info, already_loaded_state_dict=state_dict)
return shared.sd_model
try:
@@ -517,3 +517,23 @@ def reload_model_weights(sd_model=None, info=None):
print(f"Weights loaded in {timer.summary()}.")
return sd_model
+
+def unload_model_weights(sd_model=None, info=None):
+ from modules import lowvram, devices, sd_hijack
+ timer = Timer()
+
+ if shared.sd_model:
+
+ # shared.sd_model.cond_stage_model.to(devices.cpu)
+ # shared.sd_model.first_stage_model.to(devices.cpu)
+ shared.sd_model.to(devices.cpu)
+ sd_hijack.model_hijack.undo_hijack(shared.sd_model)
+ shared.sd_model = None
+ sd_model = None
+ gc.collect()
+ devices.torch_gc()
+ torch.cuda.empty_cache()
+
+ print(f"Unloaded weights {timer.summary()}.")
+
+ return sd_model \ No newline at end of file