aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_samplers.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-09-06 02:09:01 +0300
committerAUTOMATIC <16777216c@gmail.com>2022-09-06 02:09:01 +0300
commita243bc7859b7ab92a28d28c11b0ed5525fa0d6ba (patch)
tree8da54414ef8918317a6b7eab14f6ada3bed47d0e /modules/sd_samplers.py
parentb6763fb8847df5a5678f37137e7a702569e5c925 (diff)
added progressbar
added an option to disable progressbar added interrupt support to DDIM/PLMS
Diffstat (limited to 'modules/sd_samplers.py')
-rw-r--r--modules/sd_samplers.py36
1 files changed, 30 insertions, 6 deletions
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 6f028f5f..896e8b3f 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -1,10 +1,12 @@
from collections import namedtuple
+
+import ldm.models.diffusion.ddim
import torch
import tqdm
import k_diffusion.sampling
-from ldm.models.diffusion.ddim import DDIMSampler
-from ldm.models.diffusion.plms import PLMSSampler
+import ldm.models.diffusion.ddim
+import ldm.models.diffusion.plms
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -29,8 +31,8 @@ samplers_data_k_diffusion = [
samplers = [
*samplers_data_k_diffusion,
- SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(DDIMSampler, model), []),
- SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(PLMSSampler, model), []),
+ SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []),
+ SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []),
]
samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
@@ -43,6 +45,23 @@ def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs):
return sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs)
+def extended_tdqm(sequence, *args, desc=None, **kwargs):
+ state.sampling_steps = len(sequence)
+ state.sampling_step = 0
+
+ for x in tqdm.tqdm(sequence, *args, desc=state.job, **kwargs):
+ if state.interrupted:
+ break
+
+ yield x
+
+ state.sampling_step += 1
+
+
+ldm.models.diffusion.ddim.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)
+ldm.models.diffusion.plms.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)
+
+
class VanillaStableDiffusionSampler:
def __init__(self, constructor, sd_model):
self.sampler = constructor(sd_model)
@@ -102,13 +121,18 @@ class CFGDenoiser(torch.nn.Module):
return denoised
-def extended_trange(*args, **kwargs):
- for x in tqdm.trange(*args, desc=state.job, **kwargs):
+def extended_trange(count, *args, **kwargs):
+ state.sampling_steps = count
+ state.sampling_step = 0
+
+ for x in tqdm.trange(count, *args, desc=state.job, **kwargs):
if state.interrupted:
break
yield x
+ state.sampling_step += 1
+
class KDiffusionSampler:
def __init__(self, funcname, sd_model):