aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--javascript/edit-attention.js78
-rw-r--r--modules/api/api.py11
-rw-r--r--modules/initialize_util.py6
-rw-r--r--modules/sd_models.py13
-rw-r--r--modules/shared_options.py5
-rw-r--r--modules/ui_settings.py24
6 files changed, 88 insertions, 49 deletions
diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js
index 794453bf..45d9a788 100644
--- a/javascript/edit-attention.js
+++ b/javascript/edit-attention.js
@@ -26,6 +26,7 @@ function keyupEditAttention(event) {
// Set the selection to the text between the parenthesis
const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
+ if (!/.*:-?[\d.]+/s.test(parenContent)) return false;
const lastColon = parenContent.lastIndexOf(":");
selectionStart = beforeParen + 1;
selectionEnd = selectionStart + lastColon;
@@ -66,40 +67,75 @@ function keyupEditAttention(event) {
var closeCharacter = ')';
var delta = opts.keyedit_precision_attention;
- if (selectionStart > 0 && text[selectionStart - 1] == '<') {
+ if (selectionStart > 0 && /<.*:-?[\d.]+>/s.test(text.slice(selectionStart - 1, selectionEnd + text.slice(selectionEnd).indexOf(">") + 1))) {
closeCharacter = '>';
delta = opts.keyedit_precision_extra;
- } else if (selectionStart == 0 || text[selectionStart - 1] != "(") {
-
+ } else if (selectionStart > 0 && /\(.*\)|\[.*\]/s.test(text.slice(selectionStart - 1, selectionEnd + 1))) {
+ let start = text[selectionStart - 1];
+ let end = text[selectionEnd];
+ if (opts.keyedit_convert) {
+ let numParen = 0;
+
+ while (text[selectionStart - numParen - 1] == start && text[selectionEnd + numParen] == end) {
+ numParen++;
+ }
+
+ if (start == "(") {
+ weight = 1.1 ** numParen;
+ } else {
+ weight = (1 / 1.1) ** numParen;
+ }
+
+ weight = Math.round(weight / opts.keyedit_precision_attention) * opts.keyedit_precision_attention;
+
+ text = text.slice(0, selectionStart - numParen) + "(" + text.slice(selectionStart, selectionEnd) + ":" + weight + ")" + text.slice(selectionEnd + numParen);
+ selectionStart -= numParen - 1;
+ selectionEnd -= numParen - 1;
+ } else {
+ closeCharacter = null;
+ if (isPlus) {
+ text = text.slice(0, selectionStart) + start + text.slice(selectionStart, selectionEnd) + end + text.slice(selectionEnd);
+ selectionStart++;
+ selectionEnd++;
+ } else {
+ text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 1);
+ selectionStart--;
+ selectionEnd--;
+ }
+ }
+ } else if (selectionStart == 0 || !/\(.*:-?[\d.]+\)/s.test(text.slice(selectionStart - 1, selectionEnd + text.slice(selectionEnd).indexOf(")") + 1))) {
// do not include spaces at the end
while (selectionEnd > selectionStart && text[selectionEnd - 1] == ' ') {
- selectionEnd -= 1;
+ selectionEnd--;
}
+
if (selectionStart == selectionEnd) {
return;
}
text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
- selectionStart += 1;
- selectionEnd += 1;
+ selectionStart++;
+ selectionEnd++;
}
- var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
- var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + end));
- if (isNaN(weight)) return;
-
- weight += isPlus ? delta : -delta;
- weight = parseFloat(weight.toPrecision(12));
- if (String(weight).length == 1) weight += ".0";
-
- if (closeCharacter == ')' && weight == 1) {
- var endParenPos = text.substring(selectionEnd).indexOf(')');
- text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + endParenPos + 1);
- selectionStart--;
- selectionEnd--;
- } else {
- text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + end);
+ if (closeCharacter) {
+ var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
+ var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + end));
+ if (isNaN(weight)) return;
+
+ weight += isPlus ? delta : -delta;
+ weight = parseFloat(weight.toPrecision(12));
+ if (Number.isInteger(weight)) weight += ".0";
+
+ if (closeCharacter == ')' && weight == 1) {
+ var endParenPos = text.substring(selectionEnd).indexOf(')');
+ text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + endParenPos + 1);
+ selectionStart--;
+ selectionEnd--;
+ } else {
+ text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + end);
+ }
}
target.focus();
diff --git a/modules/api/api.py b/modules/api/api.py
index efedafa4..09083874 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -17,15 +17,14 @@ from fastapi.encoders import jsonable_encoder
from secrets import compare_digest
import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models
from modules.api import models
from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
-from PIL import PngImagePlugin,Image
-from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases
+from PIL import PngImagePlugin, Image
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
@@ -541,12 +540,12 @@ class Api:
return {}
def unloadapi(self):
- unload_model_weights()
+ sd_models.unload_model_weights()
return {}
def reloadapi(self):
- reload_model_weights()
+ sd_models.send_model_to_device(shared.sd_model)
return {}
@@ -566,7 +565,7 @@ class Api:
def set_config(self, req: dict[str, Any]):
checkpoint_name = req.get("sd_model_checkpoint", None)
- if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases:
+ if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases:
raise RuntimeError(f"model {checkpoint_name!r} not found")
for k, v in req.items():
diff --git a/modules/initialize_util.py b/modules/initialize_util.py
index 2894eee4..2e9b6d89 100644
--- a/modules/initialize_util.py
+++ b/modules/initialize_util.py
@@ -150,10 +150,14 @@ def dumpstacks():
def configure_sigint_handler():
# make the program just exit at ctrl+c without waiting for anything
+
+ from modules import shared
+
def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}')
- dumpstacks()
+ if shared.opts.dump_stacks_on_signal:
+ dumpstacks()
os._exit(0)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index c8efeedc..3b6cdea1 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -1,7 +1,6 @@
import collections
import os.path
import sys
-import gc
import threading
import torch
@@ -798,17 +797,7 @@ def reload_model_weights(sd_model=None, info=None):
def unload_model_weights(sd_model=None, info=None):
- timer = Timer()
-
- if model_data.sd_model:
- model_data.sd_model.to(devices.cpu)
- sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
- model_data.sd_model = None
- sd_model = None
- gc.collect()
- devices.torch_gc()
-
- print(f"Unloaded weights {timer.summary()}.")
+ send_model_to_cpu(sd_model or shared.sd_model)
return sd_model
diff --git a/modules/shared_options.py b/modules/shared_options.py
index ce395302..32bf7353 100644
--- a/modules/shared_options.py
+++ b/modules/shared_options.py
@@ -112,6 +112,7 @@ options_templates.update(options_section(('system', "System"), {
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
"hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
+ "dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."),
}))
options_templates.update(options_section(('API', "API"), {
@@ -258,8 +259,9 @@ options_templates.update(options_section(('ui', "User interface"), {
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(),
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
- "keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Ctrl+up/down word delimiters"),
+ "keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~()[]<>| ", "Ctrl+up/down word delimiters"),
"keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}),
+ "keyedit_convert": OptionInfo(True, "Convert (attention) to (attention:1.1)"),
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(),
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
@@ -334,4 +336,3 @@ options_templates.update(options_section((None, "Hidden options"), {
"restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"),
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
}))
-
diff --git a/modules/ui_settings.py b/modules/ui_settings.py
index 74a3aef3..e054d00a 100644
--- a/modules/ui_settings.py
+++ b/modules/ui_settings.py
@@ -1,6 +1,6 @@
import gradio as gr
-from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo
+from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer
from modules.call_queue import wrap_gradio_call
from modules.shared import opts
from modules.ui_components import FormRow
@@ -177,8 +177,8 @@ class UiSettings:
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
with gr.Row():
- unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
- reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
+ unload_sd_model = gr.Button(value='Unload SD checkpoint to RAM', elem_id="sett_unload_sd_model")
+ reload_sd_model = gr.Button(value='Load SD checkpoint to VRAM from RAM', elem_id="sett_reload_sd_model")
with gr.Row():
calculate_all_checkpoint_hash = gr.Button(value='Calculate hash for all checkpoint', elem_id="calculate_all_checkpoint_hash")
calculate_all_checkpoint_hash_threads = gr.Number(value=1, label="Number of parallel calculations", elem_id="calculate_all_checkpoint_hash_threads", precision=0, minimum=1)
@@ -194,16 +194,26 @@ class UiSettings:
self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
+ def call_func_and_return_text(func, text):
+ def handler():
+ t = timer.Timer()
+ func()
+ t.record(text)
+
+ return f'{text} in {t.total:.1f}s'
+
+ return handler
+
unload_sd_model.click(
- fn=sd_models.unload_model_weights,
+ fn=call_func_and_return_text(sd_models.unload_model_weights, 'Unloaded the checkpoint'),
inputs=[],
- outputs=[]
+ outputs=[self.result]
)
reload_sd_model.click(
- fn=sd_models.reload_model_weights,
+ fn=call_func_and_return_text(lambda: sd_models.send_model_to_device(shared.sd_model), 'Loaded the checkpoint'),
inputs=[],
- outputs=[]
+ outputs=[self.result]
)
request_notifications.click(