From b25c126ccdbc4da22ade46597a9addf808998989 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Wed, 29 Nov 2023 17:38:53 -0500 Subject: Protect alphas_cumprod from downcasting --- modules/sd_models.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index 841402e8..de80a493 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -387,7 +387,11 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if shared.cmd_opts.upcast_sampling and depth_model: model.depth_model = None + alphas_cumprod = model.alphas_cumprod + model.alphas_cumprod = None model.half() + model.alphas_cumprod = alphas_cumprod + model.alphas_cumprod_original = alphas_cumprod model.first_stage_model = vae if depth_model: model.depth_model = depth_model @@ -642,6 +646,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): else: weight_dtype_conversion = { 'first_stage_model': None, + 'alphas_cumprod': None, '': torch.float16, } -- cgit v1.2.1 From 588a52891dca4d030ca7028dd9c0b56022a68b57 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Wed, 29 Nov 2023 17:40:23 -0500 Subject: Add options for zero terminal SNR --- modules/shared_options.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/shared_options.py b/modules/shared_options.py index 04e68a71..51596777 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -218,6 +218,7 @@ options_templates.update(options_section(('compatibility', "Compatibility", "sd" "dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."), "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."), "use_old_scheduling": OptionInfo(False, "Use old prompt editing timelines.", infotext="Old prompt editing timelines").info("For [red:green:N]; old: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"), + "use_downcasted_alpha_bar": OptionInfo(False, "Downcast model alphas_cumprod to fp16 before sampling. For reproducing old seeds.") })) options_templates.update(options_section(('interrogate', "Interrogate"), { @@ -335,6 +336,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}, infotext='UniPC skip type'), 'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, infotext='UniPC order').info("must be < sampling steps"), 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'), + 'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise schedule for sampling").info("for use with zero terminal SNR trained models") })) options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), { -- cgit v1.2.1 From 6d0a8dcd892f7ad9b399fed6edbad6ede13c5f69 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Wed, 29 Nov 2023 17:42:07 -0500 Subject: Implement zero terminal SNR schedule option --- modules/processing.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/modules/processing.py b/modules/processing.py index ac58ef86..c88eec70 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -863,6 +863,34 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" + + def rescale_zero_terminal_snr_abar(alphas_cumprod): + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= (alphas_bar_sqrt_T) + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas_bar[-1] = 4.8973451890853435e-08 + return alphas_bar + + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) + + if opts.use_downcasted_alpha_bar: + p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) + if opts.sd_noise_schedule == "Zero Terminal SNR": + p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule + print("rescaling noise schedule for zero snr") + p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) -- cgit v1.2.1 From ec6ee5c13bf3453f8703e225a191333a9bbcf10a Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Wed, 29 Nov 2023 18:10:27 -0500 Subject: Fix infotext for ztSNR --- modules/shared_options.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index 51596777..bc3d56de 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -218,7 +218,7 @@ options_templates.update(options_section(('compatibility', "Compatibility", "sd" "dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."), "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."), "use_old_scheduling": OptionInfo(False, "Use old prompt editing timelines.", infotext="Old prompt editing timelines").info("For [red:green:N]; old: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"), - "use_downcasted_alpha_bar": OptionInfo(False, "Downcast model alphas_cumprod to fp16 before sampling. For reproducing old seeds.") + "use_downcasted_alpha_bar": OptionInfo(False, "Downcast model alphas_cumprod to fp16 before sampling. For reproducing old seeds.", infotext="Downcast alphas_cumprod") })) options_templates.update(options_section(('interrogate', "Interrogate"), { @@ -336,7 +336,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}, infotext='UniPC skip type'), 'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, infotext='UniPC order').info("must be < sampling steps"), 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'), - 'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise schedule for sampling").info("for use with zero terminal SNR trained models") + 'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise Schedule").info("for use with zero terminal SNR trained models") })) options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), { -- cgit v1.2.1 From ffa7f8201d849636bb327b3b40298e7c169ff204 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Wed, 29 Nov 2023 18:10:43 -0500 Subject: Lint --- modules/processing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index c88eec70..f3883d5b 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -863,7 +863,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - + def rescale_zero_terminal_snr_abar(alphas_cumprod): alphas_bar_sqrt = alphas_cumprod.sqrt() @@ -881,7 +881,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: alphas_bar = alphas_bar_sqrt**2 # Revert sqrt alphas_bar[-1] = 4.8973451890853435e-08 return alphas_bar - + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) if opts.use_downcasted_alpha_bar: -- cgit v1.2.1 From de79597ab9894965e3702939b8536ec3dcc3c859 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Wed, 29 Nov 2023 18:33:32 -0500 Subject: Only apply ztSNR related code if alphas_cumprod exists --- modules/processing.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index f3883d5b..7e73d7e2 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -882,15 +882,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: alphas_bar[-1] = 4.8973451890853435e-08 return alphas_bar - p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) - - if opts.use_downcasted_alpha_bar: - p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar - p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) - if opts.sd_noise_schedule == "Zero Terminal SNR": - p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule - print("rescaling noise schedule for zero snr") - p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) + if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'): + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) + + if opts.use_downcasted_alpha_bar: + p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) + if opts.sd_noise_schedule == "Zero Terminal SNR": + p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule + print("rescaling noise schedule for zero snr") + p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) -- cgit v1.2.1 From 668ae34e21df848ef4909b8b49c4142a3674701b Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Wed, 29 Nov 2023 22:48:31 -0500 Subject: remove debug print --- modules/processing.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index 7e73d7e2..d73c8bfc 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -890,7 +890,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) if opts.sd_noise_schedule == "Zero Terminal SNR": p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule - print("rescaling noise schedule for zero snr") p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): -- cgit v1.2.1 From 309a606c2fa645b6b8623f96ea56117e685a47fb Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Sat, 2 Dec 2023 13:07:45 -0500 Subject: ensure that original alpha bar always exists --- modules/processing.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index d73c8bfc..bfa59038 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -882,15 +882,17 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: alphas_bar[-1] = 4.8973451890853435e-08 return alphas_bar - if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'): - p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) - - if opts.use_downcasted_alpha_bar: - p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar - p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) - if opts.sd_noise_schedule == "Zero Terminal SNR": - p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule - p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) + if hasattr(p.sd_model, 'alphas_cumprod') and not hasattr(p.sd_model, 'alphas_cumprod_original'): + p.sd_model.alphas_cumprod_original = p.sd_model.alphas_cumprod + + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) + + if opts.use_downcasted_alpha_bar: + p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) + if opts.sd_noise_schedule == "Zero Terminal SNR": + p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule + p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) -- cgit v1.2.1 From 81c4ddf6ebebe6f18338de3b0391da1d8521a525 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Sat, 2 Dec 2023 13:11:00 -0500 Subject: fix linting --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index bfa59038..eeccea74 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -884,7 +884,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if hasattr(p.sd_model, 'alphas_cumprod') and not hasattr(p.sd_model, 'alphas_cumprod_original'): p.sd_model.alphas_cumprod_original = p.sd_model.alphas_cumprod - + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) if opts.use_downcasted_alpha_bar: -- cgit v1.2.1 From 4a43334376d9e116f7a1446f042f9af9c0484fc6 Mon Sep 17 00:00:00 2001 From: drhead Date: Sat, 2 Dec 2023 14:05:42 -0500 Subject: Revert 309a606c --- modules/processing.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index eeccea74..d73c8bfc 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -882,17 +882,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: alphas_bar[-1] = 4.8973451890853435e-08 return alphas_bar - if hasattr(p.sd_model, 'alphas_cumprod') and not hasattr(p.sd_model, 'alphas_cumprod_original'): - p.sd_model.alphas_cumprod_original = p.sd_model.alphas_cumprod - - p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) - - if opts.use_downcasted_alpha_bar: - p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar - p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) - if opts.sd_noise_schedule == "Zero Terminal SNR": - p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule - p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) + if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'): + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) + + if opts.use_downcasted_alpha_bar: + p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) + if opts.sd_noise_schedule == "Zero Terminal SNR": + p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule + p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) -- cgit v1.2.1 From dc1adeecdd02f3fb910481e808a6d60a77100fea Mon Sep 17 00:00:00 2001 From: drhead Date: Sat, 2 Dec 2023 14:06:56 -0500 Subject: Create alphas_cumprod_original on full precision path --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index de80a493..976c7d5b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -374,6 +374,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if shared.cmd_opts.no_half: model.float() + model.alphas_cumprod_original = alphas_cumprod devices.dtype_unet = torch.float32 timer.record("apply float()") else: -- cgit v1.2.1 From 78acdcf677a96894651ff0d7d8287f2a994f3781 Mon Sep 17 00:00:00 2001 From: drhead Date: Sat, 2 Dec 2023 14:09:18 -0500 Subject: fix variable --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 976c7d5b..5a19a00a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -374,7 +374,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if shared.cmd_opts.no_half: model.float() - model.alphas_cumprod_original = alphas_cumprod + model.alphas_cumprod_original = model.alphas_cumprod devices.dtype_unet = torch.float32 timer.record("apply float()") else: -- cgit v1.2.1 From 5381405eaa1e809e5cfb97522bd4c19d3c946079 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Sat, 9 Dec 2023 14:09:28 -0500 Subject: re-derive sqrt alpha bar and sqrt one minus alphabar This is the only place these values are ever referenced outside of training code so this change is very justifiable and more consistent. --- modules/sd_samplers_timesteps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index b17a8f93..c4bd5c12 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -36,7 +36,7 @@ class CompVisTimestepsVDenoiser(torch.nn.Module): self.inner_model = model def predict_eps_from_z_and_v(self, x_t, t, v): - return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t + return torch.sqrt(self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * v + torch.sqrt(1 - self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * x_t def forward(self, input, timesteps, **kwargs): model_output = self.inner_model.apply_model(input, timesteps, **kwargs) -- cgit v1.2.1