diff options
-rw-r--r-- | .eslintrc.js | 2 | ||||
-rw-r--r-- | javascript/dragdrop.js | 53 | ||||
-rw-r--r-- | javascript/imageParams.js | 18 | ||||
-rw-r--r-- | javascript/token-counters.js | 83 | ||||
-rw-r--r-- | javascript/ui.js | 71 | ||||
-rw-r--r-- | modules/api/api.py | 2 | ||||
-rw-r--r-- | modules/devices.py | 18 | ||||
-rw-r--r-- | modules/sd_models.py | 5 | ||||
-rw-r--r-- | modules/sd_samplers_kdiffusion.py | 10 | ||||
-rw-r--r-- | modules/shared.py | 2 | ||||
-rw-r--r-- | modules/ui.py | 6 | ||||
-rw-r--r-- | webui.py | 4 |
12 files changed, 170 insertions, 104 deletions
diff --git a/.eslintrc.js b/.eslintrc.js index 944cc869..218f5609 100644 --- a/.eslintrc.js +++ b/.eslintrc.js @@ -84,5 +84,7 @@ module.exports = { // imageviewer.js modalPrevImage: "readonly", modalNextImage: "readonly", + // token-counters.js + setupTokenCounters: "readonly", } }; diff --git a/javascript/dragdrop.js b/javascript/dragdrop.js index 77a24a07..5803daea 100644 --- a/javascript/dragdrop.js +++ b/javascript/dragdrop.js @@ -48,12 +48,27 @@ function dropReplaceImage(imgWrap, files) { } } +function eventHasFiles(e) { + if (!e.dataTransfer || !e.dataTransfer.files) return false; + if (e.dataTransfer.files.length > 0) return true; + if (e.dataTransfer.items.length > 0 && e.dataTransfer.items[0].kind == "file") return true; + + return false; +} + +function dragDropTargetIsPrompt(target) { + if (target?.placeholder && target?.placeholder.indexOf("Prompt") >= 0) return true; + if (target?.parentNode?.parentNode?.className?.indexOf("prompt") > 0) return true; + return false; +} + window.document.addEventListener('dragover', e => { const target = e.composedPath()[0]; - const imgWrap = target.closest('[data-testid="image"]'); - if (!imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) { - return; - } + if (!eventHasFiles(e)) return; + + var targetImage = target.closest('[data-testid="image"]'); + if (!dragDropTargetIsPrompt(target) && !targetImage) return; + e.stopPropagation(); e.preventDefault(); e.dataTransfer.dropEffect = 'copy'; @@ -61,17 +76,31 @@ window.document.addEventListener('dragover', e => { window.document.addEventListener('drop', e => { const target = e.composedPath()[0]; - if (target.placeholder.indexOf("Prompt") == -1) { - return; + if (!eventHasFiles(e)) return; + + if (dragDropTargetIsPrompt(target)) { + e.stopPropagation(); + e.preventDefault(); + + let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image"; + + const imgParent = gradioApp().getElementById(prompt_target); + const files = e.dataTransfer.files; + const fileInput = imgParent.querySelector('input[type="file"]'); + if (fileInput) { + fileInput.files = files; + fileInput.dispatchEvent(new Event('change')); + } } - const imgWrap = target.closest('[data-testid="image"]'); - if (!imgWrap) { + + var targetImage = target.closest('[data-testid="image"]'); + if (targetImage) { + e.stopPropagation(); + e.preventDefault(); + const files = e.dataTransfer.files; + dropReplaceImage(targetImage, files); return; } - e.stopPropagation(); - e.preventDefault(); - const files = e.dataTransfer.files; - dropReplaceImage(imgWrap, files); }); window.addEventListener('paste', e => { diff --git a/javascript/imageParams.js b/javascript/imageParams.js deleted file mode 100644 index 057e2d39..00000000 --- a/javascript/imageParams.js +++ /dev/null @@ -1,18 +0,0 @@ -window.onload = (function() { - window.addEventListener('drop', e => { - const target = e.composedPath()[0]; - if (target.placeholder.indexOf("Prompt") == -1) return; - - let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image"; - - e.stopPropagation(); - e.preventDefault(); - const imgParent = gradioApp().getElementById(prompt_target); - const files = e.dataTransfer.files; - const fileInput = imgParent.querySelector('input[type="file"]'); - if (fileInput) { - fileInput.files = files; - fileInput.dispatchEvent(new Event('change')); - } - }); -}); diff --git a/javascript/token-counters.js b/javascript/token-counters.js new file mode 100644 index 00000000..9d81a723 --- /dev/null +++ b/javascript/token-counters.js @@ -0,0 +1,83 @@ +let promptTokenCountDebounceTime = 800; +let promptTokenCountTimeouts = {}; +var promptTokenCountUpdateFunctions = {}; + +function update_txt2img_tokens(...args) { + // Called from Gradio + update_token_counter("txt2img_token_button"); + if (args.length == 2) { + return args[0]; + } + return args; +} + +function update_img2img_tokens(...args) { + // Called from Gradio + update_token_counter("img2img_token_button"); + if (args.length == 2) { + return args[0]; + } + return args; +} + +function update_token_counter(button_id) { + if (opts.disable_token_counters) { + return; + } + if (promptTokenCountTimeouts[button_id]) { + clearTimeout(promptTokenCountTimeouts[button_id]); + } + promptTokenCountTimeouts[button_id] = setTimeout( + () => gradioApp().getElementById(button_id)?.click(), + promptTokenCountDebounceTime, + ); +} + + +function recalculatePromptTokens(name) { + promptTokenCountUpdateFunctions[name]?.(); +} + +function recalculate_prompts_txt2img() { + // Called from Gradio + recalculatePromptTokens('txt2img_prompt'); + recalculatePromptTokens('txt2img_neg_prompt'); + return Array.from(arguments); +} + +function recalculate_prompts_img2img() { + // Called from Gradio + recalculatePromptTokens('img2img_prompt'); + recalculatePromptTokens('img2img_neg_prompt'); + return Array.from(arguments); +} + +function setupTokenCounting(id, id_counter, id_button) { + var prompt = gradioApp().getElementById(id); + var counter = gradioApp().getElementById(id_counter); + var textarea = gradioApp().querySelector(`#${id} > label > textarea`); + + if (opts.disable_token_counters) { + counter.style.display = "none"; + return; + } + + if (counter.parentElement == prompt.parentElement) { + return; + } + + prompt.parentElement.insertBefore(counter, prompt); + prompt.parentElement.style.position = "relative"; + + promptTokenCountUpdateFunctions[id] = function() { + update_token_counter(id_button); + }; + textarea.addEventListener("input", promptTokenCountUpdateFunctions[id]); +} + +function setupTokenCounters() { + setupTokenCounting('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button'); + setupTokenCounting('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button'); + setupTokenCounting('img2img_prompt', 'img2img_token_counter', 'img2img_token_button'); + setupTokenCounting('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button'); +} diff --git a/javascript/ui.js b/javascript/ui.js index 648a5290..800a2ae6 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -248,27 +248,6 @@ function confirm_clear_prompt(prompt, negative_prompt) { } -var promptTokecountUpdateFuncs = {}; - -function recalculatePromptTokens(name) { - if (promptTokecountUpdateFuncs[name]) { - promptTokecountUpdateFuncs[name](); - } -} - -function recalculate_prompts_txt2img() { - recalculatePromptTokens('txt2img_prompt'); - recalculatePromptTokens('txt2img_neg_prompt'); - return Array.from(arguments); -} - -function recalculate_prompts_img2img() { - recalculatePromptTokens('img2img_prompt'); - recalculatePromptTokens('img2img_neg_prompt'); - return Array.from(arguments); -} - - var opts = {}; onUiUpdate(function() { if (Object.keys(opts).length != 0) return; @@ -302,28 +281,7 @@ onUiUpdate(function() { json_elem.parentElement.style.display = "none"; - function registerTextarea(id, id_counter, id_button) { - var prompt = gradioApp().getElementById(id); - var counter = gradioApp().getElementById(id_counter); - var textarea = gradioApp().querySelector("#" + id + " > label > textarea"); - - if (counter.parentElement == prompt.parentElement) { - return; - } - - prompt.parentElement.insertBefore(counter, prompt); - prompt.parentElement.style.position = "relative"; - - promptTokecountUpdateFuncs[id] = function() { - update_token_counter(id_button); - }; - textarea.addEventListener("input", promptTokecountUpdateFuncs[id]); - } - - registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button'); - registerTextarea('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button'); - registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button'); - registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button'); + setupTokenCounters(); var show_all_pages = gradioApp().getElementById('settings_show_all_pages'); var settings_tabs = gradioApp().querySelector('#settings div'); @@ -354,33 +312,6 @@ onOptionsChanged(function() { }); let txt2img_textarea, img2img_textarea = undefined; -let wait_time = 800; -let token_timeouts = {}; - -function update_txt2img_tokens(...args) { - update_token_counter("txt2img_token_button"); - if (args.length == 2) { - return args[0]; - } - return args; -} - -function update_img2img_tokens(...args) { - update_token_counter( - "img2img_token_button" - ); - if (args.length == 2) { - return args[0]; - } - return args; -} - -function update_token_counter(button_id) { - if (token_timeouts[button_id]) { - clearTimeout(token_timeouts[button_id]); - } - token_timeouts[button_id] = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time); -} function restart_reload() { document.body.innerHTML = '<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>'; diff --git a/modules/api/api.py b/modules/api/api.py index eee99bbb..6a456861 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -700,4 +700,4 @@ class Api: def launch(self, server_name, port): self.app.include_router(self.router) - uvicorn.run(self.app, host=server_name, port=port) + uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0) diff --git a/modules/devices.py b/modules/devices.py index d8a34a0f..1ed6ffdc 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,5 +1,7 @@ import sys import contextlib +from functools import lru_cache + import torch from modules import errors @@ -154,3 +156,19 @@ def test_for_nans(x, where): message += " Use --disable-nan-check commandline argument to disable this check." raise NansException(message) + + +@lru_cache +def first_time_calculation(): + """ + just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and + spends about 2.7 seconds doing that, at least wih NVidia. + """ + + x = torch.zeros((1, 1)).to(device, dtype) + linear = torch.nn.Linear(1, 1).to(device, dtype) + linear(x) + + x = torch.zeros((1, 1, 3, 3)).to(device, dtype) + conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype) + conv2d(x) diff --git a/modules/sd_models.py b/modules/sd_models.py index b1afbaa7..91b3eb11 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -508,6 +508,11 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("scripts callbacks")
+ with devices.autocast(), torch.no_grad():
+ sd_model.cond_stage_model_empty_prompt = sd_model.cond_stage_model([""])
+
+ timer.record("calculate empty prompt")
+
print(f"Model loaded in {timer.summary()}.")
return sd_model
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 59982fc9..638e0ac9 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -125,6 +125,16 @@ class CFGDenoiser(torch.nn.Module): x_in = x_in[:-batch_size]
sigma_in = sigma_in[:-batch_size]
+ # TODO add infotext entry
+ if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
+ empty = shared.sd_model.cond_stage_model_empty_prompt
+ num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
+
+ if num_repeats < 0:
+ tensor = torch.cat([tensor, empty.repeat((tensor.shape[0], -num_repeats, 1))], axis=1)
+ elif num_repeats > 0:
+ uncond = torch.cat([uncond, empty.repeat((uncond.shape[0], num_repeats, 1))], axis=1)
+
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
if is_edit_model:
cond_in = torch.cat([tensor, uncond, uncond])
diff --git a/modules/shared.py b/modules/shared.py index 3099d1d2..0897f937 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -423,6 +423,7 @@ options_templates.update(options_section(('optimizations', "Optimizations"), { "token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
+ "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length").info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
}))
options_templates.update(options_section(('compatibility', "Compatibility"), {
@@ -487,6 +488,7 @@ options_templates.update(options_section(('ui', "User interface"), { "ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order").needs_restart(),
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires sampler selection").needs_restart(),
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_restart(),
+ "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_restart(),
}))
options_templates.update(options_section(('infotext', "Infotext"), {
diff --git a/modules/ui.py b/modules/ui.py index e62182da..001b9792 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -616,7 +616,8 @@ def create_ui(): outputs=[
txt2img_prompt,
txt_prompt_img
- ]
+ ],
+ show_progress=False,
)
enable_hr.change(
@@ -902,7 +903,8 @@ def create_ui(): outputs=[
img2img_prompt,
img2img_prompt_img
- ]
+ ],
+ show_progress=False,
)
img2img_args = dict(
@@ -20,7 +20,7 @@ import logging logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
-from modules import paths, timer, import_hook, errors # noqa: F401
+from modules import paths, timer, import_hook, errors, devices # noqa: F401
startup_timer = timer.Timer()
@@ -295,6 +295,8 @@ def initialize_rest(*, reload_script_modules=False): # (when reloading, this does nothing)
Thread(target=lambda: shared.sd_model).start()
+ Thread(target=devices.first_time_calculation).start()
+
shared.reload_hypernetworks()
startup_timer.record("reload hypernetworks")
|