aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/processing.py28
-rw-r--r--modules/sd_models.py6
-rw-r--r--modules/sd_samplers_timesteps.py2
-rw-r--r--modules/shared_options.py2
4 files changed, 37 insertions, 1 deletions
diff --git a/modules/processing.py b/modules/processing.py
index b30df60d..846e4796 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -898,6 +898,34 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
+ def rescale_zero_terminal_snr_abar(alphas_cumprod):
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= (alphas_bar_sqrt_T)
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+ alphas_bar[-1] = 4.8973451890853435e-08
+ return alphas_bar
+
+ if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'):
+ p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device)
+
+ if opts.use_downcasted_alpha_bar:
+ p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
+ p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device)
+ if opts.sd_noise_schedule == "Zero Terminal SNR":
+ p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
+ p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device)
+
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index d0046f88..50bc209e 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -401,6 +401,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if shared.cmd_opts.no_half:
model.float()
+ model.alphas_cumprod_original = model.alphas_cumprod
devices.dtype_unet = torch.float32
timer.record("apply float()")
else:
@@ -414,7 +415,11 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if shared.cmd_opts.upcast_sampling and depth_model:
model.depth_model = None
+ alphas_cumprod = model.alphas_cumprod
+ model.alphas_cumprod = None
model.half()
+ model.alphas_cumprod = alphas_cumprod
+ model.alphas_cumprod_original = alphas_cumprod
model.first_stage_model = vae
if depth_model:
model.depth_model = depth_model
@@ -691,6 +696,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
else:
weight_dtype_conversion = {
'first_stage_model': None,
+ 'alphas_cumprod': None,
'': torch.float16,
}
diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py
index f8afa8bd..777dd8d0 100644
--- a/modules/sd_samplers_timesteps.py
+++ b/modules/sd_samplers_timesteps.py
@@ -36,7 +36,7 @@ class CompVisTimestepsVDenoiser(torch.nn.Module):
self.inner_model = model
def predict_eps_from_z_and_v(self, x_t, t, v):
- return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t
+ return torch.sqrt(self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * v + torch.sqrt(1 - self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * x_t
def forward(self, input, timesteps, **kwargs):
model_output = self.inner_model.apply_model(input, timesteps, **kwargs)
diff --git a/modules/shared_options.py b/modules/shared_options.py
index 281591da..ce06f022 100644
--- a/modules/shared_options.py
+++ b/modules/shared_options.py
@@ -220,6 +220,7 @@ options_templates.update(options_section(('compatibility', "Compatibility", "sd"
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
"use_old_scheduling": OptionInfo(False, "Use old prompt editing timelines.", infotext="Old prompt editing timelines").info("For [red:green:N]; old: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"),
+ "use_downcasted_alpha_bar": OptionInfo(False, "Downcast model alphas_cumprod to fp16 before sampling. For reproducing old seeds.", infotext="Downcast alphas_cumprod")
}))
options_templates.update(options_section(('interrogate', "Interrogate"), {
@@ -358,6 +359,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}, infotext='UniPC skip type'),
'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, infotext='UniPC order').info("must be < sampling steps"),
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'),
+ 'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise Schedule").info("for use with zero terminal SNR trained models")
}))
options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), {