aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorbrkirch <brkirch@users.noreply.github.com>2023-01-04 23:10:31 -0500
committerbrkirch <brkirch@users.noreply.github.com>2023-01-06 00:15:22 -0500
commit3bfe2bb5498241c4873cdd71b3f0a5bac5f64d7f (patch)
treec296183ec50286cf3552add8c1c967c2f843ec47 /modules
parentf6ab5a39d762a7791573d1c52ae5a3024b10e8ed (diff)
parent5f4fa942b8ec3ed3b15a352903489d6f9e6eb46e (diff)
Merge remote-tracking branch 'upstream/master' into sub-quad_attn_opt
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py29
-rw-r--r--modules/errors.py25
-rw-r--r--modules/extras.py33
-rw-r--r--modules/generation_parameters_copypaste.py9
-rw-r--r--modules/hypernetworks/hypernetwork.py24
-rw-r--r--modules/interrogate.py4
-rw-r--r--modules/processing.py118
-rw-r--r--modules/sd_hijack.py12
-rw-r--r--modules/sd_hijack_inpainting.py5
-rw-r--r--modules/sd_models.py63
-rw-r--r--modules/sd_samplers.py5
-rw-r--r--modules/shared.py30
-rw-r--r--modules/textual_inversion/learn_schedule.py11
-rw-r--r--modules/textual_inversion/preprocess.py1
-rw-r--r--modules/textual_inversion/textual_inversion.py72
-rw-r--r--modules/txt2img.py5
-rw-r--r--modules/ui.py79
17 files changed, 390 insertions, 135 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 9c670f00..48a70a44 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,11 +1,12 @@
import base64
import io
import time
+import datetime
import uvicorn
from threading import Lock
from io import BytesIO
from gradio.processing_utils import decode_base64_to_file
-from fastapi import APIRouter, Depends, FastAPI, HTTPException
+from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from secrets import compare_digest
@@ -18,7 +19,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image
-from modules.sd_models import checkpoints_list
+from modules.sd_models import checkpoints_list, find_checkpoint_config
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
from typing import List
@@ -67,6 +68,27 @@ def encode_pil_to_base64(image):
bytes_data = output_bytes.getvalue()
return base64.b64encode(bytes_data)
+def api_middleware(app: FastAPI):
+ @app.middleware("http")
+ async def log_and_time(req: Request, call_next):
+ ts = time.time()
+ res: Response = await call_next(req)
+ duration = str(round(time.time() - ts, 4))
+ res.headers["X-Process-Time"] = duration
+ endpoint = req.scope.get('path', 'err')
+ if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
+ print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
+ t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
+ code = res.status_code,
+ ver = req.scope.get('http_version', '0.0'),
+ cli = req.scope.get('client', ('0:0.0.0', 0))[0],
+ prot = req.scope.get('scheme', 'err'),
+ method = req.scope.get('method', 'err'),
+ endpoint = endpoint,
+ duration = duration,
+ ))
+ return res
+
class Api:
def __init__(self, app: FastAPI, queue_lock: Lock):
@@ -79,6 +101,7 @@ class Api:
self.router = APIRouter()
self.app = app
self.queue_lock = queue_lock
+ api_middleware(self.app)
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
@@ -303,7 +326,7 @@ class Api:
return upscalers
def get_sd_models(self):
- return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()]
+ return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
def get_hypernetworks(self):
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
diff --git a/modules/errors.py b/modules/errors.py
index 372dc51a..a668c014 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -2,9 +2,30 @@ import sys
import traceback
+def print_error_explanation(message):
+ lines = message.strip().split("\n")
+ max_len = max([len(x) for x in lines])
+
+ print('=' * max_len, file=sys.stderr)
+ for line in lines:
+ print(line, file=sys.stderr)
+ print('=' * max_len, file=sys.stderr)
+
+
+def display(e: Exception, task):
+ print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ message = str(e)
+ if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
+ print_error_explanation("""
+The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file.
+See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
+ """)
+
+
def run(code, task):
try:
code()
except Exception as e:
- print(f"{task}: {type(e).__name__}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ display(task, e)
diff --git a/modules/extras.py b/modules/extras.py
index 5e270250..7407bfe3 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -19,8 +19,6 @@ from modules.shared import opts
import modules.gfpgan_model
from modules.ui import plaintext_to_html
import modules.codeformer_model
-import piexif
-import piexif.helper
import gradio as gr
import safetensors.torch
@@ -58,6 +56,9 @@ cached_images: LruCache = LruCache(max_size=5)
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
devices.torch_gc()
+ shared.state.begin()
+ shared.state.job = 'extras'
+
imageArr = []
# Also keep track of original file names
imageNameArr = []
@@ -94,6 +95,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
# Extra operation definitions
def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
+ shared.state.job = 'extras-gfpgan'
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
res = Image.fromarray(restored_img)
@@ -104,6 +106,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
return (res, info)
def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
+ shared.state.job = 'extras-codeformer'
restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
res = Image.fromarray(restored_img)
@@ -114,6 +117,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
return (res, info)
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
+ shared.state.job = 'extras-upscale'
upscaler = shared.sd_upscalers[scaler_index]
res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
if mode == 1 and crop:
@@ -180,6 +184,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
for image, image_name in zip(imageArr, imageNameArr):
if image is None:
return outputs, "Please select an input image.", ''
+
+ shared.state.textinfo = f'Processing image {image_name}'
+
existing_pnginfo = image.info or {}
image = image.convert("RGB")
@@ -193,6 +200,10 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
else:
basename = ''
+ if opts.enable_pnginfo: # append info before save
+ image.info = existing_pnginfo
+ image.info["extras"] = info
+
if save_output:
# Add upscaler name as a suffix.
suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else ""
@@ -203,10 +214,6 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix)
- if opts.enable_pnginfo:
- image.info = existing_pnginfo
- image.info["extras"] = info
-
if extras_mode != 2 or show_extras_results :
outputs.append(image)
@@ -242,6 +249,9 @@ def run_pnginfo(image):
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
+ shared.state.begin()
+ shared.state.job = 'model-merge'
+
def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1)
@@ -263,8 +273,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_func1, theta_func2 = theta_funcs[interp_method]
if theta_func1 and not tertiary_model_info:
+ shared.state.textinfo = "Failed: Interpolation method requires a tertiary model."
+ shared.state.end()
return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
+ shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
print(f"Loading {secondary_model_info.filename}...")
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
@@ -281,6 +294,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_1[key] = torch.zeros_like(theta_1[key])
del theta_2
+ shared.state.textinfo = f"Loading {primary_model_info.filename}..."
print(f"Loading {primary_model_info.filename}...")
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
@@ -291,6 +305,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
a = theta_0[key]
b = theta_1[key]
+ shared.state.textinfo = f'Merging layer {key}'
# this enables merging an inpainting model (A) with another one (B);
# where normal model would have 4 channels, for latenst space, inpainting model would
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
@@ -303,8 +318,6 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
result_is_inpainting_model = True
else:
- assert a.shape == b.shape, f'Incompatible shapes for layer {key}: A is {a.shape}, and B is {b.shape}'
-
theta_0[key] = theta_func2(a, b, multiplier)
if save_as_half:
@@ -332,6 +345,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
output_modelname = os.path.join(ckpt_dir, filename)
+ shared.state.textinfo = f"Saving to {output_modelname}..."
print(f"Saving to {output_modelname}...")
_, extension = os.path.splitext(output_modelname)
@@ -343,4 +357,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
sd_models.list_models()
print("Checkpoint saved.")
+ shared.state.textinfo = "Checkpoint saved to " + output_modelname
+ shared.state.end()
+
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 4baf4d9a..12a9de3d 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -212,11 +212,10 @@ def restore_old_hires_fix_params(res):
firstpass_width = math.ceil(scale * width / 64) * 64
firstpass_height = math.ceil(scale * height / 64) * 64
- hr_scale = width / firstpass_width if firstpass_width > 0 else height / firstpass_height
-
res['Size-1'] = firstpass_width
res['Size-2'] = firstpass_height
- res['Hires upscale'] = hr_scale
+ res['Hires resize-1'] = width
+ res['Hires resize-2'] = height
def parse_generation_parameters(x: str):
@@ -276,6 +275,10 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
hypernet_hash = res.get("Hypernet hash", None)
res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
+ if "Hires resize-1" not in res:
+ res["Hires resize-1"] = 0
+ res["Hires resize-2"] = 0
+
restore_old_hires_fix_params(res)
return res
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 109e8078..6a9b1398 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -402,10 +402,8 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
shared.reload_hypernetworks()
- return fn
-
-def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
@@ -417,6 +415,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
shared.loaded_hypernetwork = Hypernetwork()
shared.loaded_hypernetwork.load(path)
+ shared.state.job = "train-hypernetwork"
shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps
@@ -447,6 +446,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
+
+ clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
+ if clip_grad:
+ clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
@@ -465,7 +468,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
shared.parallel_processing_allowed = False
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
-
+
weights = hypernetwork.weights()
hypernetwork.train_mode()
@@ -524,6 +527,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
if shared.state.interrupted:
break
+ if clip_grad:
+ clip_grad_sched.step(hypernetwork.step)
+
with devices.autocast():
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
if tag_drop_out != 0 or shuffle_tags:
@@ -538,14 +544,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
_loss_step += loss.item()
scaler.scale(loss).backward()
+
# go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0:
continue
- # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}")
- # scaler.unscale_(optimizer)
- # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
- # torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
- # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
+
+ if clip_grad:
+ clip_grad(weights, clip_grad_sched.learn_rate)
+
scaler.step(optimizer)
scaler.update()
hypernetwork.step += 1
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 6f761c5a..738d8ff7 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -136,7 +136,8 @@ class InterrogateModels:
def interrogate(self, pil_image):
res = ""
-
+ shared.state.begin()
+ shared.state.job = 'interrogate'
try:
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
@@ -177,5 +178,6 @@ class InterrogateModels:
res += "<error>"
self.unload()
+ shared.state.end()
return res
diff --git a/modules/processing.py b/modules/processing.py
index a172af0b..7e853287 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -76,6 +76,24 @@ def apply_overlay(image, paste_loc, index, overlays):
return image
+def txt2img_image_conditioning(sd_model, x, width, height):
+ if sd_model.model.conditioning_key not in {'hybrid', 'concat'}:
+ # Dummy zero conditioning if we're not using inpainting model.
+ # Still takes up a bit of memory, but no encoder call.
+ # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
+ return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
+
+ # The "masked-image" in this case will just be all zeros since the entire image is masked.
+ image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
+ image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
+
+ # Add the fake full 1s mask to the first dimension.
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
+ image_conditioning = image_conditioning.to(x.dtype)
+
+ return image_conditioning
+
+
class StableDiffusionProcessing():
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
@@ -136,28 +154,12 @@ class StableDiffusionProcessing():
self.all_negative_prompts = None
self.all_seeds = None
self.all_subseeds = None
+ self.iteration = 0
def txt2img_image_conditioning(self, x, width=None, height=None):
- if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
- # Dummy zero conditioning if we're not using inpainting model.
- # Still takes up a bit of memory, but no encoder call.
- # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
- return x.new_zeros(x.shape[0], 5, 1, 1)
+ self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
- self.is_using_inpainting_conditioning = True
-
- height = height or self.height
- width = width or self.width
-
- # The "masked-image" in this case will just be all zeros since the entire image is masked.
- image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
- image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
-
- # Add the fake full 1s mask to the first dimension.
- image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
- image_conditioning = image_conditioning.to(x.dtype)
-
- return image_conditioning
+ return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
def depth2img_image_conditioning(self, source_image):
# Use the AddMiDaS helper to Format our source image to suit the MiDaS model
@@ -420,7 +422,7 @@ def fix_seed(p):
p.subseed = get_fixed_seed(p.subseed)
-def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
+def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
index = position_in_batch + iteration * p.batch_size
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
@@ -544,6 +546,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
state.job_count = p.n_iter
for n in range(p.n_iter):
+ p.iteration = n
+
if state.skipped:
state.skipped = False
@@ -658,12 +662,17 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None
- def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, **kwargs):
+ def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, **kwargs):
super().__init__(**kwargs)
self.enable_hr = enable_hr
self.denoising_strength = denoising_strength
self.hr_scale = hr_scale
self.hr_upscaler = hr_upscaler
+ self.hr_second_pass_steps = hr_second_pass_steps
+ self.hr_resize_x = hr_resize_x
+ self.hr_resize_y = hr_resize_y
+ self.hr_upscale_to_x = hr_resize_x
+ self.hr_upscale_to_y = hr_resize_y
if firstphase_width != 0 or firstphase_height != 0:
print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr)
@@ -671,14 +680,60 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.width = firstphase_width
self.height = firstphase_height
+ self.truncate_x = 0
+ self.truncate_y = 0
+
+
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
- if state.job_count == -1:
- state.job_count = self.n_iter * 2
+ if self.hr_resize_x == 0 and self.hr_resize_y == 0:
+ self.extra_generation_params["Hires upscale"] = self.hr_scale
+ self.hr_upscale_to_x = int(self.width * self.hr_scale)
+ self.hr_upscale_to_y = int(self.height * self.hr_scale)
else:
+ self.extra_generation_params["Hires resize"] = f"{self.hr_resize_x}x{self.hr_resize_y}"
+
+ if self.hr_resize_y == 0:
+ self.hr_upscale_to_x = self.hr_resize_x
+ self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
+ elif self.hr_resize_x == 0:
+ self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
+ self.hr_upscale_to_y = self.hr_resize_y
+ else:
+ target_w = self.hr_resize_x
+ target_h = self.hr_resize_y
+ src_ratio = self.width / self.height
+ dst_ratio = self.hr_resize_x / self.hr_resize_y
+
+ if src_ratio < dst_ratio:
+ self.hr_upscale_to_x = self.hr_resize_x
+ self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
+ else:
+ self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
+ self.hr_upscale_to_y = self.hr_resize_y
+
+ self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
+ self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
+
+ # special case: the user has chosen to do nothing
+ if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height:
+ self.enable_hr = False
+ self.denoising_strength = None
+ self.extra_generation_params.pop("Hires upscale", None)
+ self.extra_generation_params.pop("Hires resize", None)
+ return
+
+ if not state.processing_has_refined_job_count:
+ if state.job_count == -1:
+ state.job_count = self.n_iter
+
+ shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count)
state.job_count = state.job_count * 2
+ state.processing_has_refined_job_count = True
+
+ if self.hr_second_pass_steps:
+ self.extra_generation_params["Hires steps"] = self.hr_second_pass_steps
- self.extra_generation_params["Hires upscale"] = self.hr_scale
if self.hr_upscaler is not None:
self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
@@ -695,8 +750,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if not self.enable_hr:
return samples
- target_width = int(self.width * self.hr_scale)
- target_height = int(self.height * self.hr_scale)
+ target_width = self.hr_upscale_to_x
+ target_height = self.hr_upscale_to_y
def save_intermediate(image, index):
"""saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
@@ -705,15 +760,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return
if not isinstance(image, Image.Image):
- image = sd_samplers.sample_to_image(image, index)
+ image = sd_samplers.sample_to_image(image, index, approximation=0)
- images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
+ info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
+ images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, suffix="-before-highres-fix")
if latent_scale_mode is not None:
for i in range(samples.shape[0]):
save_intermediate(samples, i)
- samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode)
+ samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
# Avoid making the inpainting conditioning unless necessary as
# this does need some extra compute to decode / encode the image again.
@@ -750,13 +806,15 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
+ samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
+
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
# GC now before running the next img2img to prevent running out of memory
x = None
devices.torch_gc()
- samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning)
+ samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
return samples
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 55a684cc..ef25dadb 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -33,25 +33,34 @@ def apply_optimizations():
ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
+
+ optimization_method = None
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
+ optimization_method = 'xformers'
elif cmd_opts.opt_sub_quad_attention:
print("Applying sub-quadratic cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
+ optimization_method = 'sub-quadratic'
elif cmd_opts.opt_split_attention_v1:
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
+ optimization_method = 'V1'
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
print("Applying cross attention optimization (InvokeAI).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
+ optimization_method = 'InvokeAI'
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
print("Applying cross attention optimization (Doggettx).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
+ optimization_method = 'Doggettx'
+
+ return optimization_method
def undo_optimizations():
@@ -72,6 +81,7 @@ class StableDiffusionModelHijack:
layers = None
circular_enabled = False
clip = None
+ optimization_method = None
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
@@ -91,7 +101,7 @@ class StableDiffusionModelHijack:
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
- apply_optimizations()
+ self.optimization_method = apply_optimizations()
self.clip = m.cond_stage_model
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
index 3c214a35..31d2c898 100644
--- a/modules/sd_hijack_inpainting.py
+++ b/modules/sd_hijack_inpainting.py
@@ -97,8 +97,11 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
def should_hijack_inpainting(checkpoint_info):
+ from modules import sd_models
+
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
- cfg_basename = os.path.basename(checkpoint_info.config).lower()
+ cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
+
return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
diff --git a/modules/sd_models.py b/modules/sd_models.py
index b98b05fc..76a89e88 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -20,7 +20,7 @@ from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inp
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
-CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config'])
+CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
checkpoints_list = {}
checkpoints_loaded = collections.OrderedDict()
@@ -48,6 +48,14 @@ def checkpoint_tiles():
return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
+def find_checkpoint_config(info):
+ config = os.path.splitext(info.filename)[0] + ".yaml"
+ if os.path.exists(config):
+ return config
+
+ return shared.cmd_opts.config
+
+
def list_models():
checkpoints_list.clear()
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"])
@@ -73,7 +81,7 @@ def list_models():
if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt)
title, short_model_name = modeltitle(cmd_ckpt, h)
- checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config)
+ checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
shared.opts.data['sd_model_checkpoint'] = title
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
@@ -81,12 +89,7 @@ def list_models():
h = model_hash(filename)
title, short_model_name = modeltitle(filename, h)
- basename, _ = os.path.splitext(filename)
- config = basename + ".yaml"
- if not os.path.exists(config):
- config = shared.cmd_opts.config
-
- checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)
+ checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
def get_closet_checkpoint_match(searchString):
@@ -168,7 +171,10 @@ def get_state_dict_from_checkpoint(pl_sd):
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors":
- pl_sd = safetensors.torch.load_file(checkpoint_file, device=map_location or shared.weight_load_location)
+ device = map_location or shared.weight_load_location
+ if device is None:
+ device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu"
+ pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
else:
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
@@ -278,12 +284,14 @@ def enable_midas_autodownload():
midas.api.load_model = load_model_wrapper
+
def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
+ checkpoint_config = find_checkpoint_config(checkpoint_info)
- if checkpoint_info.config != shared.cmd_opts.config:
- print(f"Loading config from: {checkpoint_info.config}")
+ if checkpoint_config != shared.cmd_opts.config:
+ print(f"Loading config from: {checkpoint_config}")
if shared.sd_model:
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
@@ -291,7 +299,7 @@ def load_model(checkpoint_info=None):
gc.collect()
devices.torch_gc()
- sd_config = OmegaConf.load(checkpoint_info.config)
+ sd_config = OmegaConf.load(checkpoint_config)
if should_hijack_inpainting(checkpoint_info):
# Hardcoded config for now...
@@ -300,9 +308,6 @@ def load_model(checkpoint_info=None):
sd_config.model.params.unet_config.params.in_channels = 9
sd_config.model.params.finetune_keys = None
- # Create a "fake" config with a different name so that we know to unload it when switching models.
- checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
-
if not hasattr(sd_config.model.params, "use_ema"):
sd_config.model.params.use_ema = False
@@ -312,6 +317,7 @@ def load_model(checkpoint_info=None):
sd_config.model.params.unet_config.params.use_fp16 = False
sd_model = instantiate_from_config(sd_config.model)
+
load_model_weights(sd_model, checkpoint_info)
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
@@ -336,14 +342,17 @@ def load_model(checkpoint_info=None):
def reload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint()
-
+
if not sd_model:
sd_model = shared.sd_model
+ current_checkpoint_info = sd_model.sd_checkpoint_info
+ checkpoint_config = find_checkpoint_config(current_checkpoint_info)
+
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
- if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
+ if checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
del sd_model
checkpoints_loaded.clear()
load_model(checkpoint_info)
@@ -356,13 +365,19 @@ def reload_model_weights(sd_model=None, info=None):
sd_hijack.model_hijack.undo_hijack(sd_model)
- load_model_weights(sd_model, checkpoint_info)
-
- sd_hijack.model_hijack.hijack(sd_model)
- script_callbacks.model_loaded_callback(sd_model)
-
- if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
- sd_model.to(devices.device)
+ try:
+ load_model_weights(sd_model, checkpoint_info)
+ except Exception as e:
+ print("Failed to load checkpoint, restoring previous")
+ load_model_weights(sd_model, current_checkpoint_info)
+ raise
+ finally:
+ sd_hijack.model_hijack.hijack(sd_model)
+ script_callbacks.model_loaded_callback(sd_model)
+
+ if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
+ sd_model.to(devices.device)
print("Weights loaded.")
+
return sd_model
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index e904d860..3851a77f 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -97,8 +97,9 @@ sampler_extra_params = {
def setup_img2img_steps(p, steps=None):
if opts.img2img_fix_steps or steps is not None:
- steps = int((steps or p.steps) / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
- t_enc = p.steps - 1
+ requested_steps = (steps or p.steps)
+ steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
+ t_enc = requested_steps - 1
else:
steps = p.steps
t_enc = int(min(p.denoising_strength, 0.999) * steps)
diff --git a/modules/shared.py b/modules/shared.py
index 373551da..cb1dc312 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -14,7 +14,7 @@ import modules.interrogate
import modules.memmon
import modules.styles
import modules.devices as devices
-from modules import localization, sd_vae, extensions, script_loading
+from modules import localization, sd_vae, extensions, script_loading, errors
from modules.paths import models_path, script_path, sd_path
@@ -86,6 +86,7 @@ parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencode
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
+parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests")
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
@@ -156,6 +157,7 @@ class State:
job = ""
job_no = 0
job_count = 0
+ processing_has_refined_job_count = False
job_timestamp = '0'
sampling_step = 0
sampling_steps = 0
@@ -186,6 +188,7 @@ class State:
"interrupted": self.interrupted,
"job": self.job,
"job_count": self.job_count,
+ "job_timestamp": self.job_timestamp,
"job_no": self.job_no,
"sampling_step": self.sampling_step,
"sampling_steps": self.sampling_steps,
@@ -196,6 +199,7 @@ class State:
def begin(self):
self.sampling_step = 0
self.job_count = -1
+ self.processing_has_refined_job_count = False
self.job_no = 0
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
self.current_latent = None
@@ -216,12 +220,13 @@ class State:
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
def set_current_image(self):
+ if not parallel_processing_allowed:
+ return
+
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0:
self.do_set_current_image()
def do_set_current_image(self):
- if not parallel_processing_allowed:
- return
if self.current_latent is None:
return
@@ -233,6 +238,7 @@ class State:
self.current_image_sampling_step = self.sampling_step
+
state = State()
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
@@ -359,7 +365,7 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
- "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
+ "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
@@ -498,7 +504,12 @@ class Options:
return False
if self.data_labels[key].onchange is not None:
- self.data_labels[key].onchange()
+ try:
+ self.data_labels[key].onchange()
+ except Exception as e:
+ errors.display(e, f"changing setting {key} to {value}")
+ setattr(self, key, oldval)
+ return False
return True
@@ -563,8 +574,11 @@ if os.path.exists(config_filename):
latent_upscale_default_mode = "Latent"
latent_upscale_modes = {
- "Latent": "bilinear",
- "Latent (nearest)": "nearest",
+ "Latent": {"mode": "bilinear", "antialias": False},
+ "Latent (antialiased)": {"mode": "bilinear", "antialias": True},
+ "Latent (bicubic)": {"mode": "bicubic", "antialias": False},
+ "Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True},
+ "Latent (nearest)": {"mode": "nearest", "antialias": False},
}
sd_upscalers = []
@@ -600,7 +614,7 @@ class TotalTQDM:
return
if self._tqdm is None:
self.reset()
- self._tqdm.total=new_total
+ self._tqdm.total = new_total
def clear(self):
if self._tqdm is not None:
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
index dd0c0ad1..f63fc72f 100644
--- a/modules/textual_inversion/learn_schedule.py
+++ b/modules/textual_inversion/learn_schedule.py
@@ -58,14 +58,19 @@ class LearnRateScheduler:
self.finished = False
- def apply(self, optimizer, step_number):
+ def step(self, step_number):
if step_number < self.end_step:
- return
+ return False
try:
(self.learn_rate, self.end_step) = next(self.schedules)
- except Exception:
+ except StopIteration:
self.finished = True
+ return False
+ return True
+
+ def apply(self, optimizer, step_number):
+ if not self.step(step_number):
return
if self.verbose:
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 56b9b2eb..feb876c6 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -124,6 +124,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
files = listfiles(src)
+ shared.state.job = "preprocess"
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index fd253477..71e07bcc 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -28,6 +28,7 @@ class Embedding:
self.cached_checksum = None
self.sd_checkpoint = None
self.sd_checkpoint_name = None
+ self.optimizer_state_dict = None
def save(self, filename):
embedding_data = {
@@ -41,6 +42,13 @@ class Embedding:
torch.save(embedding_data, filename)
+ if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
+ optimizer_saved_dict = {
+ 'hash': self.checksum(),
+ 'optimizer_state_dict': self.optimizer_state_dict,
+ }
+ torch.save(optimizer_saved_dict, filename + '.optim')
+
def checksum(self):
if self.cached_checksum is not None:
return self.cached_checksum
@@ -95,9 +103,10 @@ class EmbeddingDatabase:
self.expected_shape = self.get_expected_shape()
def process_file(path, filename):
- name = os.path.splitext(filename)[0]
+ name, ext = os.path.splitext(filename)
+ ext = ext.upper()
- if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
+ if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
embed_image = Image.open(path)
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
@@ -105,8 +114,10 @@ class EmbeddingDatabase:
else:
data = extract_image_data_embed(embed_image)
name = data.get('name', name)
- else:
+ elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu")
+ else:
+ return
# textual inversion embeddings
if 'string_to_param' in data:
@@ -240,11 +251,12 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
if save_model_every or create_image_every:
assert log_directory, "Log directory is empty"
-def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
+ shared.state.job = "train-embedding"
shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps
@@ -282,6 +294,11 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
return embedding, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
+ clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
+ torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
+ None
+ if clip_grad:
+ clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
old_parallel_processing_allowed = shared.parallel_processing_allowed
@@ -300,6 +317,19 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
+ if shared.opts.save_optimizer_state:
+ optimizer_state_dict = None
+ if os.path.exists(filename + '.optim'):
+ optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
+ if embedding.checksum() == optimizer_saved_dict.get('hash', None):
+ optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
+
+ if optimizer_state_dict is not None:
+ optimizer.load_state_dict(optimizer_state_dict)
+ print("Loaded existing optimizer from checkpoint")
+ else:
+ print("No saved optimizer exists in checkpoint")
+
scaler = torch.cuda.amp.GradScaler()
batch_size = ds.batch_size
@@ -315,6 +345,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
forced_filename = "<none>"
embedding_yet_to_be_embedded = False
+ is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}
+ img_c = None
+
pbar = tqdm.tqdm(total=steps - initial_step)
try:
for i in range((steps-initial_step) * gradient_step):
@@ -332,14 +365,22 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
if shared.state.interrupted:
break
+ if clip_grad:
+ clip_grad_sched.step(embedding.step)
+
with devices.autocast():
- # c = stack_conds(batch.cond).to(devices.device)
- # mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
- # print(mask)
- # c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory)
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
c = shared.sd_model.cond_stage_model(batch.cond_text)
- loss = shared.sd_model(x, c)[0] / gradient_step
+
+ if is_training_inpainting_model:
+ if img_c is None:
+ img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
+
+ cond = {"c_concat": [img_c], "c_crossattn": [c]}
+ else:
+ cond = c
+
+ loss = shared.sd_model(x, cond)[0] / gradient_step
del x
_loss_step += loss.item()
@@ -348,6 +389,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
# go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0:
continue
+
+ if clip_grad:
+ clip_grad(embedding.vec, clip_grad_sched.learn_rate)
+
scaler.step(optimizer)
scaler.update()
embedding.step += 1
@@ -366,9 +411,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
# Before saving, change name to match current checkpoint.
embedding_name_every = f'{embedding_name}-{steps_done}'
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
- #if shared.opts.save_optimizer_state:
- #embedding.optimizer_state_dict = optimizer.state_dict()
- save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
+ save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
embedding_yet_to_be_embedded = True
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
@@ -458,7 +501,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
- save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
+ save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
except Exception:
print(traceback.format_exc(), file=sys.stderr)
pass
@@ -470,7 +513,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
return embedding, filename
-def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True):
+def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
old_embedding_name = embedding.name
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
@@ -481,6 +524,7 @@ def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cache
if remove_cached_checksum:
embedding.cached_checksum = None
embedding.name = embedding_name
+ embedding.optimizer_state_dict = optimizer.state_dict()
embedding.save(filename)
except:
embedding.sd_checkpoint = old_sd_checkpoint
diff --git a/modules/txt2img.py b/modules/txt2img.py
index e189a899..38b5f591 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -8,7 +8,7 @@ import modules.processing as processing
from modules.ui import plaintext_to_html
-def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, *args):
+def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@@ -35,6 +35,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
denoising_strength=denoising_strength if enable_hr else None,
hr_scale=hr_scale,
hr_upscaler=hr_upscaler,
+ hr_second_pass_steps=hr_second_pass_steps,
+ hr_resize_x=hr_resize_x,
+ hr_resize_y=hr_resize_y,
)
p.scripts = modules.scripts.scripts_txt2img
diff --git a/modules/ui.py b/modules/ui.py
index d941cb5f..04091e67 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -162,16 +162,14 @@ def save_files(js_data, images, do_make_zip, index):
return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
-
-
-def calc_time_left(progress, threshold, label, force_display):
+def calc_time_left(progress, threshold, label, force_display, show_eta):
if progress == 0:
return ""
else:
time_since_start = time.time() - shared.state.time_start
eta = (time_since_start/progress)
eta_relative = eta-time_since_start
- if (eta_relative > threshold and progress > 0.02) or force_display:
+ if (eta_relative > threshold and show_eta) or force_display:
if eta_relative > 3600:
return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
elif eta_relative > 60:
@@ -193,7 +191,10 @@ def check_progress_call(id_part):
if shared.state.sampling_steps > 0:
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
- time_left = calc_time_left( progress, 1, " ETA: ", shared.state.time_left_force_display )
+ # Show progress percentage and time left at the same moment, and base it also on steps done
+ show_eta = progress >= 0.01 or shared.state.sampling_step >= 10
+
+ time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta)
if time_left != "":
shared.state.time_left_force_display = True
@@ -201,7 +202,7 @@ def check_progress_call(id_part):
progressbar = ""
if opts.show_progressbar:
- progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:visible;width:{progress * 100}%;white-space:nowrap;">{"&nbsp;" * 2 + str(int(progress*100))+"%" + time_left if progress > 0.01 else ""}</div></div>"""
+ progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:visible;width:{progress * 100}%;white-space:nowrap;">{"&nbsp;" * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}</div></div>"""
image = gr_show(False)
preview_visibility = gr_show(False)
@@ -635,10 +636,11 @@ def create_sampler_and_steps_selection(choices, tabname):
if opts.samplers_in_dropdown:
with FormRow(elem_id=f"sampler_selection_{tabname}"):
sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
- steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20)
+ sampler_index.save_to_config = True
+ steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
else:
with FormGroup(elem_id=f"sampler_selection_{tabname}"):
- steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20)
+ steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
return steps, sampler_index
@@ -707,10 +709,16 @@ def create_ui():
enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
elif category == "hires_fix":
- with FormRow(visible=False, elem_id="txt2img_hires_fix") as hr_options:
- hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
- hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
- denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
+ with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options:
+ with FormRow(elem_id="txt2img_hires_fix_row1"):
+ hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
+ hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
+ denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
+
+ with FormRow(elem_id="txt2img_hires_fix_row2"):
+ hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
+ hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
+ hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
elif category == "batch":
if not opts.dimensions_and_batch_together:
@@ -751,6 +759,9 @@ def create_ui():
denoising_strength,
hr_scale,
hr_upscaler,
+ hr_second_pass_steps,
+ hr_resize_x,
+ hr_resize_y,
] + custom_inputs,
outputs=[
@@ -802,6 +813,9 @@ def create_ui():
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
(hr_scale, "Hires upscale"),
(hr_upscaler, "Hires upscaler"),
+ (hr_second_pass_steps, "Hires steps"),
+ (hr_resize_x, "Hires resize-1"),
+ (hr_resize_y, "Hires resize-2"),
*modules.scripts.scripts_txt2img.infotext_fields
]
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
@@ -1279,38 +1293,48 @@ def create_ui():
with gr.Tab(label="Train"):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
- with gr.Row():
+ with FormRow():
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
- with gr.Row():
+
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
- with gr.Row():
+
+ with FormRow():
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
+
+ with FormRow():
+ clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
+ clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False)
+
+ with FormRow():
+ batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size")
+ gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step")
- batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size")
- gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step")
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file")
training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
- create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every")
- save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every")
+
+ with FormRow():
+ create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every")
+ save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every")
+
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding")
preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img")
- with gr.Row():
- shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags")
- tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out")
- with gr.Row():
- latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method")
+
+ shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags")
+ tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out")
+
+ latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method")
with gr.Row():
+ train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding")
interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training")
train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork")
- train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding")
params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
@@ -1400,6 +1424,8 @@ def create_ui():
training_width,
training_height,
steps,
+ clip_grad_mode,
+ clip_grad_value,
shuffle_tags,
tag_drop_out,
latent_sampling_method,
@@ -1429,6 +1455,8 @@ def create_ui():
training_width,
training_height,
steps,
+ clip_grad_mode,
+ clip_grad_value,
shuffle_tags,
tag_drop_out,
latent_sampling_method,
@@ -1793,6 +1821,7 @@ def create_ui():
visit(img2img_interface, loadsave, "img2img")
visit(extras_interface, loadsave, "extras")
visit(modelmerger_interface, loadsave, "modelmerger")
+ visit(train_interface, loadsave, "train")
if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
with open(ui_config_file, "w", encoding="utf8") as file: