From 70650f87a42615a62568a896403156d0065621b4 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 23 May 2023 11:34:51 +0800 Subject: Use better way to impl --- modules/sd_samplers_kdiffusion.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 969ef02b..5fea08b0 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -295,6 +295,13 @@ class KDiffusionSampler: k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) + if opts.custom_k_sched: + p.extra_generation_params["Enable Custom KDiffusion Schedule"] = True + p.extra_generation_params["KDiffusion Scheduler Type"] = opts.k_sched_type + p.extra_generation_params["KDiffusion Scheduler sigma_max"] = opts.sigma_max + p.extra_generation_params["KDiffusion Scheduler sigma_min"] = opts.sigma_min + p.extra_generation_params["KDiffusion Scheduler rho"] = opts.rho + extra_params_kwargs = {} for param_name in self.extra_params: if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters: @@ -318,15 +325,15 @@ class KDiffusionSampler: if p.sampler_noise_scheduler_override: sigmas = p.sampler_noise_scheduler_override(steps) - elif p.enable_custom_k_sched: + elif opts.custom_k_sched: sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) - sigmas_func = k_diffusion_scheduler[p.k_sched_type] + sigmas_func = k_diffusion_scheduler[opts.k_sched_type] sigmas_kwargs = { - 'sigma_min': p.sigma_min or sigma_min, - 'sigma_max': p.sigma_max or sigma_max + 'sigma_min': opts.sigma_min or sigma_min, + 'sigma_max': opts.sigma_max or sigma_max } - if p.k_sched_type != 'exponential': - sigmas_kwargs['rho'] = p.rho + if opts.k_sched_type != 'exponential': + sigmas_kwargs['rho'] = opts.rho sigmas = sigmas_func(n=steps, **sigmas_kwargs, device=shared.device) elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) -- cgit v1.2.1