aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_samplers_timesteps.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_samplers_timesteps.py')
-rw-r--r--modules/sd_samplers_timesteps.py9
1 files changed, 4 insertions, 5 deletions
diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py
index b17a8f93..8cc7d384 100644
--- a/modules/sd_samplers_timesteps.py
+++ b/modules/sd_samplers_timesteps.py
@@ -36,7 +36,7 @@ class CompVisTimestepsVDenoiser(torch.nn.Module):
self.inner_model = model
def predict_eps_from_z_and_v(self, x_t, t, v):
- return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t
+ return torch.sqrt(self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * v + torch.sqrt(1 - self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * x_t
def forward(self, input, timesteps, **kwargs):
model_output = self.inner_model.apply_model(input, timesteps, **kwargs)
@@ -80,6 +80,7 @@ class CompVisSampler(sd_samplers_common.Sampler):
self.eta_default = 0.0
self.model_wrap_cfg = CFGDenoiserTimesteps(self)
+ self.model_wrap = self.model_wrap_cfg.inner_model
def get_timesteps(self, p, steps):
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
@@ -132,8 +133,7 @@ class CompVisSampler(sd_samplers_common.Sampler):
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
- if self.model_wrap_cfg.padded_cond_uncond:
- p.extra_generation_params["Pad conds"] = True
+ self.add_infotext(p)
return samples
@@ -157,8 +157,7 @@ class CompVisSampler(sd_samplers_common.Sampler):
}
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
- if self.model_wrap_cfg.padded_cond_uncond:
- p.extra_generation_params["Pad conds"] = True
+ self.add_infotext(p)
return samples