aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorcluder <1590330+cluder@users.noreply.github.com>2022-11-09 05:50:43 +0100
committercluder <1590330+cluder@users.noreply.github.com>2022-11-09 05:50:43 +0100
commitf37cce0e3d154f300d4ec7ed8ef6a32d1c613e50 (patch)
tree59884560de9a6462122dcab7d1a70a64c7e5ac11 /modules
parent3b51d239ac9201228c6032fc109111e347e8e6b0 (diff)
parentac085628540d0ec6a988fad93f5b8f2154209571 (diff)
Merge branch 'master' of https://github.com/cluder/stable-diffusion-webui into 4448_fix_ckpt_cache
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py14
-rw-r--r--modules/api/models.py8
-rw-r--r--modules/ldsr_model_arch.py14
-rw-r--r--modules/localization.py6
-rw-r--r--modules/safe.py40
-rw-r--r--modules/script_callbacks.py27
-rw-r--r--modules/scripts.py1
-rw-r--r--modules/shared.py2
-rw-r--r--modules/textual_inversion/preprocess.py162
-rw-r--r--modules/ui.py61
-rw-r--r--modules/ui_extensions.py55
11 files changed, 277 insertions, 113 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 112000b8..688469ad 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -63,6 +63,7 @@ class Api:
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
+ self.app.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
@@ -214,6 +215,19 @@ class Api:
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
+ def interrogateapi(self, interrogatereq: InterrogateRequest):
+ image_b64 = interrogatereq.image
+ if image_b64 is None:
+ raise HTTPException(status_code=404, detail="Image not found")
+
+ img = self.__base64_to_image(image_b64)
+
+ # Override object param
+ with self.queue_lock:
+ processed = shared.interrogator.interrogate(img)
+
+ return InterrogateResponse(caption=processed)
+
def interruptapi(self):
shared.state.interrupt()
diff --git a/modules/api/models.py b/modules/api/models.py
index f89da1ff..34dbfa16 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -65,6 +65,7 @@ class PydanticModelGenerator:
self._model_name = model_name
self._class_data = merge_class_params(class_instance)
+
self._model_def = [
ModelDef(
field=underscore(k),
@@ -167,6 +168,12 @@ class ProgressResponse(BaseModel):
state: dict = Field(title="State", description="The current state snapshot")
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
+class InterrogateRequest(BaseModel):
+ image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
+
+class InterrogateResponse(BaseModel):
+ caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
+
fields = {}
for key, value in opts.data.items():
metadata = opts.data_labels.get(key)
@@ -231,3 +238,4 @@ class ArtistItem(BaseModel):
name: str = Field(title="Name")
score: float = Field(title="Score")
category: str = Field(title="Category")
+
diff --git a/modules/ldsr_model_arch.py b/modules/ldsr_model_arch.py
index 14db5076..90e0a2f0 100644
--- a/modules/ldsr_model_arch.py
+++ b/modules/ldsr_model_arch.py
@@ -101,8 +101,8 @@ class LDSR:
down_sample_rate = target_scale / 4
wd = width_og * down_sample_rate
hd = height_og * down_sample_rate
- width_downsampled_pre = int(wd)
- height_downsampled_pre = int(hd)
+ width_downsampled_pre = int(np.ceil(wd))
+ height_downsampled_pre = int(np.ceil(hd))
if down_sample_rate != 1:
print(
@@ -110,7 +110,12 @@ class LDSR:
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
else:
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
- logs = self.run(model["model"], im_og, diffusion_steps, eta)
+
+ # pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
+ pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
+ im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
+
+ logs = self.run(model["model"], im_padded, diffusion_steps, eta)
sample = logs["sample"]
sample = sample.detach().cpu()
@@ -120,6 +125,9 @@ class LDSR:
sample = np.transpose(sample, (0, 2, 3, 1))
a = Image.fromarray(sample[0])
+ # remove padding
+ a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4))
+
del model
gc.collect()
torch.cuda.empty_cache()
diff --git a/modules/localization.py b/modules/localization.py
index b1810cda..f6a6f2fb 100644
--- a/modules/localization.py
+++ b/modules/localization.py
@@ -3,6 +3,7 @@ import os
import sys
import traceback
+
localizations = {}
@@ -16,6 +17,11 @@ def list_localizations(dirname):
localizations[fn] = os.path.join(dirname, file)
+ from modules import scripts
+ for file in scripts.list_scripts("localizations", ".json"):
+ fn, ext = os.path.splitext(file.filename)
+ localizations[fn] = file.path
+
def localization_js(current_localization_name):
fn = localizations.get(current_localization_name, None)
diff --git a/modules/safe.py b/modules/safe.py
index 348a24fc..a9209e38 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -23,11 +23,18 @@ def encode(*args):
class RestrictedUnpickler(pickle.Unpickler):
+ extra_handler = None
+
def persistent_load(self, saved_id):
assert saved_id[0] == 'storage'
return TypedStorage()
def find_class(self, module, name):
+ if self.extra_handler is not None:
+ res = self.extra_handler(module, name)
+ if res is not None:
+ return res
+
if module == 'collections' and name == 'OrderedDict':
return getattr(collections, name)
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
@@ -52,7 +59,7 @@ class RestrictedUnpickler(pickle.Unpickler):
return set
# Forbid everything else.
- raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
+ raise Exception(f"global '{module}/{name}' is forbidden")
allowed_zip_names = ["archive/data.pkl", "archive/version"]
@@ -69,7 +76,7 @@ def check_zip_filenames(filename, names):
raise Exception(f"bad file inside {filename}: {name}")
-def check_pt(filename):
+def check_pt(filename, extra_handler):
try:
# new pytorch format is a zip file
@@ -78,6 +85,7 @@ def check_pt(filename):
with z.open('archive/data.pkl') as file:
unpickler = RestrictedUnpickler(file)
+ unpickler.extra_handler = extra_handler
unpickler.load()
except zipfile.BadZipfile:
@@ -85,16 +93,42 @@ def check_pt(filename):
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
with open(filename, "rb") as file:
unpickler = RestrictedUnpickler(file)
+ unpickler.extra_handler = extra_handler
for i in range(5):
unpickler.load()
def load(filename, *args, **kwargs):
+ return load_with_extra(filename, *args, **kwargs)
+
+
+def load_with_extra(filename, extra_handler=None, *args, **kwargs):
+ """
+ this functon is intended to be used by extensions that want to load models with
+ some extra classes in them that the usual unpickler would find suspicious.
+
+ Use the extra_handler argument to specify a function that takes module and field name as text,
+ and returns that field's value:
+
+ ```python
+ def extra(module, name):
+ if module == 'collections' and name == 'OrderedDict':
+ return collections.OrderedDict
+
+ return None
+
+ safe.load_with_extra('model.pt', extra_handler=extra)
+ ```
+
+ The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
+ definitely unsafe.
+ """
+
from modules import shared
try:
if not shared.cmd_opts.disable_safe_unpickle:
- check_pt(filename)
+ check_pt(filename, extra_handler)
except pickle.UnpicklingError:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 74dfb880..f19e164c 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -7,6 +7,7 @@ from typing import Optional
from fastapi import FastAPI
from gradio import Blocks
+
def report_exception(c, job):
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
@@ -45,15 +46,21 @@ class CFGDenoiserParams:
"""Total number of sampling steps planned"""
+class UiTrainTabParams:
+ def __init__(self, txt2img_preview_params):
+ self.txt2img_preview_params = txt2img_preview_params
+
+
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callback_map = dict(
callbacks_app_started=[],
callbacks_model_loaded=[],
callbacks_ui_tabs=[],
+ callbacks_ui_train_tabs=[],
callbacks_ui_settings=[],
callbacks_before_image_saved=[],
callbacks_image_saved=[],
- callbacks_cfg_denoiser=[]
+ callbacks_cfg_denoiser=[],
)
@@ -61,6 +68,7 @@ def clear_callbacks():
for callback_list in callback_map.values():
callback_list.clear()
+
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
for c in callback_map['callbacks_app_started']:
try:
@@ -79,7 +87,7 @@ def model_loaded_callback(sd_model):
def ui_tabs_callback():
res = []
-
+
for c in callback_map['callbacks_ui_tabs']:
try:
res += c.callback() or []
@@ -89,6 +97,14 @@ def ui_tabs_callback():
return res
+def ui_train_tabs_callback(params: UiTrainTabParams):
+ for c in callback_map['callbacks_ui_train_tabs']:
+ try:
+ c.callback(params)
+ except Exception:
+ report_exception(c, 'callbacks_ui_train_tabs')
+
+
def ui_settings_callback():
for c in callback_map['callbacks_ui_settings']:
try:
@@ -169,6 +185,13 @@ def on_ui_tabs(callback):
add_callback(callback_map['callbacks_ui_tabs'], callback)
+def on_ui_train_tabs(callback):
+ """register a function to be called when the UI is creating new tabs for the train tab.
+ Create your new tabs with gr.Tab.
+ """
+ add_callback(callback_map['callbacks_ui_train_tabs'], callback)
+
+
def on_ui_settings(callback):
"""register a function to be called before UI settings are populated; add your settings
by using shared.opts.add_option(shared.OptionInfo(...)) """
diff --git a/modules/scripts.py b/modules/scripts.py
index 366c90d7..637b2329 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -3,7 +3,6 @@ import sys
import traceback
from collections import namedtuple
-import modules.ui as ui
import gradio as gr
from modules.processing import StableDiffusionProcessing
diff --git a/modules/shared.py b/modules/shared.py
index 70b998ff..e8bacd3c 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -221,8 +221,6 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
face_restorers = []
-localization.list_localizations(cmd_opts.localizations_dir)
-
def realesrgan_models_names():
import modules.realesrgan_model
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index e13b1894..488aa5b5 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -35,6 +35,84 @@ def preprocess(process_src, process_dst, process_width, process_height, preproce
deepbooru.release_process()
+def listfiles(dirname):
+ return os.listdir(dirname)
+
+
+class PreprocessParams:
+ src = None
+ dstdir = None
+ subindex = 0
+ flip = False
+ process_caption = False
+ process_caption_deepbooru = False
+ preprocess_txt_action = None
+
+
+def save_pic_with_caption(image, index, params: PreprocessParams, existing_caption=None):
+ caption = ""
+
+ if params.process_caption:
+ caption += shared.interrogator.generate_caption(image)
+
+ if params.process_caption_deepbooru:
+ if len(caption) > 0:
+ caption += ", "
+ caption += deepbooru.get_tags_from_process(image)
+
+ filename_part = params.src
+ filename_part = os.path.splitext(filename_part)[0]
+ filename_part = os.path.basename(filename_part)
+
+ basename = f"{index:05}-{params.subindex}-{filename_part}"
+ image.save(os.path.join(params.dstdir, f"{basename}.png"))
+
+ if params.preprocess_txt_action == 'prepend' and existing_caption:
+ caption = existing_caption + ' ' + caption
+ elif params.preprocess_txt_action == 'append' and existing_caption:
+ caption = caption + ' ' + existing_caption
+ elif params.preprocess_txt_action == 'copy' and existing_caption:
+ caption = existing_caption
+
+ caption = caption.strip()
+
+ if len(caption) > 0:
+ with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file:
+ file.write(caption)
+
+ params.subindex += 1
+
+
+def save_pic(image, index, params, existing_caption=None):
+ save_pic_with_caption(image, index, params, existing_caption=existing_caption)
+
+ if params.flip:
+ save_pic_with_caption(ImageOps.mirror(image), index, params, existing_caption=existing_caption)
+
+
+def split_pic(image, inverse_xy, width, height, overlap_ratio):
+ if inverse_xy:
+ from_w, from_h = image.height, image.width
+ to_w, to_h = height, width
+ else:
+ from_w, from_h = image.width, image.height
+ to_w, to_h = width, height
+ h = from_h * to_w // from_w
+ if inverse_xy:
+ image = image.resize((h, to_w))
+ else:
+ image = image.resize((to_w, h))
+
+ split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
+ y_step = (h - to_h) / (split_count - 1)
+ for i in range(split_count):
+ y = int(y_step * i)
+ if inverse_xy:
+ splitted = image.crop((y, 0, y + to_h, to_w))
+ else:
+ splitted = image.crop((0, y, to_w, y + to_h))
+ yield splitted
+
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
width = process_width
@@ -48,82 +126,28 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
os.makedirs(dst, exist_ok=True)
- files = os.listdir(src)
+ files = listfiles(src)
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)
- def save_pic_with_caption(image, index, existing_caption=None):
- caption = ""
-
- if process_caption:
- caption += shared.interrogator.generate_caption(image)
-
- if process_caption_deepbooru:
- if len(caption) > 0:
- caption += ", "
- caption += deepbooru.get_tags_from_process(image)
-
- filename_part = filename
- filename_part = os.path.splitext(filename_part)[0]
- filename_part = os.path.basename(filename_part)
-
- basename = f"{index:05}-{subindex[0]}-{filename_part}"
- image.save(os.path.join(dst, f"{basename}.png"))
-
- if preprocess_txt_action == 'prepend' and existing_caption:
- caption = existing_caption + ' ' + caption
- elif preprocess_txt_action == 'append' and existing_caption:
- caption = caption + ' ' + existing_caption
- elif preprocess_txt_action == 'copy' and existing_caption:
- caption = existing_caption
-
- caption = caption.strip()
-
- if len(caption) > 0:
- with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
- file.write(caption)
-
- subindex[0] += 1
-
- def save_pic(image, index, existing_caption=None):
- save_pic_with_caption(image, index, existing_caption=existing_caption)
-
- if process_flip:
- save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption)
-
- def split_pic(image, inverse_xy):
- if inverse_xy:
- from_w, from_h = image.height, image.width
- to_w, to_h = height, width
- else:
- from_w, from_h = image.width, image.height
- to_w, to_h = width, height
- h = from_h * to_w // from_w
- if inverse_xy:
- image = image.resize((h, to_w))
- else:
- image = image.resize((to_w, h))
-
- split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
- y_step = (h - to_h) / (split_count - 1)
- for i in range(split_count):
- y = int(y_step * i)
- if inverse_xy:
- splitted = image.crop((y, 0, y + to_h, to_w))
- else:
- splitted = image.crop((0, y, to_w, y + to_h))
- yield splitted
-
+ params = PreprocessParams()
+ params.dstdir = dst
+ params.flip = process_flip
+ params.process_caption = process_caption
+ params.process_caption_deepbooru = process_caption_deepbooru
+ params.preprocess_txt_action = preprocess_txt_action
for index, imagefile in enumerate(tqdm.tqdm(files)):
- subindex = [0]
+ params.subindex = 0
filename = os.path.join(src, imagefile)
try:
img = Image.open(filename).convert("RGB")
except Exception:
continue
+ params.src = filename
+
existing_caption = None
existing_caption_filename = os.path.splitext(filename)[0] + '.txt'
if os.path.exists(existing_caption_filename):
@@ -143,8 +167,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
process_default_resize = True
if process_split and ratio < 1.0 and ratio <= split_threshold:
- for splitted in split_pic(img, inverse_xy):
- save_pic(splitted, index, existing_caption=existing_caption)
+ for splitted in split_pic(img, inverse_xy, width, height, overlap_ratio):
+ save_pic(splitted, index, params, existing_caption=existing_caption)
process_default_resize = False
if process_focal_crop and img.height != img.width:
@@ -165,11 +189,11 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
dnn_model_path = dnn_model_path,
)
for focal in autocrop.crop_image(img, autocrop_settings):
- save_pic(focal, index, existing_caption=existing_caption)
+ save_pic(focal, index, params, existing_caption=existing_caption)
process_default_resize = False
if process_default_resize:
img = images.resize_image(1, img, width, height)
- save_pic(img, index, existing_caption=existing_caption)
+ save_pic(img, index, params, existing_caption=existing_caption)
- shared.state.nextjob() \ No newline at end of file
+ shared.state.nextjob()
diff --git a/modules/ui.py b/modules/ui.py
index 76ca9b07..7ea1177f 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -174,9 +174,9 @@ def save_pil_to_file(pil_image, dir=None):
gr.processing_utils.save_pil_to_file = save_pil_to_file
-def wrap_gradio_call(func, extra_outputs=None):
+def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
- run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
+ run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
if run_memmon:
shared.mem_mon.monitor()
t = time.perf_counter()
@@ -203,11 +203,18 @@ def wrap_gradio_call(func, extra_outputs=None):
res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
+ shared.state.skipped = False
+ shared.state.interrupted = False
+ shared.state.job_count = 0
+
+ if not add_stats:
+ return tuple(res)
+
elapsed = time.perf_counter() - t
elapsed_m = int(elapsed // 60)
elapsed_s = elapsed % 60
elapsed_text = f"{elapsed_s:.2f}s"
- if (elapsed_m > 0):
+ if elapsed_m > 0:
elapsed_text = f"{elapsed_m}m "+elapsed_text
if run_memmon:
@@ -225,10 +232,6 @@ def wrap_gradio_call(func, extra_outputs=None):
# last item is always HTML
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
- shared.state.skipped = False
- shared.state.interrupted = False
- shared.state.job_count = 0
-
return tuple(res)
return f
@@ -1138,7 +1141,7 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[html, generation_info, html2],
)
- with gr.Blocks() as modelmerger_interface:
+ with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
@@ -1158,7 +1161,7 @@ def create_ui(wrap_gradio_gpu_call):
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
- with gr.Blocks() as train_interface:
+ with gr.Blocks(analytics_enabled=False) as train_interface:
with gr.Row().style(equal_height=False):
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
@@ -1267,6 +1270,10 @@ def create_ui(wrap_gradio_gpu_call):
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
train_embedding = gr.Button(value="Train Embedding", variant='primary')
+ params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
+
+ script_callbacks.ui_train_tabs_callback(params)
+
with gr.Column():
progressbar = gr.HTML(elem_id="ti_progressbar")
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
@@ -1417,15 +1424,14 @@ def create_ui(wrap_gradio_gpu_call):
if info.refresh is not None:
if is_quicksettings:
- res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
else:
with gr.Row(variant="compact"):
- res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
else:
- res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {}))
-
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
return res
@@ -1436,7 +1442,7 @@ def create_ui(wrap_gradio_gpu_call):
opts.reorder()
def run_settings(*args):
- changed = 0
+ changed = []
for key, value, comp in zip(opts.data_labels.keys(), args, components):
assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
@@ -1454,12 +1460,12 @@ def create_ui(wrap_gradio_gpu_call):
if opts.data_labels[key].onchange is not None:
opts.data_labels[key].onchange()
- changed += 1
+ changed.append(key)
try:
opts.save(shared.config_filename)
except RuntimeError:
- return opts.dumpjson(), f'{changed} settings changed without save.'
- return opts.dumpjson(), f'{changed} settings changed.'
+ return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
+ return opts.dumpjson(), f'{len(changed)} settings changed: {", ".join(changed)}.'
def run_settings_single(value, key):
if not opts.same_type(value, opts.data_labels[key].default):
@@ -1563,11 +1569,10 @@ def create_ui(wrap_gradio_gpu_call):
shared.state.need_restart = True
restart_gradio.click(
-
fn=request_restart,
+ _js='restart_reload',
inputs=[],
outputs=[],
- _js='restart_reload'
)
if column is not None:
@@ -1637,6 +1642,17 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[component, text_settings],
)
+ component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
+
+ def get_settings_values():
+ return [getattr(opts, key) for key in component_keys]
+
+ demo.load(
+ fn=get_settings_values,
+ inputs=[],
+ outputs=[component_dict[k] for k in component_keys],
+ )
+
def modelmerger(*args):
try:
results = modules.extras.run_modelmerger(*args)
@@ -1740,7 +1756,7 @@ def create_ui(wrap_gradio_gpu_call):
return demo
-def load_javascript(raw_response):
+def reload_javascript():
with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
javascript = f'<script>{jsfile.read()}</script>'
@@ -1756,7 +1772,7 @@ def load_javascript(raw_response):
javascript += f"\n<script>{localization.localization_js(shared.opts.localization)}</script>"
def template_response(*args, **kwargs):
- res = raw_response(*args, **kwargs)
+ res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
res.body = res.body.replace(
b'</head>', f'{javascript}</head>'.encode("utf8"))
res.init_headers()
@@ -1765,4 +1781,5 @@ def load_javascript(raw_response):
gradio.routes.templates.TemplateResponse = template_response
-reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse)
+if not hasattr(shared, 'GradioTemplateResponseOriginal'):
+ shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index 8e0d41d5..02ab9643 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -140,13 +140,15 @@ def install_extension_from_url(dirname, url):
shutil.rmtree(tmpdir, True)
-def install_extension_from_index(url):
+def install_extension_from_index(url, hide_tags):
ext_table, message = install_extension_from_url(None, url)
- return refresh_available_extensions_from_data(), ext_table, message
+ code, _ = refresh_available_extensions_from_data(hide_tags)
+ return code, ext_table, message
-def refresh_available_extensions(url):
+
+def refresh_available_extensions(url, hide_tags):
global available_extensions
import urllib.request
@@ -155,13 +157,25 @@ def refresh_available_extensions(url):
available_extensions = json.loads(text)
- return url, refresh_available_extensions_from_data(), ''
+ code, tags = refresh_available_extensions_from_data(hide_tags)
+
+ return url, code, gr.CheckboxGroup.update(choices=tags), ''
+
+
+def refresh_available_extensions_for_tags(hide_tags):
+ code, _ = refresh_available_extensions_from_data(hide_tags)
+ return code, ''
-def refresh_available_extensions_from_data():
+
+def refresh_available_extensions_from_data(hide_tags):
extlist = available_extensions["extensions"]
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
+ tags = available_extensions.get("tags", {})
+ tags_to_hide = set(hide_tags)
+ hidden = 0
+
code = f"""<!-- {time.time()} -->
<table id="available_extensions">
<thead>
@@ -178,17 +192,24 @@ def refresh_available_extensions_from_data():
name = ext.get("name", "noname")
url = ext.get("url", None)
description = ext.get("description", "")
+ extension_tags = ext.get("tags", [])
if url is None:
continue
+ if len([x for x in extension_tags if x in tags_to_hide]) > 0:
+ hidden += 1
+ continue
+
existing = installed_extension_urls.get(normalize_git_url(url), None)
install_code = f"""<input onclick="install_extension_from_index(this, '{html.escape(url)}')" type="button" value="{"Install" if not existing else "Installed"}" {"disabled=disabled" if existing else ""} class="gr-button gr-button-lg gr-button-secondary">"""
+ tags_text = ", ".join([f"<span class='extension-tag' title='{tags.get(x, '')}'>{x}</span>" for x in extension_tags])
+
code += f"""
<tr>
- <td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a></td>
+ <td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td>
<td>{html.escape(description)}</td>
<td>{install_code}</td>
</tr>
@@ -199,7 +220,10 @@ def refresh_available_extensions_from_data():
</table>
"""
- return code
+ if hidden > 0:
+ code += f"<p>Extension hidden: {hidden}</p>"
+
+ return code, list(tags)
def create_ui():
@@ -238,21 +262,30 @@ def create_ui():
extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
+ with gr.Row():
+ hide_tags = gr.CheckboxGroup(value=["ads", "localization"], label="Hide extensions with tags", choices=["script", "ads", "localization"])
+
install_result = gr.HTML()
available_extensions_table = gr.HTML()
refresh_available_extensions_button.click(
- fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update()]),
- inputs=[available_extensions_index],
- outputs=[available_extensions_index, available_extensions_table, install_result],
+ fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
+ inputs=[available_extensions_index, hide_tags],
+ outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result],
)
install_extension_button.click(
fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
- inputs=[extension_to_install],
+ inputs=[extension_to_install, hide_tags],
outputs=[available_extensions_table, extensions_table, install_result],
)
+ hide_tags.change(
+ fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
+ inputs=[hide_tags],
+ outputs=[available_extensions_table, install_result]
+ )
+
with gr.TabItem("Install from URL"):
install_url = gr.Text(label="URL for extension's git repository")
install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")