aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py3
-rw-r--r--modules/devices.py5
-rw-r--r--modules/hypernetworks/hypernetwork.py6
-rw-r--r--modules/images.py10
-rw-r--r--modules/img2img.py29
-rw-r--r--modules/mac_specific.py18
-rw-r--r--modules/textual_inversion/textual_inversion.py6
-rw-r--r--modules/txt2img.py11
-rw-r--r--modules/ui_settings.py21
9 files changed, 72 insertions, 37 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 5793bb44..1804a383 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -598,7 +598,8 @@ class Api:
}
def refresh_checkpoints(self):
- shared.refresh_checkpoints()
+ with self.queue_lock:
+ shared.refresh_checkpoints()
def create_embedding(self, args: dict):
try:
diff --git a/modules/devices.py b/modules/devices.py
index c5ad950f..57e51da3 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -54,8 +54,9 @@ def torch_gc():
with torch.cuda.device(get_cuda_device_string()):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
- elif has_mps() and hasattr(torch.mps, 'empty_cache'):
- torch.mps.empty_cache()
+
+ if has_mps():
+ mac_specific.torch_mps_gc()
def enable_tf32():
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 51941c11..79670b87 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -3,6 +3,7 @@ import glob
import html
import os
import inspect
+from contextlib import closing
import modules.textual_inversion.dataset
import torch
@@ -711,8 +712,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
preview_text = p.prompt
- processed = processing.process_images(p)
- image = processed.images[0] if len(processed.images) > 0 else None
+ with closing(p):
+ processed = processing.process_images(p)
+ image = processed.images[0] if len(processed.images) > 0 else None
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
diff --git a/modules/images.py b/modules/images.py
index b5412548..4bdedb7f 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -306,12 +306,14 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None):
if ratio < src_ratio:
fill_height = height // 2 - src_h // 2
- res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
- res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
+ if fill_height > 0:
+ res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
+ res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
elif ratio > src_ratio:
fill_width = width // 2 - src_w // 2
- res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
- res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
+ if fill_width > 0:
+ res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
+ res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
return res
diff --git a/modules/img2img.py b/modules/img2img.py
index ef87eb0f..664e2688 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -1,4 +1,5 @@
import os
+from contextlib import closing
from pathlib import Path
import numpy as np
@@ -9,6 +10,7 @@ from modules import sd_samplers, images as imgutil
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
+from modules.images import save_image
import modules.shared as shared
import modules.processing as processing
from modules.ui import plaintext_to_html
@@ -112,18 +114,18 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
proc = process_images(p)
for n, processed_image in enumerate(proc.images):
- filename = image_path.name
+ filename = image_path.stem
+ infotext = proc.infotext(p, n)
relpath = os.path.dirname(os.path.relpath(image, input_dir))
if n > 0:
- left, right = os.path.splitext(filename)
- filename = f"{left}-{n}{right}"
+ filename += f"-{n}"
if not save_normally:
os.makedirs(os.path.join(output_dir, relpath), exist_ok=True)
if processed_image.mode == 'RGBA':
processed_image = processed_image.convert("RGB")
- processed_image.save(os.path.join(output_dir, relpath, filename))
+ save_image(processed_image, os.path.join(output_dir, relpath), None, extension=opts.samples_format, info=infotext, forced_filename=filename, save_to_dirs=False)
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
@@ -217,18 +219,17 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
if mask:
p.extra_generation_params["Mask blur"] = mask_blur
- if is_batch:
- assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
+ with closing(p):
+ if is_batch:
+ assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
- process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
+ process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
- processed = Processed(p, [], p.seed, "")
- else:
- processed = modules.scripts.scripts_img2img.run(p, *args)
- if processed is None:
- processed = process_images(p)
-
- p.close()
+ processed = Processed(p, [], p.seed, "")
+ else:
+ processed = modules.scripts.scripts_img2img.run(p, *args)
+ if processed is None:
+ processed = process_images(p)
shared.total_tqdm.clear()
diff --git a/modules/mac_specific.py b/modules/mac_specific.py
index 735847f5..9ceb43ba 100644
--- a/modules/mac_specific.py
+++ b/modules/mac_specific.py
@@ -1,8 +1,12 @@
+import logging
+
import torch
import platform
from modules.sd_hijack_utils import CondFunc
from packaging import version
+log = logging.getLogger(__name__)
+
# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
# use check `getattr` and try it for compatibility.
@@ -19,9 +23,23 @@ def check_for_mps() -> bool:
return False
else:
return torch.backends.mps.is_available() and torch.backends.mps.is_built()
+
+
has_mps = check_for_mps()
+def torch_mps_gc() -> None:
+ try:
+ from modules.shared import state
+ if state.current_latent is not None:
+ log.debug("`current_latent` is set, skipping MPS garbage collection")
+ return
+ from torch.mps import empty_cache
+ empty_cache()
+ except Exception:
+ log.warning("MPS garbage collection failed", exc_info=True)
+
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
def cumsum_fix(input, cumsum_func, *args, **kwargs):
if input.device.type == 'mps':
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index bb6f211c..cbe975b7 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -1,5 +1,6 @@
import os
from collections import namedtuple
+from contextlib import closing
import torch
import tqdm
@@ -584,8 +585,9 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
preview_text = p.prompt
- processed = processing.process_images(p)
- image = processed.images[0] if len(processed.images) > 0 else None
+ with closing(p):
+ processed = processing.process_images(p)
+ image = processed.images[0] if len(processed.images) > 0 else None
if unload:
shared.sd_model.first_stage_model.to(devices.cpu)
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 6aa79f23..d0be2e73 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -1,3 +1,5 @@
+from contextlib import closing
+
import modules.scripts
from modules import sd_samplers, processing
from modules.generation_parameters_copypaste import create_override_settings_dict
@@ -53,12 +55,11 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
- processed = modules.scripts.scripts_txt2img.run(p, *args)
-
- if processed is None:
- processed = processing.process_images(p)
+ with closing(p):
+ processed = modules.scripts.scripts_txt2img.run(p, *args)
- p.close()
+ if processed is None:
+ processed = processing.process_images(p)
shared.total_tqdm.clear()
diff --git a/modules/ui_settings.py b/modules/ui_settings.py
index 0c560b30..a6076bf3 100644
--- a/modules/ui_settings.py
+++ b/modules/ui_settings.py
@@ -260,13 +260,20 @@ class UiSettings:
component = self.component_dict[k]
info = opts.data_labels[k]
- change_handler = component.release if hasattr(component, 'release') else component.change
- change_handler(
- fn=lambda value, k=k: self.run_settings_single(value, key=k),
- inputs=[component],
- outputs=[component, self.text_settings],
- show_progress=info.refresh is not None,
- )
+ if isinstance(component, gr.Textbox):
+ methods = [component.submit, component.blur]
+ elif hasattr(component, 'release'):
+ methods = [component.release]
+ else:
+ methods = [component.change]
+
+ for method in methods:
+ method(
+ fn=lambda value, k=k: self.run_settings_single(value, key=k),
+ inputs=[component],
+ outputs=[component, self.text_settings],
+ show_progress=info.refresh is not None,
+ )
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
button_set_checkpoint.click(