aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-09-28 10:49:07 +0300
committerAUTOMATIC <16777216c@gmail.com>2022-09-28 10:49:07 +0300
commit2ab64ec81a270c516816b5035860361ee145b9db (patch)
tree24d749e5d2e3e21b816b5bdd74b73842f9e58c1a /modules
parent15f333a266c20319e2b95a47a8834adf7b914aec (diff)
emergency fix for #1199
Diffstat (limited to 'modules')
-rw-r--r--modules/sd_samplers.py25
1 files changed, 13 insertions, 12 deletions
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 17faeab1..a1183997 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -3,6 +3,7 @@ import numpy as np
import torch
import tqdm
from PIL import Image
+import inspect
import k_diffusion.sampling
import ldm.models.diffusion.ddim
@@ -38,11 +39,11 @@ samplers = [
samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
sampler_extra_params = {
- 'sample_euler':['s_churn','s_tmin','s_tmax','s_noise'],
- 'sample_euler_ancestral':['eta'],
- 'sample_heun' :['s_churn','s_tmin','s_tmax','s_noise'],
- 'sample_dpm_2':['s_churn','s_tmin','s_tmax','s_noise'],
- 'sample_dpm_2_ancestral':['eta'],
+ 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
+ 'sample_euler_ancestral': ['eta'],
+ 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
+ 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
+ 'sample_dpm_2_ancestral': ['eta'],
}
def setup_img2img_steps(p, steps=None):
@@ -231,7 +232,7 @@ class KDiffusionSampler:
self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization)
self.funcname = funcname
self.func = getattr(k_diffusion.sampling, self.funcname)
- self.extra_params = sampler_extra_params.get(funcname,[])
+ self.extra_params = sampler_extra_params.get(funcname, [])
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
self.sampler_noises = None
self.sampler_noise_index = 0
@@ -278,9 +279,9 @@ class KDiffusionSampler:
k_diffusion.sampling.torch = TorchHijack(self)
extra_params_kwargs = {}
- for val in self.extra_params:
- if hasattr(p,val):
- extra_params_kwargs[val] = getattr(p,val)
+ for param_name in self.extra_params:
+ if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
+ extra_params_kwargs[param_name] = getattr(p, param_name)
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
@@ -300,9 +301,9 @@ class KDiffusionSampler:
k_diffusion.sampling.torch = TorchHijack(self)
extra_params_kwargs = {}
- for val in self.extra_params:
- if hasattr(p,val):
- extra_params_kwargs[val] = getattr(p,val)
+ for param_name in self.extra_params:
+ if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
+ extra_params_kwargs[param_name] = getattr(p, param_name)
samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)