aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_samplers_common.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_samplers_common.py')
-rw-r--r--modules/sd_samplers_common.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index 15f27970..fa3614ff 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -202,8 +202,9 @@ class Sampler:
self.conditioning_key = shared.sd_model.model.conditioning_key
- self.model_wrap = None
+ self.p = None
self.model_wrap_cfg = None
+ self.sampler_extra_args = None
def callback_state(self, d):
step = d['i']
@@ -215,6 +216,7 @@ class Sampler:
shared.total_tqdm.update()
def launch_sampling(self, steps, func):
+ self.model_wrap_cfg.steps = steps
state.sampling_steps = steps
state.sampling_step = 0
@@ -234,6 +236,8 @@ class Sampler:
return p.steps
def initialize(self, p) -> dict:
+ self.p = p
+ self.model_wrap_cfg.p = p
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
self.model_wrap_cfg.step = 0