aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_samplers_kdiffusion.py
diff options
context:
space:
mode:
authorKohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>2023-05-23 11:34:51 +0800
committerKohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>2023-05-23 11:34:51 +0800
commit70650f87a42615a62568a896403156d0065621b4 (patch)
tree957faca70f469dd6f8e926c70b5c44622a2505e6 /modules/sd_samplers_kdiffusion.py
parent1846ad36a3bd2a60bc9dc59a60e16d3ca7a559fe (diff)
Use better way to impl
Diffstat (limited to 'modules/sd_samplers_kdiffusion.py')
-rw-r--r--modules/sd_samplers_kdiffusion.py19
1 files changed, 13 insertions, 6 deletions
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 969ef02b..5fea08b0 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -295,6 +295,13 @@ class KDiffusionSampler:
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
+ if opts.custom_k_sched:
+ p.extra_generation_params["Enable Custom KDiffusion Schedule"] = True
+ p.extra_generation_params["KDiffusion Scheduler Type"] = opts.k_sched_type
+ p.extra_generation_params["KDiffusion Scheduler sigma_max"] = opts.sigma_max
+ p.extra_generation_params["KDiffusion Scheduler sigma_min"] = opts.sigma_min
+ p.extra_generation_params["KDiffusion Scheduler rho"] = opts.rho
+
extra_params_kwargs = {}
for param_name in self.extra_params:
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
@@ -318,15 +325,15 @@ class KDiffusionSampler:
if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
- elif p.enable_custom_k_sched:
+ elif opts.custom_k_sched:
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
- sigmas_func = k_diffusion_scheduler[p.k_sched_type]
+ sigmas_func = k_diffusion_scheduler[opts.k_sched_type]
sigmas_kwargs = {
- 'sigma_min': p.sigma_min or sigma_min,
- 'sigma_max': p.sigma_max or sigma_max
+ 'sigma_min': opts.sigma_min or sigma_min,
+ 'sigma_max': opts.sigma_max or sigma_max
}
- if p.k_sched_type != 'exponential':
- sigmas_kwargs['rho'] = p.rho
+ if opts.k_sched_type != 'exponential':
+ sigmas_kwargs['rho'] = opts.rho
sigmas = sigmas_func(n=steps, **sigmas_kwargs, device=shared.device)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())