aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_samplers_common.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-08-08 22:09:40 +0300
committerAUTOMATIC1111 <16777216c@gmail.com>2023-08-08 22:09:51 +0300
commitf8ff8c0638997fd0aef217db1505598846f14782 (patch)
treee94ba153369a657df92857b7c342d5d245a0a8b4 /modules/sd_samplers_common.py
parent54c3e5c913b17622bed4ff4d03df488b80611e21 (diff)
merge errors
Diffstat (limited to 'modules/sd_samplers_common.py')
-rw-r--r--modules/sd_samplers_common.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index 15f27970..fa3614ff 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -202,8 +202,9 @@ class Sampler:
self.conditioning_key = shared.sd_model.model.conditioning_key
- self.model_wrap = None
+ self.p = None
self.model_wrap_cfg = None
+ self.sampler_extra_args = None
def callback_state(self, d):
step = d['i']
@@ -215,6 +216,7 @@ class Sampler:
shared.total_tqdm.update()
def launch_sampling(self, steps, func):
+ self.model_wrap_cfg.steps = steps
state.sampling_steps = steps
state.sampling_step = 0
@@ -234,6 +236,8 @@ class Sampler:
return p.steps
def initialize(self, p) -> dict:
+ self.p = p
+ self.model_wrap_cfg.p = p
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
self.model_wrap_cfg.step = 0