aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py16
-rw-r--r--modules/api/models.py1
-rw-r--r--modules/devices.py32
-rw-r--r--modules/extensions.py2
-rw-r--r--modules/ngrok.py13
-rw-r--r--modules/script_loading.py34
-rw-r--r--modules/scripts.py48
-rw-r--r--modules/sd_hijack.py22
-rw-r--r--modules/sd_hijack_inpainting.py3
-rw-r--r--modules/sd_models.py29
-rw-r--r--modules/shared.py8
-rw-r--r--modules/textual_inversion/dataset.py7
-rw-r--r--modules/ui.py22
-rw-r--r--modules/ui_extensions.py3
14 files changed, 183 insertions, 57 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 688469ad..596a6616 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -15,6 +15,9 @@ from modules.sd_models import checkpoints_list
from modules.realesrgan_model import get_realesrgan_models
from typing import List
+if shared.cmd_opts.deepdanbooru:
+ from modules.deepbooru import get_deepbooru_tags
+
def upscaler_to_index(name: str):
try:
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
@@ -220,11 +223,20 @@ class Api:
if image_b64 is None:
raise HTTPException(status_code=404, detail="Image not found")
- img = self.__base64_to_image(image_b64)
+ img = decode_base64_to_image(image_b64)
+ img = img.convert('RGB')
# Override object param
with self.queue_lock:
- processed = shared.interrogator.interrogate(img)
+ if interrogatereq.model == "clip":
+ processed = shared.interrogator.interrogate(img)
+ elif interrogatereq.model == "deepdanbooru":
+ if shared.cmd_opts.deepdanbooru:
+ processed = get_deepbooru_tags(img)
+ else:
+ raise HTTPException(status_code=404, detail="Model not found. Add --deepdanbooru when launching for using the model.")
+ else:
+ raise HTTPException(status_code=404, detail="Model not found")
return InterrogateResponse(caption=processed)
diff --git a/modules/api/models.py b/modules/api/models.py
index 34dbfa16..f9cd929e 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -170,6 +170,7 @@ class ProgressResponse(BaseModel):
class InterrogateRequest(BaseModel):
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
+ model: str = Field(default="clip", title="Model", description="The interrogate model used.")
class InterrogateResponse(BaseModel):
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
diff --git a/modules/devices.py b/modules/devices.py
index 7511e1dc..67165bf6 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -3,16 +3,27 @@ import contextlib
import torch
from modules import errors
-# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
-has_mps = getattr(torch, 'has_mps', False)
-cpu = torch.device("cpu")
+# has_mps is only available in nightly pytorch (for now) and MasOS 12.3+.
+# check `getattr` and try it for compatibility
+def has_mps() -> bool:
+ if not getattr(torch, 'has_mps', False):
+ return False
+ try:
+ torch.zeros(1).to(torch.device("mps"))
+ return True
+ except Exception:
+ return False
+
def extract_device_id(args, name):
for x in range(len(args)):
- if name in args[x]: return args[x+1]
+ if name in args[x]:
+ return args[x + 1]
+
return None
+
def get_optimal_device():
if torch.cuda.is_available():
from modules import shared
@@ -25,7 +36,7 @@ def get_optimal_device():
else:
return torch.device("cuda")
- if has_mps:
+ if has_mps():
return torch.device("mps")
return cpu
@@ -45,10 +56,12 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
+cpu = torch.device("cpu")
device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None
dtype = torch.float16
dtype_vae = torch.float16
+
def randn(seed, shape):
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
if device.type == 'mps':
@@ -82,6 +95,11 @@ def autocast(disable=False):
return torch.autocast("cuda")
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
-def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor
-def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device)
+def mps_contiguous(input_tensor, device):
+ return input_tensor.contiguous() if device.type == 'mps' else input_tensor
+
+
+def mps_contiguous_to(input_tensor, device):
+ return mps_contiguous(input_tensor, device).to(device)
diff --git a/modules/extensions.py b/modules/extensions.py
index 8e0977fd..94ce479a 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -6,7 +6,6 @@ import git
from modules import paths, shared
-
extensions = []
extensions_dir = os.path.join(paths.script_path, "extensions")
@@ -84,3 +83,4 @@ def list_extensions():
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
extensions.append(extension)
+
diff --git a/modules/ngrok.py b/modules/ngrok.py
index 5c5f349a..10d2179f 100644
--- a/modules/ngrok.py
+++ b/modules/ngrok.py
@@ -1,14 +1,23 @@
from pyngrok import ngrok, conf, exception
-
def connect(token, port, region):
+ account = None
if token == None:
token = 'None'
+ else:
+ if ':' in token:
+ # token = authtoken:username:password
+ account = token.split(':')[1] + ':' + token.split(':')[-1]
+ token = token.split(':')[0]
+
config = conf.PyngrokConfig(
auth_token=token, region=region
)
try:
- public_url = ngrok.connect(port, pyngrok_config=config).public_url
+ if account == None:
+ public_url = ngrok.connect(port, pyngrok_config=config).public_url
+ else:
+ public_url = ngrok.connect(port, pyngrok_config=config, auth=account).public_url
except exception.PyngrokNgrokError:
print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
diff --git a/modules/script_loading.py b/modules/script_loading.py
new file mode 100644
index 00000000..f93f0951
--- /dev/null
+++ b/modules/script_loading.py
@@ -0,0 +1,34 @@
+import os
+import sys
+import traceback
+from types import ModuleType
+
+
+def load_module(path):
+ with open(path, "r", encoding="utf8") as file:
+ text = file.read()
+
+ compiled = compile(text, path, 'exec')
+ module = ModuleType(os.path.basename(path))
+ exec(compiled, module.__dict__)
+
+ return module
+
+
+def preload_extensions(extensions_dir, parser):
+ if not os.path.isdir(extensions_dir):
+ return
+
+ for dirname in sorted(os.listdir(extensions_dir)):
+ preload_script = os.path.join(extensions_dir, dirname, "preload.py")
+ if not os.path.isfile(preload_script):
+ continue
+
+ try:
+ module = load_module(preload_script)
+ if hasattr(module, 'preload'):
+ module.preload(parser)
+
+ except Exception:
+ print(f"Error running preload() for {preload_script}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
diff --git a/modules/scripts.py b/modules/scripts.py
index 637b2329..986b1914 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -6,7 +6,7 @@ from collections import namedtuple
import gradio as gr
from modules.processing import StableDiffusionProcessing
-from modules import shared, paths, script_callbacks, extensions
+from modules import shared, paths, script_callbacks, extensions, script_loading
AlwaysVisible = object()
@@ -140,7 +140,7 @@ def list_files_with_name(filename):
continue
path = os.path.join(dirpath, filename)
- if os.path.isfile(filename):
+ if os.path.isfile(path):
res.append(path)
return res
@@ -161,13 +161,7 @@ def load_scripts():
sys.path = [scriptfile.basedir] + sys.path
current_basedir = scriptfile.basedir
- with open(scriptfile.path, "r", encoding="utf8") as file:
- text = file.read()
-
- from types import ModuleType
- compiled = compile(text, scriptfile.path, 'exec')
- module = ModuleType(scriptfile.filename)
- exec(compiled, module.__dict__)
+ module = script_loading.load_module(scriptfile.path)
for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script):
@@ -328,27 +322,21 @@ class ScriptRunner:
def reload_sources(self, cache):
for si, script in list(enumerate(self.scripts)):
- with open(script.filename, "r", encoding="utf8") as file:
- args_from = script.args_from
- args_to = script.args_to
- filename = script.filename
- text = file.read()
-
- from types import ModuleType
-
- module = cache.get(filename, None)
- if module is None:
- compiled = compile(text, filename, 'exec')
- module = ModuleType(script.filename)
- exec(compiled, module.__dict__)
- cache[filename] = module
-
- for key, script_class in module.__dict__.items():
- if type(script_class) == type and issubclass(script_class, Script):
- self.scripts[si] = script_class()
- self.scripts[si].filename = filename
- self.scripts[si].args_from = args_from
- self.scripts[si].args_to = args_to
+ args_from = script.args_from
+ args_to = script.args_to
+ filename = script.filename
+
+ module = cache.get(filename, None)
+ if module is None:
+ module = script_loading.load_module(script.filename)
+ cache[filename] = module
+
+ for key, script_class in module.__dict__.items():
+ if type(script_class) == type and issubclass(script_class, Script):
+ self.scripts[si] = script_class()
+ self.scripts[si].filename = filename
+ self.scripts[si].args_from = args_from
+ self.scripts[si].args_to = args_to
scripts_txt2img = ScriptRunner()
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index bc49d235..97979d05 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -14,6 +14,8 @@ from modules.sd_hijack_optimizations import invokeAI_mps_available
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
+import ldm.models.diffusion.ddim
+import ldm.models.diffusion.plms
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
@@ -406,3 +408,23 @@ def add_circular_option_to_conv_2d():
model_hijack = StableDiffusionModelHijack()
+
+
+def register_buffer(self, name, attr):
+ """
+ Fix register buffer bug for Mac OS.
+ """
+
+ if type(attr) == torch.Tensor:
+ if attr.device != devices.device:
+
+ if devices.has_mps():
+ attr = attr.to(device="mps", dtype=torch.float32)
+ else:
+ attr = attr.to(devices.device)
+
+ setattr(self, name, attr)
+
+
+ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
+ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
index fd92a335..46714a4f 100644
--- a/modules/sd_hijack_inpainting.py
+++ b/modules/sd_hijack_inpainting.py
@@ -328,4 +328,5 @@ def do_inpainting_hijack():
ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
- ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms \ No newline at end of file
+ ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
+
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 34c57bfa..80addf03 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -163,13 +163,21 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash
- if shared.opts.sd_checkpoint_cache > 0 and hasattr(model, "sd_checkpoint_info"):
+ cache_enabled = shared.opts.sd_checkpoint_cache > 0
+
+ if cache_enabled:
sd_vae.restore_base_vae(model)
- checkpoints_loaded[model.sd_checkpoint_info] = model.state_dict().copy()
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
- if checkpoint_info not in checkpoints_loaded:
+ if cache_enabled and checkpoint_info in checkpoints_loaded:
+ # use checkpoint cache
+ vae_name = sd_vae.get_filename(vae_file) if vae_file else None
+ vae_message = f" with {vae_name} VAE" if vae_name else ""
+ print(f"Loading weights [{sd_model_hash}]{vae_message} from cache")
+ model.load_state_dict(checkpoints_loaded[checkpoint_info])
+ else:
+ # load from file
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
@@ -180,6 +188,10 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
del pl_sd
model.load_state_dict(sd, strict=False)
del sd
+
+ if cache_enabled:
+ # cache newly loaded model
+ checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last)
@@ -199,14 +211,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
model.first_stage_model.to(devices.dtype_vae)
- else:
- vae_name = sd_vae.get_filename(vae_file) if vae_file else None
- vae_message = f" with {vae_name} VAE" if vae_name else ""
- print(f"Loading weights [{sd_model_hash}]{vae_message} from cache")
- model.load_state_dict(checkpoints_loaded[checkpoint_info])
-
- if shared.opts.sd_checkpoint_cache > 0:
- while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
+ # clean up cache if limit is reached
+ if cache_enabled:
+ while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model
checkpoints_loaded.popitem(last=False) # LRU
model.sd_model_hash = sd_model_hash
diff --git a/modules/shared.py b/modules/shared.py
index e8bacd3c..6936cbe0 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -3,7 +3,6 @@ import datetime
import json
import os
import sys
-from collections import OrderedDict
import time
import gradio as gr
@@ -15,7 +14,7 @@ import modules.memmon
import modules.sd_models
import modules.styles
import modules.devices as devices
-from modules import sd_samplers, sd_models, localization, sd_vae
+from modules import sd_samplers, sd_models, localization, sd_vae, extensions, script_loading
from modules.hypernetworks import hypernetwork
from modules.paths import models_path, script_path, sd_path
@@ -91,7 +90,10 @@ parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requ
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
+script_loading.preload_extensions(extensions.extensions_dir, parser)
+
cmd_opts = parser.parse_args()
+
restricted_opts = {
"samples_filename_pattern",
"directories_filename_pattern",
@@ -319,6 +321,8 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
+ "shuffle_tags": OptionInfo(False, "Shuffleing tags by ',' when create texts."),
+ "tag_drop_out": OptionInfo(0, "Dropout tags when create texts", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.1}),
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index ad726577..eb75c376 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -98,7 +98,12 @@ class PersonalizedBase(Dataset):
def create_text(self, filename_text):
text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token)
- text = text.replace("[filewords]", filename_text)
+ tags = filename_text.split(',')
+ if shared.opts.tag_drop_out != 0:
+ tags = [t for t in tags if random.random() > shared.opts.tag_drop_out]
+ if shared.opts.shuffle_tags:
+ random.shuffle(tags)
+ text = text.replace("[filewords]", ','.join(tags))
return text
def __len__(self):
diff --git a/modules/ui.py b/modules/ui.py
index 7ea1177f..5dce7f3b 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -566,6 +566,19 @@ def apply_setting(key, value):
return value
+def update_generation_info(args):
+ generation_info, html_info, img_index = args
+ try:
+ generation_info = json.loads(generation_info)
+ if img_index < 0 or img_index >= len(generation_info["infotexts"]):
+ return html_info
+ return plaintext_to_html(generation_info["infotexts"][img_index])
+ except Exception:
+ pass
+ # if the json parse or anything else fails, just return the old html_info
+ return html_info
+
+
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
def refresh():
refresh_method()
@@ -638,6 +651,15 @@ Requested path was: {f}
with gr.Group():
html_info = gr.HTML()
generation_info = gr.Textbox(visible=False)
+ if tabname == 'txt2img' or tabname == 'img2img':
+ generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
+ generation_info_button.click(
+ fn=update_generation_info,
+ _js="(x, y) => [x, y, selected_gallery_index()]",
+ inputs=[generation_info, html_info],
+ outputs=[html_info],
+ preprocess=False
+ )
save.click(
fn=wrap_gradio_call(save_files),
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index 02ab9643..6671cb60 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -134,6 +134,9 @@ def install_extension_from_url(dirname, url):
os.rename(tmpdir, target_dir)
+ import launch
+ launch.run_extension_installer(target_dir)
+
extensions.list_extensions()
return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")]
finally: