aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/processing.py13
-rw-r--r--modules/sd_samplers.py19
-rw-r--r--modules/shared.py15
-rw-r--r--webui.py4
4 files changed, 35 insertions, 16 deletions
diff --git a/modules/processing.py b/modules/processing.py
index d8c6b8d5..e01c8b3f 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -11,9 +11,8 @@ import cv2
from skimage import exposure
import modules.sd_hijack
-from modules import devices, prompt_parser, masking
+from modules import devices, prompt_parser, masking, sd_samplers
from modules.sd_hijack import model_hijack
-from modules.sd_samplers import samplers, samplers_for_img2img
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
import modules.face_restoration
@@ -110,7 +109,7 @@ class Processed:
self.width = p.width
self.height = p.height
self.sampler_index = p.sampler_index
- self.sampler = samplers[p.sampler_index].name
+ self.sampler = sd_samplers.samplers[p.sampler_index].name
self.cfg_scale = p.cfg_scale
self.steps = p.steps
self.batch_size = p.batch_size
@@ -265,7 +264,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
generation_params = {
"Steps": p.steps,
- "Sampler": samplers[p.sampler_index].name,
+ "Sampler": sd_samplers.samplers[p.sampler_index].name,
"CFG scale": p.cfg_scale,
"Seed": all_seeds[index],
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
@@ -478,7 +477,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.firstphase_height_truncated = int(scale * self.height)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
- self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
+ self.sampler = sd_samplers.samplers[self.sampler_index].constructor(self.sd_model)
if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
@@ -521,7 +520,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
shared.state.nextjob()
- self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
+ self.sampler = sd_samplers.samplers[self.sampler_index].constructor(self.sd_model)
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
# GC now before running the next img2img to prevent running out of memory
@@ -556,7 +555,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.nmask = None
def init(self, all_prompts, all_seeds, all_subseeds):
- self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model)
+ self.sampler = sd_samplers.samplers_for_img2img[self.sampler_index].constructor(self.sd_model)
crop_region = None
if self.image_mask is not None:
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index d27c547b..2e1f7715 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -32,12 +32,27 @@ samplers_data_k_diffusion = [
if hasattr(k_diffusion.sampling, funcname)
]
-samplers = [
+all_samplers = [
*samplers_data_k_diffusion,
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []),
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []),
]
-samplers_for_img2img = [x for x in samplers if x.name not in ['PLMS', 'DPM fast', 'DPM adaptive']]
+
+samplers = []
+samplers_for_img2img = []
+
+
+def set_samplers():
+ global samplers, samplers_for_img2img
+
+ hidden = set(opts.hide_samplers)
+ hidden_img2img = set(opts.hide_samplers + ['PLMS', 'DPM fast', 'DPM adaptive'])
+
+ samplers = [x for x in all_samplers if x.name not in hidden]
+ samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
+
+
+set_samplers()
sampler_extra_params = {
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
diff --git a/modules/shared.py b/modules/shared.py
index bab0fe6e..ca2e4c74 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -13,6 +13,7 @@ import modules.memmon
import modules.sd_models
import modules.styles
import modules.devices as devices
+from modules import sd_samplers
from modules.paths import script_path, sd_path
sd_model_file = os.path.join(script_path, 'model.ckpt')
@@ -238,14 +239,16 @@ options_templates.update(options_section(('ui', "User interface"), {
}))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
- "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
- 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ "hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in sd_samplers.all_samplers]}),
+ "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
+ 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
}))
+
class Options:
data = None
data_labels = options_templates
diff --git a/webui.py b/webui.py
index 47848ba5..9ef12427 100644
--- a/webui.py
+++ b/webui.py
@@ -2,7 +2,7 @@ import os
import threading
import time
import importlib
-from modules import devices
+from modules import devices, sd_samplers
from modules.paths import script_path
import signal
import threading
@@ -109,6 +109,8 @@ def webui():
time.sleep(0.5)
break
+ sd_samplers.set_samplers()
+
print('Reloading Custom Scripts')
modules.scripts.reload_scripts(os.path.join(script_path, "scripts"))
print('Reloading modules: modules.ui')