aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--javascript/hints.js81
-rw-r--r--modules/generation_parameters_copypaste.py16
-rw-r--r--modules/processing.py4
-rw-r--r--modules/script_callbacks.py20
-rw-r--r--modules/sd_hijack.py20
-rw-r--r--modules/sd_models.py4
-rw-r--r--modules/sd_samplers_kdiffusion.py40
-rw-r--r--modules/sd_unet.py92
-rw-r--r--modules/shared.py5
-rw-r--r--modules/shared_items.py11
-rw-r--r--modules/ui_tempdir.py14
-rw-r--r--scripts/xyz_grid.py6
-rw-r--r--webui.py4
-rwxr-xr-xwebui.sh9
14 files changed, 282 insertions, 44 deletions
diff --git a/javascript/hints.js b/javascript/hints.js
index 46f342cb..05ae5f22 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -116,17 +116,25 @@ var titles = {
"Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction."
};
-function updateTooltipForSpan(span) {
- if (span.title) return; // already has a title
+function updateTooltip(element) {
+ if (element.title) return; // already has a title
- let tooltip = localization[titles[span.textContent]] || titles[span.textContent];
+ let text = element.textContent;
+ let tooltip = localization[titles[text]] || titles[text];
if (!tooltip) {
- tooltip = localization[titles[span.value]] || titles[span.value];
+ let value = element.value;
+ if (value) tooltip = localization[titles[value]] || titles[value];
}
if (!tooltip) {
- for (const c of span.classList) {
+ // Gradio dropdown options have `data-value`.
+ let dataValue = element.dataset.value;
+ if (dataValue) tooltip = localization[titles[dataValue]] || titles[dataValue];
+ }
+
+ if (!tooltip) {
+ for (const c of element.classList) {
if (c in titles) {
tooltip = localization[titles[c]] || titles[c];
break;
@@ -135,34 +143,53 @@ function updateTooltipForSpan(span) {
}
if (tooltip) {
- span.title = tooltip;
+ element.title = tooltip;
}
}
-function updateTooltipForSelect(select) {
- if (select.onchange != null) return;
+// Nodes to check for adding tooltips.
+const tooltipCheckNodes = new Set();
+// Timer for debouncing tooltip check.
+let tooltipCheckTimer = null;
- select.onchange = function() {
- select.title = localization[titles[select.value]] || titles[select.value] || "";
- };
+function processTooltipCheckNodes() {
+ for (const node of tooltipCheckNodes) {
+ updateTooltip(node);
+ }
+ tooltipCheckNodes.clear();
}
-var observedTooltipElements = {SPAN: 1, BUTTON: 1, SELECT: 1, P: 1};
-
-onUiUpdate(function(m) {
- m.forEach(function(record) {
- record.addedNodes.forEach(function(node) {
- if (observedTooltipElements[node.tagName]) {
- updateTooltipForSpan(node);
- }
- if (node.tagName == "SELECT") {
- updateTooltipForSelect(node);
+onUiUpdate(function(mutationRecords) {
+ for (const record of mutationRecords) {
+ if (record.type === "childList" && record.target.classList.contains("options")) {
+ // This smells like a Gradio dropdown menu having changed,
+ // so let's enqueue an update for the input element that shows the current value.
+ let wrap = record.target.parentNode;
+ let input = wrap?.querySelector("input");
+ if (input) {
+ input.title = ""; // So we'll even have a chance to update it.
+ tooltipCheckNodes.add(input);
}
-
- if (node.querySelectorAll) {
- node.querySelectorAll('span, button, select, p').forEach(updateTooltipForSpan);
- node.querySelectorAll('select').forEach(updateTooltipForSelect);
+ }
+ for (const node of record.addedNodes) {
+ if (node.nodeType === Node.ELEMENT_NODE && !node.classList.contains("hide")) {
+ if (!node.title) {
+ if (
+ node.tagName === "SPAN" ||
+ node.tagName === "BUTTON" ||
+ node.tagName === "P" ||
+ node.tagName === "INPUT" ||
+ (node.tagName === "LI" && node.classList.contains("item")) // Gradio dropdown item
+ ) {
+ tooltipCheckNodes.add(node);
+ }
+ }
+ node.querySelectorAll('span, button, p').forEach(n => tooltipCheckNodes.add(n));
}
- });
- });
+ }
+ }
+ if (tooltipCheckNodes.size) {
+ clearTimeout(tooltipCheckTimer);
+ tooltipCheckTimer = setTimeout(processTooltipCheckNodes, 1000);
+ }
});
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index d5f0a49b..81aef502 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -306,6 +306,18 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
if "RNG" not in res:
res["RNG"] = "GPU"
+ if "Schedule type" not in res:
+ res["Schedule type"] = "Automatic"
+
+ if "Schedule max sigma" not in res:
+ res["Schedule max sigma"] = 0
+
+ if "Schedule min sigma" not in res:
+ res["Schedule min sigma"] = 0
+
+ if "Schedule rho" not in res:
+ res["Schedule rho"] = 0
+
return res
@@ -318,6 +330,10 @@ infotext_to_setting_name_mapping = [
('Conditional mask weight', 'inpainting_mask_weight'),
('Model hash', 'sd_model_checkpoint'),
('ENSD', 'eta_noise_seed_delta'),
+ ('Schedule type', 'k_sched_type'),
+ ('Schedule max sigma', 'sigma_max'),
+ ('Schedule min sigma', 'sigma_min'),
+ ('Schedule rho', 'rho'),
('Noise multiplier', 'initial_noise_multiplier'),
('Eta', 'eta_ancestral'),
('Eta DDIM', 'eta_ddim'),
diff --git a/modules/processing.py b/modules/processing.py
index 29a3743f..b75f2515 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -13,7 +13,7 @@ from skimage import exposure
from typing import Any, Dict, List
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -674,6 +674,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
sd_vae_approx.model()
+ sd_unet.apply_unet()
+
if state.job_count == -1:
state.job_count = p.n_iter
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 40f388a5..d2728e12 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -111,6 +111,7 @@ callback_map = dict(
callbacks_before_ui=[],
callbacks_on_reload=[],
callbacks_list_optimizers=[],
+ callbacks_list_unets=[],
)
@@ -271,6 +272,18 @@ def list_optimizers_callback():
return res
+def list_unets_callback():
+ res = []
+
+ for c in callback_map['callbacks_list_unets']:
+ try:
+ c.callback(res)
+ except Exception:
+ report_exception(c, 'list_unets')
+
+ return res
+
+
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@@ -430,3 +443,10 @@ def on_list_optimizers(callback):
to it."""
add_callback(callback_map['callbacks_list_optimizers'], callback)
+
+
+def on_list_unets(callback):
+ """register a function to be called when UI is making a list of alternative options for unet.
+ The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
+
+ add_callback(callback_map['callbacks_list_unets'], callback)
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index f93df0a6..487dfd60 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -3,7 +3,7 @@ from torch.nn.functional import silu
from types import MethodType
import modules.textual_inversion.textual_inversion
-from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors
+from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
@@ -43,7 +43,7 @@ def list_optimizers():
optimizers.extend(new_optimizers)
-def apply_optimizations():
+def apply_optimizations(option=None):
global current_optimizer
undo_optimizations()
@@ -60,7 +60,7 @@ def apply_optimizations():
current_optimizer.undo()
current_optimizer = None
- selection = shared.opts.cross_attention_optimization
+ selection = option or shared.opts.cross_attention_optimization
if selection == "Automatic" and len(optimizers) > 0:
matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
else:
@@ -72,12 +72,13 @@ def apply_optimizations():
matching_optimizer = optimizers[0]
if matching_optimizer is not None:
- print(f"Applying optimization: {matching_optimizer.name}... ", end='')
+ print(f"Applying attention optimization: {matching_optimizer.name}... ", end='')
matching_optimizer.apply()
print("done.")
current_optimizer = matching_optimizer
return current_optimizer.name
else:
+ print("Disabling attention optimization")
return ''
@@ -155,9 +156,9 @@ class StableDiffusionModelHijack:
def __init__(self):
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
- def apply_optimizations(self):
+ def apply_optimizations(self, option=None):
try:
- self.optimization_method = apply_optimizations()
+ self.optimization_method = apply_optimizations(option)
except Exception as e:
errors.display(e, "applying cross attention optimization")
undo_optimizations()
@@ -194,6 +195,11 @@ class StableDiffusionModelHijack:
self.layers = flatten(m)
+ if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
+ ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
+
+ ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
+
def undo_hijack(self, m):
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
m.cond_stage_model = m.cond_stage_model.wrapped
@@ -215,6 +221,8 @@ class StableDiffusionModelHijack:
self.layers = None
self.clip = None
+ ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
+
def apply_circular(self, enable):
if self.circular_enabled == enable:
return
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 3e7fc7e3..232eb9c4 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -14,7 +14,7 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
-from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
+from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
import tomesd
@@ -533,6 +533,8 @@ def reload_model_weights(sd_model=None, info=None):
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
+ sd_unet.apply_unet("None")
+
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
else:
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 638e0ac9..e9ba2c61 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -44,6 +44,14 @@ sampler_extra_params = {
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
}
+k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
+k_diffusion_scheduler = {
+ 'Automatic': None,
+ 'karras': k_diffusion.sampling.get_sigmas_karras,
+ 'exponential': k_diffusion.sampling.get_sigmas_exponential,
+ 'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential
+}
+
class CFGDenoiser(torch.nn.Module):
"""
@@ -265,6 +273,13 @@ class KDiffusionSampler:
try:
return func()
+ except RecursionError:
+ print(
+ 'Encountered RecursionError during sampling, returning last latent. '
+ 'rho >5 with a polyexponential scheduler may cause this error. '
+ 'You should try to use a smaller rho value instead.'
+ )
+ return self.last_latent
except sd_samplers_common.InterruptedException:
return self.last_latent
@@ -304,6 +319,31 @@ class KDiffusionSampler:
if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
+ elif opts.k_sched_type != "Automatic":
+ m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
+ sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max)
+ sigmas_kwargs = {
+ 'sigma_min': sigma_min,
+ 'sigma_max': sigma_max,
+ }
+
+ sigmas_func = k_diffusion_scheduler[opts.k_sched_type]
+ p.extra_generation_params["Schedule type"] = opts.k_sched_type
+
+ if opts.sigma_min != m_sigma_min and opts.sigma_min != 0:
+ sigmas_kwargs['sigma_min'] = opts.sigma_min
+ p.extra_generation_params["Schedule min sigma"] = opts.sigma_min
+ if opts.sigma_max != m_sigma_max and opts.sigma_max != 0:
+ sigmas_kwargs['sigma_max'] = opts.sigma_max
+ p.extra_generation_params["Schedule max sigma"] = opts.sigma_max
+
+ default_rho = 1. if opts.k_sched_type == "polyexponential" else 7.
+
+ if opts.k_sched_type != 'exponential' and opts.rho != 0 and opts.rho != default_rho:
+ sigmas_kwargs['rho'] = opts.rho
+ p.extra_generation_params["Schedule rho"] = opts.rho
+
+ sigmas = sigmas_func(n=steps, **sigmas_kwargs, device=shared.device)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
diff --git a/modules/sd_unet.py b/modules/sd_unet.py
new file mode 100644
index 00000000..6d708ad2
--- /dev/null
+++ b/modules/sd_unet.py
@@ -0,0 +1,92 @@
+import torch.nn
+import ldm.modules.diffusionmodules.openaimodel
+
+from modules import script_callbacks, shared, devices
+
+unet_options = []
+current_unet_option = None
+current_unet = None
+
+
+def list_unets():
+ new_unets = script_callbacks.list_unets_callback()
+
+ unet_options.clear()
+ unet_options.extend(new_unets)
+
+
+def get_unet_option(option=None):
+ option = option or shared.opts.sd_unet
+
+ if option == "None":
+ return None
+
+ if option == "Automatic":
+ name = shared.sd_model.sd_checkpoint_info.model_name
+
+ options = [x for x in unet_options if x.model_name == name]
+
+ option = options[0].label if options else "None"
+
+ return next(iter([x for x in unet_options if x.label == option]), None)
+
+
+def apply_unet(option=None):
+ global current_unet_option
+ global current_unet
+
+ new_option = get_unet_option(option)
+ if new_option == current_unet_option:
+ return
+
+ if current_unet is not None:
+ print(f"Dectivating unet: {current_unet.option.label}")
+ current_unet.deactivate()
+
+ current_unet_option = new_option
+ if current_unet_option is None:
+ current_unet = None
+
+ if not (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
+ shared.sd_model.model.diffusion_model.to(devices.device)
+
+ return
+
+ shared.sd_model.model.diffusion_model.to(devices.cpu)
+ devices.torch_gc()
+
+ current_unet = current_unet_option.create_unet()
+ current_unet.option = current_unet_option
+ print(f"Activating unet: {current_unet.option.label}")
+ current_unet.activate()
+
+
+class SdUnetOption:
+ model_name = None
+ """name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this"""
+
+ label = None
+ """name of the unet in UI"""
+
+ def create_unet(self):
+ """returns SdUnet object to be used as a Unet instead of built-in unet when making pictures"""
+ raise NotImplementedError()
+
+
+class SdUnet(torch.nn.Module):
+ def forward(self, x, timesteps, context, *args, **kwargs):
+ raise NotImplementedError()
+
+ def activate(self):
+ pass
+
+ def deactivate(self):
+ pass
+
+
+def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
+ if current_unet is not None:
+ return current_unet.forward(x, timesteps, context, *args, **kwargs)
+
+ return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
+
diff --git a/modules/shared.py b/modules/shared.py
index 0897f937..daab38dc 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -403,6 +403,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
+ "sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
@@ -517,6 +518,10 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ 'k_sched_type': OptionInfo("Automatic", "scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
+ 'sigma_min': OptionInfo(0.0, "sigma min", gr.Number).info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
+ 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise schedule"),
+ 'rho': OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a more steep noise schedule (decreases faster)"),
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
diff --git a/modules/shared_items.py b/modules/shared_items.py
index 2a8713c8..7f306a06 100644
--- a/modules/shared_items.py
+++ b/modules/shared_items.py
@@ -29,3 +29,14 @@ def cross_attention_optimizations():
return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"]
+def sd_unet_items():
+ import modules.sd_unet
+
+ return ["Automatic"] + [x.label for x in modules.sd_unet.unet_options] + ["None"]
+
+
+def refresh_unet_list():
+ import modules.sd_unet
+
+ modules.sd_unet.list_unets()
+
diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py
index f05049e1..7f6b42ae 100644
--- a/modules/ui_tempdir.py
+++ b/modules/ui_tempdir.py
@@ -4,6 +4,7 @@ from collections import namedtuple
from pathlib import Path
import gradio as gr
+import gradio.components
from PIL import PngImagePlugin
@@ -31,13 +32,16 @@ def check_tmp_file(gradio, filename):
return False
-def save_pil_to_file(pil_image, dir=None):
+def save_pil_to_file(self, pil_image, dir=None):
already_saved_as = getattr(pil_image, 'already_saved_as', None)
if already_saved_as and os.path.isfile(already_saved_as):
register_tmp_file(shared.demo, already_saved_as)
+ filename = already_saved_as
- file_obj = Savedfile(f'{already_saved_as}?{os.path.getmtime(already_saved_as)}')
- return file_obj
+ if not shared.opts.save_images_add_number:
+ filename += f'?{os.path.getmtime(already_saved_as)}'
+
+ return filename
if shared.opts.temp_dir != "":
dir = shared.opts.temp_dir
@@ -51,11 +55,11 @@ def save_pil_to_file(pil_image, dir=None):
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
- return file_obj
+ return file_obj.name
# override save to file function so that it also writes PNG info
-gr.processing_utils.save_pil_to_file = save_pil_to_file
+gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file
def on_tmpdir_changed():
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
index da820b39..7821cc65 100644
--- a/scripts/xyz_grid.py
+++ b/scripts/xyz_grid.py
@@ -10,7 +10,7 @@ import numpy as np
import modules.scripts as scripts
import gradio as gr
-from modules import images, sd_samplers, processing, sd_models, sd_vae
+from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_samplers_kdiffusion
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
from modules.shared import opts, state
import modules.shared as shared
@@ -220,6 +220,10 @@ axis_options = [
AxisOption("Sigma min", float, apply_field("s_tmin")),
AxisOption("Sigma max", float, apply_field("s_tmax")),
AxisOption("Sigma noise", float, apply_field("s_noise")),
+ AxisOption("Schedule type", str, apply_override("k_sched_type"), choices=lambda: list(sd_samplers_kdiffusion.k_diffusion_scheduler)),
+ AxisOption("Schedule min sigma", float, apply_override("sigma_min")),
+ AxisOption("Schedule max sigma", float, apply_override("sigma_max")),
+ AxisOption("Schedule rho", float, apply_override("rho")),
AxisOption("Eta", float, apply_field("eta")),
AxisOption("Clip skip", int, apply_clip_skip),
AxisOption("Denoising", float, apply_field("denoising_strength")),
diff --git a/webui.py b/webui.py
index f9210f41..1e3ff061 100644
--- a/webui.py
+++ b/webui.py
@@ -58,6 +58,7 @@ import modules.sd_hijack
import modules.sd_hijack_optimizations
import modules.sd_models
import modules.sd_vae
+import modules.sd_unet
import modules.txt2img
import modules.script_callbacks
import modules.textual_inversion.textual_inversion
@@ -291,6 +292,9 @@ def initialize_rest(*, reload_script_modules=False):
modules.sd_hijack.list_optimizers()
startup_timer.record("scripts list_optimizers")
+ modules.sd_unet.list_unets()
+ startup_timer.record("scripts list_unets")
+
def load_model():
"""
Accesses shared.sd_model property to load model.
diff --git a/webui.sh b/webui.sh
index ab52ac3b..607557b1 100755
--- a/webui.sh
+++ b/webui.sh
@@ -124,9 +124,12 @@ case "$gpu_info" in
*)
;;
esac
-if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
+if ! echo "$gpu_info" | grep -q "NVIDIA";
then
- export TORCH_COMMAND="pip install torch==2.0.1+rocm5.4.2 torchvision==0.15.2+rocm5.4.2 --index-url https://download.pytorch.org/whl/rocm5.4.2"
+ if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
+ then
+ export TORCH_COMMAND="pip install torch==2.0.1+rocm5.4.2 torchvision==0.15.2+rocm5.4.2 --index-url https://download.pytorch.org/whl/rocm5.4.2"
+ fi
fi
for preq in "${GIT}" "${python_cmd}"
@@ -190,7 +193,7 @@ fi
# Try using TCMalloc on Linux
prepare_tcmalloc() {
if [[ "${OSTYPE}" == "linux"* ]] && [[ -z "${NO_TCMALLOC}" ]] && [[ -z "${LD_PRELOAD}" ]]; then
- TCMALLOC="$(ldconfig -p | grep -Po "libtcmalloc.so.\d" | head -n 1)"
+ TCMALLOC="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -Po "libtcmalloc(_minimal|)\.so\.\d" | head -n 1)"
if [[ ! -z "${TCMALLOC}" ]]; then
echo "Using TCMalloc: ${TCMALLOC}"
export LD_PRELOAD="${TCMALLOC}"