aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/processing.py16
-rw-r--r--modules/sd_samplers.py28
-rw-r--r--modules/shared.py7
3 files changed, 44 insertions, 7 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 3abf3181..8d043f4d 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -78,7 +78,14 @@ class StableDiffusionProcessing:
self.paste_to = None
self.color_corrections = None
self.denoising_strength: float = 0
-
+
+ self.ddim_eta = opts.ddim_eta
+ self.ddim_discretize = opts.ddim_discretize
+ self.s_churn = opts.s_churn
+ self.s_tmin = opts.s_tmin
+ self.s_tmax = float('inf') # not representable as a standard ui option
+ self.s_noise = opts.s_noise
+
if not seed_enable_extras:
self.subseed = -1
self.subseed_strength = 0
@@ -117,6 +124,13 @@ class Processed:
self.extra_generation_params = p.extra_generation_params
self.index_of_first_image = index_of_first_image
+ self.ddim_eta = p.ddim_eta
+ self.ddim_discretize = p.ddim_discretize
+ self.s_churn = p.s_churn
+ self.s_tmin = p.s_tmin
+ self.s_tmax = p.s_tmax
+ self.s_noise = p.s_noise
+
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 1fc9d18c..666ee1ee 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -37,6 +37,11 @@ samplers = [
]
samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
+sampler_extra_params = {
+ 'sample_euler':['s_churn','s_tmin','s_tmax','s_noise'],
+ 'sample_heun' :['s_churn','s_tmin','s_tmax','s_noise'],
+ 'sample_dpm_2':['s_churn','s_tmin','s_tmax','s_noise'],
+}
def setup_img2img_steps(p, steps=None):
if opts.img2img_fix_steps or steps is not None:
@@ -120,9 +125,9 @@ class VanillaStableDiffusionSampler:
# existing code fails with cetain step counts, like 9
try:
- self.sampler.make_schedule(ddim_num_steps=steps, verbose=False)
+ self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=p.ddim_eta, ddim_discretize=p.ddim_discretize, verbose=False)
except Exception:
- self.sampler.make_schedule(ddim_num_steps=steps+1, verbose=False)
+ self.sampler.make_schedule(ddim_num_steps=steps+1,ddim_eta=p.ddim_eta, ddim_discretize=p.ddim_discretize, verbose=False)
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
@@ -149,9 +154,9 @@ class VanillaStableDiffusionSampler:
# existing code fails with cetin step counts, like 9
try:
- samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x)
+ samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=p.ddim_eta)
except Exception:
- samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x)
+ samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=p.ddim_eta)
return samples_ddim
@@ -224,6 +229,7 @@ class KDiffusionSampler:
self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization)
self.funcname = funcname
self.func = getattr(k_diffusion.sampling, self.funcname)
+ self.extra_params = sampler_extra_params.get(funcname,[])
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
self.sampler_noises = None
self.sampler_noise_index = 0
@@ -269,7 +275,12 @@ class KDiffusionSampler:
if self.sampler_noises is not None:
k_diffusion.sampling.torch = TorchHijack(self)
- return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state)
+ extra_params_kwargs = {}
+ for val in self.extra_params:
+ if hasattr(p,val):
+ extra_params_kwargs[val] = getattr(p,val)
+
+ return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
steps = steps or p.steps
@@ -286,7 +297,12 @@ class KDiffusionSampler:
if self.sampler_noises is not None:
k_diffusion.sampling.torch = TorchHijack(self)
- samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state)
+ extra_params_kwargs = {}
+ for val in self.extra_params:
+ if hasattr(p,val):
+ extra_params_kwargs[val] = getattr(p,val)
+
+ samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
return samples
diff --git a/modules/shared.py b/modules/shared.py
index bd030fe8..870fb3b9 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -220,6 +220,13 @@ options_templates.update(options_section(('ui', "User interface"), {
"js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
}))
+options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
+ "ddim_eta": OptionInfo(0.0, "DDIM eta", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform','quad']}),
+ 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+}))
class Options:
data = None