aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_samplers.py
diff options
context:
space:
mode:
authorDepFA <35278260+dfaker@users.noreply.github.com>2022-09-26 09:56:47 +0100
committerAUTOMATIC1111 <16777216c@gmail.com>2022-09-27 09:30:45 +0300
commit2ab3d593f9091689cdef07442df0213ef3242603 (patch)
tree5f71d578b6b9b321a891f5ab8a40b732614932d9 /modules/sd_samplers.py
parent6b78833e3331ce9f9cbd85e2c75a1b11aefecf1c (diff)
pass extra KDiffusionSampler function parameters
Diffstat (limited to 'modules/sd_samplers.py')
-rw-r--r--modules/sd_samplers.py20
1 files changed, 18 insertions, 2 deletions
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 39fdca70..2ac44f6c 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -37,6 +37,11 @@ samplers = [
]
samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
+sampler_extra_params = {
+ 'sample_euler':['s_churn','s_tmin','s_noise'],
+ 'sample_heun' :['s_churn','s_tmin','s_noise'],
+ 'sample_dpm_2':['s_churn','s_tmin','s_noise'],
+}
def setup_img2img_steps(p, steps=None):
if opts.img2img_fix_steps or steps is not None:
@@ -224,6 +229,7 @@ class KDiffusionSampler:
self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization)
self.funcname = funcname
self.func = getattr(k_diffusion.sampling, self.funcname)
+ self.extra_params = sampler_extra_params.get(funcname,[])
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
self.sampler_noises = None
self.sampler_noise_index = 0
@@ -269,7 +275,12 @@ class KDiffusionSampler:
if self.sampler_noises is not None:
k_diffusion.sampling.torch = TorchHijack(self)
- return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state)
+ extra_params_kwargs = {}
+ for val in self.extra_params:
+ if hasattr(opts,val):
+ extra_params_kwargs[val] = getattr(opts,val)
+
+ return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
steps = steps or p.steps
@@ -286,7 +297,12 @@ class KDiffusionSampler:
if self.sampler_noises is not None:
k_diffusion.sampling.torch = TorchHijack(self)
- samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state)
+ extra_params_kwargs = {}
+ for val in self.extra_params:
+ if hasattr(opts,val):
+ extra_params_kwargs[val] = getattr(opts,val)
+
+ samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
return samples