aboutsummaryrefslogtreecommitdiff
path: root/modules/lowvram.py
diff options
context:
space:
mode:
authorJairo Correa <jn.j41r0@gmail.com>2022-11-01 04:01:49 -0300
committerJairo Correa <jn.j41r0@gmail.com>2022-11-01 04:01:49 -0300
commitaf758e97fa2c4c853042f121af4e974be01e6696 (patch)
treed2775792cce1a7084fe24ef8e5225a94d04a4bc0 /modules/lowvram.py
parent5c9b3625fa03f18649e1843b5e9f2df2d4de94f9 (diff)
Unload sd_model before loading the other
Diffstat (limited to 'modules/lowvram.py')
-rw-r--r--modules/lowvram.py21
1 files changed, 13 insertions, 8 deletions
diff --git a/modules/lowvram.py b/modules/lowvram.py
index f327c3df..a4652cb1 100644
--- a/modules/lowvram.py
+++ b/modules/lowvram.py
@@ -38,13 +38,18 @@ def setup_for_low_vram(sd_model, use_medvram):
# see below for register_forward_pre_hook;
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
# useless here, and we just replace those methods
- def first_stage_model_encode_wrap(self, encoder, x):
- send_me_to_gpu(self, None)
- return encoder(x)
- def first_stage_model_decode_wrap(self, decoder, z):
- send_me_to_gpu(self, None)
- return decoder(z)
+ first_stage_model = sd_model.first_stage_model
+ first_stage_model_encode = sd_model.first_stage_model.encode
+ first_stage_model_decode = sd_model.first_stage_model.decode
+
+ def first_stage_model_encode_wrap(x):
+ send_me_to_gpu(first_stage_model, None)
+ return first_stage_model_encode(x)
+
+ def first_stage_model_decode_wrap(z):
+ send_me_to_gpu(first_stage_model, None)
+ return first_stage_model_decode(z)
# remove three big modules, cond, first_stage, and unet from the model and then
# send the model to GPU. Then put modules back. the modules will be in CPU.
@@ -56,8 +61,8 @@ def setup_for_low_vram(sd_model, use_medvram):
# register hooks for those the first two models
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
- sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
- sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
+ sd_model.first_stage_model.encode = first_stage_model_encode_wrap
+ sd_model.first_stage_model.decode = first_stage_model_decode_wrap
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
if use_medvram: