aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py5
-rw-r--r--modules/cmd_args.py2
-rw-r--r--modules/devices.py26
-rw-r--r--modules/infotext_versions.py5
-rw-r--r--modules/initialize.py3
-rw-r--r--modules/interrogate.py2
-rw-r--r--modules/launch_utils.py9
-rw-r--r--modules/npu_specific.py31
-rw-r--r--modules/options.py2
-rw-r--r--modules/processing.py8
-rw-r--r--modules/scripts.py28
-rw-r--r--modules/sd_samplers_cfg_denoiser.py70
-rw-r--r--modules/sd_samplers_common.py7
-rw-r--r--modules/sd_samplers_kdiffusion.py6
-rw-r--r--modules/sd_samplers_timesteps.py6
-rw-r--r--modules/shared.py3
-rw-r--r--modules/shared_options.py4
-rw-r--r--modules/styles.py111
-rw-r--r--modules/textual_inversion/textual_inversion.py1
-rw-r--r--modules/txt2img.py4
-rw-r--r--modules/ui_extra_networks.py31
-rw-r--r--modules/ui_extra_networks_checkpoints.py2
-rw-r--r--modules/ui_postprocessing.py3
-rw-r--r--modules/ui_prompt_styles.py9
-rw-r--r--modules/ui_toprow.py22
-rw-r--r--modules/upscaler_utils.py5
26 files changed, 275 insertions, 130 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 59e46335..4e656082 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -230,6 +230,7 @@ class Api:
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem])
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem])
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
+ self.add_api_route("/sdapi/v1/refresh-embeddings", self.refresh_embeddings, methods=["POST"])
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
@@ -747,6 +748,10 @@ class Api:
"skipped": convert_embeddings(db.skipped_embeddings),
}
+ def refresh_embeddings(self):
+ with self.queue_lock:
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
+
def refresh_checkpoints(self):
with self.queue_lock:
shared.refresh_checkpoints()
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index e58059a1..f1251b6c 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -88,7 +88,7 @@ parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anythin
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it", default=[data_path])
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
-parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
+parser.add_argument("--styles-file", type=str, action='append', help="path or wildcard path of styles files, allow multiple entries.", default=[])
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
diff --git a/modules/devices.py b/modules/devices.py
index dfffaf24..28c0c54d 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -3,8 +3,7 @@ import contextlib
from functools import lru_cache
import torch
-from modules import errors, shared
-from modules import torch_utils
+from modules import errors, shared, npu_specific
if sys.platform == "darwin":
from modules import mac_specific
@@ -58,6 +57,9 @@ def get_optimal_device_name():
if has_xpu():
return xpu_specific.get_xpu_device_string()
+ if npu_specific.has_npu:
+ return npu_specific.get_npu_device_string()
+
return "cpu"
@@ -85,6 +87,16 @@ def torch_gc():
if has_xpu():
xpu_specific.torch_xpu_gc()
+ if npu_specific.has_npu:
+ torch_npu_set_device()
+ npu_specific.torch_npu_gc()
+
+
+def torch_npu_set_device():
+ # Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
+ if npu_specific.has_npu:
+ torch.npu.set_device(0)
+
def enable_tf32():
if torch.cuda.is_available():
@@ -141,7 +153,12 @@ def manual_cast_forward(target_dtype):
args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
- org_dtype = torch_utils.get_param(self).dtype
+ org_dtype = target_dtype
+ for param in self.parameters():
+ if param.dtype != target_dtype:
+ org_dtype = param.dtype
+ break
+
if org_dtype != target_dtype:
self.to(target_dtype)
result = self.org_forward(*args, **kwargs)
@@ -170,7 +187,7 @@ def manual_cast(target_dtype):
continue
applied = True
org_forward = module_type.forward
- if module_type == torch.nn.MultiheadAttention and has_xpu():
+ if module_type == torch.nn.MultiheadAttention:
module_type.forward = manual_cast_forward(torch.float32)
else:
module_type.forward = manual_cast_forward(target_dtype)
@@ -252,4 +269,3 @@ def first_time_calculation():
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
conv2d(x)
-
diff --git a/modules/infotext_versions.py b/modules/infotext_versions.py
index a5afeebf..23b45c3f 100644
--- a/modules/infotext_versions.py
+++ b/modules/infotext_versions.py
@@ -31,9 +31,12 @@ def backcompat(d):
if ver is None:
return
- if ver < v160:
+ if ver < v160 and '[' in d.get('Prompt', ''):
d["Old prompt editing timelines"] = True
+ if ver < v160 and d.get('Sampler', '') in ('DDIM', 'PLMS'):
+ d["Pad conds v0"] = True
+
if ver < v170_tsnr:
d["Downcast alphas_cumprod"] = True
diff --git a/modules/initialize.py b/modules/initialize.py
index 7c1ac99e..f7313ff4 100644
--- a/modules/initialize.py
+++ b/modules/initialize.py
@@ -142,13 +142,14 @@ def initialize_rest(*, reload_script_modules=False):
its optimization may be None because the list of optimizaers has neet been filled
by that time, so we apply optimization again.
"""
+ from modules import devices
+ devices.torch_npu_set_device()
shared.sd_model # noqa: B018
if sd_hijack.current_optimizer is None:
sd_hijack.apply_optimizations()
- from modules import devices
devices.first_time_calculation()
if not shared.cmd_opts.skip_load_model_at_start:
Thread(target=load_model).start()
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 35a627ca..c93e7aa8 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -17,7 +17,7 @@ clip_model_name = 'ViT-L/14'
Category = namedtuple("Category", ["name", "topn", "items"])
-re_topn = re.compile(r"\.top(\d+)\.")
+re_topn = re.compile(r"\.top(\d+)$")
def category_types():
return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index 8e58d714..107c72b0 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -251,7 +251,6 @@ def list_extensions(settings_file):
except Exception:
errors.report(f'\nCould not load settings\nThe config file "{settings_file}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True)
os.replace(settings_file, os.path.join(script_path, "tmp", "config.json"))
- settings = {}
disabled_extensions = set(settings.get('disabled_extensions', []))
disable_all_extensions = settings.get('disable_all_extensions', 'none')
@@ -339,6 +338,7 @@ def prepare_environment():
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://pytorch-extension.intel.com/release-whl/stable/xpu/us/")
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
+ requirements_file_for_npu = os.environ.get('REQS_FILE_FOR_NPU', "requirements_npu.txt")
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.23.post1')
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
@@ -422,6 +422,13 @@ def prepare_environment():
run_pip(f"install -r \"{requirements_file}\"", "requirements")
startup_timer.record("install requirements")
+ if not os.path.isfile(requirements_file_for_npu):
+ requirements_file_for_npu = os.path.join(script_path, requirements_file_for_npu)
+
+ if "torch_npu" in torch_command and not requirements_met(requirements_file_for_npu):
+ run_pip(f"install -r \"{requirements_file_for_npu}\"", "requirements_for_npu")
+ startup_timer.record("install requirements_for_npu")
+
if not args.skip_install:
run_extensions_installers(settings_file=args.ui_settings_file)
diff --git a/modules/npu_specific.py b/modules/npu_specific.py
new file mode 100644
index 00000000..94100691
--- /dev/null
+++ b/modules/npu_specific.py
@@ -0,0 +1,31 @@
+import importlib
+import torch
+
+from modules import shared
+
+
+def check_for_npu():
+ if importlib.util.find_spec("torch_npu") is None:
+ return False
+ import torch_npu
+
+ try:
+ # Will raise a RuntimeError if no NPU is found
+ _ = torch_npu.npu.device_count()
+ return torch.npu.is_available()
+ except RuntimeError:
+ return False
+
+
+def get_npu_device_string():
+ if shared.cmd_opts.device_id is not None:
+ return f"npu:{shared.cmd_opts.device_id}"
+ return "npu:0"
+
+
+def torch_npu_gc():
+ with torch.npu.device(get_npu_device_string()):
+ torch.npu.empty_cache()
+
+
+has_npu = check_for_npu()
diff --git a/modules/options.py b/modules/options.py
index 503b40e9..35ccade2 100644
--- a/modules/options.py
+++ b/modules/options.py
@@ -198,6 +198,8 @@ class Options:
try:
with open(filename, "r", encoding="utf8") as file:
self.data = json.load(file)
+ except FileNotFoundError:
+ self.data = {}
except Exception:
errors.report(f'\nCould not load settings\nThe config file "{filename}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True)
os.replace(filename, os.path.join(script_path, "tmp", "config.json"))
diff --git a/modules/processing.py b/modules/processing.py
index 72d8093b..52f00bfb 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -1005,7 +1005,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
image = pp.image
mask_for_overlay = getattr(p, "mask_for_overlay", None)
- overlay_image = p.overlay_images[i] if getattr(p, "overlay_images", None) is not None and i < len(p.overlay_images) else None
+
+ if not shared.opts.overlay_inpaint:
+ overlay_image = None
+ elif getattr(p, "overlay_images", None) is not None and i < len(p.overlay_images):
+ overlay_image = p.overlay_images[i]
+ else:
+ overlay_image = None
if p.scripts is not None:
ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image)
diff --git a/modules/scripts.py b/modules/scripts.py
index 060069cf..94690a22 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -939,22 +939,34 @@ class ScriptRunner:
except Exception:
errors.report(f"Error running setup: {script.filename}", exc_info=True)
- def set_named_arg(self, args, script_type, arg_elem_id, value):
- script = next((x for x in self.scripts if type(x).__name__ == script_type), None)
+ def set_named_arg(self, args, script_name, arg_elem_id, value, fuzzy=False):
+ """Locate an arg of a specific script in script_args and set its value
+ Args:
+ args: all script args of process p, p.script_args
+ script_name: the name target script name to
+ arg_elem_id: the elem_id of the target arg
+ value: the value to set
+ fuzzy: if True, arg_elem_id can be a substring of the control.elem_id else exact match
+ Returns:
+ Updated script args
+ when script_name in not found or arg_elem_id is not found in script controls, raise RuntimeError
+ """
+ script = next((x for x in self.scripts if x.name == script_name), None)
if script is None:
- return
+ raise RuntimeError(f"script {script_name} not found")
for i, control in enumerate(script.controls):
- if arg_elem_id in control.elem_id:
+ if arg_elem_id in control.elem_id if fuzzy else arg_elem_id == control.elem_id:
index = script.args_from + i
- if isinstance(args, list):
+ if isinstance(args, tuple):
+ return args[:index] + (value,) + args[index + 1:]
+ elif isinstance(args, list):
args[index] = value
return args
- elif isinstance(args, tuple):
- return args[:index] + (value,) + args[index+1:]
else:
- return None
+ raise RuntimeError(f"args is not a list or tuple, but {type(args)}")
+ raise RuntimeError(f"arg_elem_id {arg_elem_id} not found in script {script_name}")
scripts_txt2img: ScriptRunner = None
diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py
index 6d76aa96..941dff4b 100644
--- a/modules/sd_samplers_cfg_denoiser.py
+++ b/modules/sd_samplers_cfg_denoiser.py
@@ -53,6 +53,7 @@ class CFGDenoiser(torch.nn.Module):
self.step = 0
self.image_cfg_scale = None
self.padded_cond_uncond = False
+ self.padded_cond_uncond_v0 = False
self.sampler = sampler
self.model_wrap = None
self.p = None
@@ -91,6 +92,62 @@ class CFGDenoiser(torch.nn.Module):
self.sampler.sampler_extra_args['cond'] = c
self.sampler.sampler_extra_args['uncond'] = uc
+ def pad_cond_uncond(self, cond, uncond):
+ empty = shared.sd_model.cond_stage_model_empty_prompt
+ num_repeats = (cond.shape[1] - uncond.shape[1]) // empty.shape[1]
+
+ if num_repeats < 0:
+ cond = pad_cond(cond, -num_repeats, empty)
+ self.padded_cond_uncond = True
+ elif num_repeats > 0:
+ uncond = pad_cond(uncond, num_repeats, empty)
+ self.padded_cond_uncond = True
+
+ return cond, uncond
+
+ def pad_cond_uncond_v0(self, cond, uncond):
+ """
+ Pads the 'uncond' tensor to match the shape of the 'cond' tensor.
+
+ If 'uncond' is a dictionary, it is assumed that the 'crossattn' key holds the tensor to be padded.
+ If 'uncond' is a tensor, it is padded directly.
+
+ If the number of columns in 'uncond' is less than the number of columns in 'cond', the last column of 'uncond'
+ is repeated to match the number of columns in 'cond'.
+
+ If the number of columns in 'uncond' is greater than the number of columns in 'cond', 'uncond' is truncated
+ to match the number of columns in 'cond'.
+
+ Args:
+ cond (torch.Tensor or DictWithShape): The condition tensor to match the shape of 'uncond'.
+ uncond (torch.Tensor or DictWithShape): The tensor to be padded, or a dictionary containing the tensor to be padded.
+
+ Returns:
+ tuple: A tuple containing the 'cond' tensor and the padded 'uncond' tensor.
+
+ Note:
+ This is the padding that was always used in DDIM before version 1.6.0
+ """
+
+ is_dict_cond = isinstance(uncond, dict)
+ uncond_vec = uncond['crossattn'] if is_dict_cond else uncond
+
+ if uncond_vec.shape[1] < cond.shape[1]:
+ last_vector = uncond_vec[:, -1:]
+ last_vector_repeated = last_vector.repeat([1, cond.shape[1] - uncond_vec.shape[1], 1])
+ uncond_vec = torch.hstack([uncond_vec, last_vector_repeated])
+ self.padded_cond_uncond_v0 = True
+ elif uncond_vec.shape[1] > cond.shape[1]:
+ uncond_vec = uncond_vec[:, :cond.shape[1]]
+ self.padded_cond_uncond_v0 = True
+
+ if is_dict_cond:
+ uncond['crossattn'] = uncond_vec
+ else:
+ uncond = uncond_vec
+
+ return cond, uncond
+
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
@@ -162,16 +219,11 @@ class CFGDenoiser(torch.nn.Module):
sigma_in = sigma_in[:-batch_size]
self.padded_cond_uncond = False
+ self.padded_cond_uncond_v0 = False
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
- empty = shared.sd_model.cond_stage_model_empty_prompt
- num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
-
- if num_repeats < 0:
- tensor = pad_cond(tensor, -num_repeats, empty)
- self.padded_cond_uncond = True
- elif num_repeats > 0:
- uncond = pad_cond(uncond, num_repeats, empty)
- self.padded_cond_uncond = True
+ tensor, uncond = self.pad_cond_uncond(tensor, uncond)
+ elif shared.opts.pad_cond_uncond_v0 and tensor.shape[1] != uncond.shape[1]:
+ tensor, uncond = self.pad_cond_uncond_v0(tensor, uncond)
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
if is_edit_model:
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index 58efcad2..6bd38e12 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -335,3 +335,10 @@ class Sampler:
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
raise NotImplementedError()
+
+ def add_infotext(self, p):
+ if self.model_wrap_cfg.padded_cond_uncond:
+ p.extra_generation_params["Pad conds"] = True
+
+ if self.model_wrap_cfg.padded_cond_uncond_v0:
+ p.extra_generation_params["Pad conds v0"] = True
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 8a8c87e0..337106c0 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -187,8 +187,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
- if self.model_wrap_cfg.padded_cond_uncond:
- p.extra_generation_params["Pad conds"] = True
+ self.add_infotext(p)
return samples
@@ -234,8 +233,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
- if self.model_wrap_cfg.padded_cond_uncond:
- p.extra_generation_params["Pad conds"] = True
+ self.add_infotext(p)
return samples
diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py
index 777dd8d0..8cc7d384 100644
--- a/modules/sd_samplers_timesteps.py
+++ b/modules/sd_samplers_timesteps.py
@@ -133,8 +133,7 @@ class CompVisSampler(sd_samplers_common.Sampler):
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
- if self.model_wrap_cfg.padded_cond_uncond:
- p.extra_generation_params["Pad conds"] = True
+ self.add_infotext(p)
return samples
@@ -158,8 +157,7 @@ class CompVisSampler(sd_samplers_common.Sampler):
}
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
- if self.model_wrap_cfg.padded_cond_uncond:
- p.extra_generation_params["Pad conds"] = True
+ self.add_infotext(p)
return samples
diff --git a/modules/shared.py b/modules/shared.py
index 63661939..ccdca4e7 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -1,3 +1,4 @@
+import os
import sys
import gradio as gr
@@ -11,7 +12,7 @@ parser = shared_cmd_options.parser
batch_cond_uncond = True # old field, unused now in favor of shared.opts.batch_cond_uncond
parallel_processing_allowed = True
-styles_filename = cmd_opts.styles_file
+styles_filename = cmd_opts.styles_file = cmd_opts.styles_file if len(cmd_opts.styles_file) > 0 else [os.path.join(data_path, 'styles.csv')]
config_filename = cmd_opts.ui_settings_file
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
diff --git a/modules/shared_options.py b/modules/shared_options.py
index ec5cb026..bdd066c4 100644
--- a/modules/shared_options.py
+++ b/modules/shared_options.py
@@ -201,6 +201,7 @@ options_templates.update(options_section(('img2img', "img2img", "sd"), {
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
"img2img_batch_show_results_limit": OptionInfo(32, "Show the first N batch img2img results in UI", gr.Slider, {"minimum": -1, "maximum": 1000, "step": 1}).info('0: disable, -1: show all images. Too many images can cause lag'),
+ "overlay_inpaint": OptionInfo(True, "Overlay original for inpaint").info("when inpainting, overlay the original image over the areas that weren't inpainted."),
}))
options_templates.update(options_section(('optimizations', "Optimizations", "sd"), {
@@ -209,7 +210,8 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd"
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio hr').info("only applies if non-zero and overrides above"),
- "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
+ "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
+ "pad_cond_uncond_v0": OptionInfo(False, "Pad prompt/negative prompt (v0)", infotext='Pad conds v0').info("alternative implementation for the above; used prior to 1.6.0 for DDIM sampler; ignored if the above is set; changes seeds"),
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
"fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."),
diff --git a/modules/styles.py b/modules/styles.py
index 026c4300..60bd8a7f 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -1,16 +1,16 @@
+from pathlib import Path
+from modules import errors
import csv
-import fnmatch
import os
-import os.path
import typing
import shutil
class PromptStyle(typing.NamedTuple):
name: str
- prompt: str
- negative_prompt: str
- path: str = None
+ prompt: str | None
+ negative_prompt: str | None
+ path: str | None = None
def merge_prompts(style_prompt: str, prompt: str) -> str:
@@ -79,14 +79,19 @@ def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
class StyleDatabase:
- def __init__(self, path: str):
+ def __init__(self, paths: list[str | Path]):
self.no_style = PromptStyle("None", "", "", None)
self.styles = {}
- self.path = path
-
- folder, file = os.path.split(self.path)
- filename, _, ext = file.partition('*')
- self.default_path = os.path.join(folder, filename + ext)
+ self.paths = paths
+ self.all_styles_files: list[Path] = []
+
+ folder, file = os.path.split(self.paths[0])
+ if '*' in file or '?' in file:
+ # if the first path is a wildcard pattern, find the first match else use "folder/styles.csv" as the default path
+ self.default_path = next(Path(folder).glob(file), Path(os.path.join(folder, 'styles.csv')))
+ self.paths.insert(0, self.default_path)
+ else:
+ self.default_path = Path(self.paths[0])
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
@@ -99,57 +104,58 @@ class StyleDatabase:
"""
self.styles.clear()
- path, filename = os.path.split(self.path)
-
- if "*" in filename:
- fileglob = filename.split("*")[0] + "*.csv"
- filelist = []
- for file in os.listdir(path):
- if fnmatch.fnmatch(file, fileglob):
- filelist.append(file)
- # Add a visible divider to the style list
- half_len = round(len(file) / 2)
- divider = f"{'-' * (20 - half_len)} {file.upper()}"
- divider = f"{divider} {'-' * (40 - len(divider))}"
- self.styles[divider] = PromptStyle(
- f"{divider}", None, None, "do_not_save"
+ # scans for all styles files
+ all_styles_files = []
+ for pattern in self.paths:
+ folder, file = os.path.split(pattern)
+ if '*' in file or '?' in file:
+ found_files = Path(folder).glob(file)
+ [all_styles_files.append(file) for file in found_files]
+ else:
+ # if os.path.exists(pattern):
+ all_styles_files.append(Path(pattern))
+
+ # Remove any duplicate entries
+ seen = set()
+ self.all_styles_files = [s for s in all_styles_files if not (s in seen or seen.add(s))]
+
+ for styles_file in self.all_styles_files:
+ if len(all_styles_files) > 1:
+ # add divider when more than styles file
+ # '---------------- STYLES ----------------'
+ divider = f' {styles_file.stem.upper()} '.center(40, '-')
+ self.styles[divider] = PromptStyle(f"{divider}", None, None, "do_not_save")
+ if styles_file.is_file():
+ self.load_from_csv(styles_file)
+
+ def load_from_csv(self, path: str | Path):
+ try:
+ with open(path, "r", encoding="utf-8-sig", newline="") as file:
+ reader = csv.DictReader(file, skipinitialspace=True)
+ for row in reader:
+ # Ignore empty rows or rows starting with a comment
+ if not row or row["name"].startswith("#"):
+ continue
+ # Support loading old CSV format with "name, text"-columns
+ prompt = row["prompt"] if "prompt" in row else row["text"]
+ negative_prompt = row.get("negative_prompt", "")
+ # Add style to database
+ self.styles[row["name"]] = PromptStyle(
+ row["name"], prompt, negative_prompt, str(path)
)
- # Add styles from this CSV file
- self.load_from_csv(os.path.join(path, file))
- if len(filelist) == 0:
- print(f"No styles found in {path} matching {fileglob}")
- return
- elif not os.path.exists(self.path):
- print(f"Style database not found: {self.path}")
- return
- else:
- self.load_from_csv(self.path)
-
- def load_from_csv(self, path: str):
- with open(path, "r", encoding="utf-8-sig", newline="") as file:
- reader = csv.DictReader(file, skipinitialspace=True)
- for row in reader:
- # Ignore empty rows or rows starting with a comment
- if not row or row["name"].startswith("#"):
- continue
- # Support loading old CSV format with "name, text"-columns
- prompt = row["prompt"] if "prompt" in row else row["text"]
- negative_prompt = row.get("negative_prompt", "")
- # Add style to database
- self.styles[row["name"]] = PromptStyle(
- row["name"], prompt, negative_prompt, path
- )
+ except Exception:
+ errors.report(f'Error loading styles from {path}: ', exc_info=True)
def get_style_paths(self) -> set:
"""Returns a set of all distinct paths of files that styles are loaded from."""
# Update any styles without a path to the default path
for style in list(self.styles.values()):
if not style.path:
- self.styles[style.name] = style._replace(path=self.default_path)
+ self.styles[style.name] = style._replace(path=str(self.default_path))
# Create a list of all distinct paths, including the default path
style_paths = set()
- style_paths.add(self.default_path)
+ style_paths.add(str(self.default_path))
for _, style in self.styles.items():
if style.path:
style_paths.add(style.path)
@@ -177,7 +183,6 @@ class StyleDatabase:
def save_styles(self, path: str = None) -> None:
# The path argument is deprecated, but kept for backwards compatibility
- _ = path
style_paths = self.get_style_paths()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index c6bcab15..6d815c0b 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -150,6 +150,7 @@ class EmbeddingDatabase:
return embedding
def get_expected_shape(self):
+ devices.torch_npu_set_device()
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
return vec.shape[1]
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 4efcb4c3..fc56b8a8 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -60,10 +60,10 @@ def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, g
assert len(gallery) > 0, 'No image to upscale'
assert 0 <= gallery_index < len(gallery), f'Bad image index: {gallery_index}'
- p = txt2img_create_processing(id_task, request, *args)
- p.enable_hr = True
+ p = txt2img_create_processing(id_task, request, *args, force_enable_hr=True)
p.batch_size = 1
p.n_iter = 1
+ # txt2img_upscale attribute that signifies this is called by txt2img_upscale
p.txt2img_upscale = True
geninfo = json.loads(generation_info)
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 325d848e..c03b9f08 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -134,8 +134,8 @@ def get_single_card(page: str = "", tabname: str = "", name: str = ""):
errors.display(e, "creating item for extra network")
item = page.items.get(name)
- page.read_user_metadata(item)
- item_html = page.create_item_html(tabname, item)
+ page.read_user_metadata(item, use_cache=False)
+ item_html = page.create_item_html(tabname, item, shared.html("extra-networks-card.html"))
return JSONResponse({"html": item_html})
@@ -173,9 +173,9 @@ class ExtraNetworksPage:
def refresh(self):
pass
- def read_user_metadata(self, item):
+ def read_user_metadata(self, item, use_cache=True):
filename = item.get("filename", None)
- metadata = extra_networks.get_user_metadata(filename, lister=self.lister)
+ metadata = extra_networks.get_user_metadata(filename, lister=self.lister if use_cache else None)
desc = metadata.get("description", None)
if desc is not None:
@@ -559,7 +559,7 @@ class ExtraNetworksPage:
"date_created": int(mtime),
"date_modified": int(ctime),
"name": pth.name.lower(),
- "path": str(pth.parent).lower(),
+ "path": str(pth).lower(),
}
def find_preview(self, path):
@@ -638,6 +638,7 @@ def pages_in_preferred_order(pages):
return sorted(pages, key=lambda x: tab_scores[x.name])
+
def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
ui = ExtraNetworksUi()
ui.pages = []
@@ -648,8 +649,6 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
related_tabs = []
- button_refresh = gr.Button("Refresh", elem_id=f"{tabname}_extra_refresh_internal", visible=False)
-
for page in ui.stored_extra_pages:
with gr.Tab(page.title, elem_id=f"{tabname}_{page.extra_networks_tabname}", elem_classes=["extra-page"]) as tab:
with gr.Column(elem_id=f"{tabname}_{page.extra_networks_tabname}_prompts", elem_classes=["extra-page-prompts"]):
@@ -678,6 +677,15 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
)
tab.select(fn=None, _js=jscode, inputs=[], outputs=[], show_progress=False)
+ def refresh():
+ for pg in ui.stored_extra_pages:
+ pg.refresh()
+ create_html()
+ return ui.pages_contents
+
+ button_refresh = gr.Button("Refresh", elem_id=f"{tabname}_{page.extra_networks_tabname}_extra_refresh_internal", visible=False)
+ button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages).then(fn=lambda: None, _js="function(){ " + f"applyExtraNetworkFilter('{tabname}_{page.extra_networks_tabname}');" + " }")
+
def create_html():
ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages]
@@ -686,16 +694,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
create_html()
return ui.pages_contents
- def refresh():
- for pg in ui.stored_extra_pages:
- pg.refresh()
- create_html()
- return ui.pages_contents
-
interface.load(fn=pages_html, inputs=[], outputs=ui.pages)
- # NOTE: Event is manually fired in extraNetworks.js:extraNetworksTreeRefreshOnClick()
- # button is unused and hidden at all times. Only used in order to fire this event.
- button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
return ui
diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py
index a8c33671..d69d144d 100644
--- a/modules/ui_extra_networks_checkpoints.py
+++ b/modules/ui_extra_networks_checkpoints.py
@@ -30,7 +30,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
"preview": self.find_preview(path),
"description": self.find_description(path),
"search_terms": search_terms,
- "onclick": html.escape(f"return selectCheckpoint('{name}');"),
+ "onclick": html.escape(f"return selectCheckpoint({ui_extra_networks.quote_js(name)})"),
"local_preview": f"{path}.{shared.opts.samples_format}",
"metadata": checkpoint.metadata,
"sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)},
diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py
index ff22a178..7261c2df 100644
--- a/modules/ui_postprocessing.py
+++ b/modules/ui_postprocessing.py
@@ -1,13 +1,14 @@
import gradio as gr
from modules import scripts, shared, ui_common, postprocessing, call_queue, ui_toprow
import modules.infotext_utils as parameters_copypaste
+from modules.ui_components import ResizeHandleRow
def create_ui():
dummy_component = gr.Label(visible=False)
tab_index = gr.Number(value=0, visible=False)
- with gr.Row(equal_height=False, variant='compact'):
+ with ResizeHandleRow(equal_height=False, variant='compact'):
with gr.Column(variant='compact'):
with gr.Tabs(elem_id="mode_extras"):
with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:
diff --git a/modules/ui_prompt_styles.py b/modules/ui_prompt_styles.py
index 0d74c23f..d67e3f17 100644
--- a/modules/ui_prompt_styles.py
+++ b/modules/ui_prompt_styles.py
@@ -22,9 +22,12 @@ def save_style(name, prompt, negative_prompt):
if not name:
return gr.update(visible=False)
- style = styles.PromptStyle(name, prompt, negative_prompt)
+ existing_style = shared.prompt_styles.styles.get(name)
+ path = existing_style.path if existing_style is not None else None
+
+ style = styles.PromptStyle(name, prompt, negative_prompt, path)
shared.prompt_styles.styles[style.name] = style
- shared.prompt_styles.save_styles(shared.styles_filename)
+ shared.prompt_styles.save_styles()
return gr.update(visible=True)
@@ -34,7 +37,7 @@ def delete_style(name):
return
shared.prompt_styles.styles.pop(name, None)
- shared.prompt_styles.save_styles(shared.styles_filename)
+ shared.prompt_styles.save_styles()
return '', '', ''
diff --git a/modules/ui_toprow.py b/modules/ui_toprow.py
index fbe705be..30cf1b1b 100644
--- a/modules/ui_toprow.py
+++ b/modules/ui_toprow.py
@@ -17,6 +17,7 @@ class Toprow:
button_deepbooru = None
interrupt = None
+ interrupting = None
skip = None
submit = None
@@ -96,15 +97,10 @@ class Toprow:
with gr.Row(elem_id=f"{self.id_part}_generate_box", elem_classes=["generate-box"] + (["generate-box-compact"] if self.is_compact else []), render=not self.is_compact) as submit_box:
self.submit_box = submit_box
- self.interrupt = gr.Button('Interrupt', elem_id=f"{self.id_part}_interrupt", elem_classes="generate-box-interrupt")
- self.skip = gr.Button('Skip', elem_id=f"{self.id_part}_skip", elem_classes="generate-box-skip")
- self.submit = gr.Button('Generate', elem_id=f"{self.id_part}_generate", variant='primary')
-
- self.skip.click(
- fn=lambda: shared.state.skip(),
- inputs=[],
- outputs=[],
- )
+ self.interrupt = gr.Button('Interrupt', elem_id=f"{self.id_part}_interrupt", elem_classes="generate-box-interrupt", tooltip="End generation immediately or after completing current batch")
+ self.skip = gr.Button('Skip', elem_id=f"{self.id_part}_skip", elem_classes="generate-box-skip", tooltip="Stop generation of current batch and continues onto next batch")
+ self.interrupting = gr.Button('Interrupting...', elem_id=f"{self.id_part}_interrupting", elem_classes="generate-box-interrupting", tooltip="Interrupting generation...")
+ self.submit = gr.Button('Generate', elem_id=f"{self.id_part}_generate", variant='primary', tooltip="Right click generate forever menu")
def interrupt_function():
if not shared.state.stopping_generation and shared.state.job_count > 1 and shared.opts.interrupt_after_current:
@@ -113,11 +109,9 @@ class Toprow:
else:
shared.state.interrupt()
- self.interrupt.click(
- fn=interrupt_function,
- inputs=[],
- outputs=[],
- )
+ self.skip.click(fn=shared.state.skip)
+ self.interrupt.click(fn=interrupt_function, _js='function(){ showSubmitInterruptingPlaceholder("' + self.id_part + '"); }')
+ self.interrupting.click(fn=interrupt_function)
def create_tools_row(self):
with gr.Row(elem_id=f"{self.id_part}_tools"):
diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py
index afed8b40..b5e5a80c 100644
--- a/modules/upscaler_utils.py
+++ b/modules/upscaler_utils.py
@@ -6,7 +6,7 @@ import torch
import tqdm
from PIL import Image
-from modules import images, shared, torch_utils
+from modules import devices, images, shared, torch_utils
logger = logging.getLogger(__name__)
@@ -44,7 +44,8 @@ def upscale_pil_patch(model, img: Image.Image) -> Image.Image:
with torch.no_grad():
tensor = pil_image_to_torch_bgr(img).unsqueeze(0) # add batch dimension
tensor = tensor.to(device=param.device, dtype=param.dtype)
- return torch_bgr_to_pil_image(model(tensor))
+ with devices.without_autocast():
+ return torch_bgr_to_pil_image(model(tensor))
def upscale_with_model(