aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--extensions-builtin/Lora/lora.py5
-rw-r--r--extensions-builtin/Lora/scripts/lora_script.py7
-rw-r--r--extensions-builtin/SwinIR/scripts/swinir_model.py8
-rw-r--r--html/extra-networks-card.html4
-rw-r--r--launch.py25
-rw-r--r--modules/api/api.py23
-rw-r--r--modules/api/models.py1
-rw-r--r--modules/interrogate.py49
-rw-r--r--modules/paths.py14
-rw-r--r--modules/shared.py1
-rw-r--r--modules/ui_extra_networks.py6
-rw-r--r--requirements.txt1
-rw-r--r--requirements_versions.txt1
-rw-r--r--webui.bat21
-rw-r--r--webui.py4
15 files changed, 123 insertions, 47 deletions
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py
index 137e58f7..cb8f1d36 100644
--- a/extensions-builtin/Lora/lora.py
+++ b/extensions-builtin/Lora/lora.py
@@ -166,7 +166,10 @@ def lora_forward(module, input, res):
for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None)
if module is not None:
- res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
+ if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
+ res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
+ else:
+ res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
return res
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py
index 60b9eb64..544b228d 100644
--- a/extensions-builtin/Lora/scripts/lora_script.py
+++ b/extensions-builtin/Lora/scripts/lora_script.py
@@ -3,7 +3,7 @@ import torch
import lora
import extra_networks_lora
import ui_extra_networks_lora
-from modules import script_callbacks, ui_extra_networks, extra_networks
+from modules import script_callbacks, ui_extra_networks, extra_networks, shared
def unload():
@@ -28,3 +28,8 @@ torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload)
script_callbacks.on_before_ui(before_ui)
+
+
+shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
+ "lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"),
+}))
diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py
index 9a74b253..e8783bca 100644
--- a/extensions-builtin/SwinIR/scripts/swinir_model.py
+++ b/extensions-builtin/SwinIR/scripts/swinir_model.py
@@ -8,7 +8,7 @@ from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared
-from modules.shared import cmd_opts, opts
+from modules.shared import cmd_opts, opts, state
from swinir_model_arch import SwinIR as net
from swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData
@@ -145,7 +145,13 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list:
+ if state.interrupted or state.skipped:
+ break
+
for w_idx in w_idx_list:
+ if state.interrupted or state.skipped:
+ break
+
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html
index 1bdf1d27..aa9fca87 100644
--- a/html/extra-networks-card.html
+++ b/html/extra-networks-card.html
@@ -1,8 +1,8 @@
-<div class='card' {preview_html} onclick='return cardClicked({tabname}, {prompt}, {allow_negative_prompt})'>
+<div class='card' {preview_html} onclick={card_clicked}>
<div class='actions'>
<div class='additional'>
<ul>
- <a href="#" title="replace preview image with currently selected in gallery" onclick='return saveCardPreview(event, {tabname}, {local_preview})'>replace preview</a>
+ <a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
</ul>
</div>
<span class='name'>{name}</span>
diff --git a/launch.py b/launch.py
index e7a0b50c..9d6f4a8c 100644
--- a/launch.py
+++ b/launch.py
@@ -48,10 +48,19 @@ def extract_opt(args, name):
return args, is_present, opt
-def run(command, desc=None, errdesc=None, custom_env=None):
+def run(command, desc=None, errdesc=None, custom_env=None, live=False):
if desc is not None:
print(desc)
+ if live:
+ result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
+ if result.returncode != 0:
+ raise RuntimeError(f"""{errdesc or 'Error running command'}.
+Command: {command}
+Error code: {result.returncode}""")
+
+ return ""
+
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
if result.returncode != 0:
@@ -108,18 +117,18 @@ def git_clone(url, dir, name, commithash=None):
if commithash is None:
return
- current_hash = run(f'"{git}" -C {dir} rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
+ current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
if current_hash == commithash:
return
- run(f'"{git}" -C {dir} fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
- run(f'"{git}" -C {dir} checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
+ run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
+ run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
return
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
if commithash is not None:
- run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
+ run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
def version_check(commit):
@@ -219,9 +228,9 @@ def prepare_environment():
print(f"Python {sys.version}")
print(f"Commit hash: {commit}")
-
+
if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
- run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch")
+ run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
if not skip_torch_cuda_test:
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
@@ -245,7 +254,7 @@ def prepare_environment():
if not is_installed("xformers"):
exit(0)
elif platform.system() == "Linux":
- run_pip("install xformers", "xformers")
+ run_pip("install xformers==0.0.16rc425", "xformers")
if not is_installed("pyngrok") and ngrok:
run_pip("install pyngrok", "ngrok")
diff --git a/modules/api/api.py b/modules/api/api.py
index b1dd14cc..25c65e57 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -53,7 +53,11 @@ def setUpscalers(req: dict):
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
- return Image.open(BytesIO(base64.b64decode(encoding)))
+ try:
+ image = Image.open(BytesIO(base64.b64decode(encoding)))
+ return image
+ except Exception as err:
+ raise HTTPException(status_code=500, detail="Invalid encoded image")
def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes:
@@ -371,13 +375,16 @@ class Api:
return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
def get_upscalers(self):
- upscalers = []
-
- for upscaler in shared.sd_upscalers:
- u = upscaler.scaler
- upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url})
-
- return upscalers
+ return [
+ {
+ "name": upscaler.name,
+ "model_name": upscaler.scaler.model_name,
+ "model_path": upscaler.data_path,
+ "model_url": None,
+ "scale": upscaler.scale,
+ }
+ for upscaler in shared.sd_upscalers
+ ]
def get_sd_models(self):
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
diff --git a/modules/api/models.py b/modules/api/models.py
index 1eb1fcf1..805bd8f7 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -220,6 +220,7 @@ class UpscalerItem(BaseModel):
model_name: Optional[str] = Field(title="Model Name")
model_path: Optional[str] = Field(title="Path")
model_url: Optional[str] = Field(title="URL")
+ scale: Optional[float] = Field(title="Scale")
class SDModelItem(BaseModel):
title: str = Field(title="Title")
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 19938cbb..c72ff694 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -2,6 +2,7 @@ import os
import sys
import traceback
from collections import namedtuple
+from pathlib import Path
import re
import torch
@@ -20,19 +21,20 @@ Category = namedtuple("Category", ["name", "topn", "items"])
re_topn = re.compile(r"\.top(\d+)\.")
+def category_types():
+ return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
+
def download_default_clip_interrogate_categories(content_dir):
print("Downloading CLIP categories...")
tmpdir = content_dir + "_tmp"
+ category_types = ["artists", "flavors", "mediums", "movements"]
+
try:
os.makedirs(tmpdir)
-
- torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/artists.txt", os.path.join(tmpdir, "artists.txt"))
- torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/flavors.txt", os.path.join(tmpdir, "flavors.top3.txt"))
- torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/mediums.txt", os.path.join(tmpdir, "mediums.txt"))
- torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/movements.txt", os.path.join(tmpdir, "movements.txt"))
-
+ for category_type in category_types:
+ torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
os.rename(tmpdir, content_dir)
except Exception as e:
@@ -51,31 +53,44 @@ class InterrogateModels:
def __init__(self, content_dir):
self.loaded_categories = None
+ self.skip_categories = []
self.content_dir = content_dir
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
def categories(self):
- if self.loaded_categories is not None:
- return self.loaded_categories
-
- self.loaded_categories = []
-
if not os.path.exists(self.content_dir):
download_default_clip_interrogate_categories(self.content_dir)
+ if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories:
+ return self.loaded_categories
+
+ self.loaded_categories = []
+
if os.path.exists(self.content_dir):
- for filename in os.listdir(self.content_dir):
- m = re_topn.search(filename)
+ self.skip_categories = shared.opts.interrogate_clip_skip_categories
+ category_types = []
+ for filename in Path(self.content_dir).glob('*.txt'):
+ category_types.append(filename.stem)
+ if filename.stem in self.skip_categories:
+ continue
+ m = re_topn.search(filename.stem)
topn = 1 if m is None else int(m.group(1))
-
- with open(os.path.join(self.content_dir, filename), "r", encoding="utf8") as file:
+ with open(filename, "r", encoding="utf8") as file:
lines = [x.strip() for x in file.readlines()]
- self.loaded_categories.append(Category(name=filename, topn=topn, items=lines))
+ self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines))
return self.loaded_categories
+ def create_fake_fairscale(self):
+ class FakeFairscale:
+ def checkpoint_wrapper(self):
+ pass
+
+ sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale
+
def load_blip_model(self):
+ self.create_fake_fairscale()
import models.blip
files = modelloader.load_models(
@@ -139,6 +154,8 @@ class InterrogateModels:
def rank(self, image_features, text_array, top_count=1):
import clip
+ devices.torch_gc()
+
if shared.opts.interrogate_clip_dict_limit != 0:
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
diff --git a/modules/paths.py b/modules/paths.py
index 4dd03a35..20b3e4d8 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -38,3 +38,17 @@ for d, must_exist, what, options in path_dirs:
else:
sys.path.append(d)
paths[what] = d
+
+
+class Prioritize:
+ def __init__(self, name):
+ self.name = name
+ self.path = None
+
+ def __enter__(self):
+ self.path = sys.path.copy()
+ sys.path = [paths[self.name]] + sys.path
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ sys.path = self.path
+ self.path = None
diff --git a/modules/shared.py b/modules/shared.py
index e17b4561..5f713bee 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -424,6 +424,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
"interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"),
+ "interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types),
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
"deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 2ddac3d8..8b4f97f8 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -3,6 +3,7 @@ import os.path
from modules import shared
import gradio as gr
import json
+import html
from modules.generation_parameters_copypaste import image_from_url_text
@@ -54,12 +55,13 @@ class ExtraNetworksPage:
preview = item.get("preview", None)
args = {
- "preview_html": "style='background-image: url(" + json.dumps(preview) + ")'" if preview else '',
+ "preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
"prompt": item["prompt"],
"tabname": json.dumps(tabname),
"local_preview": json.dumps(item["local_preview"]),
"name": item["name"],
- "allow_negative_prompt": "true" if self.allow_negative_prompt else "false",
+ "card_clicked": '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"',
+ "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
}
return self.card_page.format(**args)
diff --git a/requirements.txt b/requirements.txt
index ef5e3472..a4be1ec3 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,7 +1,6 @@
blendmodes
accelerate
basicsr
-fairscale==0.4.4
fonts
font-roboto
gfpgan
diff --git a/requirements_versions.txt b/requirements_versions.txt
index f97ad765..135908be 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -14,7 +14,6 @@ scikit-image==0.19.2
fonts
font-roboto
timm==0.6.7
-fairscale==0.4.9
piexif==1.1.3
einops==0.4.1
jsonmerge==1.8.0
diff --git a/webui.bat b/webui.bat
index 3165b94d..209d972b 100644
--- a/webui.bat
+++ b/webui.bat
@@ -3,17 +3,28 @@
if not defined PYTHON (set PYTHON=python)
if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv")
+
set ERROR_REPORTING=FALSE
mkdir tmp 2>NUL
%PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt
-if %ERRORLEVEL% == 0 goto :start_venv
+if %ERRORLEVEL% == 0 goto :check_pip
echo Couldn't launch python
goto :show_stdout_stderr
+:check_pip
+%PYTHON% -mpip --help >tmp/stdout.txt 2>tmp/stderr.txt
+if %ERRORLEVEL% == 0 goto :start_venv
+if "%PIP_INSTALLER_LOCATION%" == "" goto :show_stdout_stderr
+%PYTHON% "%PIP_INSTALLER_LOCATION%" >tmp/stdout.txt 2>tmp/stderr.txt
+if %ERRORLEVEL% == 0 goto :start_venv
+echo Couldn't install pip
+goto :show_stdout_stderr
+
:start_venv
if ["%VENV_DIR%"] == ["-"] goto :skip_venv
+if ["%SKIP_VENV%"] == ["1"] goto :skip_venv
dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt
if %ERRORLEVEL% == 0 goto :activate_venv
@@ -28,13 +39,13 @@ goto :show_stdout_stderr
:activate_venv
set PYTHON="%VENV_DIR%\Scripts\Python.exe"
echo venv %PYTHON%
-if [%ACCELERATE%] == ["True"] goto :accelerate
-goto :launch
:skip_venv
+if [%ACCELERATE%] == ["True"] goto :accelerate
+goto :launch
:accelerate
-echo "Checking for accelerate"
+echo Checking for accelerate
set ACCELERATE="%VENV_DIR%\Scripts\accelerate.exe"
if EXIST %ACCELERATE% goto :accelerate_launch
@@ -44,7 +55,7 @@ pause
exit /b
:accelerate_launch
-echo "Accelerating"
+echo Accelerating
%ACCELERATE% launch --num_cpu_threads_per_process=6 launch.py
pause
exit /b
diff --git a/webui.py b/webui.py
index bc2baeab..e1565a8d 100644
--- a/webui.py
+++ b/webui.py
@@ -1,6 +1,5 @@
import os
import sys
-import threading
import time
import importlib
import signal
@@ -10,6 +9,9 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from packaging import version
+import logging
+logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
+
from modules import import_hook, errors, extra_networks
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call