aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2022-12-03 08:29:56 +0300
committerGitHub <noreply@github.com>2022-12-03 08:29:56 +0300
commitae81b377d4e745a30ce019b82a2f67e3531770a7 (patch)
tree72495a41ecd26b582afd98ac79818217aa279517 /modules
parentc3777777d0ae87dbceb47ef17627ec98391ff869 (diff)
parent67efee33a6c65e58b3f6c788993d0e68a33e4fd0 (diff)
Merge pull request #5165 from klimaleksus/fix-sequential-vae
Make VAE step sequential to prevent VRAM spikes, will fix #3059, #2082, #2561, #3462
Diffstat (limited to 'modules')
-rw-r--r--modules/processing.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/modules/processing.py b/modules/processing.py
index edceb532..fd995b8a 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -530,8 +530,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
with devices.autocast():
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
- samples_ddim = samples_ddim.to(devices.dtype_vae)
- x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
+ x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
+ x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
del samples_ddim