aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_samplers_kdiffusion.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_samplers_kdiffusion.py')
-rw-r--r--modules/sd_samplers_kdiffusion.py57
1 files changed, 7 insertions, 50 deletions
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 9a29f1ae..adb6883e 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -2,18 +2,12 @@ from collections import deque
import torch
import inspect
import k_diffusion.sampling
-import ldm.models.diffusion.ddim
-import ldm.models.diffusion.plms
from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_compvis
from modules.shared import opts, state
import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
-# imports for functions that previously were here and are used by other modules
-from modules.sd_samplers_common import samples_to_image_grid, sample_to_image
-
-
samplers_k_diffusion = [
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
('Euler', 'sample_euler', ['k_euler'], {}),
@@ -40,50 +34,6 @@ samplers_data_k_diffusion = [
if hasattr(k_diffusion.sampling, funcname)
]
-all_samplers = [
- *samplers_data_k_diffusion,
- sd_samplers_common.SamplerData('DDIM', lambda model: sd_samplers_compvis.VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
- sd_samplers_common.SamplerData('PLMS', lambda model: sd_samplers_compvis.VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
-]
-all_samplers_map = {x.name: x for x in all_samplers}
-
-samplers = []
-samplers_for_img2img = []
-samplers_map = {}
-
-
-def create_sampler(name, model):
- if name is not None:
- config = all_samplers_map.get(name, None)
- else:
- config = all_samplers[0]
-
- assert config is not None, f'bad sampler name: {name}'
-
- sampler = config.constructor(model)
- sampler.config = config
-
- return sampler
-
-
-def set_samplers():
- global samplers, samplers_for_img2img
-
- hidden = set(opts.hide_samplers)
- hidden_img2img = set(opts.hide_samplers + ['PLMS'])
-
- samplers = [x for x in all_samplers if x.name not in hidden]
- samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
-
- samplers_map.clear()
- for sampler in all_samplers:
- samplers_map[sampler.name.lower()] = sampler.name
- for alias in sampler.aliases:
- samplers_map[alias.lower()] = sampler.name
-
-
-set_samplers()
-
sampler_extra_params = {
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
@@ -92,6 +42,13 @@ sampler_extra_params = {
class CFGDenoiser(torch.nn.Module):
+ """
+ Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
+ that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
+ instead of one. Originally, the second prompt is just an empty string, but we use non-empty
+ negative prompt.
+ """
+
def __init__(self, model):
super().__init__()
self.inner_model = model