aboutsummaryrefslogtreecommitdiff
path: root/modules/lowvram.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/lowvram.py')
-rw-r--r--modules/lowvram.py51
1 files changed, 37 insertions, 14 deletions
diff --git a/modules/lowvram.py b/modules/lowvram.py
index d95bcfbf..da4f33a8 100644
--- a/modules/lowvram.py
+++ b/modules/lowvram.py
@@ -53,19 +53,46 @@ def setup_for_low_vram(sd_model, use_medvram):
send_me_to_gpu(first_stage_model, None)
return first_stage_model_decode(z)
- # for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field
- if hasattr(sd_model.cond_stage_model, 'model'):
- sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
-
- # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then
- # send the model to GPU. Then put modules back. the modules will be in CPU.
- stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), getattr(sd_model, 'embedder', None), sd_model.model
- sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None
+ to_remain_in_cpu = [
+ (sd_model, 'first_stage_model'),
+ (sd_model, 'depth_model'),
+ (sd_model, 'embedder'),
+ (sd_model, 'model'),
+ (sd_model, 'embedder'),
+ ]
+
+ is_sdxl = hasattr(sd_model, 'conditioner')
+ is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
+
+ if is_sdxl:
+ to_remain_in_cpu.append((sd_model, 'conditioner'))
+ elif is_sd2:
+ to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
+ else:
+ to_remain_in_cpu.append((sd_model.cond_stage_model, 'transformer'))
+
+ # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model
+ stored = []
+ for obj, field in to_remain_in_cpu:
+ module = getattr(obj, field, None)
+ stored.append(module)
+ setattr(obj, field, None)
+
+ # send the model to GPU.
sd_model.to(devices.device)
- sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored
+
+ # put modules back. the modules will be in CPU.
+ for (obj, field), module in zip(to_remain_in_cpu, stored):
+ setattr(obj, field, module)
# register hooks for those the first three models
- sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
+ if is_sdxl:
+ sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
+ elif is_sd2:
+ sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
+ else:
+ sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
+
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
@@ -75,10 +102,6 @@ def setup_for_low_vram(sd_model, use_medvram):
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
- if hasattr(sd_model.cond_stage_model, 'model'):
- sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer
- del sd_model.cond_stage_model.transformer
-
if use_medvram:
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
else: