aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/api/api.py10
-rw-r--r--modules/devices.py56
-rw-r--r--modules/launch_utils.py3
-rw-r--r--modules/script_callbacks.py5
-rw-r--r--modules/sd_hijack_utils.py12
-rw-r--r--modules/sd_models.py6
-rw-r--r--modules/sd_samplers.py3
-rw-r--r--modules/sd_samplers_cfg_denoiser.py2
-rw-r--r--modules/sd_samplers_lcm.py104
-rw-r--r--modules/sd_vae.py3
-rw-r--r--modules/shared_init.py1
-rw-r--r--modules/sysinfo.py2
-rw-r--r--modules/ui.py4
-rw-r--r--modules/xpu_specific.py20
-rw-r--r--scripts/postprocessing_upscale.py4
-rw-r--r--style.css2
16 files changed, 201 insertions, 36 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 9d1292e9..59e46335 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -879,7 +879,15 @@ class Api:
def launch(self, server_name, port, root_path):
self.app.include_router(self.router)
- uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)
+ uvicorn.run(
+ self.app,
+ host=server_name,
+ port=port,
+ timeout_keep_alive=shared.cmd_opts.timeout_keep_alive,
+ root_path=root_path,
+ ssl_keyfile=shared.cmd_opts.tls_keyfile,
+ ssl_certfile=shared.cmd_opts.tls_certfile
+ )
def kill_webui(self):
restart.stop_program()
diff --git a/modules/devices.py b/modules/devices.py
index ff279ac5..0321d12c 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -110,6 +110,7 @@ device_codeformer: torch.device = None
dtype: torch.dtype = torch.float16
dtype_vae: torch.dtype = torch.float16
dtype_unet: torch.dtype = torch.float16
+dtype_inference: torch.dtype = torch.float16
unet_needs_upcast = False
@@ -131,21 +132,44 @@ patch_module_list = [
]
-def manual_cast_forward(self, *args, **kwargs):
- org_dtype = torch_utils.get_param(self).dtype
- self.to(dtype)
- args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
- kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
- result = self.org_forward(*args, **kwargs)
- self.to(org_dtype)
- return result
+def manual_cast_forward(target_dtype):
+ def forward_wrapper(self, *args, **kwargs):
+ if any(
+ isinstance(arg, torch.Tensor) and arg.dtype != target_dtype
+ for arg in args
+ ):
+ args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
+ kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
+
+ org_dtype = torch_utils.get_param(self).dtype
+ if org_dtype != target_dtype:
+ self.to(target_dtype)
+ result = self.org_forward(*args, **kwargs)
+ if org_dtype != target_dtype:
+ self.to(org_dtype)
+
+ if target_dtype != dtype_inference:
+ if isinstance(result, tuple):
+ result = tuple(
+ i.to(dtype_inference)
+ if isinstance(i, torch.Tensor)
+ else i
+ for i in result
+ )
+ elif isinstance(result, torch.Tensor):
+ result = result.to(dtype_inference)
+ return result
+ return forward_wrapper
@contextlib.contextmanager
-def manual_cast():
+def manual_cast(target_dtype):
for module_type in patch_module_list:
org_forward = module_type.forward
- module_type.forward = manual_cast_forward
+ if module_type == torch.nn.MultiheadAttention and has_xpu():
+ module_type.forward = manual_cast_forward(torch.float32)
+ else:
+ module_type.forward = manual_cast_forward(target_dtype)
module_type.org_forward = org_forward
try:
yield None
@@ -161,15 +185,15 @@ def autocast(disable=False):
if fp8 and device==cpu:
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
- if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()):
- return manual_cast()
+ if fp8 and dtype_inference == torch.float32:
+ return manual_cast(dtype)
- if has_mps() and shared.cmd_opts.precision != "full":
- return manual_cast()
-
- if dtype == torch.float32 or shared.cmd_opts.precision == "full":
+ if dtype == torch.float32 or dtype_inference == torch.float32:
return contextlib.nullcontext()
+ if has_xpu() or has_mps() or cuda_no_autocast():
+ return manual_cast(dtype)
+
return torch.autocast("cuda")
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index c2a7ae93..8e58d714 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -344,11 +344,13 @@ def prepare_environment():
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
+ assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git")
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
+ assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917")
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
@@ -405,6 +407,7 @@ def prepare_environment():
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
+ git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash)
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 9ed7ad21..a54cb3eb 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -41,7 +41,7 @@ class ExtraNoiseParams:
class CFGDenoiserParams:
- def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
+ def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond, denoiser=None):
self.x = x
"""Latent image representation in the process of being denoised"""
@@ -63,6 +63,9 @@ class CFGDenoiserParams:
self.text_uncond = text_uncond
""" Encoder hidden states of text conditioning from negative prompt"""
+ self.denoiser = denoiser
+ """Current CFGDenoiser object with processing parameters"""
+
class CFGDenoisedParams:
def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
diff --git a/modules/sd_hijack_utils.py b/modules/sd_hijack_utils.py
index f8684475..79bf6e46 100644
--- a/modules/sd_hijack_utils.py
+++ b/modules/sd_hijack_utils.py
@@ -11,10 +11,14 @@ class CondFunc:
break
except ImportError:
pass
- for attr_name in func_path[i:-1]:
- resolved_obj = getattr(resolved_obj, attr_name)
- orig_func = getattr(resolved_obj, func_path[-1])
- setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
+ try:
+ for attr_name in func_path[i:-1]:
+ resolved_obj = getattr(resolved_obj, attr_name)
+ orig_func = getattr(resolved_obj, func_path[-1])
+ setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
+ except AttributeError:
+ print(f"Warning: Failed to resolve {orig_func} for CondFunc hijack")
+ pass
self.__init__(orig_func, sub_func, cond_func)
return lambda *args, **kwargs: self(*args, **kwargs)
def __init__(self, orig_func, sub_func, cond_func):
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 50bc209e..2c045771 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -842,13 +842,13 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False):
sd_hijack.model_hijack.hijack(sd_model)
timer.record("hijack")
- script_callbacks.model_loaded_callback(sd_model)
- timer.record("script callbacks")
-
if not sd_model.lowvram:
sd_model.to(devices.device)
timer.record("move model to device")
+ script_callbacks.model_loaded_callback(sd_model)
+ timer.record("script callbacks")
+
print(f"Weights loaded in {timer.summary()}.")
model_data.set_sd_model(sd_model)
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 45faae62..a58528a0 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -1,4 +1,4 @@
-from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, shared
+from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared
# imports for functions that previously were here and are used by other modules
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
@@ -6,6 +6,7 @@ from modules.sd_samplers_common import samples_to_image_grid, sample_to_image #
all_samplers = [
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
*sd_samplers_timesteps.samplers_data_timesteps,
+ *sd_samplers_lcm.samplers_data_lcm,
]
all_samplers_map = {x.name: x for x in all_samplers}
diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py
index eb9d5daf..6d76aa96 100644
--- a/modules/sd_samplers_cfg_denoiser.py
+++ b/modules/sd_samplers_cfg_denoiser.py
@@ -146,7 +146,7 @@ class CFGDenoiser(torch.nn.Module):
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
- denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
+ denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond, self)
cfg_denoiser_callback(denoiser_params)
x_in = denoiser_params.x
image_cond_in = denoiser_params.image_cond
diff --git a/modules/sd_samplers_lcm.py b/modules/sd_samplers_lcm.py
new file mode 100644
index 00000000..59839b72
--- /dev/null
+++ b/modules/sd_samplers_lcm.py
@@ -0,0 +1,104 @@
+import torch
+
+from k_diffusion import utils, sampling
+from k_diffusion.external import DiscreteEpsDDPMDenoiser
+from k_diffusion.sampling import default_noise_sampler, trange
+
+from modules import shared, sd_samplers_cfg_denoiser, sd_samplers_kdiffusion, sd_samplers_common
+
+
+class LCMCompVisDenoiser(DiscreteEpsDDPMDenoiser):
+ def __init__(self, model):
+ timesteps = 1000
+ original_timesteps = 50 # LCM Original Timesteps (default=50, for current version of LCM)
+ self.skip_steps = timesteps // original_timesteps
+
+ alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32, device=model.device)
+ for x in range(original_timesteps):
+ alphas_cumprod_valid[original_timesteps - 1 - x] = model.alphas_cumprod[timesteps - 1 - x * self.skip_steps]
+
+ super().__init__(model, alphas_cumprod_valid, quantize=None)
+
+
+ def get_sigmas(self, n=None,):
+ if n is None:
+ return sampling.append_zero(self.sigmas.flip(0))
+
+ start = self.sigma_to_t(self.sigma_max)
+ end = self.sigma_to_t(self.sigma_min)
+
+ t = torch.linspace(start, end, n, device=shared.sd_model.device)
+
+ return sampling.append_zero(self.t_to_sigma(t))
+
+
+ def sigma_to_t(self, sigma, quantize=None):
+ log_sigma = sigma.log()
+ dists = log_sigma - self.log_sigmas[:, None]
+ return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)
+
+
+ def t_to_sigma(self, timestep):
+ t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
+ return super().t_to_sigma(t)
+
+
+ def get_eps(self, *args, **kwargs):
+ return self.inner_model.apply_model(*args, **kwargs)
+
+
+ def get_scaled_out(self, sigma, output, input):
+ sigma_data = 0.5
+ scaled_timestep = utils.append_dims(self.sigma_to_t(sigma), output.ndim) * 10.0
+
+ c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
+ c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
+
+ return c_out * output + c_skip * input
+
+
+ def forward(self, input, sigma, **kwargs):
+ c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
+ eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
+ return self.get_scaled_out(sigma, input + eps * c_out, input)
+
+
+def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
+ extra_args = {} if extra_args is None else extra_args
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
+ s_in = x.new_ones([x.shape[0]])
+
+ for i in trange(len(sigmas) - 1, disable=disable):
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+
+ x = denoised
+ if sigmas[i + 1] > 0:
+ x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
+ return x
+
+
+class CFGDenoiserLCM(sd_samplers_cfg_denoiser.CFGDenoiser):
+ @property
+ def inner_model(self):
+ if self.model_wrap is None:
+ denoiser = LCMCompVisDenoiser
+ self.model_wrap = denoiser(shared.sd_model)
+
+ return self.model_wrap
+
+
+class LCMSampler(sd_samplers_kdiffusion.KDiffusionSampler):
+ def __init__(self, funcname, sd_model, options=None):
+ super().__init__(funcname, sd_model, options)
+ self.model_wrap_cfg = CFGDenoiserLCM(self)
+ self.model_wrap = self.model_wrap_cfg.inner_model
+
+
+samplers_lcm = [('LCM', sample_lcm, ['k_lcm'], {})]
+samplers_data_lcm = [
+ sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: LCMSampler(funcname, model), aliases, options)
+ for label, funcname, aliases, options in samplers_lcm
+]
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index 31306d8b..43687e48 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -273,10 +273,11 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
load_vae(sd_model, vae_file, vae_source)
sd_hijack.model_hijack.hijack(sd_model)
- script_callbacks.model_loaded_callback(sd_model)
if not sd_model.lowvram:
sd_model.to(devices.device)
+ script_callbacks.model_loaded_callback(sd_model)
+
print("VAE weights loaded.")
return sd_model
diff --git a/modules/shared_init.py b/modules/shared_init.py
index 586be342..935e3a21 100644
--- a/modules/shared_init.py
+++ b/modules/shared_init.py
@@ -29,6 +29,7 @@ def initialize():
devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
+ devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype
shared.device = devices.device
shared.weight_load_location = None if cmd_opts.lowram else "cpu"
diff --git a/modules/sysinfo.py b/modules/sysinfo.py
index 5abf616b..f336251e 100644
--- a/modules/sysinfo.py
+++ b/modules/sysinfo.py
@@ -24,9 +24,11 @@ environment_whitelist = {
"XFORMERS_PACKAGE",
"CLIP_PACKAGE",
"OPENCLIP_PACKAGE",
+ "ASSETS_REPO",
"STABLE_DIFFUSION_REPO",
"K_DIFFUSION_REPO",
"BLIP_REPO",
+ "ASSETS_COMMIT_HASH",
"STABLE_DIFFUSION_COMMIT_HASH",
"K_DIFFUSION_COMMIT_HASH",
"BLIP_COMMIT_HASH",
diff --git a/modules/ui.py b/modules/ui.py
index 2d2e333b..a716a040 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -13,7 +13,7 @@ from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import gradio_extensons # noqa: F401
-from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow
+from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow, launch_utils
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow
from modules.paths import script_path
from modules.ui_common import create_refresh_button
@@ -1223,3 +1223,5 @@ def setup_ui_api(app):
app.add_api_route("/internal/sysinfo", download_sysinfo, methods=["GET"])
app.add_api_route("/internal/sysinfo-download", lambda: download_sysinfo(attachment=True), methods=["GET"])
+ import fastapi.staticfiles
+ app.mount("/webui-assets", fastapi.staticfiles.StaticFiles(directory=launch_utils.repo_dir('stable-diffusion-webui-assets')), name="webui-assets")
diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py
index 4e11125b..2971dbc3 100644
--- a/modules/xpu_specific.py
+++ b/modules/xpu_specific.py
@@ -94,11 +94,23 @@ def torch_xpu_scaled_dot_product_attention(
return torch.reshape(result, (*N, L, Ev))
+def is_xpu_device(device: str | torch.device = None):
+ if device is None:
+ return False
+ if isinstance(device, str):
+ return device.startswith("xpu")
+ return device.type == "xpu"
+
+
if has_xpu:
- # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device
- CondFunc('torch.Generator',
- lambda orig_func, device=None: torch.xpu.Generator(device),
- lambda orig_func, device=None: device is not None and device.type == "xpu")
+ try:
+ # torch.Generator supports "xpu" device since 2.1
+ torch.Generator("xpu")
+ except RuntimeError:
+ # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for torch < 2.1)
+ CondFunc('torch.Generator',
+ lambda orig_func, device=None: torch.xpu.Generator(device),
+ lambda orig_func, device=None: is_xpu_device(device))
# W/A for some OPs that could not handle different input dtypes
CondFunc('torch.nn.functional.layer_norm',
diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py
index ed709688..a57f9d4a 100644
--- a/scripts/postprocessing_upscale.py
+++ b/scripts/postprocessing_upscale.py
@@ -26,8 +26,8 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to:
with FormRow():
with gr.Column(elem_id="upscaling_column_size", scale=4):
- upscaling_resize_w = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="extras_upscaling_resize_w")
- upscaling_resize_h = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="extras_upscaling_resize_h")
+ upscaling_resize_w = gr.Slider(minimum=64, maximum=8192, step=8, label="Width", value=512, elem_id="extras_upscaling_resize_w")
+ upscaling_resize_h = gr.Slider(minimum=64, maximum=8192, step=8, label="Height", value=512, elem_id="extras_upscaling_resize_h")
with gr.Column(elem_id="upscaling_dimensions_row", scale=1, elem_classes="dimensions-tools"):
upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn", tooltip="Switch width/height")
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
diff --git a/style.css b/style.css
index 6d4c8a0d..4957c523 100644
--- a/style.css
+++ b/style.css
@@ -1,6 +1,6 @@
/* temporary fix to load default gradio font in frontend instead of backend */
-@import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap');
+@import url('webui-assets/css/sourcesanspro.css');
/* temporary fix to hide gradio crop tool until it's fixed https://github.com/gradio-app/gradio/issues/3810 */