aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorC43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com>2022-10-05 14:30:57 +0300
committerAUTOMATIC1111 <16777216c@gmail.com>2022-10-06 13:42:21 +0300
commit71901b3d3bea1d035bf4a7229d19356b4b062151 (patch)
tree3f0414230478a70dc4d1d623840bb459c99a32f9
parentc1a068ed0acc788774afc1541ca69342fd1d94ad (diff)
add karras scheduling variants
-rw-r--r--modules/sd_samplers.py13
1 files changed, 13 insertions, 0 deletions
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 2e1f7715..8d6eb762 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -26,6 +26,17 @@ samplers_k_diffusion = [
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad']),
]
+if opts.show_karras_scheduler_variants:
+ k_diffusion.sampling.sample_dpm_2_ka = k_diffusion.sampling.sample_dpm_2
+ k_diffusion.sampling.sample_dpm_2_ancestral_ka = k_diffusion.sampling.sample_dpm_2_ancestral
+ k_diffusion.sampling.sample_lms_ka = k_diffusion.sampling.sample_lms
+ samplers_k_diffusion_ka = [
+ ('LMS K Scheduling', 'sample_lms_ka', ['k_lms_ka']),
+ ('DPM2 K Scheduling', 'sample_dpm_2_ka', ['k_dpm_2_ka']),
+ ('DPM2 a K Scheduling', 'sample_dpm_2_ancestral_ka', ['k_dpm_2_a_ka']),
+ ]
+ samplers_k_diffusion.extend(samplers_k_diffusion_ka)
+
samplers_data_k_diffusion = [
SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases)
for label, funcname, aliases in samplers_k_diffusion
@@ -345,6 +356,8 @@ class KDiffusionSampler:
if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
+ elif self.funcname.endswith('ka'):
+ sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
else:
sigmas = self.model_wrap.get_sigmas(steps)
x = x * sigmas[0]