aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py29
-rw-r--r--modules/errors.py25
-rw-r--r--modules/extras.py31
-rw-r--r--modules/generation_parameters_copypaste.py5
-rw-r--r--modules/hypernetworks/hypernetwork.py1
-rw-r--r--modules/interrogate.py4
-rw-r--r--modules/modelloader.py20
-rw-r--r--modules/processing.py6
-rw-r--r--modules/sd_hijack_inpainting.py237
-rw-r--r--modules/sd_models.py61
-rw-r--r--modules/shared.py17
-rw-r--r--modules/textual_inversion/preprocess.py1
-rw-r--r--modules/textual_inversion/textual_inversion.py1
-rw-r--r--modules/ui.py16
14 files changed, 176 insertions, 278 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 9c670f00..a6c1d6ed 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,11 +1,12 @@
import base64
import io
import time
+import datetime
import uvicorn
from threading import Lock
from io import BytesIO
from gradio.processing_utils import decode_base64_to_file
-from fastapi import APIRouter, Depends, FastAPI, HTTPException
+from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from secrets import compare_digest
@@ -18,7 +19,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image
-from modules.sd_models import checkpoints_list
+from modules.sd_models import checkpoints_list, find_checkpoint_config
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
from typing import List
@@ -67,6 +68,27 @@ def encode_pil_to_base64(image):
bytes_data = output_bytes.getvalue()
return base64.b64encode(bytes_data)
+def api_middleware(app: FastAPI):
+ @app.middleware("http")
+ async def log_and_time(req: Request, call_next):
+ ts = time.time()
+ res: Response = await call_next(req)
+ duration = str(round(time.time() - ts, 4))
+ res.headers["X-Process-Time"] = duration
+ endpoint = req.scope.get('path', 'err')
+ if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
+ print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
+ t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
+ code = res.status_code,
+ ver = req.scope.get('http_version', '0.0'),
+ cli = req.scope.get('client', ('0:0.0.0', 0))[0],
+ prot = req.scope.get('scheme', 'err'),
+ method = req.scope.get('method', 'err'),
+ endpoint = endpoint,
+ duration = duration,
+ ))
+ return res
+
class Api:
def __init__(self, app: FastAPI, queue_lock: Lock):
@@ -78,6 +100,7 @@ class Api:
self.router = APIRouter()
self.app = app
+ init_api_middleware(self.app)
self.queue_lock = queue_lock
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
@@ -303,7 +326,7 @@ class Api:
return upscalers
def get_sd_models(self):
- return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()]
+ return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
def get_hypernetworks(self):
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
diff --git a/modules/errors.py b/modules/errors.py
index 372dc51a..a668c014 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -2,9 +2,30 @@ import sys
import traceback
+def print_error_explanation(message):
+ lines = message.strip().split("\n")
+ max_len = max([len(x) for x in lines])
+
+ print('=' * max_len, file=sys.stderr)
+ for line in lines:
+ print(line, file=sys.stderr)
+ print('=' * max_len, file=sys.stderr)
+
+
+def display(e: Exception, task):
+ print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ message = str(e)
+ if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
+ print_error_explanation("""
+The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file.
+See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
+ """)
+
+
def run(code, task):
try:
code()
except Exception as e:
- print(f"{task}: {type(e).__name__}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ display(task, e)
diff --git a/modules/extras.py b/modules/extras.py
index 5e270250..d665440a 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -58,6 +58,9 @@ cached_images: LruCache = LruCache(max_size=5)
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
devices.torch_gc()
+ shared.state.begin()
+ shared.state.job = 'extras'
+
imageArr = []
# Also keep track of original file names
imageNameArr = []
@@ -94,6 +97,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
# Extra operation definitions
def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
+ shared.state.job = 'extras-gfpgan'
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
res = Image.fromarray(restored_img)
@@ -104,6 +108,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
return (res, info)
def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
+ shared.state.job = 'extras-codeformer'
restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
res = Image.fromarray(restored_img)
@@ -114,6 +119,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
return (res, info)
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
+ shared.state.job = 'extras-upscale'
upscaler = shared.sd_upscalers[scaler_index]
res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
if mode == 1 and crop:
@@ -180,6 +186,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
for image, image_name in zip(imageArr, imageNameArr):
if image is None:
return outputs, "Please select an input image.", ''
+
+ shared.state.textinfo = f'Processing image {image_name}'
+
existing_pnginfo = image.info or {}
image = image.convert("RGB")
@@ -193,6 +202,10 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
else:
basename = ''
+ if opts.enable_pnginfo: # append info before save
+ image.info = existing_pnginfo
+ image.info["extras"] = info
+
if save_output:
# Add upscaler name as a suffix.
suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else ""
@@ -203,10 +216,6 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix)
- if opts.enable_pnginfo:
- image.info = existing_pnginfo
- image.info["extras"] = info
-
if extras_mode != 2 or show_extras_results :
outputs.append(image)
@@ -242,6 +251,9 @@ def run_pnginfo(image):
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
+ shared.state.begin()
+ shared.state.job = 'model-merge'
+
def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1)
@@ -263,8 +275,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_func1, theta_func2 = theta_funcs[interp_method]
if theta_func1 and not tertiary_model_info:
+ shared.state.textinfo = "Failed: Interpolation method requires a tertiary model."
+ shared.state.end()
return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
+ shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
print(f"Loading {secondary_model_info.filename}...")
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
@@ -281,6 +296,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_1[key] = torch.zeros_like(theta_1[key])
del theta_2
+ shared.state.textinfo = f"Loading {primary_model_info.filename}..."
print(f"Loading {primary_model_info.filename}...")
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
@@ -291,6 +307,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
a = theta_0[key]
b = theta_1[key]
+ shared.state.textinfo = f'Merging layer {key}'
# this enables merging an inpainting model (A) with another one (B);
# where normal model would have 4 channels, for latenst space, inpainting model would
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
@@ -303,8 +320,6 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
result_is_inpainting_model = True
else:
- assert a.shape == b.shape, f'Incompatible shapes for layer {key}: A is {a.shape}, and B is {b.shape}'
-
theta_0[key] = theta_func2(a, b, multiplier)
if save_as_half:
@@ -332,6 +347,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
output_modelname = os.path.join(ckpt_dir, filename)
+ shared.state.textinfo = f"Saving to {output_modelname}..."
print(f"Saving to {output_modelname}...")
_, extension = os.path.splitext(output_modelname)
@@ -343,4 +359,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
sd_models.list_models()
print("Checkpoint saved.")
+ shared.state.textinfo = "Checkpoint saved to " + output_modelname
+ shared.state.end()
+
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index d94f11a3..4baf4d9a 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -37,7 +37,10 @@ def quote(text):
def image_from_url_text(filedata):
- if type(filedata) == dict and filedata["is_file"]:
+ if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
+ filedata = filedata[0]
+
+ if type(filedata) == dict and filedata.get("is_file", False):
filename = filedata["name"]
is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 109e8078..450fecac 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -417,6 +417,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
shared.loaded_hypernetwork = Hypernetwork()
shared.loaded_hypernetwork.load(path)
+ shared.state.job = "train-hypernetwork"
shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 6f761c5a..738d8ff7 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -136,7 +136,8 @@ class InterrogateModels:
def interrogate(self, pil_image):
res = ""
-
+ shared.state.begin()
+ shared.state.job = 'interrogate'
try:
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
@@ -177,5 +178,6 @@ class InterrogateModels:
res += "<error>"
self.unload()
+ shared.state.end()
return res
diff --git a/modules/modelloader.py b/modules/modelloader.py
index e647f6fa..6a1a7ac8 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -123,6 +123,23 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
pass
+builtin_upscaler_classes = []
+forbidden_upscaler_classes = set()
+
+
+def list_builtin_upscalers():
+ load_upscalers()
+
+ builtin_upscaler_classes.clear()
+ builtin_upscaler_classes.extend(Upscaler.__subclasses__())
+
+
+def forbid_loaded_nonbuiltin_upscalers():
+ for cls in Upscaler.__subclasses__():
+ if cls not in builtin_upscaler_classes:
+ forbidden_upscaler_classes.add(cls)
+
+
def load_upscalers():
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
# so we'll try to import any _model.py files before looking in __subclasses__
@@ -139,6 +156,9 @@ def load_upscalers():
datas = []
commandline_options = vars(shared.cmd_opts)
for cls in Upscaler.__subclasses__():
+ if cls in forbidden_upscaler_classes:
+ continue
+
name = cls.__name__
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
scaler = cls(commandline_options.get(cmd_name, None))
diff --git a/modules/processing.py b/modules/processing.py
index 4654570c..fd7c7015 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -685,7 +685,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
- latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_default_mode
+ latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
if self.enable_hr and latent_scale_mode is None:
assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}"
@@ -705,7 +705,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return
if not isinstance(image, Image.Image):
- image = sd_samplers.sample_to_image(image, index)
+ image = sd_samplers.sample_to_image(image, index, approximation=0)
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
@@ -713,7 +713,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
for i in range(samples.shape[0]):
save_intermediate(samples, i)
- samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode)
+ samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
# Avoid making the inpainting conditioning unless necessary as
# this does need some extra compute to decode / encode the image again.
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
index 06b75772..31d2c898 100644
--- a/modules/sd_hijack_inpainting.py
+++ b/modules/sd_hijack_inpainting.py
@@ -12,191 +12,6 @@ from ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
-# =================================================================================================
-# Monkey patch DDIMSampler methods from RunwayML repo directly.
-# Adapted from:
-# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py
-# =================================================================================================
-@torch.no_grad()
-def sample_ddim(self,
- S,
- batch_size,
- shape,
- conditioning=None,
- callback=None,
- normals_sequence=None,
- img_callback=None,
- quantize_x0=False,
- eta=0.,
- mask=None,
- x0=None,
- temperature=1.,
- noise_dropout=0.,
- score_corrector=None,
- corrector_kwargs=None,
- verbose=True,
- x_T=None,
- log_every_t=100,
- unconditional_guidance_scale=1.,
- unconditional_conditioning=None,
- # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
- **kwargs
- ):
- if conditioning is not None:
- if isinstance(conditioning, dict):
- ctmp = conditioning[list(conditioning.keys())[0]]
- while isinstance(ctmp, list):
- ctmp = ctmp[0]
- cbs = ctmp.shape[0]
- if cbs != batch_size:
- print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
- else:
- if conditioning.shape[0] != batch_size:
- print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
-
- self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
- # sampling
- C, H, W = shape
- size = (batch_size, C, H, W)
- print(f'Data shape for DDIM sampling is {size}, eta {eta}')
-
- samples, intermediates = self.ddim_sampling(conditioning, size,
- callback=callback,
- img_callback=img_callback,
- quantize_denoised=quantize_x0,
- mask=mask, x0=x0,
- ddim_use_original_steps=False,
- noise_dropout=noise_dropout,
- temperature=temperature,
- score_corrector=score_corrector,
- corrector_kwargs=corrector_kwargs,
- x_T=x_T,
- log_every_t=log_every_t,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=unconditional_conditioning,
- )
- return samples, intermediates
-
-@torch.no_grad()
-def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
- unconditional_guidance_scale=1., unconditional_conditioning=None):
- b, *_, device = *x.shape, x.device
-
- if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
- e_t = self.model.apply_model(x, t, c)
- else:
- x_in = torch.cat([x] * 2)
- t_in = torch.cat([t] * 2)
- if isinstance(c, dict):
- assert isinstance(unconditional_conditioning, dict)
- c_in = dict()
- for k in c:
- if isinstance(c[k], list):
- c_in[k] = [
- torch.cat([unconditional_conditioning[k][i], c[k][i]])
- for i in range(len(c[k]))
- ]
- else:
- c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
- else:
- c_in = torch.cat([unconditional_conditioning, c])
- e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
- e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
-
- if score_corrector is not None:
- assert self.model.parameterization == "eps"
- e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
-
- alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
- alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
- sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
- sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
- # select parameters corresponding to the currently considered timestep
- a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
- a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
- sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
- sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
-
- # current prediction for x_0
- pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
- if quantize_denoised:
- pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
- # direction pointing to x_t
- dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
- noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
- if noise_dropout > 0.:
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
- x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
- return x_prev, pred_x0
-
-
-# =================================================================================================
-# Monkey patch PLMSSampler methods.
-# This one was not actually patched correctly in the RunwayML repo, but we can replicate the changes.
-# Adapted from:
-# https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/plms.py
-# =================================================================================================
-@torch.no_grad()
-def sample_plms(self,
- S,
- batch_size,
- shape,
- conditioning=None,
- callback=None,
- normals_sequence=None,
- img_callback=None,
- quantize_x0=False,
- eta=0.,
- mask=None,
- x0=None,
- temperature=1.,
- noise_dropout=0.,
- score_corrector=None,
- corrector_kwargs=None,
- verbose=True,
- x_T=None,
- log_every_t=100,
- unconditional_guidance_scale=1.,
- unconditional_conditioning=None,
- # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
- **kwargs
- ):
- if conditioning is not None:
- if isinstance(conditioning, dict):
- ctmp = conditioning[list(conditioning.keys())[0]]
- while isinstance(ctmp, list):
- ctmp = ctmp[0]
- cbs = ctmp.shape[0]
- if cbs != batch_size:
- print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
- else:
- if conditioning.shape[0] != batch_size:
- print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
-
- self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
- # sampling
- C, H, W = shape
- size = (batch_size, C, H, W)
- # print(f'Data shape for PLMS sampling is {size}') # remove unnecessary message
-
- samples, intermediates = self.plms_sampling(conditioning, size,
- callback=callback,
- img_callback=img_callback,
- quantize_denoised=quantize_x0,
- mask=mask, x0=x0,
- ddim_use_original_steps=False,
- noise_dropout=noise_dropout,
- temperature=temperature,
- score_corrector=score_corrector,
- corrector_kwargs=corrector_kwargs,
- x_T=x_T,
- log_every_t=log_every_t,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=unconditional_conditioning,
- )
- return samples, intermediates
-
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
@@ -280,61 +95,17 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
return x_prev, pred_x0, e_t
-# =================================================================================================
-# Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config.
-# Adapted from:
-# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py
-# =================================================================================================
-
-@torch.no_grad()
-def get_unconditional_conditioning(self, batch_size, null_label=None):
- if null_label is not None:
- xc = null_label
- if isinstance(xc, ListConfig):
- xc = list(xc)
- if isinstance(xc, dict) or isinstance(xc, list):
- c = self.get_learned_conditioning(xc)
- else:
- if hasattr(xc, "to"):
- xc = xc.to(self.device)
- c = self.get_learned_conditioning(xc)
- else:
- # todo: get null label from cond_stage_model
- raise NotImplementedError()
- c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
- return c
-
-
-class LatentInpaintDiffusion(LatentDiffusion):
- def __init__(
- self,
- concat_keys=("mask", "masked_image"),
- masked_image_key="masked_image",
- *args,
- **kwargs,
- ):
- super().__init__(*args, **kwargs)
- self.masked_image_key = masked_image_key
- assert self.masked_image_key in concat_keys
- self.concat_keys = concat_keys
-
def should_hijack_inpainting(checkpoint_info):
+ from modules import sd_models
+
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
- cfg_basename = os.path.basename(checkpoint_info.config).lower()
+ cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
+
return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
def do_inpainting_hijack():
- # most of this stuff seems to no longer be needed because it is already included into SD2.0
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
- # this file should be cleaned up later if everything turns out to work fine
-
- # ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
- # ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
-
- # ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
- # ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
- # ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
diff --git a/modules/sd_models.py b/modules/sd_models.py
index bff8d6c9..6dca4ddf 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -20,7 +20,7 @@ from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inp
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
-CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config'])
+CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
checkpoints_list = {}
checkpoints_loaded = collections.OrderedDict()
@@ -48,6 +48,14 @@ def checkpoint_tiles():
return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
+def find_checkpoint_config(info):
+ config = os.path.splitext(info.filename)[0] + ".yaml"
+ if os.path.exists(config):
+ return config
+
+ return shared.cmd_opts.config
+
+
def list_models():
checkpoints_list.clear()
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"])
@@ -73,7 +81,7 @@ def list_models():
if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt)
title, short_model_name = modeltitle(cmd_ckpt, h)
- checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config)
+ checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
shared.opts.data['sd_model_checkpoint'] = title
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
@@ -81,12 +89,7 @@ def list_models():
h = model_hash(filename)
title, short_model_name = modeltitle(filename, h)
- basename, _ = os.path.splitext(filename)
- config = basename + ".yaml"
- if not os.path.exists(config):
- config = shared.cmd_opts.config
-
- checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)
+ checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
def get_closet_checkpoint_match(searchString):
@@ -278,12 +281,14 @@ def enable_midas_autodownload():
midas.api.load_model = load_model_wrapper
+
def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
+ checkpoint_config = find_checkpoint_config(checkpoint_info)
- if checkpoint_info.config != shared.cmd_opts.config:
- print(f"Loading config from: {checkpoint_info.config}")
+ if checkpoint_config != shared.cmd_opts.config:
+ print(f"Loading config from: {checkpoint_config}")
if shared.sd_model:
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
@@ -291,7 +296,7 @@ def load_model(checkpoint_info=None):
gc.collect()
devices.torch_gc()
- sd_config = OmegaConf.load(checkpoint_info.config)
+ sd_config = OmegaConf.load(checkpoint_config)
if should_hijack_inpainting(checkpoint_info):
# Hardcoded config for now...
@@ -301,7 +306,7 @@ def load_model(checkpoint_info=None):
sd_config.model.params.finetune_keys = None
# Create a "fake" config with a different name so that we know to unload it when switching models.
- checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
+ checkpoint_info = checkpoint_info._replace(config=checkpoint_config.replace(".yaml", "-inpainting.yaml"))
if not hasattr(sd_config.model.params, "use_ema"):
sd_config.model.params.use_ema = False
@@ -312,6 +317,7 @@ def load_model(checkpoint_info=None):
sd_config.model.params.unet_config.params.use_fp16 = False
sd_model = instantiate_from_config(sd_config.model)
+
load_model_weights(sd_model, checkpoint_info)
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
@@ -324,26 +330,29 @@ def load_model(checkpoint_info=None):
sd_model.eval()
shared.sd_model = sd_model
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
+
script_callbacks.model_loaded_callback(sd_model)
print("Model loaded.")
- sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload = True) # Reload embeddings after model load as they may or may not fit the model
-
return sd_model
def reload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint()
-
+
if not sd_model:
sd_model = shared.sd_model
+ current_checkpoint_info = sd_model.sd_checkpoint_info
+ checkpoint_config = find_checkpoint_config(current_checkpoint_info)
+
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
- if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
+ if checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
del sd_model
checkpoints_loaded.clear()
load_model(checkpoint_info)
@@ -356,13 +365,19 @@ def reload_model_weights(sd_model=None, info=None):
sd_hijack.model_hijack.undo_hijack(sd_model)
- load_model_weights(sd_model, checkpoint_info)
-
- sd_hijack.model_hijack.hijack(sd_model)
- script_callbacks.model_loaded_callback(sd_model)
-
- if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
- sd_model.to(devices.device)
+ try:
+ load_model_weights(sd_model, checkpoint_info)
+ except Exception as e:
+ print("Failed to load checkpoint, restoring previous")
+ load_model_weights(sd_model, current_checkpoint_info)
+ raise
+ finally:
+ sd_hijack.model_hijack.hijack(sd_model)
+ script_callbacks.model_loaded_callback(sd_model)
+
+ if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
+ sd_model.to(devices.device)
print("Weights loaded.")
+
return sd_model
diff --git a/modules/shared.py b/modules/shared.py
index c541d18c..9c9fd857 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -14,7 +14,7 @@ import modules.interrogate
import modules.memmon
import modules.styles
import modules.devices as devices
-from modules import localization, sd_vae, extensions, script_loading
+from modules import localization, sd_vae, extensions, script_loading, errors
from modules.paths import models_path, script_path, sd_path
@@ -82,6 +82,7 @@ parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencode
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
+parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests")
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
@@ -494,7 +495,12 @@ class Options:
return False
if self.data_labels[key].onchange is not None:
- self.data_labels[key].onchange()
+ try:
+ self.data_labels[key].onchange()
+ except Exception as e:
+ errors.display(e, f"changing setting {key} to {value}")
+ setattr(self, key, oldval)
+ return False
return True
@@ -559,8 +565,11 @@ if os.path.exists(config_filename):
latent_upscale_default_mode = "Latent"
latent_upscale_modes = {
- "Latent": "bilinear",
- "Latent (nearest)": "nearest",
+ "Latent": {"mode": "bilinear", "antialias": False},
+ "Latent (antialiased)": {"mode": "bilinear", "antialias": True},
+ "Latent (bicubic)": {"mode": "bicubic", "antialias": False},
+ "Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True},
+ "Latent (nearest)": {"mode": "nearest", "antialias": False},
}
sd_upscalers = []
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 56b9b2eb..feb876c6 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -124,6 +124,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
files = listfiles(src)
+ shared.state.job = "preprocess"
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 16176e90..214db01c 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -256,6 +256,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
create_image_every = create_image_every or 0
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
+ shared.state.job = "train-embedding"
shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps
diff --git a/modules/ui.py b/modules/ui.py
index f2e7c0d6..bfc93634 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -635,6 +635,7 @@ def create_sampler_and_steps_selection(choices, tabname):
if opts.samplers_in_dropdown:
with FormRow(elem_id=f"sampler_selection_{tabname}"):
sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
+ sampler_index.save_to_config = True
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20)
else:
with FormGroup(elem_id=f"sampler_selection_{tabname}"):
@@ -1529,8 +1530,10 @@ def create_ui():
with gr.Blocks(analytics_enabled=False) as settings_interface:
with gr.Row():
- settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
- restart_gradio = gr.Button(value='Restart UI', variant='primary', elem_id="settings_restart_gradio")
+ with gr.Column(scale=6):
+ settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
+ with gr.Column():
+ restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio")
result = gr.HTML(elem_id="settings_result")
@@ -1574,6 +1577,11 @@ def create_ui():
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
+ if os.path.exists("html/licenses.html"):
+ with open("html/licenses.html", encoding="utf8") as file:
+ with gr.TabItem("Licenses"):
+ gr.HTML(file.read(), elem_id="licenses")
+
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
request_notifications.click(
@@ -1659,6 +1667,10 @@ def create_ui():
if os.path.exists(os.path.join(script_path, "notification.mp3")):
audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
+ if os.path.exists("html/footer.html"):
+ with open("html/footer.html", encoding="utf8") as file:
+ gr.HTML(file.read(), elem_id="footer")
+
text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
settings_submit.click(
fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]),