From 59a2b9e5afc27d2fda72069ca0635070535d18fe Mon Sep 17 00:00:00 2001 From: Greendayle Date: Wed, 5 Oct 2022 20:50:10 +0200 Subject: deepdanbooru interrogator --- ... your deepbooru release project folder here.txt | 0 modules/deepbooru.py | 60 ++++++++++++++++++++++ modules/ui.py | 24 +++++++-- requirements.txt | 3 ++ requirements_versions.txt | 3 ++ style.css | 7 ++- 6 files changed, 91 insertions(+), 6 deletions(-) create mode 100644 models/deepbooru/Put your deepbooru release project folder here.txt create mode 100644 modules/deepbooru.py diff --git a/models/deepbooru/Put your deepbooru release project folder here.txt b/models/deepbooru/Put your deepbooru release project folder here.txt new file mode 100644 index 00000000..e69de29b diff --git a/modules/deepbooru.py b/modules/deepbooru.py new file mode 100644 index 00000000..958b1c3d --- /dev/null +++ b/modules/deepbooru.py @@ -0,0 +1,60 @@ +import os.path +from concurrent.futures import ProcessPoolExecutor + +import numpy as np +import deepdanbooru as dd +import tensorflow as tf + + +def _load_tf_and_return_tags(pil_image, threshold): + this_folder = os.path.dirname(__file__) + model_path = os.path.join(this_folder, '..', 'models', 'deepbooru', 'deepdanbooru-v3-20211112-sgd-e28') + if not os.path.exists(model_path): + return "Download https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip unpack and put into models/deepbooru" + + tags = dd.project.load_tags_from_project(model_path) + model = dd.project.load_model_from_project( + model_path, compile_model=True + ) + + width = model.input_shape[2] + height = model.input_shape[1] + image = np.array(pil_image) + image = tf.image.resize( + image, + size=(height, width), + method=tf.image.ResizeMethod.AREA, + preserve_aspect_ratio=True, + ) + image = image.numpy() # EagerTensor to np.array + image = dd.image.transform_and_pad_image(image, width, height) + image = image / 255.0 + image_shape = image.shape + image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2])) + + y = model.predict(image)[0] + + result_dict = {} + + for i, tag in enumerate(tags): + result_dict[tag] = y[i] + + + + result_tags_out = [] + result_tags_print = [] + for tag in tags: + if result_dict[tag] >= threshold: + result_tags_out.append(tag) + result_tags_print.append(f'{result_dict[tag]} {tag}') + + print('\n'.join(sorted(result_tags_print, reverse=True))) + + return ', '.join(result_tags_out) + + +def get_deepbooru_tags(pil_image, threshold=0.5): + with ProcessPoolExecutor() as executor: + f = executor.submit(_load_tf_and_return_tags, pil_image, threshold) + ret = f.result() # will rethrow any exceptions + return ret \ No newline at end of file diff --git a/modules/ui.py b/modules/ui.py index 20dc8c37..ae98219a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -23,6 +23,7 @@ import gradio.utils import gradio.routes from modules import sd_hijack +from modules.deepbooru import get_deepbooru_tags from modules.paths import script_path from modules.shared import opts, cmd_opts import modules.shared as shared @@ -312,6 +313,11 @@ def interrogate(image): return gr_show(True) if prompt is None else prompt +def interrogate_deepbooru(image): + prompt = get_deepbooru_tags(image) + return gr_show(True) if prompt is None else prompt + + def create_seed_inputs(): with gr.Row(): with gr.Box(): @@ -439,15 +445,17 @@ def create_toprow(is_img2img): outputs=[], ) - with gr.Row(): + with gr.Row(scale=1): if is_img2img: - interrogate = gr.Button('Interrogate', elem_id="interrogate") + interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") + deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") else: interrogate = None + deepbooru = None prompt_style_apply = gr.Button('Apply style', elem_id="style_apply") save_style = gr.Button('Create style', elem_id="style_create") - return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste, token_counter, token_button + return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button def setup_progressbar(progressbar, preview, id_part, textinfo=None): @@ -476,7 +484,7 @@ def create_ui(wrap_gradio_gpu_call): import modules.txt2img with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False) + txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) with gr.Row(elem_id='txt2img_progress_row'): @@ -628,7 +636,7 @@ def create_ui(wrap_gradio_gpu_call): token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True) + img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True) with gr.Row(elem_id='img2img_progress_row'): with gr.Column(scale=1): @@ -785,6 +793,12 @@ def create_ui(wrap_gradio_gpu_call): outputs=[img2img_prompt], ) + img2img_deepbooru.click( + fn=interrogate_deepbooru, + inputs=[init_img], + outputs=[img2img_prompt], + ) + save.click( fn=wrap_gradio_call(save_files), _js="(x, y, z) => [x, y, selected_gallery_index()]", diff --git a/requirements.txt b/requirements.txt index 631fe616..cab101f8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,3 +23,6 @@ resize-right torchdiffeq kornia lark +deepdanbooru +tensorflow +tensorflow-io diff --git a/requirements_versions.txt b/requirements_versions.txt index fdff2687..811953c6 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -22,3 +22,6 @@ resize-right==0.0.2 torchdiffeq==0.2.3 kornia==0.6.7 lark==1.1.2 +git+https://github.com/KichangKim/DeepDanbooru.git@edf73df4cdaeea2cf00e9ac08bd8a9026b7a7b26#egg=deepdanbooru[tensorflow] +tensorflow==2.10.0 +tensorflow-io==0.27.0 diff --git a/style.css b/style.css index 39586bf1..2fd351f9 100644 --- a/style.css +++ b/style.css @@ -103,7 +103,12 @@ #style_apply, #style_create, #interrogate{ margin: 0.75em 0.25em 0.25em 0.25em; - min-width: 3em; + min-width: 5em; +} + +#style_apply, #style_create, #deepbooru{ + margin: 0.75em 0.25em 0.25em 0.25em; + min-width: 5em; } #style_pos_col, #style_neg_col{ -- cgit v1.2.1 From 1506fab29ad54beb9f52236912abc432209c8089 Mon Sep 17 00:00:00 2001 From: Greendayle Date: Wed, 5 Oct 2022 21:15:08 +0200 Subject: removing problematic tag --- modules/deepbooru.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 958b1c3d..841cb9c5 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -38,13 +38,12 @@ def _load_tf_and_return_tags(pil_image, threshold): for i, tag in enumerate(tags): result_dict[tag] = y[i] - - - result_tags_out = [] result_tags_print = [] for tag in tags: if result_dict[tag] >= threshold: + if tag.startswith("rating:"): + continue result_tags_out.append(tag) result_tags_print.append(f'{result_dict[tag]} {tag}') -- cgit v1.2.1 From 17a99baf0c929e5df4dfc4b2a96aa3890a141112 Mon Sep 17 00:00:00 2001 From: Greendayle Date: Wed, 5 Oct 2022 22:05:24 +0200 Subject: better model search --- modules/deepbooru.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 841cb9c5..a64fd9cd 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -9,8 +9,15 @@ import tensorflow as tf def _load_tf_and_return_tags(pil_image, threshold): this_folder = os.path.dirname(__file__) model_path = os.path.join(this_folder, '..', 'models', 'deepbooru', 'deepdanbooru-v3-20211112-sgd-e28') - if not os.path.exists(model_path): - return "Download https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip unpack and put into models/deepbooru" + + model_good = False + for path_candidate in [model_path, os.path.dirname(model_path)]: + if os.path.exists(os.path.join(path_candidate, 'project.json')): + model_path = path_candidate + model_good = True + if not model_good: + return ("Download https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/" + "deepdanbooru-v3-20211112-sgd-e28.zip unpack and put into models/deepbooru") tags = dd.project.load_tags_from_project(model_path) model = dd.project.load_model_from_project( -- cgit v1.2.1 From 4320f386d9641c7c234589c4cb0c0c6cbeb156ad Mon Sep 17 00:00:00 2001 From: Greendayle Date: Wed, 5 Oct 2022 22:39:32 +0200 Subject: removing underscores and colons --- modules/deepbooru.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/deepbooru.py b/modules/deepbooru.py index a64fd9cd..fb5018a6 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -56,7 +56,7 @@ def _load_tf_and_return_tags(pil_image, threshold): print('\n'.join(sorted(result_tags_print, reverse=True))) - return ', '.join(result_tags_out) + return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ') def get_deepbooru_tags(pil_image, threshold=0.5): -- cgit v1.2.1 From f174fb29228a04955fb951b32b0bab79e33ec2b8 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Fri, 7 Oct 2022 05:21:49 +0300 Subject: add xformers attention --- modules/sd_hijack_optimizations.py | 39 +++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index ea4cfdfc..da1b76e1 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,7 +1,9 @@ import math import torch from torch import einsum - +import xformers.ops +import functorch +xformers._is_functorch_available=True from ldm.util import default from einops import rearrange @@ -92,6 +94,41 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2) +def _maybe_init(self, x): + """ + Initialize the attention operator, if required We expect the head dimension to be exposed here, meaning that x + : B, Head, Length + """ + if self.attention_op is not None: + return + _, M, K = x.shape + try: + self.attention_op = xformers.ops.AttentionOpDispatch( + dtype=x.dtype, + device=x.device, + k=K, + attn_bias_type=type(None), + has_dropout=False, + kv_len=M, + q_len=M, + ).op + except NotImplementedError as err: + raise NotImplementedError(f"Please install xformers with the flash attention / cutlass components.\n{err}") + +def xformers_attention_forward(self, x, context=None, mask=None): + h = self.heads + q_in = self.to_q(x) + context = default(context, x) + k_in = self.to_k(context) + v_in = self.to_v(context) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + self._maybe_init(q) + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + def cross_attention_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) -- cgit v1.2.1 From 2eb911b056ce6ff4434f673366782ed34f2b2f12 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Fri, 7 Oct 2022 05:22:28 +0300 Subject: Update sd_hijack.py --- modules/sd_hijack.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index a6fa890c..6221ed5a 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -20,12 +20,17 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At def apply_optimizations(): - ldm.modules.diffusionmodules.model.nonlinearity = silu - if cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 - elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward + if cmd_opts.opt_split_attention: + ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward + ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward + elif not cmd_opts.disable_opt_xformers_attention: + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward + ldm.modules.attention.CrossAttention._maybe_init = sd_hijack_optimizations._maybe_init + ldm.modules.attention.CrossAttention.attention_op = None + ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward -- cgit v1.2.1 From da4ab2707b4cb0611cf181ba248a271d1937433e Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Fri, 7 Oct 2022 05:23:06 +0300 Subject: Update shared.py --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/shared.py b/modules/shared.py index 25bb6e6c..8cc3b2fe 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -43,6 +43,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET')) parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR')) parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR')) +parser.add_argument("--disable-opt-xformers-attention", action='store_true', help="force-disables xformers attention optimization") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") -- cgit v1.2.1 From cd8bb597c6bcb6c59b538b7a1ab8f2face764fc5 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Fri, 7 Oct 2022 05:23:25 +0300 Subject: Update requirements.txt --- requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index 631fe616..304a066a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,3 +23,5 @@ resize-right torchdiffeq kornia lark +functorch +#xformers? -- cgit v1.2.1 From 35d6b231628d18d53d166c3a92fea1523e88d51e Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Fri, 7 Oct 2022 05:31:53 +0300 Subject: Update sd_hijack.py --- modules/sd_hijack.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 6221ed5a..a006c0a3 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -20,17 +20,16 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At def apply_optimizations(): + ldm.modules.diffusionmodules.model.nonlinearity = silu if cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 if cmd_opts.opt_split_attention: ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward - ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward elif not cmd_opts.disable_opt_xformers_attention: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.attention.CrossAttention._maybe_init = sd_hijack_optimizations._maybe_init ldm.modules.attention.CrossAttention.attention_op = None - ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward -- cgit v1.2.1 From 5303df24282ba06abb34a423f2967354d37d078e Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Fri, 7 Oct 2022 06:01:14 +0300 Subject: Update sd_hijack.py --- modules/sd_hijack.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index a006c0a3..ddacb0ad 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -23,10 +23,10 @@ def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu if cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 - if cmd_opts.opt_split_attention: + elif cmd_opts.opt_split_attention: ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward - elif not cmd_opts.disable_opt_xformers_attention: + elif not cmd_opts.disable_opt_xformers_attention and not cmd_opts.opt_split_attention: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.attention.CrossAttention._maybe_init = sd_hijack_optimizations._maybe_init ldm.modules.attention.CrossAttention.attention_op = None -- cgit v1.2.1 From 5e3ff846c56dc8e1d5c76ea04a8f2f74d7da07fc Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Fri, 7 Oct 2022 06:38:01 +0300 Subject: Update sd_hijack.py --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index ddacb0ad..cbdb9d3c 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -26,7 +26,7 @@ def apply_optimizations(): elif cmd_opts.opt_split_attention: ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward - elif not cmd_opts.disable_opt_xformers_attention and not cmd_opts.opt_split_attention: + elif not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip): ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.attention.CrossAttention._maybe_init = sd_hijack_optimizations._maybe_init ldm.modules.attention.CrossAttention.attention_op = None -- cgit v1.2.1 From bad7cb29cecac51c5c0f39afec332b007ed73133 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 7 Oct 2022 10:17:52 +0300 Subject: added support for hypernetworks (???) --- modules/hypernetwork.py | 55 ++++++++++++++++++++++++++++++++++++++ modules/sd_hijack_optimizations.py | 17 ++++++++++-- modules/shared.py | 9 ++++++- scripts/xy_grid.py | 10 +++++++ 4 files changed, 88 insertions(+), 3 deletions(-) create mode 100644 modules/hypernetwork.py diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py new file mode 100644 index 00000000..9ed1eed9 --- /dev/null +++ b/modules/hypernetwork.py @@ -0,0 +1,55 @@ +import glob +import os +import torch +from modules import devices + + +class HypernetworkModule(torch.nn.Module): + def __init__(self, dim, state_dict): + super().__init__() + + self.linear1 = torch.nn.Linear(dim, dim * 2) + self.linear2 = torch.nn.Linear(dim * 2, dim) + + self.load_state_dict(state_dict, strict=True) + self.to(devices.device) + + def forward(self, x): + return x + (self.linear2(self.linear1(x))) + + +class Hypernetwork: + filename = None + name = None + + def __init__(self, filename): + self.filename = filename + self.name = os.path.splitext(os.path.basename(filename))[0] + self.layers = {} + + state_dict = torch.load(filename, map_location='cpu') + for size, sd in state_dict.items(): + self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) + + +def load_hypernetworks(path): + res = {} + + for filename in glob.iglob(path + '**/*.pt', recursive=True): + hn = Hypernetwork(filename) + res[hn.name] = hn + + return res + +def apply(self, x, context=None, mask=None, original=None): + + + if CrossAttention.hypernetwork is not None and context.shape[2] in CrossAttention.hypernetwork: + if context.shape[1] == 77 and CrossAttention.noise_cond: + context = context + (torch.randn_like(context) * 0.1) + h_k, h_v = CrossAttention.hypernetwork[context.shape[2]] + k = self.to_k(h_k(context)) + v = self.to_v(h_v(context)) + else: + k = self.to_k(context) + v = self.to_v(context) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index ea4cfdfc..d9cca485 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -5,6 +5,8 @@ from torch import einsum from ldm.util import default from einops import rearrange +from modules import shared + # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion def split_cross_attention_forward_v1(self, x, context=None, mask=None): @@ -42,8 +44,19 @@ def split_cross_attention_forward(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - k_in = self.to_k(context) * self.scale - v_in = self.to_v(context) + + hypernetwork = shared.selected_hypernetwork() + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + + if hypernetwork_layers is not None: + k_in = self.to_k(hypernetwork_layers[0](context)) + v_in = self.to_v(hypernetwork_layers[1](context)) + else: + k_in = self.to_k(context) + v_in = self.to_v(context) + + k_in *= self.scale + del context, x q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) diff --git a/modules/shared.py b/modules/shared.py index 25bb6e6c..879d8424 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,7 +13,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers +from modules import sd_samplers, hypernetwork from modules.paths import models_path, script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') @@ -76,6 +76,12 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram config_filename = cmd_opts.ui_settings_file +hypernetworks = hypernetwork.load_hypernetworks(os.path.join(models_path, 'hypernetworks')) + + +def selected_hypernetwork(): + return hypernetworks.get(opts.sd_hypernetwork, None) + class State: interrupted = False @@ -206,6 +212,7 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}), + "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 6344e612..c0c364df 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -77,6 +77,11 @@ def apply_checkpoint(p, x, xs): modules.sd_models.reload_model_weights(shared.sd_model, info) +def apply_hypernetwork(p, x, xs): + hn = shared.hypernetworks.get(x, None) + opts.data["sd_hypernetwork"] = hn.name if hn is not None else 'None' + + def format_value_add_label(p, opt, x): if type(x) == float: x = round(x, 8) @@ -122,6 +127,7 @@ axis_options = [ AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list), AxisOption("Sampler", str, apply_sampler, format_value), AxisOption("Checkpoint name", str, apply_checkpoint, format_value), + AxisOption("Hypernetwork", str, apply_hypernetwork, format_value), AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label), AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label), AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label), @@ -193,6 +199,8 @@ class Script(scripts.Script): modules.processing.fix_seed(p) p.batch_size = 1 + initial_hn = opts.sd_hypernetwork + def process_axis(opt, vals): if opt.label == 'Nothing': return [0] @@ -300,4 +308,6 @@ class Script(scripts.Script): # restore checkpoint in case it was changed by axes modules.sd_models.reload_model_weights(shared.sd_model) + opts.data["sd_hypernetwork"] = initial_hn + return processed -- cgit v1.2.1 From d15b3ec0013c10f02f0fb80e8448bac8872a151f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 7 Oct 2022 10:40:22 +0300 Subject: support loading VAE --- modules/sd_models.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index 5f992064..8f794b47 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -134,6 +134,14 @@ def load_model_weights(model, checkpoint_file, sd_model_hash): devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 + vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" + if os.path.exists(vae_file): + print(f"Loading VAE weights from: {vae_file}") + vae_ckpt = torch.load(vae_file, map_location="cpu") + vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} + + model.first_stage_model.load_state_dict(vae_dict) + model.sd_model_hash = sd_model_hash model.sd_model_checkpint = checkpoint_file -- cgit v1.2.1 From 97bc0b9504572d2df80598d0b694703bcd626de6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 7 Oct 2022 13:22:50 +0300 Subject: do not stop working on failed hypernetwork load --- modules/hypernetwork.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py index 9ed1eed9..c5cf4afa 100644 --- a/modules/hypernetwork.py +++ b/modules/hypernetwork.py @@ -1,5 +1,8 @@ import glob import os +import sys +import traceback + import torch from modules import devices @@ -36,8 +39,12 @@ def load_hypernetworks(path): res = {} for filename in glob.iglob(path + '**/*.pt', recursive=True): - hn = Hypernetwork(filename) - res[hn.name] = hn + try: + hn = Hypernetwork(filename) + res[hn.name] = hn + except Exception: + print(f"Error loading hypernetwork {filename}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) return res -- cgit v1.2.1 From f7c787eb7c295c27439f4fbdf78c26b8389560be Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 7 Oct 2022 16:39:51 +0300 Subject: make it possible to use hypernetworks without opt split attention --- modules/hypernetwork.py | 42 ++++++++++++++++++++++++++++++++++-------- modules/sd_hijack.py | 6 ++++-- 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py index c5cf4afa..c7b86682 100644 --- a/modules/hypernetwork.py +++ b/modules/hypernetwork.py @@ -4,7 +4,12 @@ import sys import traceback import torch -from modules import devices + +from ldm.util import default +from modules import devices, shared +import torch +from torch import einsum +from einops import rearrange, repeat class HypernetworkModule(torch.nn.Module): @@ -48,15 +53,36 @@ def load_hypernetworks(path): return res -def apply(self, x, context=None, mask=None, original=None): +def attention_CrossAttention_forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) - if CrossAttention.hypernetwork is not None and context.shape[2] in CrossAttention.hypernetwork: - if context.shape[1] == 77 and CrossAttention.noise_cond: - context = context + (torch.randn_like(context) * 0.1) - h_k, h_v = CrossAttention.hypernetwork[context.shape[2]] - k = self.to_k(h_k(context)) - v = self.to_v(h_v(context)) + hypernetwork = shared.selected_hypernetwork() + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + + if hypernetwork_layers is not None: + k = self.to_k(hypernetwork_layers[0](context)) + v = self.to_v(hypernetwork_layers[1](context)) else: k = self.to_k(context) v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if mask is not None: + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index a6fa890c..d68f89cc 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -8,7 +8,7 @@ from torch import einsum from torch.nn.functional import silu import modules.textual_inversion.textual_inversion -from modules import prompt_parser, devices, sd_hijack_optimizations, shared +from modules import prompt_parser, devices, sd_hijack_optimizations, shared, hypernetwork from modules.shared import opts, device, cmd_opts import ldm.modules.attention @@ -20,6 +20,8 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At def apply_optimizations(): + undo_optimizations() + ldm.modules.diffusionmodules.model.nonlinearity = silu if cmd_opts.opt_split_attention_v1: @@ -30,7 +32,7 @@ def apply_optimizations(): def undo_optimizations(): - ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward + ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward -- cgit v1.2.1 From 54fa613c8391e3973cca9d94cdf539061932508b Mon Sep 17 00:00:00 2001 From: Greendayle Date: Fri, 7 Oct 2022 20:37:43 +0200 Subject: loading tf only in interrogation process --- modules/deepbooru.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/modules/deepbooru.py b/modules/deepbooru.py index fb5018a6..79dc59bd 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -1,12 +1,13 @@ import os.path from concurrent.futures import ProcessPoolExecutor -import numpy as np -import deepdanbooru as dd -import tensorflow as tf def _load_tf_and_return_tags(pil_image, threshold): + import deepdanbooru as dd + import tensorflow as tf + import numpy as np + this_folder = os.path.dirname(__file__) model_path = os.path.join(this_folder, '..', 'models', 'deepbooru', 'deepdanbooru-v3-20211112-sgd-e28') -- cgit v1.2.1 From fa2ea648db81f5723bb5d722f2fe0ebd7dfc319a Mon Sep 17 00:00:00 2001 From: Greendayle Date: Fri, 7 Oct 2022 20:46:38 +0200 Subject: even more powerfull fix --- modules/deepbooru.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 79dc59bd..60094336 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -60,8 +60,13 @@ def _load_tf_and_return_tags(pil_image, threshold): return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ') +def subprocess_init_no_cuda(): + import os + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + + def get_deepbooru_tags(pil_image, threshold=0.5): - with ProcessPoolExecutor() as executor: - f = executor.submit(_load_tf_and_return_tags, pil_image, threshold) + with ProcessPoolExecutor(initializer=subprocess_init_no_cuda) as executor: + f = executor.submit(_load_tf_and_return_tags, pil_image, threshold, ) ret = f.result() # will rethrow any exceptions return ret \ No newline at end of file -- cgit v1.2.1 From 5f12e7efd92ad802742f96788b4be3249ad02829 Mon Sep 17 00:00:00 2001 From: Greendayle Date: Fri, 7 Oct 2022 20:58:30 +0200 Subject: linux test --- modules/deepbooru.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 60094336..781b2249 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -1,6 +1,6 @@ import os.path from concurrent.futures import ProcessPoolExecutor - +from multiprocessing import get_context def _load_tf_and_return_tags(pil_image, threshold): @@ -66,7 +66,8 @@ def subprocess_init_no_cuda(): def get_deepbooru_tags(pil_image, threshold=0.5): - with ProcessPoolExecutor(initializer=subprocess_init_no_cuda) as executor: + context = get_context('spawn') + with ProcessPoolExecutor(initializer=subprocess_init_no_cuda, mp_context=context) as executor: f = executor.submit(_load_tf_and_return_tags, pil_image, threshold, ) ret = f.result() # will rethrow any exceptions return ret \ No newline at end of file -- cgit v1.2.1 From c9cc65b201679ea43c763b0d85e749d40bbc5433 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 04:09:18 +0300 Subject: switch to the proper way of calling xformers --- modules/sd_hijack_optimizations.py | 28 +++------------------------- 1 file changed, 3 insertions(+), 25 deletions(-) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index da1b76e1..7fb4a45e 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -94,39 +94,17 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2) -def _maybe_init(self, x): - """ - Initialize the attention operator, if required We expect the head dimension to be exposed here, meaning that x - : B, Head, Length - """ - if self.attention_op is not None: - return - _, M, K = x.shape - try: - self.attention_op = xformers.ops.AttentionOpDispatch( - dtype=x.dtype, - device=x.device, - k=K, - attn_bias_type=type(None), - has_dropout=False, - kv_len=M, - q_len=M, - ).op - except NotImplementedError as err: - raise NotImplementedError(f"Please install xformers with the flash attention / cutlass components.\n{err}") - def xformers_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) context = default(context, x) k_in = self.to_k(context) v_in = self.to_v(context) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in - self._maybe_init(q) - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + out = rearrange(out, 'b n h d -> b n (h d)', h=h) return self.to_out(out) def cross_attention_attnblock_forward(self, x): -- cgit v1.2.1 From b70eaeb2005a5a9593119e7fd32b8072c2a208d5 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 04:10:35 +0300 Subject: delete broken and unnecessary aliases --- modules/sd_hijack.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index cbdb9d3c..0e99c319 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -21,16 +21,14 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu - if cmd_opts.opt_split_attention_v1: + if not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip): + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward + elif cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 elif cmd_opts.opt_split_attention: ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward - elif not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip): - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward - ldm.modules.attention.CrossAttention._maybe_init = sd_hijack_optimizations._maybe_init - ldm.modules.attention.CrossAttention.attention_op = None - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward def undo_optimizations(): -- cgit v1.2.1 From a958f9b3fdea95c01d360aba1b6fe0ce3ea6b349 Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Fri, 7 Oct 2022 20:05:47 -0300 Subject: edit-attention browser compatibility and readme typo --- README.md | 2 +- javascript/edit-attention.js | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a14a6330..0516c2cd 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web - Attention, specify parts of text that the model should pay more attention to - a man in a ((tuxedo)) - will pay more attention to tuxedo - a man in a (tuxedo:1.21) - alternative syntax - - select text and press ctrl+up or ctrl+down to aduotmatically adjust attention to selected text + - select text and press ctrl+up or ctrl+down to automatically adjust attention to selected text - Loopback, run img2img processing multiple times - X/Y plot, a way to draw a 2 dimensional plot of images with different parameters - Textual Inversion diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index c67ed579..0280c603 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -1,5 +1,5 @@ addEventListener('keydown', (event) => { - let target = event.originalTarget; + let target = event.originalTarget || event.composedPath()[0]; if (!target.hasAttribute("placeholder")) return; if (!target.placeholder.toLowerCase().includes("prompt")) return; -- cgit v1.2.1 From f2055cb1d4ce45d7aaacc49d8ab5bec7791a8f47 Mon Sep 17 00:00:00 2001 From: brkirch Date: Sat, 8 Oct 2022 01:47:02 -0400 Subject: Add hypernetwork support to split cross attention v1 * Add hypernetwork support to split_cross_attention_forward_v1 * Fix device check in esrgan_model.py to use devices.device_esrgan instead of shared.device --- modules/esrgan_model.py | 2 +- modules/sd_hijack_optimizations.py | 18 ++++++++++++++---- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index d17e730f..28548124 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -111,7 +111,7 @@ class UpscalerESRGAN(Upscaler): print("Unable to load %s from %s" % (self.model_path, filename)) return None - pretrained_net = torch.load(filename, map_location='cpu' if shared.device.type == 'mps' else None) + pretrained_net = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32) pretrained_net = fix_model_layers(crt_model, pretrained_net) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index d9cca485..3351c740 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -12,13 +12,22 @@ from modules import shared def split_cross_attention_forward_v1(self, x, context=None, mask=None): h = self.heads - q = self.to_q(x) + q_in = self.to_q(x) context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) + + hypernetwork = shared.selected_hypernetwork() + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + + if hypernetwork_layers is not None: + k_in = self.to_k(hypernetwork_layers[0](context)) + v_in = self.to_v(hypernetwork_layers[1](context)) + else: + k_in = self.to_k(context) + v_in = self.to_v(context) del context, x - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) for i in range(0, q.shape[0], 2): @@ -31,6 +40,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) del s2 + del q, k, v r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) del r1 -- cgit v1.2.1 From e21e4732531299ef4895baccdb7a6493a3886924 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sat, 8 Oct 2022 05:34:17 +0100 Subject: Context Menus --- javascript/contextMenus.js | 165 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 165 insertions(+) create mode 100644 javascript/contextMenus.js diff --git a/javascript/contextMenus.js b/javascript/contextMenus.js new file mode 100644 index 00000000..99d1d3f7 --- /dev/null +++ b/javascript/contextMenus.js @@ -0,0 +1,165 @@ + +contextMenuInit = function(){ + let eventListenerApplied=false; + let menuSpecs = new Map(); + + const uid = function(){ + return Date.now().toString(36) + Math.random().toString(36).substr(2); + } + + function showContextMenu(event,element,menuEntries){ + let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft; + let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop; + + let oldMenu = gradioApp().querySelector('#context-menu') + if(oldMenu){ + oldMenu.remove() + } + + let tabButton = gradioApp().querySelector('button') + let baseStyle = window.getComputedStyle(tabButton) + + const contextMenu = document.createElement('nav') + contextMenu.id = "context-menu" + contextMenu.style.background = baseStyle.background + contextMenu.style.color = baseStyle.color + contextMenu.style.fontFamily = baseStyle.fontFamily + contextMenu.style.top = posy+'px' + contextMenu.style.left = posx+'px' + + + + const contextMenuList = document.createElement('ul') + contextMenuList.className = 'context-menu-items'; + contextMenu.append(contextMenuList); + + menuEntries.forEach(function(entry){ + let contextMenuEntry = document.createElement('a') + contextMenuEntry.innerHTML = entry['name'] + contextMenuEntry.addEventListener("click", function(e) { + entry['func'](); + }) + contextMenuList.append(contextMenuEntry); + + }) + + gradioApp().getRootNode().appendChild(contextMenu) + + let menuWidth = contextMenu.offsetWidth + 4; + let menuHeight = contextMenu.offsetHeight + 4; + + let windowWidth = window.innerWidth; + let windowHeight = window.innerHeight; + + if ( (windowWidth - posx) < menuWidth ) { + contextMenu.style.left = windowWidth - menuWidth + "px"; + } + + if ( (windowHeight - posy) < menuHeight ) { + contextMenu.style.top = windowHeight - menuHeight + "px"; + } + + } + + function appendContextMenuOption(targetEmementSelector,entryName,entryFunction){ + + currentItems = menuSpecs.get(targetEmementSelector) + + if(!currentItems){ + currentItems = [] + menuSpecs.set(targetEmementSelector,currentItems); + } + let newItem = {'id':targetEmementSelector+'_'+uid(), + 'name':entryName, + 'func':entryFunction, + 'isNew':true} + + currentItems.push(newItem) + return newItem['id'] + } + + function removeContextMenuOption(uid){ + + } + + function addContextMenuEventListener(){ + if(eventListenerApplied){ + return; + } + gradioApp().addEventListener("click", function(e) { + let source = e.composedPath()[0] + if(source.id && source.indexOf('check_progress')>-1){ + return + } + + let oldMenu = gradioApp().querySelector('#context-menu') + if(oldMenu){ + oldMenu.remove() + } + }); + gradioApp().addEventListener("contextmenu", function(e) { + let oldMenu = gradioApp().querySelector('#context-menu') + if(oldMenu){ + oldMenu.remove() + } + menuSpecs.forEach(function(v,k) { + if(e.composedPath()[0].matches(k)){ + showContextMenu(e,e.composedPath()[0],v) + e.preventDefault() + return + } + }) + }); + eventListenerApplied=true + + } + + return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener] +} + +initResponse = contextMenuInit() +appendContextMenuOption = initResponse[0] +removeContextMenuOption = initResponse[1] +addContextMenuEventListener = initResponse[2] + + +//Start example Context Menu Items +generateOnRepeatId = appendContextMenuOption('#txt2img_generate','Generate forever',function(){ + let genbutton = gradioApp().querySelector('#txt2img_generate'); + let interruptbutton = gradioApp().querySelector('#txt2img_interrupt'); + if(!interruptbutton.offsetParent){ + genbutton.click(); + } + clearInterval(window.generateOnRepeatInterval) + window.generateOnRepeatInterval = setInterval(function(){ + if(!interruptbutton.offsetParent){ + genbutton.click(); + } + }, + 500)} +) + +cancelGenerateForever = function(){ + clearInterval(window.generateOnRepeatInterval) + let interruptbutton = gradioApp().querySelector('#txt2img_interrupt'); + if(interruptbutton.offsetParent){ + interruptbutton.click(); + } +} + +appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever) +appendContextMenuOption('#txt2img_generate','Cancel generate forever',cancelGenerateForever) + +appendContextMenuOption('#roll','Roll three', + function(){ + let rollbutton = gradioApp().querySelector('#roll'); + setTimeout(function(){rollbutton.click()},100) + setTimeout(function(){rollbutton.click()},200) + setTimeout(function(){rollbutton.click()},300) + } +) +//End example Context Menu Items + +onUiUpdate(function(){ + addContextMenuEventListener() +}); \ No newline at end of file -- cgit v1.2.1 From 83749bfc72923b946abb825ebf4fdcc8b6035c8e Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sat, 8 Oct 2022 05:35:03 +0100 Subject: context menu styling --- style.css | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/style.css b/style.css index da0729a2..50c5e557 100644 --- a/style.css +++ b/style.css @@ -410,4 +410,31 @@ input[type="range"]{ #img2img_image div.h-60{ height: 480px; -} \ No newline at end of file +} + +#context-menu{ + z-index:9999; + position:absolute; + display:block; + padding:0px 0; + border:2px solid #a55000; + border-radius:8px; + box-shadow:1px 1px 2px #CE6400; + width: 200px; +} + +.context-menu-items{ + list-style: none; + margin: 0; + padding: 0; +} + +.context-menu-items a{ + display:block; + padding:5px; + cursor:pointer; +} + +.context-menu-items a:hover{ + background: #a55000; +} -- cgit v1.2.1 From 21679435e531e729a4aea494e6cb9b7152ecdf75 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sat, 8 Oct 2022 05:46:42 +0100 Subject: implement removal --- javascript/contextMenus.js | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/javascript/contextMenus.js b/javascript/contextMenus.js index 99d1d3f7..2d82269f 100644 --- a/javascript/contextMenus.js +++ b/javascript/contextMenus.js @@ -79,7 +79,13 @@ contextMenuInit = function(){ } function removeContextMenuOption(uid){ - + menuSpecs.forEach(function(v,k) { + let index = -1 + v.forEach(function(e,ei){if(e['id']==uid){index=ei}}) + if(index>=0){ + v.splice(index, 1); + } + }) } function addContextMenuEventListener(){ @@ -148,7 +154,8 @@ cancelGenerateForever = function(){ } appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever) -appendContextMenuOption('#txt2img_generate','Cancel generate forever',cancelGenerateForever) +appendContextMenuOption('#txt2img_generate', 'Cancel generate forever',cancelGenerateForever) + appendContextMenuOption('#roll','Roll three', function(){ @@ -162,4 +169,4 @@ appendContextMenuOption('#roll','Roll three', onUiUpdate(function(){ addContextMenuEventListener() -}); \ No newline at end of file +}); -- cgit v1.2.1 From 87db6f01cc6b118fe0c82c36c6686d72d060c417 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 10:15:29 +0300 Subject: add info about cross attention javascript shortcut code --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 0516c2cd..d6e1d50b 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web - Attention, specify parts of text that the model should pay more attention to - a man in a ((tuxedo)) - will pay more attention to tuxedo - a man in a (tuxedo:1.21) - alternative syntax - - select text and press ctrl+up or ctrl+down to automatically adjust attention to selected text + - select text and press ctrl+up or ctrl+down to automatically adjust attention to selected text (code contributed by anonymous user) - Loopback, run img2img processing multiple times - X/Y plot, a way to draw a 2 dimensional plot of images with different parameters - Textual Inversion -- cgit v1.2.1 From 5d54f35c583bd5a3b0ee271a862827f1ca81ef09 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 11:55:02 +0300 Subject: add xformers attnblock and hypernetwork support --- modules/sd_hijack_optimizations.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 7fb4a45e..c78d5838 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -98,8 +98,14 @@ def xformers_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) context = default(context, x) - k_in = self.to_k(context) - v_in = self.to_v(context) + hypernetwork = shared.selected_hypernetwork() + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + if hypernetwork_layers is not None: + k_in = self.to_k(hypernetwork_layers[0](context)) + v_in = self.to_v(hypernetwork_layers[1](context)) + else: + k_in = self.to_k(context) + v_in = self.to_v(context) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) @@ -169,3 +175,13 @@ def cross_attention_attnblock_forward(self, x): h3 += x return h3 + + def xformers_attnblock_forward(self, x): + h_ = x + h_ = self.norm(h_) + q1 = self.q(h_).contiguous() + k1 = self.k(h_).contiguous() + v = self.v(h_).contiguous() + out = xformers.ops.memory_efficient_attention(q1, k1, v) + out = self.proj_out(out) + return x+out -- cgit v1.2.1 From 76a616fa6b814c681eaf6edc87eb3001b8c2b6be Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 11:55:38 +0300 Subject: Update sd_hijack_optimizations.py --- modules/sd_hijack_optimizations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index c78d5838..ee58c7e4 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -176,7 +176,7 @@ def cross_attention_attnblock_forward(self, x): return h3 - def xformers_attnblock_forward(self, x): +def xformers_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) q1 = self.q(h_).contiguous() -- cgit v1.2.1 From 91d66f5520df416db718103d460550ad495e952d Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 11:56:01 +0300 Subject: use new attnblock for xformers path --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 0e99c319..3da8c8ce 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -23,7 +23,7 @@ def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu if not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip): ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 elif cmd_opts.opt_split_attention: -- cgit v1.2.1 From 616b7218f7c469d25c138634472017a7e18e742e Mon Sep 17 00:00:00 2001 From: leko Date: Fri, 7 Oct 2022 23:09:21 +0800 Subject: fix: handles when state_dict does not exist --- modules/sd_models.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 8f794b47..9409d070 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -122,7 +122,11 @@ def load_model_weights(model, checkpoint_file, sd_model_hash): pl_sd = torch.load(checkpoint_file, map_location="cpu") if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] + + if "state_dict" in pl_sd: + sd = pl_sd["state_dict"] + else: + sd = pl_sd model.load_state_dict(sd, strict=False) -- cgit v1.2.1 From 706d5944a075a6523ea7f00165d630efc085ca22 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 13:38:57 +0300 Subject: let user choose his own prompt token count limit --- modules/processing.py | 6 ++++++ modules/sd_hijack.py | 13 +++++++------ modules/shared.py | 5 +++-- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index f773a30e..d814d5ac 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -123,6 +123,7 @@ class Processed: self.index_of_first_image = index_of_first_image self.styles = p.styles self.job_timestamp = state.job_timestamp + self.max_prompt_tokens = opts.max_prompt_tokens self.eta = p.eta self.ddim_discretize = p.ddim_discretize @@ -141,6 +142,7 @@ class Processed: self.all_subseeds = all_subseeds or [self.subseed] self.infotexts = infotexts or [info] + def js(self): obj = { "prompt": self.prompt, @@ -169,6 +171,7 @@ class Processed: "infotexts": self.infotexts, "styles": self.styles, "job_timestamp": self.job_timestamp, + "max_prompt_tokens": self.max_prompt_tokens, } return json.dumps(obj) @@ -266,6 +269,8 @@ def fix_seed(p): def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0): index = position_in_batch + iteration * p.batch_size + max_tokens = getattr(p, 'max_prompt_tokens', opts.max_prompt_tokens) + generation_params = { "Steps": p.steps, "Sampler": sd_samplers.samplers[p.sampler_index].name, @@ -281,6 +286,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), + "Max tokens": (None if max_tokens == shared.vanilla_max_prompt_tokens else max_tokens) } generation_params.update(p.extra_generation_params) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index d68f89cc..340329c0 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -18,7 +18,6 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward - def apply_optimizations(): undo_optimizations() @@ -83,7 +82,7 @@ class StableDiffusionModelHijack: layer.padding_mode = 'circular' if enable else 'zeros' def tokenize(self, text): - max_length = self.clip.max_length - 2 + max_length = opts.max_prompt_tokens - 2 _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) return remade_batch_tokens[0], token_count, max_length @@ -94,7 +93,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.wrapped = wrapped self.hijack: StableDiffusionModelHijack = hijack self.tokenizer = wrapped.tokenizer - self.max_length = wrapped.max_length self.token_mults = {} tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] @@ -116,7 +114,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def tokenize_line(self, line, used_custom_terms, hijack_comments): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id - maxlen = self.wrapped.max_length + maxlen = opts.max_prompt_tokens if opts.enable_emphasis: parsed = prompt_parser.parse_prompt_attention(line) @@ -191,7 +189,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def process_text_old(self, text): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id - maxlen = self.wrapped.max_length + maxlen = self.wrapped.max_length # you get to stay at 77 used_custom_terms = [] remade_batch_tokens = [] overflowing_words = [] @@ -268,8 +266,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if len(used_custom_terms) > 0: self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) + position_ids_array = [min(x, 75) for x in range(len(remade_batch_tokens[0])-1)] + [76] + position_ids = torch.asarray(position_ids_array, device=devices.device).expand((1, -1)) + tokens = torch.asarray(remade_batch_tokens).to(device) - outputs = self.wrapped.transformer(input_ids=tokens) + outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids) z = outputs.last_hidden_state # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise diff --git a/modules/shared.py b/modules/shared.py index 879d8424..864e772c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -118,8 +118,8 @@ prompt_styles = modules.styles.StyleDatabase(styles_filename) interrogator = modules.interrogate.InterrogateModels("interrogate") face_restorers = [] -# This was moved to webui.py with the other model "setup" calls. -# modules.sd_models.list_models() + +vanilla_max_prompt_tokens = 77 def realesrgan_models_names(): @@ -221,6 +221,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "filter_nsfw": OptionInfo(False, "Filter NSFW content"), + "max_prompt_tokens": OptionInfo(vanilla_max_prompt_tokens, f"Max prompt token count. Two tokens are reserved for for start and end. Default is {vanilla_max_prompt_tokens}. Setting this to a different value will result in different pictures for same seed.", gr.Number, {"precision": 0}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), })) -- cgit v1.2.1 From 786d9f63aaa4515df82eb2cf357ea92f3dae1e29 Mon Sep 17 00:00:00 2001 From: Trung Ngo Date: Tue, 4 Oct 2022 22:56:30 -0500 Subject: Add button to skip the current iteration --- javascript/hints.js | 1 + javascript/progressbar.js | 20 ++++++++++++++------ modules/img2img.py | 4 ++++ modules/processing.py | 4 ++++ modules/shared.py | 5 +++++ modules/ui.py | 8 ++++++++ style.css | 14 ++++++++++++-- webui.py | 1 + 8 files changed, 49 insertions(+), 8 deletions(-) diff --git a/javascript/hints.js b/javascript/hints.js index 8adcd983..8e352e94 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -35,6 +35,7 @@ titles = { "Denoising strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image. With values below 1.0, processing will take less steps than the Sampling Steps slider specifies.", "Denoising strength change factor": "In loopback mode, on each loop the denoising strength is multiplied by this value. <1 means decreasing variety so your sequence will converge on a fixed picture. >1 means increasing variety so your sequence will become more and more chaotic.", + "Skip": "Stop processing current image and continue processing.", "Interrupt": "Stop processing images and return any results accumulated so far.", "Save": "Write image to a directory (default - log/images) and generation parameters into csv file.", diff --git a/javascript/progressbar.js b/javascript/progressbar.js index f9e9290e..4395a215 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -1,8 +1,9 @@ // code related to showing and updating progressbar shown as the image is being made global_progressbars = {} -function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_interrupt, id_preview, id_gallery){ +function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){ var progressbar = gradioApp().getElementById(id_progressbar) + var skip = id_skip ? gradioApp().getElementById(id_skip) : null var interrupt = gradioApp().getElementById(id_interrupt) if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){ @@ -32,30 +33,37 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_inte var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0; if(!progressDiv){ + if (skip) { + skip.style.display = "none" + } interrupt.style.display = "none" } } - window.setTimeout(function(){ requestMoreProgress(id_part, id_progressbar_span, id_interrupt) }, 500) + window.setTimeout(function() { requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) }, 500) }); mutationObserver.observe( progressbar, { childList:true, subtree:true }) } } onUiUpdate(function(){ - check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery') - check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery') - check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', 'ti_interrupt', 'ti_preview', 'ti_gallery') + check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery') + check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery') + check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', '', 'ti_interrupt', 'ti_preview', 'ti_gallery') }) -function requestMoreProgress(id_part, id_progressbar_span, id_interrupt){ +function requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt){ btn = gradioApp().getElementById(id_part+"_check_progress"); if(btn==null) return; btn.click(); var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0; + var skip = id_skip ? gradioApp().getElementById(id_skip) : null var interrupt = gradioApp().getElementById(id_interrupt) if(progressDiv && interrupt){ + if (skip) { + skip.style.display = "block" + } interrupt.style.display = "block" } } diff --git a/modules/img2img.py b/modules/img2img.py index da212d72..e60b7e0f 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -32,6 +32,10 @@ def process_batch(p, input_dir, output_dir, args): for i, image in enumerate(images): state.job = f"{i+1} out of {len(images)}" + if state.skipped: + state.skipped = False + state.interrupted = False + continue if state.interrupted: break diff --git a/modules/processing.py b/modules/processing.py index d814d5ac..6805039c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -355,6 +355,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: state.job_count = p.n_iter for n in range(p.n_iter): + if state.skipped: + state.skipped = False + state.interrupted = False + if state.interrupted: break diff --git a/modules/shared.py b/modules/shared.py index 864e772c..7f802bd9 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -84,6 +84,7 @@ def selected_hypernetwork(): class State: + skipped = False interrupted = False job = "" job_no = 0 @@ -96,6 +97,10 @@ class State: current_image_sampling_step = 0 textinfo = None + def skip(self): + self.skipped = True + self.interrupted = True + def interrupt(self): self.interrupted = True diff --git a/modules/ui.py b/modules/ui.py index 4f18126f..e3e62fdd 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -191,6 +191,7 @@ def wrap_gradio_call(func, extra_outputs=None): # last item is always HTML res[-1] += f"

Time taken: {elapsed_text}

{vram_html}
" + shared.state.skipped = False shared.state.interrupted = False shared.state.job_count = 0 @@ -411,9 +412,16 @@ def create_toprow(is_img2img): with gr.Column(scale=1): with gr.Row(): + skip = gr.Button('Skip', elem_id=f"{id_part}_skip") interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') + skip.click( + fn=lambda: shared.state.skip(), + inputs=[], + outputs=[], + ) + interrupt.click( fn=lambda: shared.state.interrupt(), inputs=[], diff --git a/style.css b/style.css index 50c5e557..6904fc50 100644 --- a/style.css +++ b/style.css @@ -393,10 +393,20 @@ input[type="range"]{ #txt2img_interrupt, #img2img_interrupt{ position: absolute; - width: 100%; + width: 50%; height: 72px; background: #b4c0cc; - border-radius: 8px; + border-radius: 0px; + display: none; +} + +#txt2img_skip, #img2img_skip{ + position: absolute; + width: 50%; + right: 0px; + height: 72px; + background: #b4c0cc; + border-radius: 0px; display: none; } diff --git a/webui.py b/webui.py index 480360fe..3b4cf5e9 100644 --- a/webui.py +++ b/webui.py @@ -58,6 +58,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): shared.state.current_latent = None shared.state.current_image = None shared.state.current_image_sampling_step = 0 + shared.state.skipped = False shared.state.interrupted = False shared.state.textinfo = None -- cgit v1.2.1 From 00117a07efbbe8482add12262a179326541467de Mon Sep 17 00:00:00 2001 From: Trung Ngo Date: Sat, 8 Oct 2022 05:33:21 -0500 Subject: check specifically for skipped --- modules/img2img.py | 2 -- modules/processing.py | 3 +-- modules/sd_samplers.py | 4 ++-- modules/shared.py | 1 - 4 files changed, 3 insertions(+), 7 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index e60b7e0f..24126774 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -34,8 +34,6 @@ def process_batch(p, input_dir, output_dir, args): state.job = f"{i+1} out of {len(images)}" if state.skipped: state.skipped = False - state.interrupted = False - continue if state.interrupted: break diff --git a/modules/processing.py b/modules/processing.py index 6805039c..3657fe69 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -357,7 +357,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed: for n in range(p.n_iter): if state.skipped: state.skipped = False - state.interrupted = False if state.interrupted: break @@ -385,7 +384,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: with devices.autocast(): samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength) - if state.interrupted: + if state.interrupted or state.skipped: # if we are interruped, sample returns just noise # use the image collected previously in sampler loop diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index df17e93c..13a8b322 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -106,7 +106,7 @@ def extended_tdqm(sequence, *args, desc=None, **kwargs): seq = sequence if cmd_opts.disable_console_progressbars else tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs) for x in seq: - if state.interrupted: + if state.interrupted or state.skipped: break yield x @@ -254,7 +254,7 @@ def extended_trange(sampler, count, *args, **kwargs): seq = range(count) if cmd_opts.disable_console_progressbars else tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs) for x in seq: - if state.interrupted: + if state.interrupted or state.skipped: break if sampler.stop_at is not None and x > sampler.stop_at: diff --git a/modules/shared.py b/modules/shared.py index 7f802bd9..ca462628 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -99,7 +99,6 @@ class State: def skip(self): self.skipped = True - self.interrupted = True def interrupt(self): self.interrupted = True -- cgit v1.2.1 From 4999eb2ef9b30e8c42ca7e4a94d4bbffe4d1f015 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 14:25:47 +0300 Subject: do not let user choose his own prompt token count limit --- README.md | 1 + modules/processing.py | 5 ----- modules/sd_hijack.py | 25 ++++++++++++------------- modules/shared.py | 3 --- 4 files changed, 13 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index d6e1d50b..ef9b5e31 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web - [Composable-Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/), a way to use multiple prompts at once - separate prompts using uppercase `AND` - also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2` +- No token limit for prompts (original stable diffusion lets you use up to 75 tokens) ## Installation and Running Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. diff --git a/modules/processing.py b/modules/processing.py index 3657fe69..d5162ddc 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -123,7 +123,6 @@ class Processed: self.index_of_first_image = index_of_first_image self.styles = p.styles self.job_timestamp = state.job_timestamp - self.max_prompt_tokens = opts.max_prompt_tokens self.eta = p.eta self.ddim_discretize = p.ddim_discretize @@ -171,7 +170,6 @@ class Processed: "infotexts": self.infotexts, "styles": self.styles, "job_timestamp": self.job_timestamp, - "max_prompt_tokens": self.max_prompt_tokens, } return json.dumps(obj) @@ -269,8 +267,6 @@ def fix_seed(p): def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0): index = position_in_batch + iteration * p.batch_size - max_tokens = getattr(p, 'max_prompt_tokens', opts.max_prompt_tokens) - generation_params = { "Steps": p.steps, "Sampler": sd_samplers.samplers[p.sampler_index].name, @@ -286,7 +282,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), - "Max tokens": (None if max_tokens == shared.vanilla_max_prompt_tokens else max_tokens) } generation_params.update(p.extra_generation_params) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 340329c0..2c1332c9 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -36,6 +36,13 @@ def undo_optimizations(): ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward +def get_target_prompt_token_count(token_count): + if token_count < 75: + return 75 + + return math.ceil(token_count / 10) * 10 + + class StableDiffusionModelHijack: fixes = None comments = [] @@ -84,7 +91,7 @@ class StableDiffusionModelHijack: def tokenize(self, text): max_length = opts.max_prompt_tokens - 2 _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) - return remade_batch_tokens[0], token_count, max_length + return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count) class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): @@ -114,7 +121,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def tokenize_line(self, line, used_custom_terms, hijack_comments): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id - maxlen = opts.max_prompt_tokens if opts.enable_emphasis: parsed = prompt_parser.parse_prompt_attention(line) @@ -146,19 +152,12 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): used_custom_terms.append((embedding.name, embedding.checksum())) i += embedding_length_in_tokens - if len(remade_tokens) > maxlen - 2: - vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} - ovf = remade_tokens[maxlen - 2:] - overflowing_words = [vocab.get(int(x), "") for x in ovf] - overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) - hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") - token_count = len(remade_tokens) - remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) - remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] + prompt_target_length = get_target_prompt_token_count(token_count) + tokens_to_add = prompt_target_length - len(remade_tokens) + 1 - multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) - multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] + remade_tokens = [id_start] + remade_tokens + [id_end] * tokens_to_add + multipliers = [1.0] + multipliers + [1.0] * tokens_to_add return remade_tokens, fixes, multipliers, token_count diff --git a/modules/shared.py b/modules/shared.py index ca462628..475d7e52 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -123,8 +123,6 @@ interrogator = modules.interrogate.InterrogateModels("interrogate") face_restorers = [] -vanilla_max_prompt_tokens = 77 - def realesrgan_models_names(): import modules.realesrgan_model @@ -225,7 +223,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "filter_nsfw": OptionInfo(False, "Filter NSFW content"), - "max_prompt_tokens": OptionInfo(vanilla_max_prompt_tokens, f"Max prompt token count. Two tokens are reserved for for start and end. Default is {vanilla_max_prompt_tokens}. Setting this to a different value will result in different pictures for same seed.", gr.Number, {"precision": 0}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), })) -- cgit v1.2.1 From 4201fd14f5769a4cf6723d2bc5495c3c84a2cd00 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 14:42:34 +0300 Subject: install xformers --- launch.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/launch.py b/launch.py index 75edb66a..f3fbe16a 100644 --- a/launch.py +++ b/launch.py @@ -124,6 +124,9 @@ if not is_installed("gfpgan"): if not is_installed("clip"): run_pip(f"install {clip_package}", "clip") +if not is_installed("xformers"): + run_pip("install https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/a/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers") + os.makedirs(dir_repos, exist_ok=True) git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash) -- cgit v1.2.1 From 3f166be1b60ff2ab33a6d2646809ec3f48796303 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 14:42:50 +0300 Subject: Update requirements.txt --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 304a066a..81641d68 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,4 +24,3 @@ torchdiffeq kornia lark functorch -#xformers? -- cgit v1.2.1 From 77f4237d1c3af1756e7dab2699e3dcebad5619d6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 15:25:59 +0300 Subject: fix bugs related to variable prompt lengths --- modules/sd_hijack.py | 14 +++++++++----- modules/sd_samplers.py | 35 ++++++++++++++++++++++++++++------- 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 2c1332c9..7e7fde0f 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -89,7 +89,6 @@ class StableDiffusionModelHijack: layer.padding_mode = 'circular' if enable else 'zeros' def tokenize(self, text): - max_length = opts.max_prompt_tokens - 2 _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count) @@ -174,7 +173,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if line in cache: remade_tokens, fixes, multipliers = cache[line] else: - remade_tokens, fixes, multipliers, token_count = self.tokenize_line(line, used_custom_terms, hijack_comments) + remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments) + token_count = max(current_token_count, token_count) cache[line] = (remade_tokens, fixes, multipliers) @@ -265,15 +265,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if len(used_custom_terms) > 0: self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) - position_ids_array = [min(x, 75) for x in range(len(remade_batch_tokens[0])-1)] + [76] + target_token_count = get_target_prompt_token_count(token_count) + 2 + + position_ids_array = [min(x, 75) for x in range(target_token_count-1)] + [76] position_ids = torch.asarray(position_ids_array, device=devices.device).expand((1, -1)) - tokens = torch.asarray(remade_batch_tokens).to(device) + remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens] + tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device) outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids) z = outputs.last_hidden_state # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise - batch_multipliers = torch.asarray(batch_multipliers).to(device) + batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers] + batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device) original_mean = z.mean() z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) new_mean = z.mean() diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 13a8b322..eade0dbb 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -142,6 +142,16 @@ class VanillaStableDiffusionSampler: assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers' cond = tensor + # for DDIM, shapes must match, we can't just process cond and uncond independently; + # filling unconditional_conditioning with repeats of the last vector to match length is + # not 100% correct but should work well enough + if unconditional_conditioning.shape[1] < cond.shape[1]: + last_vector = unconditional_conditioning[:, -1:] + last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1]) + unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated]) + elif unconditional_conditioning.shape[1] > cond.shape[1]: + unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]] + if self.mask is not None: img_orig = self.sampler.model.q_sample(self.init_latent, ts) x_dec = img_orig * self.mask + self.nmask * x_dec @@ -221,18 +231,29 @@ class CFGDenoiser(torch.nn.Module): x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) - cond_in = torch.cat([tensor, uncond]) - if shared.batch_cond_uncond: - x_out = self.inner_model(x_in, sigma_in, cond=cond_in) + if tensor.shape[1] == uncond.shape[1]: + cond_in = torch.cat([tensor, uncond]) + + if shared.batch_cond_uncond: + x_out = self.inner_model(x_in, sigma_in, cond=cond_in) + else: + x_out = torch.zeros_like(x_in) + for batch_offset in range(0, x_out.shape[0], batch_size): + a = batch_offset + b = a + batch_size + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b]) else: x_out = torch.zeros_like(x_in) - for batch_offset in range(0, x_out.shape[0], batch_size): + batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size + for batch_offset in range(0, tensor.shape[0], batch_size): a = batch_offset - b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b]) + b = min(a + batch_size, tensor.shape[0]) + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=tensor[a:b]) + + x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=uncond) - denoised_uncond = x_out[-batch_size:] + denoised_uncond = x_out[-uncond.shape[0]:] denoised = torch.clone(denoised_uncond) for i, conds in enumerate(conds_list): -- cgit v1.2.1 From 7001bffe0247804793dfabb69ac96d832572ccd0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 15:43:25 +0300 Subject: fix AND broken for long prompts --- modules/prompt_parser.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index f00256f2..15666073 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -239,6 +239,15 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): conds_list.append(conds_for_batch) + # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes + # and won't be able to torch.stack them. So this fixes that. + token_count = max([x.shape[0] for x in tensors]) + for i in range(len(tensors)): + if tensors[i].shape[0] != token_count: + last_vector = tensors[i][-1:] + last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1]) + tensors[i] = torch.vstack([tensors[i], last_vector_repeated]) + return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype) -- cgit v1.2.1 From 772db721a52da374d627b60994222051f26c27a7 Mon Sep 17 00:00:00 2001 From: ddPn08 Date: Fri, 7 Oct 2022 23:02:07 +0900 Subject: fix glob path in hypernetwork.py --- modules/hypernetwork.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py index c7b86682..7f062242 100644 --- a/modules/hypernetwork.py +++ b/modules/hypernetwork.py @@ -43,7 +43,7 @@ class Hypernetwork: def load_hypernetworks(path): res = {} - for filename in glob.iglob(path + '**/*.pt', recursive=True): + for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): try: hn = Hypernetwork(filename) res[hn.name] = hn -- cgit v1.2.1 From 32e428ff19c28c87bb2ed362316b928b372e3a70 Mon Sep 17 00:00:00 2001 From: guaneec Date: Sat, 8 Oct 2022 16:01:34 +0800 Subject: Remove duplicate event listeners --- javascript/imageviewer.js | 3 +++ 1 file changed, 3 insertions(+) diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index 3a0baac8..4c0e8f4b 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -86,6 +86,9 @@ function showGalleryImage(){ if(fullImg_preview != null){ fullImg_preview.forEach(function function_name(e) { + if (e.dataset.modded) + return; + e.dataset.modded = true; if(e && e.parentElement.tagName == 'DIV'){ e.style.cursor='pointer' -- cgit v1.2.1 From 5f85a74b00c0154bfd559dc67edfa7e30342b7c9 Mon Sep 17 00:00:00 2001 From: MrCheeze Date: Fri, 7 Oct 2022 17:48:34 -0400 Subject: fix bug where when using prompt composition, hijack_comments generated before the final AND will be dropped --- modules/processing.py | 1 + modules/sd_hijack.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index d5162ddc..8240ee27 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -313,6 +313,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: os.makedirs(p.outpath_grids, exist_ok=True) modules.sd_hijack.model_hijack.apply_circular(p.tiling) + modules.sd_hijack.model_hijack.clear_comments() comments = {} diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 7e7fde0f..ba808a39 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -88,6 +88,9 @@ class StableDiffusionModelHijack: for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]: layer.padding_mode = 'circular' if enable else 'zeros' + def clear_comments(self): + self.comments = [] + def tokenize(self, text): _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count) @@ -260,7 +263,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) self.hijack.fixes = hijack_fixes - self.hijack.comments = hijack_comments + self.hijack.comments += hijack_comments if len(used_custom_terms) > 0: self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) -- cgit v1.2.1 From d0e85873ac72416d32dee8720dc9e93ab3d3e236 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 16:13:26 +0300 Subject: check for OS and env variable --- launch.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/launch.py b/launch.py index f3fbe16a..a2089b3b 100644 --- a/launch.py +++ b/launch.py @@ -4,6 +4,7 @@ import os import sys import importlib.util import shlex +import platform dir_repos = "repositories" dir_tmp = "tmp" @@ -31,6 +32,7 @@ def extract_arg(args, name): args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test') +args, xformers = extract_arg(args, '--xformers') def repo_dir(name): @@ -124,8 +126,11 @@ if not is_installed("gfpgan"): if not is_installed("clip"): run_pip(f"install {clip_package}", "clip") -if not is_installed("xformers"): - run_pip("install https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/a/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers") +if not is_installed("xformers") and xformers: + if platform.system() == "Windows": + run_pip("install https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/a/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers") + elif: + run_pip("install xformers", "xformers") os.makedirs(dir_repos, exist_ok=True) -- cgit v1.2.1 From 26b459a3799c5cdf71ca8ed5315a99f69c69f02c Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 16:20:04 +0300 Subject: default to split attention if cuda is available and xformers is not --- modules/sd_hijack.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 3da8c8ce..04adcf03 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -21,12 +21,12 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu - if not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip): + if not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip or shared.xformers_available): ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 - elif cmd_opts.opt_split_attention: + elif cmd_opts.opt_split_attention or torch.cuda.is_available(): ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward -- cgit v1.2.1 From ddfa9a97865c732193023a71521c5b7b53d8571b Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 16:20:41 +0300 Subject: add xformers_available shared variable --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared.py b/modules/shared.py index 8cc3b2fe..6ed4b802 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -74,7 +74,7 @@ device = devices.device batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram - +xformers_available = False config_filename = cmd_opts.ui_settings_file -- cgit v1.2.1 From 69d0053583757ce2942d62de81e8b89e6be07840 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 16:21:40 +0300 Subject: update sd_hijack_opt to respect new env variables --- modules/sd_hijack_optimizations.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index ee58c7e4..be09ec8f 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,9 +1,14 @@ import math import torch from torch import einsum -import xformers.ops -import functorch -xformers._is_functorch_available=True +try: + import xformers.ops + import functorch + xformers._is_functorch_available = True + shared.xformers_available = True +except: + print('Cannot find xformers, defaulting to split attention. Try setting --xformers in your webui-user file if you wish to install it.') + continue from ldm.util import default from einops import rearrange -- cgit v1.2.1 From ca5f0f149c29c344a6badd055b15b5e5fcd6e938 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 16:22:38 +0300 Subject: Update launch.py --- launch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/launch.py b/launch.py index a2089b3b..a592e1ba 100644 --- a/launch.py +++ b/launch.py @@ -129,7 +129,7 @@ if not is_installed("clip"): if not is_installed("xformers") and xformers: if platform.system() == "Windows": run_pip("install https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/a/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers") - elif: + elif platform.system() == "Linux": run_pip("install xformers", "xformers") os.makedirs(dir_repos, exist_ok=True) -- cgit v1.2.1 From 7ffea1507813540b8cd9e73feb7bf23de1ac4e27 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 16:24:06 +0300 Subject: Update requirements_versions.txt --- requirements_versions.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements_versions.txt b/requirements_versions.txt index fdff2687..fec3e9d5 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -22,3 +22,4 @@ resize-right==0.0.2 torchdiffeq==0.2.3 kornia==0.6.7 lark==1.1.2 +functorch==0.2.1 -- cgit v1.2.1 From 970de9ee6891ff586821d0d80dde01c2f6c681b3 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 16:29:43 +0300 Subject: Update sd_hijack.py --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 04adcf03..5b30539f 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -21,7 +21,7 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu - if not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip or shared.xformers_available): + if not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip) and shared.xformers_available: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: -- cgit v1.2.1 From 7ff1170a2e11b6f00f587407326db0b9f8f51adf Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 16:33:39 +0300 Subject: emergency fix for xformers (continue + shared) --- modules/sd_hijack_optimizations.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index e43e2c7a..05023b6f 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,19 +1,19 @@ import math import torch from torch import einsum -try: - import xformers.ops - import functorch - xformers._is_functorch_available = True - shared.xformers_available = True -except: - print('Cannot find xformers, defaulting to split attention. Try setting --xformers in your webui-user file if you wish to install it.') - continue + from ldm.util import default from einops import rearrange from modules import shared +try: + import xformers.ops + import functorch + xformers._is_functorch_available = True + shared.xformers_available = True +except Exception: + print('Cannot find xformers, defaulting to split attention. Try adding --xformers commandline argument to your webui-user file if you wish to install it.') # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion def split_cross_attention_forward_v1(self, x, context=None, mask=None): -- cgit v1.2.1 From dc1117233ef8f9b25ff1ac40b158f20b70ba2fcb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 17:02:18 +0300 Subject: simplify xfrmers options: --xformers to enable and that's it --- launch.py | 2 +- modules/sd_hijack.py | 2 +- modules/sd_hijack_optimizations.py | 20 +++++++++++++------- modules/shared.py | 2 +- 4 files changed, 16 insertions(+), 10 deletions(-) diff --git a/launch.py b/launch.py index a592e1ba..61f62096 100644 --- a/launch.py +++ b/launch.py @@ -32,7 +32,7 @@ def extract_arg(args, name): args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test') -args, xformers = extract_arg(args, '--xformers') +xformers = '--xformers' in args def repo_dir(name): diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 5d93f7f6..91e98c16 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -22,7 +22,7 @@ def apply_optimizations(): undo_optimizations() ldm.modules.diffusionmodules.model.nonlinearity = silu - if not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip) and shared.xformers_available: + if cmd_opts.xformers and shared.xformers_available and not torch.version.hip: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 05023b6f..d23d733b 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,4 +1,7 @@ import math +import sys +import traceback + import torch from torch import einsum @@ -7,13 +10,16 @@ from einops import rearrange from modules import shared -try: - import xformers.ops - import functorch - xformers._is_functorch_available = True - shared.xformers_available = True -except Exception: - print('Cannot find xformers, defaulting to split attention. Try adding --xformers commandline argument to your webui-user file if you wish to install it.') +if shared.cmd_opts.xformers: + try: + import xformers.ops + import functorch + xformers._is_functorch_available = True + shared.xformers_available = True + except Exception: + print("Cannot import xformers", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion def split_cross_attention_forward_v1(self, x, context=None, mask=None): diff --git a/modules/shared.py b/modules/shared.py index d68df751..02cb2722 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -43,7 +43,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET')) parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR')) parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR')) -parser.add_argument("--disable-opt-xformers-attention", action='store_true', help="force-disables xformers attention optimization") +parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") -- cgit v1.2.1 From 27032c47df9c07ac21dd5b89fa7dc247bb8705b6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 17:10:05 +0300 Subject: restore old opt_split_attention/disable_opt_split_attention logic --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 91e98c16..335a2bcf 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -27,7 +27,7 @@ def apply_optimizations(): ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 - elif cmd_opts.opt_split_attention or torch.cuda.is_available(): + elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward -- cgit v1.2.1 From 4f33289d0fc5aa3a197f4a4c926d03d44f0d597e Mon Sep 17 00:00:00 2001 From: Milly Date: Sat, 8 Oct 2022 22:48:15 +0900 Subject: Fixed typo --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui.py b/modules/ui.py index e3e62fdd..ffd75f6a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -946,7 +946,7 @@ def create_ui(wrap_gradio_gpu_call): custom_name = gr.Textbox(label="Custom Name (Optional)") interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3) interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method") - save_as_half = gr.Checkbox(value=False, label="Safe as float16") + save_as_half = gr.Checkbox(value=False, label="Save as float16") modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') with gr.Column(variant='panel'): -- cgit v1.2.1 From cfc33f99d47d1f45af15499e5965834089d11858 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 17:28:58 +0300 Subject: why did you do this --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 335a2bcf..ed271976 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -28,7 +28,7 @@ def apply_optimizations(): elif cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): - ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward -- cgit v1.2.1 From 7e639cd49855ef59e087ae9a9122756a937007eb Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 17:22:20 +0300 Subject: check for 3.10 --- launch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/launch.py b/launch.py index 61f62096..1d65a779 100644 --- a/launch.py +++ b/launch.py @@ -126,7 +126,7 @@ if not is_installed("gfpgan"): if not is_installed("clip"): run_pip(f"install {clip_package}", "clip") -if not is_installed("xformers") and xformers: +if not is_installed("xformers") and xformers and platform.python_version().startswith("3.10"): if platform.system() == "Windows": run_pip("install https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/a/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers") elif platform.system() == "Linux": -- cgit v1.2.1 From 017b6b8744f0771e498656ec043e12d5cc6969a7 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 17:27:21 +0300 Subject: check for ampere --- modules/sd_hijack.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index ed271976..5e266d5e 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -22,9 +22,10 @@ def apply_optimizations(): undo_optimizations() ldm.modules.diffusionmodules.model.nonlinearity = silu - if cmd_opts.xformers and shared.xformers_available and not torch.version.hip: - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward + if cmd_opts.xformers and shared.xformers_available and torch.version.cuda: + if torch.cuda.get_device_capability(shared.device) == (8, 6): + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): -- cgit v1.2.1 From cc0258aea7b6605be3648900063cfa96ed7c5ffa Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 17:44:53 +0300 Subject: check for ampere without destroying the optimizations. again. --- modules/sd_hijack.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 5e266d5e..a3e374f0 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -22,10 +22,9 @@ def apply_optimizations(): undo_optimizations() ldm.modules.diffusionmodules.model.nonlinearity = silu - if cmd_opts.xformers and shared.xformers_available and torch.version.cuda: - if torch.cuda.get_device_capability(shared.device) == (8, 6): - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward + if cmd_opts.xformers and shared.xformers_available and torch.version.cuda and torch.cuda.get_device_capability(shared.device) == (8, 6): + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): -- cgit v1.2.1 From 34acad1628e98a5e0cbd459fa69ded915864f53d Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Fri, 7 Oct 2022 22:56:00 +0100 Subject: Add GZipMiddleware to root demo --- webui.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/webui.py b/webui.py index 3b4cf5e9..18de8e16 100644 --- a/webui.py +++ b/webui.py @@ -5,6 +5,8 @@ import importlib import signal import threading +from fastapi.middleware.gzip import GZipMiddleware + from modules.paths import script_path from modules import devices, sd_samplers @@ -93,7 +95,7 @@ def webui(): demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) - demo.launch( + app,local_url,share_url = demo.launch( share=cmd_opts.share, server_name="0.0.0.0" if cmd_opts.listen else None, server_port=cmd_opts.port, @@ -102,6 +104,8 @@ def webui(): inbrowser=cmd_opts.autolaunch, prevent_thread_lock=True ) + + app.add_middleware(GZipMiddleware,minimum_size=1000) while 1: time.sleep(0.5) -- cgit v1.2.1 From a5550f0213c3f145b1c984816ebcef92c48853ee Mon Sep 17 00:00:00 2001 From: Artem Zagidulin Date: Wed, 5 Oct 2022 19:10:39 +0300 Subject: alternate prompt --- modules/prompt_parser.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 15666073..919d5d31 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -13,13 +13,14 @@ import lark schedule_parser = lark.Lark(r""" !start: (prompt | /[][():]/+)* -prompt: (emphasized | scheduled | plain | WHITESPACE)* +prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)* !emphasized: "(" prompt ")" | "(" prompt ":" prompt ")" | "[" prompt "]" scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]" +alternate: "[" prompt ("|" prompt)+ "]" WHITESPACE: /\s+/ -plain: /([^\\\[\]():]|\\.)+/ +plain: /([^\\\[\]():|]|\\.)+/ %import common.SIGNED_NUMBER -> NUMBER """) @@ -59,6 +60,8 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): tree.children[-1] *= steps tree.children[-1] = min(steps, int(tree.children[-1])) l.append(tree.children[-1]) + def alternate(self, tree): + l.extend(range(1, steps+1)) CollectSteps().visit(tree) return sorted(set(l)) @@ -67,6 +70,8 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): def scheduled(self, args): before, after, _, when = args yield before or () if step <= when else after + def alternate(self, args): + yield next(args[(step - 1)%len(args)]) def start(self, args): def flatten(x): if type(x) == str: -- cgit v1.2.1 From 01f8cb44474e454903c11718e6a4f33dbde34bb8 Mon Sep 17 00:00:00 2001 From: Greendayle Date: Sat, 8 Oct 2022 18:02:56 +0200 Subject: made deepdanbooru optional, added to readme, automatic download of deepbooru model --- README.md | 2 ++ launch.py | 4 ++++ modules/deepbooru.py | 20 ++++++++++---------- modules/shared.py | 1 + modules/ui.py | 19 ++++++++++++------- requirements.txt | 3 --- requirements_versions.txt | 3 --- 7 files changed, 29 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index ef9b5e31..6cd7a1f9 100644 --- a/README.md +++ b/README.md @@ -66,6 +66,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web - separate prompts using uppercase `AND` - also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2` - No token limit for prompts (original stable diffusion lets you use up to 75 tokens) +- DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args) ## Installation and Running Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. @@ -123,4 +124,5 @@ The documentation was moved from this README over to the project's [wiki](https: - Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot - CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. +- DeepDanbooru - interrogator for anime diffusors https://github.com/KichangKim/DeepDanbooru - (You) diff --git a/launch.py b/launch.py index 61f62096..d46426eb 100644 --- a/launch.py +++ b/launch.py @@ -33,6 +33,7 @@ def extract_arg(args, name): args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test') xformers = '--xformers' in args +deepdanbooru = '--deepdanbooru' in args def repo_dir(name): @@ -132,6 +133,9 @@ if not is_installed("xformers") and xformers: elif platform.system() == "Linux": run_pip("install xformers", "xformers") +if not is_installed("deepdanbooru") and deepdanbooru: + run_pip("install git+https://github.com/KichangKim/DeepDanbooru.git@edf73df4cdaeea2cf00e9ac08bd8a9026b7a7b26#egg=deepdanbooru[tensorflow] tensorflow==2.10.0 tensorflow-io==0.27.0", "deepdanbooru") + os.makedirs(dir_repos, exist_ok=True) git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash) diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 781b2249..7e3c0618 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -9,16 +9,16 @@ def _load_tf_and_return_tags(pil_image, threshold): import numpy as np this_folder = os.path.dirname(__file__) - model_path = os.path.join(this_folder, '..', 'models', 'deepbooru', 'deepdanbooru-v3-20211112-sgd-e28') - - model_good = False - for path_candidate in [model_path, os.path.dirname(model_path)]: - if os.path.exists(os.path.join(path_candidate, 'project.json')): - model_path = path_candidate - model_good = True - if not model_good: - return ("Download https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/" - "deepdanbooru-v3-20211112-sgd-e28.zip unpack and put into models/deepbooru") + model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru')) + if not os.path.exists(os.path.join(model_path, 'project.json')): + # there is no point importing these every time + import zipfile + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip", + model_path) + with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref: + zip_ref.extractall(model_path) + os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip")) tags = dd.project.load_tags_from_project(model_path) model = dd.project.load_model_from_project( diff --git a/modules/shared.py b/modules/shared.py index 02cb2722..c87b726e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -44,6 +44,7 @@ parser.add_argument("--scunet-models-path", type=str, help="Path to directory wi parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR')) parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR')) parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") +parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") diff --git a/modules/ui.py b/modules/ui.py index 30583fe9..c5c11c3c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -23,9 +23,10 @@ import gradio.utils import gradio.routes from modules import sd_hijack -from modules.deepbooru import get_deepbooru_tags from modules.paths import script_path from modules.shared import opts, cmd_opts +if cmd_opts.deepdanbooru: + from modules.deepbooru import get_deepbooru_tags import modules.shared as shared from modules.sd_samplers import samplers, samplers_for_img2img from modules.sd_hijack import model_hijack @@ -437,7 +438,10 @@ def create_toprow(is_img2img): with gr.Row(scale=1): if is_img2img: interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") - deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") + if cmd_opts.deepdanbooru: + deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") + else: + deepbooru = None else: interrogate = None deepbooru = None @@ -782,11 +786,12 @@ def create_ui(wrap_gradio_gpu_call): outputs=[img2img_prompt], ) - img2img_deepbooru.click( - fn=interrogate_deepbooru, - inputs=[init_img], - outputs=[img2img_prompt], - ) + if cmd_opts.deepdanbooru: + img2img_deepbooru.click( + fn=interrogate_deepbooru, + inputs=[init_img], + outputs=[img2img_prompt], + ) save.click( fn=wrap_gradio_call(save_files), diff --git a/requirements.txt b/requirements.txt index cd3953c6..81641d68 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,4 @@ resize-right torchdiffeq kornia lark -deepdanbooru -tensorflow -tensorflow-io functorch diff --git a/requirements_versions.txt b/requirements_versions.txt index 2d256a54..fec3e9d5 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -22,7 +22,4 @@ resize-right==0.0.2 torchdiffeq==0.2.3 kornia==0.6.7 lark==1.1.2 -git+https://github.com/KichangKim/DeepDanbooru.git@edf73df4cdaeea2cf00e9ac08bd8a9026b7a7b26#egg=deepdanbooru[tensorflow] -tensorflow==2.10.0 -tensorflow-io==0.27.0 functorch==0.2.1 -- cgit v1.2.1 From f9c5da159245bb1e7603b3c8b9e0703bcb1c2ff5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 19:05:19 +0300 Subject: add fallback for xformers_attnblock_forward --- modules/sd_hijack_optimizations.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index d23d733b..dba21192 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -211,6 +211,7 @@ def cross_attention_attnblock_forward(self, x): return h3 def xformers_attnblock_forward(self, x): + try: h_ = x h_ = self.norm(h_) q1 = self.q(h_).contiguous() @@ -218,4 +219,6 @@ def xformers_attnblock_forward(self, x): v = self.v(h_).contiguous() out = xformers.ops.memory_efficient_attention(q1, k1, v) out = self.proj_out(out) - return x+out + return x + out + except NotImplementedError: + return cross_attention_attnblock_forward(self, x) -- cgit v1.2.1 From 3061cdb7b610d4ba7f1ea695d9d6364b591e5bc7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 19:22:15 +0300 Subject: add --force-enable-xformers option and also add messages to console regarding cross attention optimizations --- modules/sd_hijack.py | 6 +++++- modules/shared.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index a3e374f0..307cc67d 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -22,12 +22,16 @@ def apply_optimizations(): undo_optimizations() ldm.modules.diffusionmodules.model.nonlinearity = silu - if cmd_opts.xformers and shared.xformers_available and torch.version.cuda and torch.cuda.get_device_capability(shared.device) == (8, 6): + + if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and torch.cuda.get_device_capability(shared.device) == (8, 6)): + print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: + print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): + print("Applying cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward diff --git a/modules/shared.py b/modules/shared.py index 02cb2722..8f941226 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -44,6 +44,7 @@ parser.add_argument("--scunet-models-path", type=str, help="Path to directory wi parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR')) parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR')) parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") +parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") -- cgit v1.2.1 From 15c4278f1a18b8104e135dd82690d10cff39a2e7 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sat, 8 Oct 2022 17:50:01 +0100 Subject: TI preprocess wording MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I had to check the code to work out what splitting was 🤷🏿 --- modules/ui.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index ffd75f6a..d52d74c6 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -980,9 +980,9 @@ def create_ui(wrap_gradio_gpu_call): process_dst = gr.Textbox(label='Destination directory') with gr.Row(): - process_flip = gr.Checkbox(label='Flip') - process_split = gr.Checkbox(label='Split into two') - process_caption = gr.Checkbox(label='Add caption') + process_flip = gr.Checkbox(label='Create flipped copies') + process_split = gr.Checkbox(label='Split oversized images into two') + process_caption = gr.Checkbox(label='Use CLIP caption as filename') with gr.Row(): with gr.Column(scale=3): -- cgit v1.2.1 From b458fa48fe5734a872bca83061d702609cb52940 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sat, 8 Oct 2022 17:56:28 +0100 Subject: Update ui.py --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui.py b/modules/ui.py index d52d74c6..b09359aa 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -982,7 +982,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Row(): process_flip = gr.Checkbox(label='Create flipped copies') process_split = gr.Checkbox(label='Split oversized images into two') - process_caption = gr.Checkbox(label='Use CLIP caption as filename') + process_caption = gr.Checkbox(label='Use BLIP caption as filename') with gr.Row(): with gr.Column(scale=3): -- cgit v1.2.1 From 1371d7608b402d6f15c200ec2f5fde4579836a05 Mon Sep 17 00:00:00 2001 From: Fampai Date: Sat, 8 Oct 2022 14:28:22 -0400 Subject: Added ability to ignore last n layers in FrozenCLIPEmbedder --- modules/sd_hijack.py | 11 +++++++++-- modules/shared.py | 1 + 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 307cc67d..f12a9696 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -281,8 +281,15 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens] tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device) - outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids) - z = outputs.last_hidden_state + + tmp = -opts.CLIP_ignore_last_layers + if (opts.CLIP_ignore_last_layers == 0): + outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids) + z = outputs.last_hidden_state + else: + outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=tmp) + z = outputs.hidden_states[tmp] + z = self.wrapped.transformer.text_model.final_layer_norm(z) # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers] diff --git a/modules/shared.py b/modules/shared.py index 8f941226..af8dc744 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -225,6 +225,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "filter_nsfw": OptionInfo(False, "Filter NSFW content"), + 'CLIP_ignore_last_layers': OptionInfo(0, "Ignore last layers of CLIP model", gr.Slider, {"minimum": 0, "maximum": 5, "step": 1}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), })) -- cgit v1.2.1 From e6e42f98df2c928c4f49351ad6b466387ce87d42 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 19:25:10 +0300 Subject: make --force-enable-xformers work without needing --xformers --- modules/sd_hijack_optimizations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index dba21192..c4396bb9 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -10,7 +10,7 @@ from einops import rearrange from modules import shared -if shared.cmd_opts.xformers: +if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: try: import xformers.ops import functorch -- cgit v1.2.1 From 3b2141c5fb6a3c2b8ab4b1e759a97ead77260129 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 22:21:15 +0300 Subject: add 'Ignore last layers of CLIP model' option as a parameter to the infotext --- modules/processing.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index 8240ee27..515fc91a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -123,6 +123,7 @@ class Processed: self.index_of_first_image = index_of_first_image self.styles = p.styles self.job_timestamp = state.job_timestamp + self.clip_skip = opts.CLIP_ignore_last_layers self.eta = p.eta self.ddim_discretize = p.ddim_discretize @@ -141,7 +142,6 @@ class Processed: self.all_subseeds = all_subseeds or [self.subseed] self.infotexts = infotexts or [info] - def js(self): obj = { "prompt": self.prompt, @@ -170,6 +170,7 @@ class Processed: "infotexts": self.infotexts, "styles": self.styles, "job_timestamp": self.job_timestamp, + "clip_skip": self.clip_skip, } return json.dumps(obj) @@ -267,6 +268,8 @@ def fix_seed(p): def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0): index = position_in_batch + iteration * p.batch_size + clip_skip = getattr(p, 'clip_skip', opts.CLIP_ignore_last_layers) + generation_params = { "Steps": p.steps, "Sampler": sd_samplers.samplers[p.sampler_index].name, @@ -282,6 +285,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), + "Clip skip": None if clip_skip==0 else clip_skip, } generation_params.update(p.extra_generation_params) -- cgit v1.2.1 From 610a7f4e1480c0ffeedb2a07dc27ae86bf03c3a8 Mon Sep 17 00:00:00 2001 From: Edouard Leurent Date: Sat, 8 Oct 2022 16:49:43 +0100 Subject: Break after finding the local directory of stable diffusion Otherwise, we may override it with one of the next two path (. or ..) if it is present there, and then the local paths of other modules (taming transformers, codeformers, etc.) wont be found in sd_path/../. Fix https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/1085 --- modules/paths.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/paths.py b/modules/paths.py index 606f7d66..0519caa0 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -12,6 +12,7 @@ possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), for possible_sd_path in possible_sd_paths: if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')): sd_path = os.path.abspath(possible_sd_path) + break assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths) -- cgit v1.2.1 From 432782163ae53e605470bcefc9a6f796c4556912 Mon Sep 17 00:00:00 2001 From: Aidan Holland Date: Sat, 8 Oct 2022 15:12:24 -0400 Subject: chore: Fix typos --- README.md | 2 +- javascript/imageviewer.js | 2 +- modules/interrogate.py | 4 ++-- modules/processing.py | 2 +- modules/scunet_model_arch.py | 4 ++-- modules/sd_models.py | 4 ++-- modules/sd_samplers.py | 4 ++-- modules/shared.py | 6 +++--- modules/swinir_model_arch.py | 2 +- modules/ui.py | 4 ++-- 10 files changed, 17 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index ef9b5e31..63dd0c18 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web - Sampling method selection - Interrupt processing at any time - 4GB video card support (also reports of 2GB working) -- Correct seeds for batches +- Correct seeds for batches - Prompt length validation - get length of prompt in tokens as you type - get a warning after generation if some text was truncated diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index 4c0e8f4b..6a00c0da 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -95,7 +95,7 @@ function showGalleryImage(){ e.addEventListener('click', function (evt) { if(!opts.js_modal_lightbox) return; - modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initialy_zoomed) + modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed) showModal(evt) },true); } diff --git a/modules/interrogate.py b/modules/interrogate.py index eed87144..635e266e 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -140,11 +140,11 @@ class InterrogateModels: res = caption - cilp_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device) + clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device) precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext with torch.no_grad(), precision_scope("cuda"): - image_features = self.clip_model.encode_image(cilp_image).type(self.dtype) + image_features = self.clip_model.encode_image(clip_image).type(self.dtype) image_features /= image_features.norm(dim=-1, keepdim=True) diff --git a/modules/processing.py b/modules/processing.py index 515fc91a..31220881 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -386,7 +386,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if state.interrupted or state.skipped: - # if we are interruped, sample returns just noise + # if we are interrupted, sample returns just noise # use the image collected previously in sampler loop samples_ddim = shared.state.current_latent diff --git a/modules/scunet_model_arch.py b/modules/scunet_model_arch.py index 972a2639..43ca8d36 100644 --- a/modules/scunet_model_arch.py +++ b/modules/scunet_model_arch.py @@ -40,7 +40,7 @@ class WMSA(nn.Module): Returns: attn_mask: should be (1 1 w p p), """ - # supporting sqaure. + # supporting square. attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device) if self.type == 'W': return attn_mask @@ -65,7 +65,7 @@ class WMSA(nn.Module): x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size) h_windows = x.size(1) w_windows = x.size(2) - # sqaure validation + # square validation # assert h_windows == w_windows x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size) diff --git a/modules/sd_models.py b/modules/sd_models.py index 9409d070..a09866ce 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -147,7 +147,7 @@ def load_model_weights(model, checkpoint_file, sd_model_hash): model.first_stage_model.load_state_dict(vae_dict) model.sd_model_hash = sd_model_hash - model.sd_model_checkpint = checkpoint_file + model.sd_model_checkpoint = checkpoint_file def load_model(): @@ -175,7 +175,7 @@ def reload_model_weights(sd_model, info=None): from modules import lowvram, devices, sd_hijack checkpoint_info = info or select_checkpoint() - if sd_model.sd_model_checkpint == checkpoint_info.filename: + if sd_model.sd_model_checkpoint == checkpoint_info.filename: return if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index eade0dbb..6e743f7e 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -181,7 +181,7 @@ class VanillaStableDiffusionSampler: self.initialize(p) - # existing code fails with cetain step counts, like 9 + # existing code fails with certain step counts, like 9 try: self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False) except Exception: @@ -204,7 +204,7 @@ class VanillaStableDiffusionSampler: steps = steps or p.steps - # existing code fails with cetin step counts, like 9 + # existing code fails with certain step counts, like 9 try: samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta) except Exception: diff --git a/modules/shared.py b/modules/shared.py index af8dc744..2dc092d6 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -141,9 +141,9 @@ class OptionInfo: self.section = None -def options_section(section_identifer, options_dict): +def options_section(section_identifier, options_dict): for k, v in options_dict.items(): - v.section = section_identifer + v.section = section_identifier return options_dict @@ -246,7 +246,7 @@ options_templates.update(options_section(('ui', "User interface"), { "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "font": OptionInfo("", "Font for image grids that have text"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), - "js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), + "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), })) diff --git a/modules/swinir_model_arch.py b/modules/swinir_model_arch.py index 461fb354..863f42db 100644 --- a/modules/swinir_model_arch.py +++ b/modules/swinir_model_arch.py @@ -166,7 +166,7 @@ class SwinTransformerBlock(nn.Module): Args: dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. + input_resolution (tuple[int]): Input resolution. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. diff --git a/modules/ui.py b/modules/ui.py index b09359aa..b51af121 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -38,7 +38,7 @@ from modules import prompt_parser from modules.images import save_image import modules.textual_inversion.ui -# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI +# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() mimetypes.add_type('application/javascript', '.js') @@ -102,7 +102,7 @@ def save_files(js_data, images, index): import csv filenames = [] - #quick dictionary to class object conversion. Its neccesary due apply_filename_pattern requiring it + #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it class MyObject: def __init__(self, d=None): if d is not None: -- cgit v1.2.1 From 050a6a798cec90ae2f881c2ddd3f0221e69907dc Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 23:26:48 +0300 Subject: support loading .yaml config with same name as model support EMA weights in processing (????) --- modules/processing.py | 2 +- modules/sd_models.py | 30 +++++++++++++++++++++++------- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 31220881..4fea6d56 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -347,7 +347,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: infotexts = [] output_images = [] - with torch.no_grad(): + with torch.no_grad(), p.sd_model.ema_scope(): with devices.autocast(): p.init(all_prompts, all_seeds, all_subseeds) diff --git a/modules/sd_models.py b/modules/sd_models.py index a09866ce..cb3982b1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,7 +14,7 @@ from modules.paths import models_path model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) -CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name']) +CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config']) checkpoints_list = {} try: @@ -63,14 +63,20 @@ def list_models(): if os.path.exists(cmd_ckpt): h = model_hash(cmd_ckpt) title, short_model_name = modeltitle(cmd_ckpt, h) - checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name) + checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config) shared.opts.data['sd_model_checkpoint'] = title elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) for filename in model_list: h = model_hash(filename) title, short_model_name = modeltitle(filename, h) - checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name) + + basename, _ = os.path.splitext(filename) + config = basename + ".yaml" + if not os.path.exists(config): + config = shared.cmd_opts.config + + checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config) def get_closet_checkpoint_match(searchString): @@ -116,7 +122,10 @@ def select_checkpoint(): return checkpoint_info -def load_model_weights(model, checkpoint_file, sd_model_hash): +def load_model_weights(model, checkpoint_info): + checkpoint_file = checkpoint_info.filename + sd_model_hash = checkpoint_info.hash + print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") pl_sd = torch.load(checkpoint_file, map_location="cpu") @@ -148,15 +157,19 @@ def load_model_weights(model, checkpoint_file, sd_model_hash): model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_file + model.sd_checkpoint_info = checkpoint_info def load_model(): from modules import lowvram, sd_hijack checkpoint_info = select_checkpoint() - sd_config = OmegaConf.load(shared.cmd_opts.config) + if checkpoint_info.config != shared.cmd_opts.config: + print(f"Loading config from: {shared.cmd_opts.config}") + + sd_config = OmegaConf.load(checkpoint_info.config) sd_model = instantiate_from_config(sd_config.model) - load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash) + load_model_weights(sd_model, checkpoint_info) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) @@ -178,6 +191,9 @@ def reload_model_weights(sd_model, info=None): if sd_model.sd_model_checkpoint == checkpoint_info.filename: return + if sd_model.sd_checkpoint_info.config != checkpoint_info.config: + return load_model() + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.send_everything_to_cpu() else: @@ -185,7 +201,7 @@ def reload_model_weights(sd_model, info=None): sd_hijack.model_hijack.undo_hijack(sd_model) - load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash) + load_model_weights(sd_model, checkpoint_info) sd_hijack.model_hijack.hijack(sd_model) -- cgit v1.2.1 From c77c89cc83c618472ad352cf8a28fde28c3a1377 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 10:23:31 +0300 Subject: make main model loading and model merger use the same code --- modules/extras.py | 6 +++--- modules/sd_models.py | 14 +++++++++----- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/modules/extras.py b/modules/extras.py index 1d9e64e5..ef6e6de7 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -169,9 +169,9 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int print(f"Loading {secondary_model_info.filename}...") secondary_model = torch.load(secondary_model_info.filename, map_location='cpu') - - theta_0 = primary_model['state_dict'] - theta_1 = secondary_model['state_dict'] + + theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model) + theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model) theta_funcs = { "Weighted Sum": weighted_sum, diff --git a/modules/sd_models.py b/modules/sd_models.py index cb3982b1..18fb8c2e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -122,6 +122,13 @@ def select_checkpoint(): return checkpoint_info +def get_state_dict_from_checkpoint(pl_sd): + if "state_dict" in pl_sd: + return pl_sd["state_dict"] + + return pl_sd + + def load_model_weights(model, checkpoint_info): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash @@ -131,11 +138,8 @@ def load_model_weights(model, checkpoint_info): pl_sd = torch.load(checkpoint_file, map_location="cpu") if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") - - if "state_dict" in pl_sd: - sd = pl_sd["state_dict"] - else: - sd = pl_sd + + sd = get_state_dict_from_checkpoint(pl_sd) model.load_state_dict(sd, strict=False) -- cgit v1.2.1 From 4e569fd888f8e3c5632a072d51abbb6e4d17abd6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 10:31:47 +0300 Subject: fixed incorrect message about loading config; thanks anon! --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 18fb8c2e..2101b18d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -169,7 +169,7 @@ def load_model(): checkpoint_info = select_checkpoint() if checkpoint_info.config != shared.cmd_opts.config: - print(f"Loading config from: {shared.cmd_opts.config}") + print(f"Loading config from: {checkpoint_info.config}") sd_config = OmegaConf.load(checkpoint_info.config) sd_model = instantiate_from_config(sd_config.model) -- cgit v1.2.1 From 5ab7e88d9b0bb0125af9f7237242a00a93360ce5 Mon Sep 17 00:00:00 2001 From: aoirusann <82883326+aoirusann@users.noreply.github.com> Date: Sat, 8 Oct 2022 13:09:29 +0800 Subject: Add `Download` & `Download as zip` --- modules/ui.py | 39 ++++++++++++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 5 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index b51af121..fe7f10a7 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -98,9 +98,10 @@ def send_gradio_gallery_to_image(x): return image_from_url_text(x[0]) -def save_files(js_data, images, index): +def save_files(js_data, images, do_make_zip, index): import csv filenames = [] + fullfns = [] #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it class MyObject: @@ -141,10 +142,22 @@ def save_files(js_data, images, index): filename = os.path.relpath(fullfn, path) filenames.append(filename) + fullfns.append(fullfn) writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) - return '', '', plaintext_to_html(f"Saved: {filenames[0]}") + # Make Zip + if do_make_zip: + zip_filepath = os.path.join(path, "images.zip") + + from zipfile import ZipFile + with ZipFile(zip_filepath, "w") as zip_file: + for i in range(len(fullfns)): + with open(fullfns[i], mode="rb") as f: + zip_file.writestr(filenames[i], f.read()) + fullfns.insert(0, zip_filepath) + + return fullfns, '', '', plaintext_to_html(f"Saved: {filenames[0]}") def wrap_gradio_call(func, extra_outputs=None): @@ -521,6 +534,12 @@ def create_ui(wrap_gradio_gpu_call): button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder' open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id) + with gr.Row(): + do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False) + + with gr.Row(): + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False) + with gr.Group(): html_info = gr.HTML() generation_info = gr.Textbox(visible=False) @@ -570,13 +589,15 @@ def create_ui(wrap_gradio_gpu_call): save.click( fn=wrap_gradio_call(save_files), - _js="(x, y, z) => [x, y, selected_gallery_index()]", + _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]", inputs=[ generation_info, txt2img_gallery, + do_make_zip, html_info, ], outputs=[ + download_files, html_info, html_info, html_info, @@ -701,6 +722,12 @@ def create_ui(wrap_gradio_gpu_call): button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder' open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id) + with gr.Row(): + do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False) + + with gr.Row(): + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False) + with gr.Group(): html_info = gr.HTML() generation_info = gr.Textbox(visible=False) @@ -776,13 +803,15 @@ def create_ui(wrap_gradio_gpu_call): save.click( fn=wrap_gradio_call(save_files), - _js="(x, y, z) => [x, y, selected_gallery_index()]", + _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]", inputs=[ generation_info, img2img_gallery, - html_info + do_make_zip, + html_info, ], outputs=[ + download_files, html_info, html_info, html_info, -- cgit v1.2.1 From 14192c5b207b16b1ec7a4c9c4ea538d1a6811a4d Mon Sep 17 00:00:00 2001 From: aoirusann Date: Sun, 9 Oct 2022 13:01:10 +0800 Subject: Support `Download` for txt files. --- modules/images.py | 39 +++++++++++++++++++++++++++++++++++++-- modules/ui.py | 5 ++++- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/modules/images.py b/modules/images.py index 29c5ee24..c0a90676 100644 --- a/modules/images.py +++ b/modules/images.py @@ -349,6 +349,38 @@ def get_next_sequence_number(path, basename): def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None): + '''Save an image. + + Args: + image (`PIL.Image`): + The image to be saved. + path (`str`): + The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory. + basename (`str`): + The base filename which will be applied to `filename pattern`. + seed, prompt, short_filename, + extension (`str`): + Image file extension, default is `png`. + pngsectionname (`str`): + Specify the name of the section which `info` will be saved in. + info (`str` or `PngImagePlugin.iTXt`): + PNG info chunks. + existing_info (`dict`): + Additional PNG info. `existing_info == {pngsectionname: info, ...}` + no_prompt: + TODO I don't know its meaning. + p (`StableDiffusionProcessing`) + forced_filename (`str`): + If specified, `basename` and filename pattern will be ignored. + save_to_dirs (bool): + If true, the image will be saved into a subdirectory of `path`. + + Returns: (fullfn, txt_fullfn) + fullfn (`str`): + The full path of the saved imaged. + txt_fullfn (`str` or None): + If a text file is saved for this image, this will be its full path. Otherwise None. + ''' if short_filename or prompt is None or seed is None: file_decoration = "" elif opts.save_to_dirs: @@ -424,7 +456,10 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i piexif.insert(exif_bytes(), fullfn_without_extension + ".jpg") if opts.save_txt and info is not None: - with open(f"{fullfn_without_extension}.txt", "w", encoding="utf8") as file: + txt_fullfn = f"{fullfn_without_extension}.txt" + with open(txt_fullfn, "w", encoding="utf8") as file: file.write(info + "\n") + else: + txt_fullfn = None - return fullfn + return fullfn, txt_fullfn diff --git a/modules/ui.py b/modules/ui.py index fe7f10a7..debd8873 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -138,11 +138,14 @@ def save_files(js_data, images, do_make_zip, index): is_grid = image_index < p.index_of_first_image i = 0 if is_grid else (image_index - p.index_of_first_image) - fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) + fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) filename = os.path.relpath(fullfn, path) filenames.append(filename) fullfns.append(fullfn) + if txt_fullfn: + filenames.append(os.path.basename(txt_fullfn)) + fullfns.append(txt_fullfn) writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) -- cgit v1.2.1 From 122d42687b97ec4df4c2a8c335d2de385cd1f1a1 Mon Sep 17 00:00:00 2001 From: Fampai Date: Sat, 8 Oct 2022 22:37:35 -0400 Subject: Fix VRAM Issue by only loading in hypernetwork when selected in settings --- modules/hypernetwork.py | 23 +++++++++++++++-------- modules/sd_hijack_optimizations.py | 6 +++--- modules/shared.py | 7 ++----- webui.py | 3 +++ 4 files changed, 23 insertions(+), 16 deletions(-) diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py index 7f062242..19f1c227 100644 --- a/modules/hypernetwork.py +++ b/modules/hypernetwork.py @@ -40,18 +40,25 @@ class Hypernetwork: self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) -def load_hypernetworks(path): +def list_hypernetworks(path): res = {} - for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): + name = os.path.splitext(os.path.basename(filename))[0] + res[name] = filename + return res + + +def load_hypernetwork(filename): + print(f"Loading hypernetwork {filename}") + path = shared.hypernetworks.get(filename, None) + if (path is not None): try: - hn = Hypernetwork(filename) - res[hn.name] = hn + shared.loaded_hypernetwork = Hypernetwork(path) except Exception: - print(f"Error loading hypernetwork {filename}", file=sys.stderr) + print(f"Error loading hypernetwork {path}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - - return res + else: + shared.loaded_hypernetwork = None def attention_CrossAttention_forward(self, x, context=None, mask=None): @@ -60,7 +67,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None): q = self.to_q(x) context = default(context, x) - hypernetwork = shared.selected_hypernetwork() + hypernetwork = shared.loaded_hypernetwork hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) if hypernetwork_layers is not None: diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index c4396bb9..634fb4b2 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -28,7 +28,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - hypernetwork = shared.selected_hypernetwork() + hypernetwork = shared.loaded_hypernetwork hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) if hypernetwork_layers is not None: @@ -68,7 +68,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - hypernetwork = shared.selected_hypernetwork() + hypernetwork = shared.loaded_hypernetwork hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) if hypernetwork_layers is not None: @@ -132,7 +132,7 @@ def xformers_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) context = default(context, x) - hypernetwork = shared.selected_hypernetwork() + hypernetwork = shared.loaded_hypernetwork hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) if hypernetwork_layers is not None: k_in = self.to_k(hypernetwork_layers[0](context)) diff --git a/modules/shared.py b/modules/shared.py index b2c76a32..9dce6cb7 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -79,11 +79,8 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram xformers_available = False config_filename = cmd_opts.ui_settings_file -hypernetworks = hypernetwork.load_hypernetworks(os.path.join(models_path, 'hypernetworks')) - - -def selected_hypernetwork(): - return hypernetworks.get(opts.sd_hypernetwork, None) +hypernetworks = hypernetwork.list_hypernetworks(os.path.join(models_path, 'hypernetworks')) +loaded_hypernetwork = None class State: diff --git a/webui.py b/webui.py index 18de8e16..270584f7 100644 --- a/webui.py +++ b/webui.py @@ -82,6 +82,9 @@ modules.scripts.load_scripts(os.path.join(script_path, "scripts")) shared.sd_model = modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) +loaded_hypernetwork = modules.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork) +shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) + def webui(): # make the program just exit at ctrl+c without waiting for anything -- cgit v1.2.1 From 03e570886f430f39020e504aba057a95f2e62484 Mon Sep 17 00:00:00 2001 From: frostydad <64224601+Cyberes@users.noreply.github.com> Date: Sat, 8 Oct 2022 18:13:13 -0600 Subject: Fix incorrect sampler name in output --- modules/processing.py | 9 ++++++++- scripts/xy_grid.py | 16 +++++++++------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 4fea6d56..6b8664a0 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1,3 +1,4 @@ + import json import math import os @@ -46,6 +47,12 @@ def apply_color_correction(correction, image): return image +def get_correct_sampler(p): + if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img): + return sd_samplers.samplers + elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img): + return sd_samplers.samplers_for_img2img + class StableDiffusionProcessing: def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None): self.sd_model = sd_model @@ -272,7 +279,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration generation_params = { "Steps": p.steps, - "Sampler": sd_samplers.samplers[p.sampler_index].name, + "Sampler": get_correct_sampler(p)[p.sampler_index].name, "CFG scale": p.cfg_scale, "Seed": all_seeds[index], "Face restoration": (opts.face_restoration_model if p.restore_faces else None), diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index c0c364df..26ae2199 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -11,7 +11,7 @@ import modules.scripts as scripts import gradio as gr from modules import images -from modules.processing import process_images, Processed +from modules.processing import process_images, Processed, get_correct_sampler from modules.shared import opts, cmd_opts, state import modules.shared as shared import modules.sd_samplers @@ -56,15 +56,17 @@ def apply_order(p, x, xs): p.prompt = prompt_tmp + p.prompt -samplers_dict = {} -for i, sampler in enumerate(modules.sd_samplers.samplers): - samplers_dict[sampler.name.lower()] = i - for alias in sampler.aliases: - samplers_dict[alias.lower()] = i +def build_samplers_dict(p): + samplers_dict = {} + for i, sampler in enumerate(get_correct_sampler(p)): + samplers_dict[sampler.name.lower()] = i + for alias in sampler.aliases: + samplers_dict[alias.lower()] = i + return samplers_dict def apply_sampler(p, x, xs): - sampler_index = samplers_dict.get(x.lower(), None) + sampler_index = build_samplers_dict(p).get(x.lower(), None) if sampler_index is None: raise RuntimeError(f"Unknown sampler: {x}") -- cgit v1.2.1 From ef93acdc731b7a2b3c13651b6de1bce58af989d4 Mon Sep 17 00:00:00 2001 From: frostydad <64224601+Cyberes@users.noreply.github.com> Date: Sat, 8 Oct 2022 18:15:35 -0600 Subject: remove line break --- modules/processing.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index 6b8664a0..7fa1144e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1,4 +1,3 @@ - import json import math import os -- cgit v1.2.1 From 1ffeb42d38d9276dc28918189d32f60d593a162c Mon Sep 17 00:00:00 2001 From: Nicolas Noullet Date: Sun, 9 Oct 2022 00:18:45 +0200 Subject: Fix typo --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared.py b/modules/shared.py index 9dce6cb7..dffa0094 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -238,7 +238,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), options_templates.update(options_section(('ui', "User interface"), { "show_progressbar": OptionInfo(True, "Show progressbar"), - "show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}), + "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}), "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), -- cgit v1.2.1 From e2930f9821c197da94e208b5ae73711002844efc Mon Sep 17 00:00:00 2001 From: Tony Beeman Date: Fri, 7 Oct 2022 17:46:39 -0700 Subject: Fix for Prompts_from_file showing extra textbox. --- modules/scripts.py | 30 ++++++++++++++++++++++++++---- scripts/prompts_from_file.py | 4 ++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/modules/scripts.py b/modules/scripts.py index 45230f9a..d8f87927 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -1,4 +1,5 @@ import os +from pydoc import visiblename import sys import traceback @@ -31,6 +32,15 @@ class Script: def show(self, is_img2img): return True + + # Called when the ui for this script has been shown. + # Useful for hiding some controls, since the scripts module sets visibility to + # everything to true. The parameters will be the parameters returned by the ui method + # The return value should be gradio updates, similar to what you would return + # from a Gradio event handler. + def on_show(self, *args): + return [ui.gr_show(True)] * len(args) + # This is where the additional processing is implemented. The parameters include # self, the model object "p" (a StableDiffusionProcessing class, see # processing.py), and the parameters returned by the ui method. @@ -125,20 +135,32 @@ class ScriptRunner: inputs += controls script.args_to = len(inputs) - def select_script(script_index): + def select_script(*args): + script_index = args[0] + on_show_updates = [] if 0 < script_index <= len(self.scripts): script = self.scripts[script_index-1] args_from = script.args_from args_to = script.args_to + script_args = args[args_from:args_to] + on_show_updates = wrap_call(script.on_show, script.filename, "on_show", *script_args) else: args_from = 0 args_to = 0 - return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))] + ret = [ ui.gr_show(True)] # always show the dropdown + for i in range(1, len(inputs)): + if (args_from <= i < args_to): + ret.append( on_show_updates[i - args_from] ) + else: + ret.append(ui.gr_show(False)) + return ret + + # return [ui.gr_show(True if (i == 0) else on_show_updates[i - args_from] if args_from <= i < args_to else False) for i in range(len(inputs))] dropdown.change( fn=select_script, - inputs=[dropdown], + inputs=inputs, outputs=inputs ) @@ -198,4 +220,4 @@ def reload_scripts(basedir): load_scripts(basedir) scripts_txt2img = ScriptRunner() - scripts_img2img = ScriptRunner() + scripts_img2img = ScriptRunner() \ No newline at end of file diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index 513d9a1c..110889a6 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -10,6 +10,7 @@ from modules.processing import Processed, process_images from PIL import Image from modules.shared import opts, cmd_opts, state +g_txt_mode = False class Script(scripts.Script): def title(self): @@ -29,6 +30,9 @@ class Script(scripts.Script): checkbox_txt.change(fn=lambda x: [gr.File.update(visible = not x), gr.TextArea.update(visible = x)], inputs=[checkbox_txt], outputs=[file, prompt_txt]) return [checkbox_txt, file, prompt_txt] + def on_show(self, checkbox_txt, file, prompt_txt): + return [ gr.Checkbox.update(visible = True), gr.File.update(visible = not checkbox_txt), gr.TextArea.update(visible = checkbox_txt) ] + def run(self, p, checkbox_txt, data: bytes, prompt_txt: str): if (checkbox_txt): lines = [x.strip() for x in prompt_txt.splitlines()] -- cgit v1.2.1 From 86cb16886f8f48169cee4658ad0c5e5443beed2a Mon Sep 17 00:00:00 2001 From: Tony Beeman Date: Fri, 7 Oct 2022 23:51:50 -0700 Subject: Pull Request Code Review Fixes --- modules/scripts.py | 1 - scripts/prompts_from_file.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/modules/scripts.py b/modules/scripts.py index d8f87927..8dfd4de9 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -1,5 +1,4 @@ import os -from pydoc import visiblename import sys import traceback diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index 110889a6..b24f1a80 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -10,8 +10,6 @@ from modules.processing import Processed, process_images from PIL import Image from modules.shared import opts, cmd_opts, state -g_txt_mode = False - class Script(scripts.Script): def title(self): return "Prompts from file or textbox" -- cgit v1.2.1 From cbf6dad02d04d98e5a2d5e870777ab99b5796b2d Mon Sep 17 00:00:00 2001 From: Tony Beeman Date: Sat, 8 Oct 2022 10:40:30 -0700 Subject: Handle case where on_show returns the wrong number of arguments --- modules/scripts.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/modules/scripts.py b/modules/scripts.py index 8dfd4de9..7d89979d 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -143,6 +143,8 @@ class ScriptRunner: args_to = script.args_to script_args = args[args_from:args_to] on_show_updates = wrap_call(script.on_show, script.filename, "on_show", *script_args) + if (len(on_show_updates) != (args_to - args_from)): + print("Error in custom script (" + script.filename + "): on_show() method should return the same number of arguments as ui().", file=sys.stderr) else: args_from = 0 args_to = 0 @@ -150,13 +152,14 @@ class ScriptRunner: ret = [ ui.gr_show(True)] # always show the dropdown for i in range(1, len(inputs)): if (args_from <= i < args_to): - ret.append( on_show_updates[i - args_from] ) + if (i - args_from) < len(on_show_updates): + ret.append( on_show_updates[i - args_from] ) + else: + ret.append(ui.gr_show(True)) else: ret.append(ui.gr_show(False)) return ret - # return [ui.gr_show(True if (i == 0) else on_show_updates[i - args_from] if args_from <= i < args_to else False) for i in range(len(inputs))] - dropdown.change( fn=select_script, inputs=inputs, -- cgit v1.2.1 From ab4fe4f44c3d2675a351269fe2ff1ddeac557aa6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 11:59:41 +0300 Subject: hide filenames for save button by default --- modules/ui.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 8071b1cb..e1ab2665 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -162,7 +162,7 @@ def save_files(js_data, images, do_make_zip, index): zip_file.writestr(filenames[i], f.read()) fullfns.insert(0, zip_filepath) - return fullfns, '', '', plaintext_to_html(f"Saved: {filenames[0]}") + return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}") def wrap_gradio_call(func, extra_outputs=None): @@ -553,7 +553,7 @@ def create_ui(wrap_gradio_gpu_call): do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False) with gr.Row(): - download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False) + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False) with gr.Group(): html_info = gr.HTML() @@ -741,7 +741,7 @@ def create_ui(wrap_gradio_gpu_call): do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False) with gr.Row(): - download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False) + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False) with gr.Group(): html_info = gr.HTML() -- cgit v1.2.1 From 0241d811d23427b99f6b1eda1540bdf8d87963d5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 12:04:44 +0300 Subject: Revert "Fix for Prompts_from_file showing extra textbox." This reverts commit e2930f9821c197da94e208b5ae73711002844efc. --- modules/scripts.py | 32 ++++---------------------------- 1 file changed, 4 insertions(+), 28 deletions(-) diff --git a/modules/scripts.py b/modules/scripts.py index 7d89979d..45230f9a 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -31,15 +31,6 @@ class Script: def show(self, is_img2img): return True - - # Called when the ui for this script has been shown. - # Useful for hiding some controls, since the scripts module sets visibility to - # everything to true. The parameters will be the parameters returned by the ui method - # The return value should be gradio updates, similar to what you would return - # from a Gradio event handler. - def on_show(self, *args): - return [ui.gr_show(True)] * len(args) - # This is where the additional processing is implemented. The parameters include # self, the model object "p" (a StableDiffusionProcessing class, see # processing.py), and the parameters returned by the ui method. @@ -134,35 +125,20 @@ class ScriptRunner: inputs += controls script.args_to = len(inputs) - def select_script(*args): - script_index = args[0] - on_show_updates = [] + def select_script(script_index): if 0 < script_index <= len(self.scripts): script = self.scripts[script_index-1] args_from = script.args_from args_to = script.args_to - script_args = args[args_from:args_to] - on_show_updates = wrap_call(script.on_show, script.filename, "on_show", *script_args) - if (len(on_show_updates) != (args_to - args_from)): - print("Error in custom script (" + script.filename + "): on_show() method should return the same number of arguments as ui().", file=sys.stderr) else: args_from = 0 args_to = 0 - ret = [ ui.gr_show(True)] # always show the dropdown - for i in range(1, len(inputs)): - if (args_from <= i < args_to): - if (i - args_from) < len(on_show_updates): - ret.append( on_show_updates[i - args_from] ) - else: - ret.append(ui.gr_show(True)) - else: - ret.append(ui.gr_show(False)) - return ret + return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))] dropdown.change( fn=select_script, - inputs=inputs, + inputs=[dropdown], outputs=inputs ) @@ -222,4 +198,4 @@ def reload_scripts(basedir): load_scripts(basedir) scripts_txt2img = ScriptRunner() - scripts_img2img = ScriptRunner() \ No newline at end of file + scripts_img2img = ScriptRunner() -- cgit v1.2.1 From 6f6798ddabe10d320fe8ea05edf0fdcef0c51a8e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 12:33:37 +0300 Subject: prevent a possible code execution error (thanks, RyotaK) --- modules/ui.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/modules/ui.py b/modules/ui.py index e1ab2665..dad509f3 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1153,6 +1153,15 @@ def create_ui(wrap_gradio_gpu_call): component_dict = {} def open_folder(f): + if not os.path.isdir(f): + print(f""" +WARNING +An open_folder request was made with an argument that is not a folder. +This could be an error or a malicious attempt to run code on your computer. +Requested path was: {f} +""", file=sys.stderr) + return + if not shared.cmd_opts.hide_ui_dir_config: path = os.path.normpath(f) if platform.system() == "Windows": -- cgit v1.2.1 From d74c38108f95e44d83a1706ee5ab218124972868 Mon Sep 17 00:00:00 2001 From: Jesse Williams <33797815+xram64@users.noreply.github.com> Date: Sat, 8 Oct 2022 01:30:49 -0400 Subject: Confirm that options are valid before starting When using the 'Sampler' or 'Checkpoint' options, if one of the entered names has a typo, an error will only be thrown once the `draw_xy_grid` loop reaches that name. This can waste a lot of time for large grids with a typo near the end of a list, since the script needs to start over and re-generate any earlier images to finish making the grid. Also fixing typo in variable name in `draw_xy_grid`. --- scripts/xy_grid.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 26ae2199..07040886 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -145,7 +145,7 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend): ver_texts = [[images.GridAnnotation(y)] for y in y_labels] hor_texts = [[images.GridAnnotation(x)] for x in x_labels] - first_pocessed = None + first_processed = None state.job_count = len(xs) * len(ys) * p.n_iter @@ -154,8 +154,8 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend): state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" processed = cell(x, y) - if first_pocessed is None: - first_pocessed = processed + if first_processed is None: + first_processed = processed try: res.append(processed.images[0]) @@ -166,9 +166,9 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend): if draw_legend: grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts) - first_pocessed.images = [grid] + first_processed.images = [grid] - return first_pocessed + return first_processed re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") @@ -216,7 +216,6 @@ class Script(scripts.Script): m = re_range.fullmatch(val) mc = re_range_count.fullmatch(val) if m is not None: - start = int(m.group(1)) end = int(m.group(2))+1 step = int(m.group(3)) if m.group(3) is not None else 1 @@ -258,6 +257,16 @@ class Script(scripts.Script): valslist = list(permutations(valslist)) valslist = [opt.type(x) for x in valslist] + + # Confirm options are valid before starting + if opt.label == "Sampler": + for sampler_val in valslist: + if sampler_val.lower() not in samplers_dict.keys(): + raise RuntimeError(f"Unknown sampler: {sampler_val}") + elif opt.label == "Checkpoint name": + for ckpt_val in valslist: + if modules.sd_models.get_closet_checkpoint_match(ckpt_val) is None: + raise RuntimeError(f"Checkpoint for {ckpt_val} not found") return valslist -- cgit v1.2.1 From a65a45272e8f26ee3bc52a5300b396266508a9a5 Mon Sep 17 00:00:00 2001 From: Brendan Byrd Date: Thu, 6 Oct 2022 19:31:36 -0400 Subject: Don't change the seed initially if "Keep -1 for seeds" is checked Fixes #1049 --- scripts/xy_grid.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 07040886..a8f53bef 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -198,7 +198,9 @@ class Script(scripts.Script): return [x_type, x_values, y_type, y_values, draw_legend, no_fixed_seeds] def run(self, p, x_type, x_values, y_type, y_values, draw_legend, no_fixed_seeds): - modules.processing.fix_seed(p) + if not no_fixed_seeds: + modules.processing.fix_seed(p) + p.batch_size = 1 initial_hn = opts.sd_hypernetwork -- cgit v1.2.1 From 0609ce06c0778536cb368ac3867292f87c6d9fc7 Mon Sep 17 00:00:00 2001 From: Milly Date: Fri, 7 Oct 2022 03:36:08 +0900 Subject: Removed duplicate definition model_path --- modules/bsrgan_model.py | 2 -- modules/esrgan_model.py | 2 -- modules/ldsr_model.py | 2 -- modules/realesrgan_model.py | 2 -- modules/scunet_model.py | 2 -- modules/swinir_model.py | 2 -- modules/upscaler.py | 7 ++++--- 7 files changed, 4 insertions(+), 15 deletions(-) diff --git a/modules/bsrgan_model.py b/modules/bsrgan_model.py index 3bd80791..737e1a76 100644 --- a/modules/bsrgan_model.py +++ b/modules/bsrgan_model.py @@ -10,13 +10,11 @@ from basicsr.utils.download_util import load_file_from_url import modules.upscaler from modules import devices, modelloader from modules.bsrgan_model_arch import RRDBNet -from modules.paths import models_path class UpscalerBSRGAN(modules.upscaler.Upscaler): def __init__(self, dirname): self.name = "BSRGAN" - self.model_path = os.path.join(models_path, self.name) self.model_name = "BSRGAN 4x" self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth" self.user_path = dirname diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 28548124..3970e6e4 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -7,7 +7,6 @@ from basicsr.utils.download_util import load_file_from_url import modules.esrgam_model_arch as arch from modules import shared, modelloader, images, devices -from modules.paths import models_path from modules.upscaler import Upscaler, UpscalerData from modules.shared import opts @@ -76,7 +75,6 @@ class UpscalerESRGAN(Upscaler): self.model_name = "ESRGAN_4x" self.scalers = [] self.user_path = dirname - self.model_path = os.path.join(models_path, self.name) super().__init__() model_paths = self.find_models(ext_filter=[".pt", ".pth"]) scalers = [] diff --git a/modules/ldsr_model.py b/modules/ldsr_model.py index 1c1070fc..8c4db44a 100644 --- a/modules/ldsr_model.py +++ b/modules/ldsr_model.py @@ -7,13 +7,11 @@ from basicsr.utils.download_util import load_file_from_url from modules.upscaler import Upscaler, UpscalerData from modules.ldsr_model_arch import LDSR from modules import shared -from modules.paths import models_path class UpscalerLDSR(Upscaler): def __init__(self, user_path): self.name = "LDSR" - self.model_path = os.path.join(models_path, self.name) self.user_path = user_path self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1" self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1" diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index dc0123e0..3ac0b97a 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -8,14 +8,12 @@ from basicsr.utils.download_util import load_file_from_url from realesrgan import RealESRGANer from modules.upscaler import Upscaler, UpscalerData -from modules.paths import models_path from modules.shared import cmd_opts, opts class UpscalerRealESRGAN(Upscaler): def __init__(self, path): self.name = "RealESRGAN" - self.model_path = os.path.join(models_path, self.name) self.user_path = path super().__init__() try: diff --git a/modules/scunet_model.py b/modules/scunet_model.py index fb64b740..36a996bf 100644 --- a/modules/scunet_model.py +++ b/modules/scunet_model.py @@ -9,14 +9,12 @@ from basicsr.utils.download_util import load_file_from_url import modules.upscaler from modules import devices, modelloader -from modules.paths import models_path from modules.scunet_model_arch import SCUNet as net class UpscalerScuNET(modules.upscaler.Upscaler): def __init__(self, dirname): self.name = "ScuNET" - self.model_path = os.path.join(models_path, self.name) self.model_name = "ScuNET GAN" self.model_name2 = "ScuNET PSNR" self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth" diff --git a/modules/swinir_model.py b/modules/swinir_model.py index 9bd454c6..fbd11f84 100644 --- a/modules/swinir_model.py +++ b/modules/swinir_model.py @@ -8,7 +8,6 @@ from basicsr.utils.download_util import load_file_from_url from tqdm import tqdm from modules import modelloader -from modules.paths import models_path from modules.shared import cmd_opts, opts, device from modules.swinir_model_arch import SwinIR as net from modules.upscaler import Upscaler, UpscalerData @@ -25,7 +24,6 @@ class UpscalerSwinIR(Upscaler): "/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \ "-L_x4_GAN.pth " self.model_name = "SwinIR 4x" - self.model_path = os.path.join(models_path, self.name) self.user_path = dirname super().__init__() scalers = [] diff --git a/modules/upscaler.py b/modules/upscaler.py index d9d7c5e2..34672be7 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -36,10 +36,11 @@ class Upscaler: self.half = not modules.shared.cmd_opts.no_half self.pre_pad = 0 self.mod_scale = None - if self.name is not None and create_dirs: + + if self.model_path is not None and self.name: self.model_path = os.path.join(models_path, self.name) - if not os.path.exists(self.model_path): - os.makedirs(self.model_path) + if self.model_path and create_dirs: + os.makedirs(self.model_path, exist_ok=True) try: import cv2 -- cgit v1.2.1 From bd833409ac7b8337040d521f6b65ced51e1b2ea8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 13:10:15 +0300 Subject: additional changes for saving pnginfo for #1803 --- modules/extras.py | 4 ++++ modules/processing.py | 6 ++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/modules/extras.py b/modules/extras.py index ef6e6de7..39dd3806 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -98,6 +98,10 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=image_name if opts.use_original_name_batch else None) + if opts.enable_pnginfo: + image.info = existing_pnginfo + image.info["extras"] = info + outputs.append(image) devices.torch_gc() diff --git a/modules/processing.py b/modules/processing.py index 7fa1144e..2c991317 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -451,7 +451,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: text = infotext(n, i) infotexts.append(text) - image.info["parameters"] = text + if opts.enable_pnginfo: + image.info["parameters"] = text output_images.append(image) del x_samples_ddim @@ -470,7 +471,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if opts.return_grid: text = infotext() infotexts.insert(0, text) - grid.info["parameters"] = text + if opts.enable_pnginfo: + grid.info["parameters"] = text output_images.insert(0, grid) index_of_first_image = 1 -- cgit v1.2.1 From f4578b343ded3b8ccd1879ea0c0b3cdadfcc3a5f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 13:23:30 +0300 Subject: fix model switching not working properly if there is a different yaml config --- modules/sd_models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 2101b18d..d0c74dd8 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -196,7 +196,8 @@ def reload_model_weights(sd_model, info=None): return if sd_model.sd_checkpoint_info.config != checkpoint_info.config: - return load_model() + shared.sd_model = load_model() + return shared.sd_model if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.send_everything_to_cpu() -- cgit v1.2.1 From 77a719648db515f10136e8b8483d5b16bda2eaeb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 13:48:04 +0300 Subject: fix logic error in #1832 --- modules/upscaler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/upscaler.py b/modules/upscaler.py index 34672be7..6ab2fb40 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -37,7 +37,7 @@ class Upscaler: self.pre_pad = 0 self.mod_scale = None - if self.model_path is not None and self.name: + if self.model_path is None and self.name: self.model_path = os.path.join(models_path, self.name) if self.model_path and create_dirs: os.makedirs(self.model_path, exist_ok=True) -- cgit v1.2.1 From 542a3d3a4a00c1383fbdaf938ceefef87cf834bb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 14:33:22 +0300 Subject: fix btoken hypernetworks in XY plot --- modules/hypernetwork.py | 7 +++++-- scripts/xy_grid.py | 9 +++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py index 19f1c227..498bc9d8 100644 --- a/modules/hypernetwork.py +++ b/modules/hypernetwork.py @@ -49,15 +49,18 @@ def list_hypernetworks(path): def load_hypernetwork(filename): - print(f"Loading hypernetwork {filename}") path = shared.hypernetworks.get(filename, None) - if (path is not None): + if path is not None: + print(f"Loading hypernetwork {filename}") try: shared.loaded_hypernetwork = Hypernetwork(path) except Exception: print(f"Error loading hypernetwork {path}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) else: + if shared.loaded_hypernetwork is not None: + print(f"Unloading hypernetwork") + shared.loaded_hypernetwork = None diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index a8f53bef..fe949067 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -10,7 +10,7 @@ import numpy as np import modules.scripts as scripts import gradio as gr -from modules import images +from modules import images, hypernetwork from modules.processing import process_images, Processed, get_correct_sampler from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -80,8 +80,7 @@ def apply_checkpoint(p, x, xs): def apply_hypernetwork(p, x, xs): - hn = shared.hypernetworks.get(x, None) - opts.data["sd_hypernetwork"] = hn.name if hn is not None else 'None' + hypernetwork.load_hypernetwork(x) def format_value_add_label(p, opt, x): @@ -203,8 +202,6 @@ class Script(scripts.Script): p.batch_size = 1 - initial_hn = opts.sd_hypernetwork - def process_axis(opt, vals): if opt.label == 'Nothing': return [0] @@ -321,6 +318,6 @@ class Script(scripts.Script): # restore checkpoint in case it was changed by axes modules.sd_models.reload_model_weights(shared.sd_model) - opts.data["sd_hypernetwork"] = initial_hn + hypernetwork.load_hypernetwork(opts.sd_hypernetwork) return processed -- cgit v1.2.1 From d6d10a37bfd21568e74efb46137f906da96d5fdb Mon Sep 17 00:00:00 2001 From: William Moorehouse Date: Sun, 9 Oct 2022 04:58:40 -0400 Subject: Added extended model details to infotext --- modules/processing.py | 3 +++ modules/sd_models.py | 3 ++- modules/shared.py | 1 + 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index 2c991317..d1bcee4a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -284,6 +284,9 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Size": f"{p.width}x{p.height}", "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), + "Model": (None if not opts.add_extended_model_details_to_info or not shared.sd_model.sd_model_name else shared.sd_model.sd_model_name), + "Model VAE": (None if not opts.add_extended_model_details_to_info or not shared.sd_model.sd_model_vae_name else shared.sd_model.sd_model_vae_name), + "Model hypernetwork": (None if not opts.add_extended_model_details_to_info or not opts.sd_hypernetwork else opts.sd_hypernetwork), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), diff --git a/modules/sd_models.py b/modules/sd_models.py index d0c74dd8..3fa42329 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -4,7 +4,7 @@ import sys from collections import namedtuple import torch from omegaconf import OmegaConf - +from pathlib import Path from ldm.util import instantiate_from_config @@ -158,6 +158,7 @@ def load_model_weights(model, checkpoint_info): vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} model.first_stage_model.load_state_dict(vae_dict) + model.sd_model_vae_name = Path(vae_file).stem model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_file diff --git a/modules/shared.py b/modules/shared.py index dffa0094..ca63f7d8 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -242,6 +242,7 @@ options_templates.update(options_section(('ui', "User interface"), { "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), + "add_extended_model_details_to_info": OptionInfo(False, "Add extended model details to generation information (model name, VAE, hypernetwork)"), "font": OptionInfo("", "Font for image grids that have text"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), -- cgit v1.2.1 From 006791c13d70e582eee766b7d0499e9821a86bf9 Mon Sep 17 00:00:00 2001 From: William Moorehouse Date: Sun, 9 Oct 2022 05:09:18 -0400 Subject: Fix grabbing the model name for infotext --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index d1bcee4a..c035c990 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -284,7 +284,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Size": f"{p.width}x{p.height}", "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), - "Model": (None if not opts.add_extended_model_details_to_info or not shared.sd_model.sd_model_name else shared.sd_model.sd_model_name), + "Model": (None if not opts.add_extended_model_details_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name), "Model VAE": (None if not opts.add_extended_model_details_to_info or not shared.sd_model.sd_model_vae_name else shared.sd_model.sd_model_vae_name), "Model hypernetwork": (None if not opts.add_extended_model_details_to_info or not opts.sd_hypernetwork else opts.sd_hypernetwork), "Batch size": (None if p.batch_size < 2 else p.batch_size), -- cgit v1.2.1 From 594cbfd8fbe4078b43ceccf01509eeef3d6790c6 Mon Sep 17 00:00:00 2001 From: William Moorehouse Date: Sun, 9 Oct 2022 07:27:11 -0400 Subject: Sanitize infotext output (for now) --- modules/processing.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index c035c990..049f3769 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -284,9 +284,9 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Size": f"{p.width}x{p.height}", "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), - "Model": (None if not opts.add_extended_model_details_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name), - "Model VAE": (None if not opts.add_extended_model_details_to_info or not shared.sd_model.sd_model_vae_name else shared.sd_model.sd_model_vae_name), - "Model hypernetwork": (None if not opts.add_extended_model_details_to_info or not opts.sd_hypernetwork else opts.sd_hypernetwork), + "Model": (None if not opts.add_extended_model_details_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), + "Model VAE": (None if not opts.add_extended_model_details_to_info or not shared.sd_model.sd_model_vae_name else shared.sd_model.sd_model_vae_name.replace(',', '').replace(':', '')), + "Model hypernetwork": (None if not opts.add_extended_model_details_to_info or not opts.sd_hypernetwork else opts.sd_hypernetwork.replace(',', '').replace(':', '')), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), -- cgit v1.2.1 From e6e8cabe0c9c335e0d72345602c069b198558b53 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 14:57:48 +0300 Subject: change up #2056 to make it work how i want it to plus make xy plot write correct values to images --- modules/processing.py | 5 ++--- modules/sd_models.py | 2 -- modules/shared.py | 2 +- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 049f3769..04aed989 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -284,9 +284,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Size": f"{p.width}x{p.height}", "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), - "Model": (None if not opts.add_extended_model_details_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), - "Model VAE": (None if not opts.add_extended_model_details_to_info or not shared.sd_model.sd_model_vae_name else shared.sd_model.sd_model_vae_name.replace(',', '').replace(':', '')), - "Model hypernetwork": (None if not opts.add_extended_model_details_to_info or not opts.sd_hypernetwork else opts.sd_hypernetwork.replace(',', '').replace(':', '')), + "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), + "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(':', '')), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), diff --git a/modules/sd_models.py b/modules/sd_models.py index 3fa42329..e63d3c29 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -4,7 +4,6 @@ import sys from collections import namedtuple import torch from omegaconf import OmegaConf -from pathlib import Path from ldm.util import instantiate_from_config @@ -158,7 +157,6 @@ def load_model_weights(model, checkpoint_info): vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} model.first_stage_model.load_state_dict(vae_dict) - model.sd_model_vae_name = Path(vae_file).stem model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_file diff --git a/modules/shared.py b/modules/shared.py index ca63f7d8..6ecc2503 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -242,7 +242,7 @@ options_templates.update(options_section(('ui', "User interface"), { "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), - "add_extended_model_details_to_info": OptionInfo(False, "Add extended model details to generation information (model name, VAE, hypernetwork)"), + "add_model_name_to_info": OptionInfo(False, "Add model name to generation information"), "font": OptionInfo("", "Font for image grids that have text"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), -- cgit v1.2.1 From 2c52f4da7ff80a3ec277105f4db6146c6379898a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 15:01:42 +0300 Subject: fix broken samplers in XY plot --- scripts/xy_grid.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index fe949067..c89ca1a9 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -259,6 +259,7 @@ class Script(scripts.Script): # Confirm options are valid before starting if opt.label == "Sampler": + samplers_dict = build_samplers_dict(p) for sampler_val in valslist: if sampler_val.lower() not in samplers_dict.keys(): raise RuntimeError(f"Unknown sampler: {sampler_val}") -- cgit v1.2.1 From 9d1138e2940c4ddcd2685bcba12c7d407e9e0ec5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 15:08:10 +0300 Subject: fix typo in filename for ESRGAN arch --- modules/esrgam_model_arch.py | 80 -------------------------------------------- modules/esrgan_model.py | 2 +- modules/esrgan_model_arch.py | 80 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 81 deletions(-) delete mode 100644 modules/esrgam_model_arch.py create mode 100644 modules/esrgan_model_arch.py diff --git a/modules/esrgam_model_arch.py b/modules/esrgam_model_arch.py deleted file mode 100644 index e413d36e..00000000 --- a/modules/esrgam_model_arch.py +++ /dev/null @@ -1,80 +0,0 @@ -# this file is taken from https://github.com/xinntao/ESRGAN - -import functools -import torch -import torch.nn as nn -import torch.nn.functional as F - - -def make_layer(block, n_layers): - layers = [] - for _ in range(n_layers): - layers.append(block()) - return nn.Sequential(*layers) - - -class ResidualDenseBlock_5C(nn.Module): - def __init__(self, nf=64, gc=32, bias=True): - super(ResidualDenseBlock_5C, self).__init__() - # gc: growth channel, i.e. intermediate channels - self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) - self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) - self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) - self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) - self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - - # initialization - # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) - - def forward(self, x): - x1 = self.lrelu(self.conv1(x)) - x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) - x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) - x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) - x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) - return x5 * 0.2 + x - - -class RRDB(nn.Module): - '''Residual in Residual Dense Block''' - - def __init__(self, nf, gc=32): - super(RRDB, self).__init__() - self.RDB1 = ResidualDenseBlock_5C(nf, gc) - self.RDB2 = ResidualDenseBlock_5C(nf, gc) - self.RDB3 = ResidualDenseBlock_5C(nf, gc) - - def forward(self, x): - out = self.RDB1(x) - out = self.RDB2(out) - out = self.RDB3(out) - return out * 0.2 + x - - -class RRDBNet(nn.Module): - def __init__(self, in_nc, out_nc, nf, nb, gc=32): - super(RRDBNet, self).__init__() - RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) - - self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) - self.RRDB_trunk = make_layer(RRDB_block_f, nb) - self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - #### upsampling - self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) - - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - - def forward(self, x): - fea = self.conv_first(x) - trunk = self.trunk_conv(self.RRDB_trunk(fea)) - fea = fea + trunk - - fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) - fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) - out = self.conv_last(self.lrelu(self.HRconv(fea))) - - return out diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 3970e6e4..46ad0da3 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -5,7 +5,7 @@ import torch from PIL import Image from basicsr.utils.download_util import load_file_from_url -import modules.esrgam_model_arch as arch +import modules.esrgan_model_arch as arch from modules import shared, modelloader, images, devices from modules.upscaler import Upscaler, UpscalerData from modules.shared import opts diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py new file mode 100644 index 00000000..e413d36e --- /dev/null +++ b/modules/esrgan_model_arch.py @@ -0,0 +1,80 @@ +# this file is taken from https://github.com/xinntao/ESRGAN + +import functools +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def make_layer(block, n_layers): + layers = [] + for _ in range(n_layers): + layers.append(block()) + return nn.Sequential(*layers) + + +class ResidualDenseBlock_5C(nn.Module): + def __init__(self, nf=64, gc=32, bias=True): + super(ResidualDenseBlock_5C, self).__init__() + # gc: growth channel, i.e. intermediate channels + self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) + self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) + self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) + self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) + self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + return x5 * 0.2 + x + + +class RRDB(nn.Module): + '''Residual in Residual Dense Block''' + + def __init__(self, nf, gc=32): + super(RRDB, self).__init__() + self.RDB1 = ResidualDenseBlock_5C(nf, gc) + self.RDB2 = ResidualDenseBlock_5C(nf, gc) + self.RDB3 = ResidualDenseBlock_5C(nf, gc) + + def forward(self, x): + out = self.RDB1(x) + out = self.RDB2(out) + out = self.RDB3(out) + return out * 0.2 + x + + +class RRDBNet(nn.Module): + def __init__(self, in_nc, out_nc, nf, nb, gc=32): + super(RRDBNet, self).__init__() + RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) + + self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + self.RRDB_trunk = make_layer(RRDB_block_f, nb) + self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + #### upsampling + self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + fea = self.conv_first(x) + trunk = self.trunk_conv(self.RRDB_trunk(fea)) + fea = fea + trunk + + fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) + fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) + out = self.conv_last(self.lrelu(self.HRconv(fea))) + + return out -- cgit v1.2.1 From f8197976ef5f0523faffb2b237e9166fb2bedecd Mon Sep 17 00:00:00 2001 From: Greendayle Date: Sun, 9 Oct 2022 13:44:13 +0200 Subject: Shielded launch enviroment creation stuff from multiprocessing --- launch.py | 174 ++++++++++++++++++++++++++++++-------------------------------- 1 file changed, 85 insertions(+), 89 deletions(-) diff --git a/launch.py b/launch.py index b0a59b6a..d1a4fd6a 100644 --- a/launch.py +++ b/launch.py @@ -6,40 +6,11 @@ import importlib.util import shlex import platform -dir_repos = "repositories" -dir_tmp = "tmp" - -python = sys.executable -git = os.environ.get('GIT', "git") -torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113") -requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") -commandline_args = os.environ.get('COMMANDLINE_ARGS', "") - -gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379") -clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1") - -stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc") -taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6") -k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "f4e99857772fc3a126ba886aadf795a332774878") -codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") -blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") - -args = shlex.split(commandline_args) - def extract_arg(args, name): return [x for x in args if x != name], name in args -args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test') -xformers = '--xformers' in args -deepdanbooru = '--deepdanbooru' in args - - -def repo_dir(name): - return os.path.join(dir_repos, name) - - def run(command, desc=None, errdesc=None): if desc is not None: print(desc) @@ -59,23 +30,11 @@ stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.st return result.stdout.decode(encoding="utf8", errors="ignore") -def run_python(code, desc=None, errdesc=None): - return run(f'"{python}" -c "{code}"', desc, errdesc) - - -def run_pip(args, desc=None): - return run(f'"{python}" -m pip {args} --prefer-binary', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}") - - def check_run(command): result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) return result.returncode == 0 -def check_run_python(code): - return check_run(f'"{python}" -c "{code}"') - - def is_installed(package): try: spec = importlib.util.find_spec(package) @@ -85,80 +44,117 @@ def is_installed(package): return spec is not None -def git_clone(url, dir, name, commithash=None): - # TODO clone into temporary dir and move if successful +def prepare_enviroment(): + dir_repos = "repositories" - if os.path.exists(dir): - if commithash is None: - return + python = sys.executable + git = os.environ.get('GIT', "git") + torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113") + requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") + commandline_args = os.environ.get('COMMANDLINE_ARGS', "") - current_hash = run(f'"{git}" -C {dir} rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip() - if current_hash == commithash: - return + gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379") + clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1") + + stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc") + taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6") + k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "f4e99857772fc3a126ba886aadf795a332774878") + codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") + blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") + + args = shlex.split(commandline_args) + + args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test') + xformers = '--xformers' in args + deepdanbooru = '--deepdanbooru' in args + + def repo_dir(name): + return os.path.join(dir_repos, name) + + def run_python(code, desc=None, errdesc=None): + return run(f'"{python}" -c "{code}"', desc, errdesc) - run(f'"{git}" -C {dir} fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}") - run(f'"{git}" -C {dir} checkout {commithash}', f"Checking out commint for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}") - return + def run_pip(args, desc=None): + return run(f'"{python}" -m pip {args} --prefer-binary', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}") - run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}") + def check_run_python(code): + return check_run(f'"{python}" -c "{code}"') - if commithash is not None: - run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}") + def git_clone(url, dir, name, commithash=None): + # TODO clone into temporary dir and move if successful + if os.path.exists(dir): + if commithash is None: + return -try: - commit = run(f"{git} rev-parse HEAD").strip() -except Exception: - commit = "" + current_hash = run(f'"{git}" -C {dir} rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip() + if current_hash == commithash: + return -print(f"Python {sys.version}") -print(f"Commit hash: {commit}") + run(f'"{git}" -C {dir} fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}") + run(f'"{git}" -C {dir} checkout {commithash}', f"Checking out commint for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}") + return + + run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}") + if commithash is not None: + run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}") + + try: + commit = run(f"{git} rev-parse HEAD").strip() + except Exception: + commit = "" -if not is_installed("torch") or not is_installed("torchvision"): - run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch") + print(f"Python {sys.version}") + print(f"Commit hash: {commit}") -if not skip_torch_cuda_test: - run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'") + if not is_installed("torch") or not is_installed("torchvision"): + run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch") -if not is_installed("gfpgan"): - run_pip(f"install {gfpgan_package}", "gfpgan") + if not skip_torch_cuda_test: + run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'") -if not is_installed("clip"): - run_pip(f"install {clip_package}", "clip") + if not is_installed("gfpgan"): + run_pip(f"install {gfpgan_package}", "gfpgan") -if not is_installed("xformers") and xformers and platform.python_version().startswith("3.10"): - if platform.system() == "Windows": - run_pip("install https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/a/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers") - elif platform.system() == "Linux": - run_pip("install xformers", "xformers") + if not is_installed("clip"): + run_pip(f"install {clip_package}", "clip") -if not is_installed("deepdanbooru") and deepdanbooru: - run_pip("install git+https://github.com/KichangKim/DeepDanbooru.git@edf73df4cdaeea2cf00e9ac08bd8a9026b7a7b26#egg=deepdanbooru[tensorflow] tensorflow==2.10.0 tensorflow-io==0.27.0", "deepdanbooru") + if not is_installed("xformers") and xformers and platform.python_version().startswith("3.10"): + if platform.system() == "Windows": + run_pip("install https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/a/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers") + elif platform.system() == "Linux": + run_pip("install xformers", "xformers") -os.makedirs(dir_repos, exist_ok=True) + if not is_installed("deepdanbooru") and deepdanbooru: + run_pip("install git+https://github.com/KichangKim/DeepDanbooru.git@edf73df4cdaeea2cf00e9ac08bd8a9026b7a7b26#egg=deepdanbooru[tensorflow] tensorflow==2.10.0 tensorflow-io==0.27.0", "deepdanbooru") -git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash) -git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash) -git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) -git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) -git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash) + os.makedirs(dir_repos, exist_ok=True) -if not is_installed("lpips"): - run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer") + git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash) + git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash) + git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) + git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) + git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash) -run_pip(f"install -r {requirements_file}", "requirements for Web UI") + if not is_installed("lpips"): + run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer") -sys.argv += args + run_pip(f"install -r {requirements_file}", "requirements for Web UI") + + sys.argv += args + + if "--exit" in args: + print("Exiting because of --exit argument") + exit(0) -if "--exit" in args: - print("Exiting because of --exit argument") - exit(0) def start_webui(): print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}") import webui webui.webui() + if __name__ == "__main__": + prepare_enviroment() start_webui() -- cgit v1.2.1 From bba2ac8324ccd1a67c78e5f59babae8323ec7dc6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 15:22:51 +0300 Subject: reshuffle the code a bit in launcher to keep functions in one place for #2069 --- launch.py | 77 ++++++++++++++++++++++++++++++++++----------------------------- 1 file changed, 41 insertions(+), 36 deletions(-) diff --git a/launch.py b/launch.py index d1a4fd6a..f42f557d 100644 --- a/launch.py +++ b/launch.py @@ -6,6 +6,10 @@ import importlib.util import shlex import platform +dir_repos = "repositories" +python = sys.executable +git = os.environ.get('GIT', "git") + def extract_arg(args, name): return [x for x in args if x != name], name in args @@ -44,11 +48,44 @@ def is_installed(package): return spec is not None -def prepare_enviroment(): - dir_repos = "repositories" +def repo_dir(name): + return os.path.join(dir_repos, name) + + +def run_python(code, desc=None, errdesc=None): + return run(f'"{python}" -c "{code}"', desc, errdesc) + + +def run_pip(args, desc=None): + return run(f'"{python}" -m pip {args} --prefer-binary', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}") + + +def check_run_python(code): + return check_run(f'"{python}" -c "{code}"') + + +def git_clone(url, dir, name, commithash=None): + # TODO clone into temporary dir and move if successful + + if os.path.exists(dir): + if commithash is None: + return + + current_hash = run(f'"{git}" -C {dir} rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip() + if current_hash == commithash: + return + + run(f'"{git}" -C {dir} fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}") + run(f'"{git}" -C {dir} checkout {commithash}', f"Checking out commint for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}") + return - python = sys.executable - git = os.environ.get('GIT', "git") + run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}") + + if commithash is not None: + run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}") + + +def prepare_enviroment(): torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") commandline_args = os.environ.get('COMMANDLINE_ARGS', "") @@ -68,38 +105,6 @@ def prepare_enviroment(): xformers = '--xformers' in args deepdanbooru = '--deepdanbooru' in args - def repo_dir(name): - return os.path.join(dir_repos, name) - - def run_python(code, desc=None, errdesc=None): - return run(f'"{python}" -c "{code}"', desc, errdesc) - - def run_pip(args, desc=None): - return run(f'"{python}" -m pip {args} --prefer-binary', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}") - - def check_run_python(code): - return check_run(f'"{python}" -c "{code}"') - - def git_clone(url, dir, name, commithash=None): - # TODO clone into temporary dir and move if successful - - if os.path.exists(dir): - if commithash is None: - return - - current_hash = run(f'"{git}" -C {dir} rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip() - if current_hash == commithash: - return - - run(f'"{git}" -C {dir} fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}") - run(f'"{git}" -C {dir} checkout {commithash}', f"Checking out commint for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}") - return - - run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}") - - if commithash is not None: - run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}") - try: commit = run(f"{git} rev-parse HEAD").strip() except Exception: -- cgit v1.2.1 From 875ddfeecfaffad9eee24813301637cba310337d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 17:58:43 +0300 Subject: added guard for torch.load to prevent loading pickles with unknown content --- modules/paths.py | 1 + modules/safe.py | 89 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ modules/shared.py | 1 + 3 files changed, 91 insertions(+) create mode 100644 modules/safe.py diff --git a/modules/paths.py b/modules/paths.py index 0519caa0..1e7a2fbc 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -1,6 +1,7 @@ import argparse import os import sys +import modules.safe script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) models_path = os.path.join(script_path, "models") diff --git a/modules/safe.py b/modules/safe.py new file mode 100644 index 00000000..2d2c1371 --- /dev/null +++ b/modules/safe.py @@ -0,0 +1,89 @@ +# this code is adapted from the script contributed by anon from /h/ + +import io +import pickle +import collections +import sys +import traceback + +import torch +import numpy +import _codecs +import zipfile + + +def encode(*args): + out = _codecs.encode(*args) + return out + + +class RestrictedUnpickler(pickle.Unpickler): + def persistent_load(self, saved_id): + assert saved_id[0] == 'storage' + return torch.storage._TypedStorage() + + def find_class(self, module, name): + if module == 'collections' and name == 'OrderedDict': + return getattr(collections, name) + if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']: + return getattr(torch._utils, name) + if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage']: + return getattr(torch, name) + if module == 'torch.nn.modules.container' and name in ['ParameterDict']: + return getattr(torch.nn.modules.container, name) + if module == 'numpy.core.multiarray' and name == 'scalar': + return numpy.core.multiarray.scalar + if module == 'numpy' and name == 'dtype': + return numpy.dtype + if module == '_codecs' and name == 'encode': + return encode + if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint': + import pytorch_lightning.callbacks + return pytorch_lightning.callbacks.model_checkpoint + if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint': + import pytorch_lightning.callbacks.model_checkpoint + return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint + if module == "__builtin__" and name == 'set': + return set + + # Forbid everything else. + raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden") + + +def check_pt(filename): + try: + + # new pytorch format is a zip file + with zipfile.ZipFile(filename) as z: + with z.open('archive/data.pkl') as file: + unpickler = RestrictedUnpickler(file) + unpickler.load() + + except zipfile.BadZipfile: + + # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle + with open(filename, "rb") as file: + unpickler = RestrictedUnpickler(file) + for i in range(5): + unpickler.load() + + +def load(filename, *args, **kwargs): + from modules import shared + + try: + if not shared.cmd_opts.disable_safe_unpickle: + check_pt(filename) + + except Exception: + print(f"Error verifying pickled file from {filename}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr) + print(f"You can skip this check with --disable-safe-unpickle commandline argument.", file=sys.stderr) + return None + + return unsafe_torch_load(filename, *args, **kwargs) + + +unsafe_torch_load = torch.load +torch.load = load diff --git a/modules/shared.py b/modules/shared.py index 6ecc2503..3d7f08e1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -65,6 +65,7 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) +parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) cmd_opts = parser.parse_args() -- cgit v1.2.1 From d3cd46b0388918128af203fda37fa63461c46611 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 9 Oct 2022 16:19:33 +0100 Subject: Update lightbox to change displayed image as soon as generation is complete (#1933) * add updateOnBackgroundChange * typo fixes. * reindent to 4 spaces --- javascript/imageviewer.js | 174 ++++++++++++++++++++++++++-------------------- 1 file changed, 99 insertions(+), 75 deletions(-) diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index 6a00c0da..65a33dd7 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -1,72 +1,97 @@ // A full size 'lightbox' preview modal shown when left clicking on gallery previews - function closeModal() { - gradioApp().getElementById("lightboxModal").style.display = "none"; + gradioApp().getElementById("lightboxModal").style.display = "none"; } function showModal(event) { - const source = event.target || event.srcElement; - const modalImage = gradioApp().getElementById("modalImage") - const lb = gradioApp().getElementById("lightboxModal") - modalImage.src = source.src - if (modalImage.style.display === 'none') { - lb.style.setProperty('background-image', 'url(' + source.src + ')'); - } - lb.style.display = "block"; - lb.focus() - event.stopPropagation() + const source = event.target || event.srcElement; + const modalImage = gradioApp().getElementById("modalImage") + const lb = gradioApp().getElementById("lightboxModal") + modalImage.src = source.src + if (modalImage.style.display === 'none') { + lb.style.setProperty('background-image', 'url(' + source.src + ')'); + } + lb.style.display = "block"; + lb.focus() + event.stopPropagation() } function negmod(n, m) { - return ((n % m) + m) % m; + return ((n % m) + m) % m; } -function modalImageSwitch(offset){ - var allgalleryButtons = gradioApp().querySelectorAll(".gallery-item.transition-all") - var galleryButtons = [] - allgalleryButtons.forEach(function(elem){ - if(elem.parentElement.offsetParent){ - galleryButtons.push(elem); +function updateOnBackgroundChange() { + const modalImage = gradioApp().getElementById("modalImage") + if (modalImage && modalImage.offsetParent) { + let allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2") + let currentButton = null + allcurrentButtons.forEach(function(elem) { + if (elem.parentElement.offsetParent) { + currentButton = elem; + } + }) + + if (modalImage.src != currentButton.children[0].src) { + modalImage.src = currentButton.children[0].src; + if (modalImage.style.display === 'none') { + modal.style.setProperty('background-image', `url(${modalImage.src})`) + } + } } - }) - - if(galleryButtons.length>1){ - var allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2") - var currentButton = null - allcurrentButtons.forEach(function(elem){ - if(elem.parentElement.offsetParent){ - currentButton = elem; +} + +function modalImageSwitch(offset) { + var allgalleryButtons = gradioApp().querySelectorAll(".gallery-item.transition-all") + var galleryButtons = [] + allgalleryButtons.forEach(function(elem) { + if (elem.parentElement.offsetParent) { + galleryButtons.push(elem); } - }) - - var result = -1 - galleryButtons.forEach(function(v, i){ if(v==currentButton) { result = i } }) - - if(result != -1){ - nextButton = galleryButtons[negmod((result+offset),galleryButtons.length)] - nextButton.click() - const modalImage = gradioApp().getElementById("modalImage"); - const modal = gradioApp().getElementById("lightboxModal"); - modalImage.src = nextButton.children[0].src; - if (modalImage.style.display === 'none') { - modal.style.setProperty('background-image', `url(${modalImage.src})`) + }) + + if (galleryButtons.length > 1) { + var allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2") + var currentButton = null + allcurrentButtons.forEach(function(elem) { + if (elem.parentElement.offsetParent) { + currentButton = elem; + } + }) + + var result = -1 + galleryButtons.forEach(function(v, i) { + if (v == currentButton) { + result = i + } + }) + + if (result != -1) { + nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)] + nextButton.click() + const modalImage = gradioApp().getElementById("modalImage"); + const modal = gradioApp().getElementById("lightboxModal"); + modalImage.src = nextButton.children[0].src; + if (modalImage.style.display === 'none') { + modal.style.setProperty('background-image', `url(${modalImage.src})`) + } + setTimeout(function() { + modal.focus() + }, 10) } - setTimeout( function(){modal.focus()},10) - } - } + } } -function modalNextImage(event){ - modalImageSwitch(1) - event.stopPropagation() +function modalNextImage(event) { + modalImageSwitch(1) + event.stopPropagation() } -function modalPrevImage(event){ - modalImageSwitch(-1) - event.stopPropagation() +function modalPrevImage(event) { + modalImageSwitch(-1) + event.stopPropagation() } -function modalKeyHandler(event){ +function modalKeyHandler(event) { switch (event.key) { case "ArrowLeft": modalPrevImage(event) @@ -80,24 +105,22 @@ function modalKeyHandler(event){ } } -function showGalleryImage(){ +function showGalleryImage() { setTimeout(function() { fullImg_preview = gradioApp().querySelectorAll('img.w-full.object-contain') - - if(fullImg_preview != null){ + + if (fullImg_preview != null) { fullImg_preview.forEach(function function_name(e) { if (e.dataset.modded) return; e.dataset.modded = true; if(e && e.parentElement.tagName == 'DIV'){ - e.style.cursor='pointer' - e.addEventListener('click', function (evt) { if(!opts.js_modal_lightbox) return; modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed) showModal(evt) - },true); + }, true); } }); } @@ -105,21 +128,21 @@ function showGalleryImage(){ }, 100); } -function modalZoomSet(modalImage, enable){ - if( enable ){ +function modalZoomSet(modalImage, enable) { + if (enable) { modalImage.classList.add('modalImageFullscreen'); - } else{ + } else { modalImage.classList.remove('modalImageFullscreen'); } } -function modalZoomToggle(event){ +function modalZoomToggle(event) { modalImage = gradioApp().getElementById("modalImage"); modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen')) event.stopPropagation() } -function modalTileImageToggle(event){ +function modalTileImageToggle(event) { const modalImage = gradioApp().getElementById("modalImage"); const modal = gradioApp().getElementById("lightboxModal"); const isTiling = modalImage.style.display === 'none'; @@ -134,17 +157,18 @@ function modalTileImageToggle(event){ event.stopPropagation() } -function galleryImageHandler(e){ - if(e && e.parentElement.tagName == 'BUTTON'){ +function galleryImageHandler(e) { + if (e && e.parentElement.tagName == 'BUTTON') { e.onclick = showGalleryImage; } } -onUiUpdate(function(){ +onUiUpdate(function() { fullImg_preview = gradioApp().querySelectorAll('img.w-full') - if(fullImg_preview != null){ - fullImg_preview.forEach(galleryImageHandler); + if (fullImg_preview != null) { + fullImg_preview.forEach(galleryImageHandler); } + updateOnBackgroundChange(); }) document.addEventListener("DOMContentLoaded", function() { @@ -152,13 +176,13 @@ document.addEventListener("DOMContentLoaded", function() { const modal = document.createElement('div') modal.onclick = closeModal; modal.id = "lightboxModal"; - modal.tabIndex=0 + modal.tabIndex = 0 modal.addEventListener('keydown', modalKeyHandler, true) const modalControls = document.createElement('div') modalControls.className = 'modalControls gradio-container'; modal.append(modalControls); - + const modalZoom = document.createElement('span') modalZoom.className = 'modalZoom cursor'; modalZoom.innerHTML = '⤡' @@ -183,30 +207,30 @@ document.addEventListener("DOMContentLoaded", function() { const modalImage = document.createElement('img') modalImage.id = 'modalImage'; modalImage.onclick = closeModal; - modalImage.tabIndex=0 + modalImage.tabIndex = 0 modalImage.addEventListener('keydown', modalKeyHandler, true) modal.appendChild(modalImage) const modalPrev = document.createElement('a') modalPrev.className = 'modalPrev'; modalPrev.innerHTML = '❮' - modalPrev.tabIndex=0 - modalPrev.addEventListener('click',modalPrevImage,true); + modalPrev.tabIndex = 0 + modalPrev.addEventListener('click', modalPrevImage, true); modalPrev.addEventListener('keydown', modalKeyHandler, true) modal.appendChild(modalPrev) const modalNext = document.createElement('a') modalNext.className = 'modalNext'; modalNext.innerHTML = '❯' - modalNext.tabIndex=0 - modalNext.addEventListener('click',modalNextImage,true); + modalNext.tabIndex = 0 + modalNext.addEventListener('click', modalNextImage, true); modalNext.addEventListener('keydown', modalKeyHandler, true) modal.appendChild(modalNext) gradioApp().getRootNode().appendChild(modal) - + document.body.appendChild(modalFragment); - + }); -- cgit v1.2.1 From 9ecea0a8d6bdc434755e11128487fd62f1ff130f Mon Sep 17 00:00:00 2001 From: Artem Zagidulin Date: Sun, 9 Oct 2022 16:14:56 +0300 Subject: fix missing png info when Extras Batch Process --- modules/extras.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/extras.py b/modules/extras.py index 39dd3806..41e8612c 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -29,7 +29,7 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v if extras_mode == 1: #convert file to pillow image for img in image_folder: - image = Image.fromarray(np.array(Image.open(img))) + image = Image.open(img) imageArr.append(image) imageNameArr.append(os.path.splitext(img.orig_name)[0]) else: -- cgit v1.2.1 From a2d70f25bf51264d8d68f4f36937b390f79334a7 Mon Sep 17 00:00:00 2001 From: supersteve3d <39339941+supersteve3d@users.noreply.github.com> Date: Sun, 9 Oct 2022 23:40:18 +0800 Subject: Add files via upload Updated txt2img screenshot (UI as of Oct 9th) for github webui / README.md --- txt2img_Screenshot.png | Bin 539132 -> 337094 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/txt2img_Screenshot.png b/txt2img_Screenshot.png index fedd538e..6e2759a4 100644 Binary files a/txt2img_Screenshot.png and b/txt2img_Screenshot.png differ -- cgit v1.2.1 From 45bf9a6264b3507473e02cc3f9aa36559f24aca2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 18:58:55 +0300 Subject: added clip skip to XY plot --- scripts/xy_grid.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index c89ca1a9..7b0d9083 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -83,6 +83,10 @@ def apply_hypernetwork(p, x, xs): hypernetwork.load_hypernetwork(x) +def apply_clip_skip(p, x, xs): + opts.data["CLIP_ignore_last_layers"] = x + + def format_value_add_label(p, opt, x): if type(x) == float: x = round(x, 8) @@ -134,6 +138,7 @@ axis_options = [ AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label), AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label), AxisOption("Eta", float, apply_field("eta"), format_value_add_label), + AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label), AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones ] @@ -201,6 +206,7 @@ class Script(scripts.Script): modules.processing.fix_seed(p) p.batch_size = 1 + CLIP_ignore_last_layers = opts.CLIP_ignore_last_layers def process_axis(opt, vals): if opt.label == 'Nothing': @@ -321,4 +327,6 @@ class Script(scripts.Script): hypernetwork.load_hypernetwork(opts.sd_hypernetwork) + opts.data["CLIP_ignore_last_layers"] = CLIP_ignore_last_layers + return processed -- cgit v1.2.1 From 6c383d2e82045fc4475d665f83bdeeac8fd844d9 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 22:24:07 +0300 Subject: show model selection setting on top of page --- modules/shared.py | 5 +++-- modules/ui.py | 54 +++++++++++++++++++++++++++++++++++++++++++++--------- style.css | 9 +++++++++ 3 files changed, 57 insertions(+), 11 deletions(-) diff --git a/modules/shared.py b/modules/shared.py index 3d7f08e1..270fa402 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -131,13 +131,14 @@ def realesrgan_models_names(): class OptionInfo: - def __init__(self, default=None, label="", component=None, component_args=None, onchange=None): + def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False): self.default = default self.label = label self.component = component self.component_args = component_args self.onchange = onchange self.section = None + self.show_on_main_page = show_on_main_page def options_section(section_identifier, options_dict): @@ -214,7 +215,7 @@ options_templates.update(options_section(('system', "System"), { })) options_templates.update(options_section(('sd', "Stable Diffusion"), { - "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}), + "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, show_on_main_page=True), "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), diff --git a/modules/ui.py b/modules/ui.py index dad509f3..2231a8ed 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1175,10 +1175,13 @@ Requested path was: {f} changed = 0 for key, value, comp in zip(opts.data_labels.keys(), args, components): - if not opts.same_type(value, opts.data_labels[key].default): - return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" + if comp != dummy_component and not opts.same_type(value, opts.data_labels[key].default): + return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}", opts.dumpjson() for key, value, comp in zip(opts.data_labels.keys(), args, components): + if comp == dummy_component: + continue + comp_args = opts.data_labels[key].component_args if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: continue @@ -1196,6 +1199,21 @@ Requested path was: {f} return f'{changed} settings changed.', opts.dumpjson() + def run_settings_single(value, key): + if not opts.same_type(value, opts.data_labels[key].default): + return gr.update(visible=True), opts.dumpjson() + + oldval = opts.data.get(key, None) + opts.data[key] = value + + if oldval != value: + if opts.data_labels[key].onchange is not None: + opts.data_labels[key].onchange() + + opts.save(shared.config_filename) + + return gr.update(value=value), opts.dumpjson() + with gr.Blocks(analytics_enabled=False) as settings_interface: settings_submit = gr.Button(value="Apply settings", variant='primary') result = gr.HTML() @@ -1203,6 +1221,8 @@ Requested path was: {f} settings_cols = 3 items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols) + quicksettings_list = [] + cols_displayed = 0 items_displayed = 0 previous_section = None @@ -1225,10 +1245,14 @@ Requested path was: {f} gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='

{}

'.format(item.section[1])) - component = create_setting_component(k) - component_dict[k] = component - components.append(component) - items_displayed += 1 + if item.show_on_main_page: + quicksettings_list.append((i, k, item)) + components.append(dummy_component) + else: + component = create_setting_component(k) + component_dict[k] = component + components.append(component) + items_displayed += 1 request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") request_notifications.click( @@ -1242,7 +1266,6 @@ Requested path was: {f} reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') - def reload_scripts(): modules.scripts.reload_script_body_only() @@ -1289,7 +1312,11 @@ Requested path was: {f} css += css_hide_progressbar with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: - + with gr.Row(elem_id="quicksettings"): + for i, k, item in quicksettings_list: + component = create_setting_component(k) + component_dict[k] = component + settings_interface.gradio_ref = demo with gr.Tabs() as tabs: @@ -1306,7 +1333,16 @@ Requested path was: {f} inputs=components, outputs=[result, text_settings], ) - + + for i, k, item in quicksettings_list: + component = component_dict[k] + + component.change( + fn=lambda value, k=k: run_settings_single(value, key=k), + inputs=[component], + outputs=[component, text_settings], + ) + def modelmerger(*args): try: results = modules.extras.run_modelmerger(*args) diff --git a/style.css b/style.css index 101d2052..28160bdf 100644 --- a/style.css +++ b/style.css @@ -453,3 +453,12 @@ input[type="range"]{ .context-menu-items a:hover{ background: #a55000; } + +#quicksettings > div{ + border: none; +} + +#quicksettings > div > div{ + max-width: 32em; + padding: 0; +} -- cgit v1.2.1 From e59c66c0088422b27f64b401ef42c242f836725a Mon Sep 17 00:00:00 2001 From: Fampai Date: Sat, 8 Oct 2022 16:32:05 -0400 Subject: Optimized code for Ignoring last CLIP layers --- modules/sd_hijack.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index f12a9696..4a2d2153 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -282,14 +282,10 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens] tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device) - tmp = -opts.CLIP_ignore_last_layers - if (opts.CLIP_ignore_last_layers == 0): - outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids) - z = outputs.last_hidden_state - else: - outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=tmp) - z = outputs.hidden_states[tmp] - z = self.wrapped.transformer.text_model.final_layer_norm(z) + tmp = -opts.CLIP_stop_at_last_layers + outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=tmp) + z = outputs.hidden_states[tmp] + z = self.wrapped.transformer.text_model.final_layer_norm(z) # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers] -- cgit v1.2.1 From a14f7bf113a2af9e06a1c4d06c2efa244f9c5730 Mon Sep 17 00:00:00 2001 From: Fampai Date: Sat, 8 Oct 2022 16:33:06 -0400 Subject: Corrected CLIP Layer Ignore description and updated its range to the max possible --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared.py b/modules/shared.py index 270fa402..1995a99a 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -225,7 +225,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "filter_nsfw": OptionInfo(False, "Filter NSFW content"), - 'CLIP_ignore_last_layers': OptionInfo(0, "Ignore last layers of CLIP model", gr.Slider, {"minimum": 0, "maximum": 5, "step": 1}), + 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), })) -- cgit v1.2.1 From ec2bd9be75865c9f3a8c898163ab381688c03b6e Mon Sep 17 00:00:00 2001 From: Fampai Date: Sat, 8 Oct 2022 17:28:42 -0400 Subject: Fix issues with CLIP ignore option name change --- modules/processing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 04aed989..92a105a2 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -129,7 +129,7 @@ class Processed: self.index_of_first_image = index_of_first_image self.styles = p.styles self.job_timestamp = state.job_timestamp - self.clip_skip = opts.CLIP_ignore_last_layers + self.clip_skip = opts.CLIP_stop_at_last_layers self.eta = p.eta self.ddim_discretize = p.ddim_discretize @@ -274,7 +274,7 @@ def fix_seed(p): def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0): index = position_in_batch + iteration * p.batch_size - clip_skip = getattr(p, 'clip_skip', opts.CLIP_ignore_last_layers) + clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers) generation_params = { "Steps": p.steps, -- cgit v1.2.1 From ad3ae441081155dcd4fde805279e5082ca264695 Mon Sep 17 00:00:00 2001 From: Fampai Date: Sun, 9 Oct 2022 04:32:40 -0400 Subject: Updated code for legibility --- modules/sd_hijack.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 4a2d2153..7793d25b 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -284,8 +284,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): tmp = -opts.CLIP_stop_at_last_layers outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=tmp) - z = outputs.hidden_states[tmp] - z = self.wrapped.transformer.text_model.final_layer_norm(z) + if tmp < -1: + z = outputs.hidden_states[tmp] + z = self.wrapped.transformer.text_model.final_layer_norm(z) + else: + z = outputs.last_hidden_state # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers] -- cgit v1.2.1 From 1824e9ee3ab4f94aee8908a62ea2569a01aeb3d7 Mon Sep 17 00:00:00 2001 From: Fampai Date: Sun, 9 Oct 2022 14:15:43 -0400 Subject: Removed unnecessary tmp variable --- modules/sd_hijack.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 7793d25b..437acce4 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -282,10 +282,9 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens] tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device) - tmp = -opts.CLIP_stop_at_last_layers - outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=tmp) - if tmp < -1: - z = outputs.hidden_states[tmp] + outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=-opts.CLIP_stop_at_last_layers) + if opts.CLIP_stop_at_last_layers > 1: + z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers] z = self.wrapped.transformer.text_model.final_layer_norm(z) else: z = outputs.last_hidden_state -- cgit v1.2.1 From 84ddd44113b36062e8ba6cb2e5db0fce4f48efb8 Mon Sep 17 00:00:00 2001 From: Fampai Date: Sun, 9 Oct 2022 14:57:17 -0400 Subject: Clip skip variable name change breaks x/y plot script. This fixes that --- scripts/xy_grid.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 7b0d9083..771eb8e4 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -84,7 +84,7 @@ def apply_hypernetwork(p, x, xs): def apply_clip_skip(p, x, xs): - opts.data["CLIP_ignore_last_layers"] = x + opts.data["CLIP_stop_at_last_layers"] = x def format_value_add_label(p, opt, x): @@ -206,7 +206,7 @@ class Script(scripts.Script): modules.processing.fix_seed(p) p.batch_size = 1 - CLIP_ignore_last_layers = opts.CLIP_ignore_last_layers + CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers def process_axis(opt, vals): if opt.label == 'Nothing': @@ -327,6 +327,6 @@ class Script(scripts.Script): hypernetwork.load_hypernetwork(opts.sd_hypernetwork) - opts.data["CLIP_ignore_last_layers"] = CLIP_ignore_last_layers + opts.data["CLIP_stop_at_last_layers"] = CLIP_stop_at_last_layers return processed -- cgit v1.2.1 From 8d340cfb884e1dbff5b6f477f4ecf7d104279115 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 22:30:59 +0300 Subject: do not add clip skip to parameters if it's 1 or 0 --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index 92a105a2..94d2dd62 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -293,7 +293,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), - "Clip skip": None if clip_skip==0 else clip_skip, + "Clip skip": None if clip_skip <= 1 else clip_skip, } generation_params.update(p.extra_generation_params) -- cgit v1.2.1 From a65476718f08a35f527b973ef731e6f488bace5e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 23:38:49 +0300 Subject: add DoubleStorage to list of allowed classes for pickle --- modules/safe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/safe.py b/modules/safe.py index 2d2c1371..4d06f2a5 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -27,7 +27,7 @@ class RestrictedUnpickler(pickle.Unpickler): return getattr(collections, name) if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']: return getattr(torch._utils, name) - if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage']: + if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage']: return getattr(torch, name) if module == 'torch.nn.modules.container' and name in ['ParameterDict']: return getattr(torch.nn.modules.container, name) -- cgit v1.2.1 From 45fbd1c5fec887988ab555aac75a999d4f3aff40 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 10 Oct 2022 00:42:18 +0300 Subject: remove background for quicksettings row (for dark theme) --- style.css | 1 + 1 file changed, 1 insertion(+) diff --git a/style.css b/style.css index 28160bdf..c0c3f2bb 100644 --- a/style.css +++ b/style.css @@ -456,6 +456,7 @@ input[type="range"]{ #quicksettings > div{ border: none; + background: none; } #quicksettings > div > div{ -- cgit v1.2.1 From 8acc901ba3a252dc6ab4fabcb41644cf64d1774c Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 10 Oct 2022 00:38:55 -0400 Subject: Newer versions of PyTorch use TypedStorage instead Pytorch 1.13 and later will rename _TypedStorage to TypedStorage, so check for TypedStorage and use _TypedStorage if it is not available. Currently this is needed so that nightly builds of PyTorch work correctly. --- modules/safe.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/safe.py b/modules/safe.py index 4d06f2a5..05917463 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -12,6 +12,10 @@ import _codecs import zipfile +# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage +TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage + + def encode(*args): out = _codecs.encode(*args) return out @@ -20,7 +24,7 @@ def encode(*args): class RestrictedUnpickler(pickle.Unpickler): def persistent_load(self, saved_id): assert saved_id[0] == 'storage' - return torch.storage._TypedStorage() + return TypedStorage() def find_class(self, module, name): if module == 'collections' and name == 'OrderedDict': -- cgit v1.2.1 From a3578233395e585e68c2118d3630cb2a961d4a36 Mon Sep 17 00:00:00 2001 From: Bepis <36346617+bbepis@users.noreply.github.com> Date: Mon, 10 Oct 2022 23:12:29 +1100 Subject: Add a pull request template --- .../PULL_REQUEST_TEMPLATE/pull_request_template.md | 28 ++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 .github/PULL_REQUEST_TEMPLATE/pull_request_template.md diff --git a/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md b/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md new file mode 100644 index 00000000..86009613 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md @@ -0,0 +1,28 @@ +# Please read the [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) before submitting a pull request! + +If you have a large change, pay special attention to this paragraph: + +> Before making changes, if you think that your feature will result in more than 100 lines changing, find me and talk to me about the feature you are proposing. It pains me to reject the hard work someone else did, but I won't add everything to the repo, and it's better if the rejection happens before you have to waste time working on the feature. + +Otherwise, after making sure you're following the rules described in wiki page, remove this section and continue on. + +**Describe what this pull request is trying to achieve.** + +A clear and concise description of what you're trying to accomplish with this, so your intent doesn't have to be extracted from your code. + +**Additional notes and description of your changes** + +More technical discussion about your changes go here, plus anything that a maintainer might have to specifically take a look at, or be wary of. + +**Environment this was tested in** + +List the environment you have developed / tested this on. As per the contributing page, changes should be able to work on Windows out of the box. + - OS: [e.g. Windows, Linux] + - Browser [e.g. chrome, safari] + - Graphics card [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB] + +**Screenshots or videos of your changes** + +If applicable, screenshots or a video showing off your changes. If it edits an existing UI, it should ideally contain a comparison of what used to be there, before your changes were made. + +This is **required** for anything that touches the user interface. \ No newline at end of file -- cgit v1.2.1