aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py138
-rw-r--r--modules/api/models.py4
-rw-r--r--modules/call_queue.py1
-rw-r--r--modules/codeformer/codeformer_arch.py276
-rw-r--r--modules/codeformer/vqgan_arch.py435
-rw-r--r--modules/codeformer_model.py158
-rw-r--r--modules/devices.py61
-rw-r--r--modules/errors.py4
-rw-r--r--modules/esrgan_model.py199
-rw-r--r--modules/esrgan_model_arch.py465
-rw-r--r--modules/face_restoration_utils.py180
-rw-r--r--modules/gfpgan_model.py166
-rw-r--r--modules/hat_model.py43
-rw-r--r--modules/images.py14
-rw-r--r--modules/img2img.py4
-rw-r--r--modules/infotext.py (renamed from modules/generation_parameters_copypaste.py)98
-rw-r--r--modules/infotext_versions.py39
-rw-r--r--modules/initialize.py3
-rw-r--r--modules/initialize_util.py2
-rw-r--r--modules/interrogate.py4
-rw-r--r--modules/launch_utils.py13
-rw-r--r--modules/modelloader.py84
-rw-r--r--modules/paths.py1
-rw-r--r--modules/paths_internal.py1
-rw-r--r--modules/postprocessing.py4
-rw-r--r--modules/processing.py180
-rw-r--r--modules/processing_scripts/refiner.py7
-rw-r--r--modules/processing_scripts/seed.py13
-rw-r--r--modules/progress.py22
-rw-r--r--modules/realesrgan_model.py158
-rw-r--r--modules/scripts.py81
-rw-r--r--modules/sd_disable_initialization.py2
-rw-r--r--modules/sd_models.py55
-rw-r--r--modules/sd_models_config.py6
-rw-r--r--modules/sd_models_xl.py11
-rw-r--r--modules/sd_samplers_cfg_denoiser.py21
-rw-r--r--modules/sd_samplers_timesteps.py3
-rw-r--r--modules/shared_items.py4
-rw-r--r--modules/shared_options.py28
-rw-r--r--modules/shared_state.py7
-rw-r--r--modules/styles.py78
-rw-r--r--modules/sysinfo.py2
-rw-r--r--modules/textual_inversion/textual_inversion.py10
-rw-r--r--modules/torch_utils.py17
-rw-r--r--modules/txt2img.py2
-rw-r--r--modules/ui.py53
-rw-r--r--modules/ui_common.py4
-rw-r--r--modules/ui_extra_networks.py7
-rw-r--r--modules/ui_extra_networks_user_metadata.py4
-rw-r--r--modules/ui_gradio_extensions.py11
-rw-r--r--modules/ui_loadsave.py2
-rw-r--r--modules/ui_postprocessing.py2
-rw-r--r--modules/ui_toprow.py12
-rw-r--r--modules/upscaler.py3
-rw-r--r--modules/upscaler_utils.py140
-rw-r--r--modules/util.py12
-rw-r--r--modules/xlmr.py5
-rw-r--r--modules/xlmr_m18.py4
-rw-r--r--modules/xpu_specific.py74
59 files changed, 1466 insertions, 1971 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index b3d74e51..0e2807de 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -17,7 +17,7 @@ from fastapi.encoders import jsonable_encoder
from secrets import compare_digest
import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext, sd_models
from modules.api import models
from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
@@ -31,7 +31,7 @@ from typing import Any
import piexif
import piexif.helper
from contextlib import closing
-
+from modules.progress import create_task_id, add_task_to_queue, start_task, finish_task, current_task
def script_name_to_index(name, scripts):
try:
@@ -251,6 +251,24 @@ class Api:
self.default_script_arg_txt2img = []
self.default_script_arg_img2img = []
+ txt2img_script_runner = scripts.scripts_txt2img
+ img2img_script_runner = scripts.scripts_img2img
+
+ if not txt2img_script_runner.scripts or not img2img_script_runner.scripts:
+ ui.create_ui()
+
+ if not txt2img_script_runner.scripts:
+ txt2img_script_runner.initialize_scripts(False)
+ if not self.default_script_arg_txt2img:
+ self.default_script_arg_txt2img = self.init_default_script_args(txt2img_script_runner)
+
+ if not img2img_script_runner.scripts:
+ img2img_script_runner.initialize_scripts(True)
+ if not self.default_script_arg_img2img:
+ self.default_script_arg_img2img = self.init_default_script_args(img2img_script_runner)
+
+
+
def add_api_route(self, path: str, endpoint, **kwargs):
if shared.cmd_opts.api_auth:
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
@@ -312,8 +330,13 @@ class Api:
script_args[script.args_from:script.args_to] = ui_default_values
return script_args
- def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner):
+ def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner, *, input_script_args=None):
script_args = default_script_args.copy()
+
+ if input_script_args is not None:
+ for index, value in input_script_args.items():
+ script_args[index] = value
+
# position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
if selectable_scripts:
script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
@@ -335,13 +358,83 @@ class Api:
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
return script_args
+ def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_script_args=None):
+ """Processes `infotext` field from the `request`, and sets other fields of the `request` accoring to what's in infotext.
+
+ If request already has a field set, and that field is encountered in infotext too, the value from infotext is ignored.
+
+ Additionally, fills `mentioned_script_args` dict with index: value pairs for script arguments read from infotext.
+ """
+
+ if not request.infotext:
+ return {}
+
+ possible_fields = infotext.paste_fields[tabname]["fields"]
+ set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have differenrt names for this
+ params = infotext.parse_generation_parameters(request.infotext)
+
+ def get_field_value(field, params):
+ value = field.function(params) if field.function else params.get(field.label)
+ if value is None:
+ return None
+
+ if field.api in request.__fields__:
+ target_type = request.__fields__[field.api].type_
+ else:
+ target_type = type(field.component.value)
+
+ if target_type == type(None):
+ return None
+
+ if isinstance(value, dict) and value.get('__type__') == 'generic_update': # this is a gradio.update rather than a value
+ value = value.get('value')
+
+ if value is not None and not isinstance(value, target_type):
+ value = target_type(value)
+
+ return value
+
+ for field in possible_fields:
+ if not field.api:
+ continue
+
+ if field.api in set_fields:
+ continue
+
+ value = get_field_value(field, params)
+ if value is not None:
+ setattr(request, field.api, value)
+
+ if request.override_settings is None:
+ request.override_settings = {}
+
+ overriden_settings = infotext.get_override_settings(params)
+ for _, setting_name, value in overriden_settings:
+ if setting_name not in request.override_settings:
+ request.override_settings[setting_name] = value
+
+ if script_runner is not None and mentioned_script_args is not None:
+ indexes = {v: i for i, v in enumerate(script_runner.inputs)}
+ script_fields = ((field, indexes[field.component]) for field in possible_fields if field.component in indexes)
+
+ for field, index in script_fields:
+ value = get_field_value(field, params)
+
+ if value is None:
+ continue
+
+ mentioned_script_args[index] = value
+
+ return params
+
def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
+ task_id = txt2imgreq.force_task_id or create_task_id("txt2img")
+
script_runner = scripts.scripts_txt2img
- if not script_runner.scripts:
- script_runner.initialize_scripts(False)
- ui.create_ui()
- if not self.default_script_arg_txt2img:
- self.default_script_arg_txt2img = self.init_default_script_args(script_runner)
+
+ infotext_script_args = {}
+ self.apply_infotext(txt2imgreq, "txt2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
+
selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)
populate = txt2imgreq.copy(update={ # Override __init__ params
@@ -356,12 +449,15 @@ class Api:
args.pop('script_name', None)
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
args.pop('alwayson_scripts', None)
+ args.pop('infotext', None)
- script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner)
+ script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)
send_images = args.pop('send_images', True)
args.pop('save_images', None)
+ add_task_to_queue(task_id)
+
with self.queue_lock:
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
p.is_api = True
@@ -371,12 +467,14 @@ class Api:
try:
shared.state.begin(job="scripts_txt2img")
+ start_task(task_id)
if selectable_scripts is not None:
p.script_args = script_args
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
else:
p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p)
+ finish_task(task_id)
finally:
shared.state.end()
shared.total_tqdm.clear()
@@ -386,6 +484,8 @@ class Api:
return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
+ task_id = img2imgreq.force_task_id or create_task_id("img2img")
+
init_images = img2imgreq.init_images
if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found")
@@ -395,11 +495,10 @@ class Api:
mask = decode_base64_to_image(mask)
script_runner = scripts.scripts_img2img
- if not script_runner.scripts:
- script_runner.initialize_scripts(True)
- ui.create_ui()
- if not self.default_script_arg_img2img:
- self.default_script_arg_img2img = self.init_default_script_args(script_runner)
+
+ infotext_script_args = {}
+ self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
+
selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
populate = img2imgreq.copy(update={ # Override __init__ params
@@ -416,12 +515,15 @@ class Api:
args.pop('script_name', None)
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
args.pop('alwayson_scripts', None)
+ args.pop('infotext', None)
- script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner)
+ script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)
send_images = args.pop('send_images', True)
args.pop('save_images', None)
+ add_task_to_queue(task_id)
+
with self.queue_lock:
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
p.init_images = [decode_base64_to_image(x) for x in init_images]
@@ -432,12 +534,14 @@ class Api:
try:
shared.state.begin(job="scripts_img2img")
+ start_task(task_id)
if selectable_scripts is not None:
p.script_args = script_args
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
else:
p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p)
+ finish_task(task_id)
finally:
shared.state.end()
shared.total_tqdm.clear()
@@ -480,7 +584,7 @@ class Api:
if geninfo is None:
geninfo = ""
- params = generation_parameters_copypaste.parse_generation_parameters(geninfo)
+ params = infotext.parse_generation_parameters(geninfo)
script_callbacks.infotext_pasted_callback(geninfo, params)
return models.PNGInfoResponse(info=geninfo, items=items, parameters=params)
@@ -511,7 +615,7 @@ class Api:
if shared.state.current_image and not req.skip_current_image:
current_image = encode_pil_to_base64(shared.state.current_image)
- return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
+ return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo, current_task=current_task)
def interrogateapi(self, interrogatereq: models.InterrogateRequest):
image_b64 = interrogatereq.image
diff --git a/modules/api/models.py b/modules/api/models.py
index 33894b3e..16edf11c 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -107,6 +107,8 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
{"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}},
+ {"key": "force_task_id", "type": str, "default": None},
+ {"key": "infotext", "type": str, "default": None},
]
).generate_model()
@@ -124,6 +126,8 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
{"key": "send_images", "type": bool, "default": True},
{"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}},
+ {"key": "force_task_id", "type": str, "default": None},
+ {"key": "infotext", "type": str, "default": None},
]
).generate_model()
diff --git a/modules/call_queue.py b/modules/call_queue.py
index ddf0d573..bcd7c546 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -78,6 +78,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
shared.state.skipped = False
shared.state.interrupted = False
+ shared.state.stopping_generation = False
shared.state.job_count = 0
if not add_stats:
diff --git a/modules/codeformer/codeformer_arch.py b/modules/codeformer/codeformer_arch.py
deleted file mode 100644
index 12db6814..00000000
--- a/modules/codeformer/codeformer_arch.py
+++ /dev/null
@@ -1,276 +0,0 @@
-# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
-
-import math
-import torch
-from torch import nn, Tensor
-import torch.nn.functional as F
-from typing import Optional
-
-from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock
-from basicsr.utils.registry import ARCH_REGISTRY
-
-def calc_mean_std(feat, eps=1e-5):
- """Calculate mean and std for adaptive_instance_normalization.
-
- Args:
- feat (Tensor): 4D tensor.
- eps (float): A small value added to the variance to avoid
- divide-by-zero. Default: 1e-5.
- """
- size = feat.size()
- assert len(size) == 4, 'The input feature should be 4D tensor.'
- b, c = size[:2]
- feat_var = feat.view(b, c, -1).var(dim=2) + eps
- feat_std = feat_var.sqrt().view(b, c, 1, 1)
- feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
- return feat_mean, feat_std
-
-
-def adaptive_instance_normalization(content_feat, style_feat):
- """Adaptive instance normalization.
-
- Adjust the reference features to have the similar color and illuminations
- as those in the degradate features.
-
- Args:
- content_feat (Tensor): The reference feature.
- style_feat (Tensor): The degradate features.
- """
- size = content_feat.size()
- style_mean, style_std = calc_mean_std(style_feat)
- content_mean, content_std = calc_mean_std(content_feat)
- normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
- return normalized_feat * style_std.expand(size) + style_mean.expand(size)
-
-
-class PositionEmbeddingSine(nn.Module):
- """
- This is a more standard version of the position embedding, very similar to the one
- used by the Attention is all you need paper, generalized to work on images.
- """
-
- def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
- super().__init__()
- self.num_pos_feats = num_pos_feats
- self.temperature = temperature
- self.normalize = normalize
- if scale is not None and normalize is False:
- raise ValueError("normalize should be True if scale is passed")
- if scale is None:
- scale = 2 * math.pi
- self.scale = scale
-
- def forward(self, x, mask=None):
- if mask is None:
- mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
- not_mask = ~mask
- y_embed = not_mask.cumsum(1, dtype=torch.float32)
- x_embed = not_mask.cumsum(2, dtype=torch.float32)
- if self.normalize:
- eps = 1e-6
- y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
- x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
-
- dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
- dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
-
- pos_x = x_embed[:, :, :, None] / dim_t
- pos_y = y_embed[:, :, :, None] / dim_t
- pos_x = torch.stack(
- (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
- ).flatten(3)
- pos_y = torch.stack(
- (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
- ).flatten(3)
- pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
- return pos
-
-def _get_activation_fn(activation):
- """Return an activation function given a string"""
- if activation == "relu":
- return F.relu
- if activation == "gelu":
- return F.gelu
- if activation == "glu":
- return F.glu
- raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
-
-
-class TransformerSALayer(nn.Module):
- def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
- super().__init__()
- self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
- # Implementation of Feedforward model - MLP
- self.linear1 = nn.Linear(embed_dim, dim_mlp)
- self.dropout = nn.Dropout(dropout)
- self.linear2 = nn.Linear(dim_mlp, embed_dim)
-
- self.norm1 = nn.LayerNorm(embed_dim)
- self.norm2 = nn.LayerNorm(embed_dim)
- self.dropout1 = nn.Dropout(dropout)
- self.dropout2 = nn.Dropout(dropout)
-
- self.activation = _get_activation_fn(activation)
-
- def with_pos_embed(self, tensor, pos: Optional[Tensor]):
- return tensor if pos is None else tensor + pos
-
- def forward(self, tgt,
- tgt_mask: Optional[Tensor] = None,
- tgt_key_padding_mask: Optional[Tensor] = None,
- query_pos: Optional[Tensor] = None):
-
- # self attention
- tgt2 = self.norm1(tgt)
- q = k = self.with_pos_embed(tgt2, query_pos)
- tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
- key_padding_mask=tgt_key_padding_mask)[0]
- tgt = tgt + self.dropout1(tgt2)
-
- # ffn
- tgt2 = self.norm2(tgt)
- tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
- tgt = tgt + self.dropout2(tgt2)
- return tgt
-
-class Fuse_sft_block(nn.Module):
- def __init__(self, in_ch, out_ch):
- super().__init__()
- self.encode_enc = ResBlock(2*in_ch, out_ch)
-
- self.scale = nn.Sequential(
- nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
- nn.LeakyReLU(0.2, True),
- nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
-
- self.shift = nn.Sequential(
- nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
- nn.LeakyReLU(0.2, True),
- nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
-
- def forward(self, enc_feat, dec_feat, w=1):
- enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
- scale = self.scale(enc_feat)
- shift = self.shift(enc_feat)
- residual = w * (dec_feat * scale + shift)
- out = dec_feat + residual
- return out
-
-
-@ARCH_REGISTRY.register()
-class CodeFormer(VQAutoEncoder):
- def __init__(self, dim_embd=512, n_head=8, n_layers=9,
- codebook_size=1024, latent_size=256,
- connect_list=('32', '64', '128', '256'),
- fix_modules=('quantize', 'generator')):
- super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
-
- if fix_modules is not None:
- for module in fix_modules:
- for param in getattr(self, module).parameters():
- param.requires_grad = False
-
- self.connect_list = connect_list
- self.n_layers = n_layers
- self.dim_embd = dim_embd
- self.dim_mlp = dim_embd*2
-
- self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
- self.feat_emb = nn.Linear(256, self.dim_embd)
-
- # transformer
- self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
- for _ in range(self.n_layers)])
-
- # logits_predict head
- self.idx_pred_layer = nn.Sequential(
- nn.LayerNorm(dim_embd),
- nn.Linear(dim_embd, codebook_size, bias=False))
-
- self.channels = {
- '16': 512,
- '32': 256,
- '64': 256,
- '128': 128,
- '256': 128,
- '512': 64,
- }
-
- # after second residual block for > 16, before attn layer for ==16
- self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
- # after first residual block for > 16, before attn layer for ==16
- self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
-
- # fuse_convs_dict
- self.fuse_convs_dict = nn.ModuleDict()
- for f_size in self.connect_list:
- in_ch = self.channels[f_size]
- self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
-
- def _init_weights(self, module):
- if isinstance(module, (nn.Linear, nn.Embedding)):
- module.weight.data.normal_(mean=0.0, std=0.02)
- if isinstance(module, nn.Linear) and module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.LayerNorm):
- module.bias.data.zero_()
- module.weight.data.fill_(1.0)
-
- def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
- # ################### Encoder #####################
- enc_feat_dict = {}
- out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
- for i, block in enumerate(self.encoder.blocks):
- x = block(x)
- if i in out_list:
- enc_feat_dict[str(x.shape[-1])] = x.clone()
-
- lq_feat = x
- # ################# Transformer ###################
- # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
- pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
- # BCHW -> BC(HW) -> (HW)BC
- feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
- query_emb = feat_emb
- # Transformer encoder
- for layer in self.ft_layers:
- query_emb = layer(query_emb, query_pos=pos_emb)
-
- # output logits
- logits = self.idx_pred_layer(query_emb) # (hw)bn
- logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
-
- if code_only: # for training stage II
- # logits doesn't need softmax before cross_entropy loss
- return logits, lq_feat
-
- # ################# Quantization ###################
- # if self.training:
- # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
- # # b(hw)c -> bc(hw) -> bchw
- # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
- # ------------
- soft_one_hot = F.softmax(logits, dim=2)
- _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
- quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
- # preserve gradients
- # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
-
- if detach_16:
- quant_feat = quant_feat.detach() # for training stage III
- if adain:
- quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
-
- # ################## Generator ####################
- x = quant_feat
- fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
-
- for i, block in enumerate(self.generator.blocks):
- x = block(x)
- if i in fuse_list: # fuse after i-th block
- f_size = str(x.shape[-1])
- if w>0:
- x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
- out = x
- # logits doesn't need softmax before cross_entropy loss
- return out, logits, lq_feat
diff --git a/modules/codeformer/vqgan_arch.py b/modules/codeformer/vqgan_arch.py
deleted file mode 100644
index 09ee6660..00000000
--- a/modules/codeformer/vqgan_arch.py
+++ /dev/null
@@ -1,435 +0,0 @@
-# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
-
-'''
-VQGAN code, adapted from the original created by the Unleashing Transformers authors:
-https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
-
-'''
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from basicsr.utils import get_root_logger
-from basicsr.utils.registry import ARCH_REGISTRY
-
-def normalize(in_channels):
- return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
-
-
-@torch.jit.script
-def swish(x):
- return x*torch.sigmoid(x)
-
-
-# Define VQVAE classes
-class VectorQuantizer(nn.Module):
- def __init__(self, codebook_size, emb_dim, beta):
- super(VectorQuantizer, self).__init__()
- self.codebook_size = codebook_size # number of embeddings
- self.emb_dim = emb_dim # dimension of embedding
- self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
- self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
- self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
-
- def forward(self, z):
- # reshape z -> (batch, height, width, channel) and flatten
- z = z.permute(0, 2, 3, 1).contiguous()
- z_flattened = z.view(-1, self.emb_dim)
-
- # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
- d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
- 2 * torch.matmul(z_flattened, self.embedding.weight.t())
-
- mean_distance = torch.mean(d)
- # find closest encodings
- # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
- min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
- # [0-1], higher score, higher confidence
- min_encoding_scores = torch.exp(-min_encoding_scores/10)
-
- min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
- min_encodings.scatter_(1, min_encoding_indices, 1)
-
- # get quantized latent vectors
- z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
- # compute loss for embedding
- loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
- # preserve gradients
- z_q = z + (z_q - z).detach()
-
- # perplexity
- e_mean = torch.mean(min_encodings, dim=0)
- perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
- # reshape back to match original input shape
- z_q = z_q.permute(0, 3, 1, 2).contiguous()
-
- return z_q, loss, {
- "perplexity": perplexity,
- "min_encodings": min_encodings,
- "min_encoding_indices": min_encoding_indices,
- "min_encoding_scores": min_encoding_scores,
- "mean_distance": mean_distance
- }
-
- def get_codebook_feat(self, indices, shape):
- # input indices: batch*token_num -> (batch*token_num)*1
- # shape: batch, height, width, channel
- indices = indices.view(-1,1)
- min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
- min_encodings.scatter_(1, indices, 1)
- # get quantized latent vectors
- z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
-
- if shape is not None: # reshape back to match original input shape
- z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
-
- return z_q
-
-
-class GumbelQuantizer(nn.Module):
- def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
- super().__init__()
- self.codebook_size = codebook_size # number of embeddings
- self.emb_dim = emb_dim # dimension of embedding
- self.straight_through = straight_through
- self.temperature = temp_init
- self.kl_weight = kl_weight
- self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
- self.embed = nn.Embedding(codebook_size, emb_dim)
-
- def forward(self, z):
- hard = self.straight_through if self.training else True
-
- logits = self.proj(z)
-
- soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
-
- z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
-
- # + kl divergence to the prior loss
- qy = F.softmax(logits, dim=1)
- diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
- min_encoding_indices = soft_one_hot.argmax(dim=1)
-
- return z_q, diff, {
- "min_encoding_indices": min_encoding_indices
- }
-
-
-class Downsample(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
-
- def forward(self, x):
- pad = (0, 1, 0, 1)
- x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
- x = self.conv(x)
- return x
-
-
-class Upsample(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
-
- def forward(self, x):
- x = F.interpolate(x, scale_factor=2.0, mode="nearest")
- x = self.conv(x)
-
- return x
-
-
-class ResBlock(nn.Module):
- def __init__(self, in_channels, out_channels=None):
- super(ResBlock, self).__init__()
- self.in_channels = in_channels
- self.out_channels = in_channels if out_channels is None else out_channels
- self.norm1 = normalize(in_channels)
- self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
- self.norm2 = normalize(out_channels)
- self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
- if self.in_channels != self.out_channels:
- self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
-
- def forward(self, x_in):
- x = x_in
- x = self.norm1(x)
- x = swish(x)
- x = self.conv1(x)
- x = self.norm2(x)
- x = swish(x)
- x = self.conv2(x)
- if self.in_channels != self.out_channels:
- x_in = self.conv_out(x_in)
-
- return x + x_in
-
-
-class AttnBlock(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.in_channels = in_channels
-
- self.norm = normalize(in_channels)
- self.q = torch.nn.Conv2d(
- in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0
- )
- self.k = torch.nn.Conv2d(
- in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0
- )
- self.v = torch.nn.Conv2d(
- in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0
- )
- self.proj_out = torch.nn.Conv2d(
- in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0
- )
-
- def forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q = self.q(h_)
- k = self.k(h_)
- v = self.v(h_)
-
- # compute attention
- b, c, h, w = q.shape
- q = q.reshape(b, c, h*w)
- q = q.permute(0, 2, 1)
- k = k.reshape(b, c, h*w)
- w_ = torch.bmm(q, k)
- w_ = w_ * (int(c)**(-0.5))
- w_ = F.softmax(w_, dim=2)
-
- # attend to values
- v = v.reshape(b, c, h*w)
- w_ = w_.permute(0, 2, 1)
- h_ = torch.bmm(v, w_)
- h_ = h_.reshape(b, c, h, w)
-
- h_ = self.proj_out(h_)
-
- return x+h_
-
-
-class Encoder(nn.Module):
- def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
- super().__init__()
- self.nf = nf
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- self.resolution = resolution
- self.attn_resolutions = attn_resolutions
-
- curr_res = self.resolution
- in_ch_mult = (1,)+tuple(ch_mult)
-
- blocks = []
- # initial convultion
- blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
-
- # residual and downsampling blocks, with attention on smaller res (16x16)
- for i in range(self.num_resolutions):
- block_in_ch = nf * in_ch_mult[i]
- block_out_ch = nf * ch_mult[i]
- for _ in range(self.num_res_blocks):
- blocks.append(ResBlock(block_in_ch, block_out_ch))
- block_in_ch = block_out_ch
- if curr_res in attn_resolutions:
- blocks.append(AttnBlock(block_in_ch))
-
- if i != self.num_resolutions - 1:
- blocks.append(Downsample(block_in_ch))
- curr_res = curr_res // 2
-
- # non-local attention block
- blocks.append(ResBlock(block_in_ch, block_in_ch))
- blocks.append(AttnBlock(block_in_ch))
- blocks.append(ResBlock(block_in_ch, block_in_ch))
-
- # normalise and convert to latent size
- blocks.append(normalize(block_in_ch))
- blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
- self.blocks = nn.ModuleList(blocks)
-
- def forward(self, x):
- for block in self.blocks:
- x = block(x)
-
- return x
-
-
-class Generator(nn.Module):
- def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
- super().__init__()
- self.nf = nf
- self.ch_mult = ch_mult
- self.num_resolutions = len(self.ch_mult)
- self.num_res_blocks = res_blocks
- self.resolution = img_size
- self.attn_resolutions = attn_resolutions
- self.in_channels = emb_dim
- self.out_channels = 3
- block_in_ch = self.nf * self.ch_mult[-1]
- curr_res = self.resolution // 2 ** (self.num_resolutions-1)
-
- blocks = []
- # initial conv
- blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
-
- # non-local attention block
- blocks.append(ResBlock(block_in_ch, block_in_ch))
- blocks.append(AttnBlock(block_in_ch))
- blocks.append(ResBlock(block_in_ch, block_in_ch))
-
- for i in reversed(range(self.num_resolutions)):
- block_out_ch = self.nf * self.ch_mult[i]
-
- for _ in range(self.num_res_blocks):
- blocks.append(ResBlock(block_in_ch, block_out_ch))
- block_in_ch = block_out_ch
-
- if curr_res in self.attn_resolutions:
- blocks.append(AttnBlock(block_in_ch))
-
- if i != 0:
- blocks.append(Upsample(block_in_ch))
- curr_res = curr_res * 2
-
- blocks.append(normalize(block_in_ch))
- blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
-
- self.blocks = nn.ModuleList(blocks)
-
-
- def forward(self, x):
- for block in self.blocks:
- x = block(x)
-
- return x
-
-
-@ARCH_REGISTRY.register()
-class VQAutoEncoder(nn.Module):
- def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
- beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
- super().__init__()
- logger = get_root_logger()
- self.in_channels = 3
- self.nf = nf
- self.n_blocks = res_blocks
- self.codebook_size = codebook_size
- self.embed_dim = emb_dim
- self.ch_mult = ch_mult
- self.resolution = img_size
- self.attn_resolutions = attn_resolutions or [16]
- self.quantizer_type = quantizer
- self.encoder = Encoder(
- self.in_channels,
- self.nf,
- self.embed_dim,
- self.ch_mult,
- self.n_blocks,
- self.resolution,
- self.attn_resolutions
- )
- if self.quantizer_type == "nearest":
- self.beta = beta #0.25
- self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
- elif self.quantizer_type == "gumbel":
- self.gumbel_num_hiddens = emb_dim
- self.straight_through = gumbel_straight_through
- self.kl_weight = gumbel_kl_weight
- self.quantize = GumbelQuantizer(
- self.codebook_size,
- self.embed_dim,
- self.gumbel_num_hiddens,
- self.straight_through,
- self.kl_weight
- )
- self.generator = Generator(
- self.nf,
- self.embed_dim,
- self.ch_mult,
- self.n_blocks,
- self.resolution,
- self.attn_resolutions
- )
-
- if model_path is not None:
- chkpt = torch.load(model_path, map_location='cpu')
- if 'params_ema' in chkpt:
- self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
- logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
- elif 'params' in chkpt:
- self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
- logger.info(f'vqgan is loaded from: {model_path} [params]')
- else:
- raise ValueError('Wrong params!')
-
-
- def forward(self, x):
- x = self.encoder(x)
- quant, codebook_loss, quant_stats = self.quantize(x)
- x = self.generator(quant)
- return x, codebook_loss, quant_stats
-
-
-
-# patch based discriminator
-@ARCH_REGISTRY.register()
-class VQGANDiscriminator(nn.Module):
- def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
- super().__init__()
-
- layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
- ndf_mult = 1
- ndf_mult_prev = 1
- for n in range(1, n_layers): # gradually increase the number of filters
- ndf_mult_prev = ndf_mult
- ndf_mult = min(2 ** n, 8)
- layers += [
- nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
- nn.BatchNorm2d(ndf * ndf_mult),
- nn.LeakyReLU(0.2, True)
- ]
-
- ndf_mult_prev = ndf_mult
- ndf_mult = min(2 ** n_layers, 8)
-
- layers += [
- nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
- nn.BatchNorm2d(ndf * ndf_mult),
- nn.LeakyReLU(0.2, True)
- ]
-
- layers += [
- nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
- self.main = nn.Sequential(*layers)
-
- if model_path is not None:
- chkpt = torch.load(model_path, map_location='cpu')
- if 'params_d' in chkpt:
- self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
- elif 'params' in chkpt:
- self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
- else:
- raise ValueError('Wrong params!')
-
- def forward(self, x):
- return self.main(x)
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index da42b5e9..44b84618 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -1,132 +1,64 @@
-import os
+from __future__ import annotations
-import cv2
-import torch
-
-import modules.face_restoration
-import modules.shared
-from modules import shared, devices, modelloader, errors
-from modules.paths import models_path
-
-# codeformer people made a choice to include modified basicsr library to their project which makes
-# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
-# I am making a choice to include some files from codeformer to work around this issue.
-model_dir = "Codeformer"
-model_path = os.path.join(models_path, model_dir)
-model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
-
-codeformer = None
-
-
-def setup_model(dirname):
- os.makedirs(model_path, exist_ok=True)
-
- path = modules.paths.paths.get("CodeFormer", None)
- if path is None:
- return
-
- try:
- from torchvision.transforms.functional import normalize
- from modules.codeformer.codeformer_arch import CodeFormer
- from basicsr.utils import img2tensor, tensor2img
- from facelib.utils.face_restoration_helper import FaceRestoreHelper
- from facelib.detection.retinaface import retinaface
-
- net_class = CodeFormer
-
- class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
- def name(self):
- return "CodeFormer"
-
- def __init__(self, dirname):
- self.net = None
- self.face_helper = None
- self.cmd_dir = dirname
+import logging
- def create_models(self):
-
- if self.net is not None and self.face_helper is not None:
- self.net.to(devices.device_codeformer)
- return self.net, self.face_helper
- model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth'])
- if len(model_paths) != 0:
- ckpt_path = model_paths[0]
- else:
- print("Unable to load codeformer model.")
- return None, None
- net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
- checkpoint = torch.load(ckpt_path)['params_ema']
- net.load_state_dict(checkpoint)
- net.eval()
-
- if hasattr(retinaface, 'device'):
- retinaface.device = devices.device_codeformer
- face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
-
- self.net = net
- self.face_helper = face_helper
-
- return net, face_helper
-
- def send_model_to(self, device):
- self.net.to(device)
- self.face_helper.face_det.to(device)
- self.face_helper.face_parse.to(device)
-
- def restore(self, np_image, w=None):
- np_image = np_image[:, :, ::-1]
-
- original_resolution = np_image.shape[0:2]
+import torch
- self.create_models()
- if self.net is None or self.face_helper is None:
- return np_image
+from modules import (
+ devices,
+ errors,
+ face_restoration,
+ face_restoration_utils,
+ modelloader,
+ shared,
+)
- self.send_model_to(devices.device_codeformer)
+logger = logging.getLogger(__name__)
- self.face_helper.clean_all()
- self.face_helper.read_image(np_image)
- self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
- self.face_helper.align_warp_face()
+model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
+model_download_name = 'codeformer-v0.1.0.pth'
- for cropped_face in self.face_helper.cropped_faces:
- cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
- normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
- cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
+# used by e.g. postprocessing_codeformer.py
+codeformer: face_restoration.FaceRestoration | None = None
- try:
- with torch.no_grad():
- output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
- restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
- del output
- devices.torch_gc()
- except Exception:
- errors.report('Failed inference for CodeFormer', exc_info=True)
- restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
- restored_face = restored_face.astype('uint8')
- self.face_helper.add_restored_face(restored_face)
+class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
+ def name(self):
+ return "CodeFormer"
- self.face_helper.get_inverse_affine(None)
+ def load_net(self) -> torch.Module:
+ for model_path in modelloader.load_models(
+ model_path=self.model_path,
+ model_url=model_url,
+ command_path=self.model_path,
+ download_name=model_download_name,
+ ext_filter=['.pth'],
+ ):
+ return modelloader.load_spandrel_model(
+ model_path,
+ device=devices.device_codeformer,
+ expected_architecture='CodeFormer',
+ ).model
+ raise ValueError("No codeformer model found")
- restored_img = self.face_helper.paste_faces_to_input_image()
- restored_img = restored_img[:, :, ::-1]
+ def get_device(self):
+ return devices.device_codeformer
- if original_resolution != restored_img.shape[0:2]:
- restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
+ def restore(self, np_image, w: float | None = None):
+ if w is None:
+ w = getattr(shared.opts, "code_former_weight", 0.5)
- self.face_helper.clean_all()
+ def restore_face(cropped_face_t):
+ assert self.net is not None
+ return self.net(cropped_face_t, w=w, adain=True)[0]
- if shared.opts.face_restoration_unload:
- self.send_model_to(devices.cpu)
+ return self.restore_with_helper(np_image, restore_face)
- return restored_img
- global codeformer
+def setup_model(dirname: str) -> None:
+ global codeformer
+ try:
codeformer = FaceRestorerCodeFormer(dirname)
shared.face_restorers.append(codeformer)
-
except Exception:
errors.report("Error setting up CodeFormer", exc_info=True)
-
- # sys.path = stored_sys_path
diff --git a/modules/devices.py b/modules/devices.py
index ea1f712f..ff279ac5 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -4,6 +4,7 @@ from functools import lru_cache
import torch
from modules import errors, shared
+from modules import torch_utils
if sys.platform == "darwin":
from modules import mac_specific
@@ -23,6 +24,23 @@ def has_mps() -> bool:
return mac_specific.has_mps
+def cuda_no_autocast(device_id=None) -> bool:
+ if device_id is None:
+ device_id = get_cuda_device_id()
+ return (
+ torch.cuda.get_device_capability(device_id) == (7, 5)
+ and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
+ )
+
+
+def get_cuda_device_id():
+ return (
+ int(shared.cmd_opts.device_id)
+ if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
+ else 0
+ ) or torch.cuda.current_device()
+
+
def get_cuda_device_string():
if shared.cmd_opts.device_id is not None:
return f"cuda:{shared.cmd_opts.device_id}"
@@ -73,8 +91,7 @@ def enable_tf32():
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
- device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device()
- if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"):
+ if cuda_no_autocast():
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@@ -84,6 +101,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
cpu: torch.device = torch.device("cpu")
+fp8: bool = False
device: torch.device = None
device_interrogate: torch.device = None
device_gfpgan: torch.device = None
@@ -104,12 +122,51 @@ def cond_cast_float(input):
nv_rng = None
+patch_module_list = [
+ torch.nn.Linear,
+ torch.nn.Conv2d,
+ torch.nn.MultiheadAttention,
+ torch.nn.GroupNorm,
+ torch.nn.LayerNorm,
+]
+
+
+def manual_cast_forward(self, *args, **kwargs):
+ org_dtype = torch_utils.get_param(self).dtype
+ self.to(dtype)
+ args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
+ kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
+ result = self.org_forward(*args, **kwargs)
+ self.to(org_dtype)
+ return result
+
+
+@contextlib.contextmanager
+def manual_cast():
+ for module_type in patch_module_list:
+ org_forward = module_type.forward
+ module_type.forward = manual_cast_forward
+ module_type.org_forward = org_forward
+ try:
+ yield None
+ finally:
+ for module_type in patch_module_list:
+ module_type.forward = module_type.org_forward
def autocast(disable=False):
if disable:
return contextlib.nullcontext()
+ if fp8 and device==cpu:
+ return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
+
+ if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()):
+ return manual_cast()
+
+ if has_mps() and shared.cmd_opts.precision != "full":
+ return manual_cast()
+
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
return contextlib.nullcontext()
diff --git a/modules/errors.py b/modules/errors.py
index eb234a83..48aa13a1 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -107,8 +107,8 @@ def check_versions():
import torch
import gradio
- expected_torch_version = "2.0.0"
- expected_xformers_version = "0.0.20"
+ expected_torch_version = "2.1.2"
+ expected_xformers_version = "0.0.23.post1"
expected_gradio_version = "3.41.2"
if version.parse(torch.__version__) < version.parse(expected_torch_version):
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index 02a1727d..70041ab0 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -1,121 +1,7 @@
-import sys
-
-import numpy as np
-import torch
-from PIL import Image
-
-import modules.esrgan_model_arch as arch
-from modules import modelloader, images, devices
+from modules import modelloader, devices, errors
from modules.shared import opts
from modules.upscaler import Upscaler, UpscalerData
-
-
-def mod2normal(state_dict):
- # this code is copied from https://github.com/victorca25/iNNfer
- if 'conv_first.weight' in state_dict:
- crt_net = {}
- items = list(state_dict)
-
- crt_net['model.0.weight'] = state_dict['conv_first.weight']
- crt_net['model.0.bias'] = state_dict['conv_first.bias']
-
- for k in items.copy():
- if 'RDB' in k:
- ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
- if '.weight' in k:
- ori_k = ori_k.replace('.weight', '.0.weight')
- elif '.bias' in k:
- ori_k = ori_k.replace('.bias', '.0.bias')
- crt_net[ori_k] = state_dict[k]
- items.remove(k)
-
- crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
- crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
- crt_net['model.3.weight'] = state_dict['upconv1.weight']
- crt_net['model.3.bias'] = state_dict['upconv1.bias']
- crt_net['model.6.weight'] = state_dict['upconv2.weight']
- crt_net['model.6.bias'] = state_dict['upconv2.bias']
- crt_net['model.8.weight'] = state_dict['HRconv.weight']
- crt_net['model.8.bias'] = state_dict['HRconv.bias']
- crt_net['model.10.weight'] = state_dict['conv_last.weight']
- crt_net['model.10.bias'] = state_dict['conv_last.bias']
- state_dict = crt_net
- return state_dict
-
-
-def resrgan2normal(state_dict, nb=23):
- # this code is copied from https://github.com/victorca25/iNNfer
- if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
- re8x = 0
- crt_net = {}
- items = list(state_dict)
-
- crt_net['model.0.weight'] = state_dict['conv_first.weight']
- crt_net['model.0.bias'] = state_dict['conv_first.bias']
-
- for k in items.copy():
- if "rdb" in k:
- ori_k = k.replace('body.', 'model.1.sub.')
- ori_k = ori_k.replace('.rdb', '.RDB')
- if '.weight' in k:
- ori_k = ori_k.replace('.weight', '.0.weight')
- elif '.bias' in k:
- ori_k = ori_k.replace('.bias', '.0.bias')
- crt_net[ori_k] = state_dict[k]
- items.remove(k)
-
- crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
- crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
- crt_net['model.3.weight'] = state_dict['conv_up1.weight']
- crt_net['model.3.bias'] = state_dict['conv_up1.bias']
- crt_net['model.6.weight'] = state_dict['conv_up2.weight']
- crt_net['model.6.bias'] = state_dict['conv_up2.bias']
-
- if 'conv_up3.weight' in state_dict:
- # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
- re8x = 3
- crt_net['model.9.weight'] = state_dict['conv_up3.weight']
- crt_net['model.9.bias'] = state_dict['conv_up3.bias']
-
- crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
- crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
- crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
- crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
-
- state_dict = crt_net
- return state_dict
-
-
-def infer_params(state_dict):
- # this code is copied from https://github.com/victorca25/iNNfer
- scale2x = 0
- scalemin = 6
- n_uplayer = 0
- plus = False
-
- for block in list(state_dict):
- parts = block.split(".")
- n_parts = len(parts)
- if n_parts == 5 and parts[2] == "sub":
- nb = int(parts[3])
- elif n_parts == 3:
- part_num = int(parts[1])
- if (part_num > scalemin
- and parts[0] == "model"
- and parts[2] == "weight"):
- scale2x += 1
- if part_num > n_uplayer:
- n_uplayer = part_num
- out_nc = state_dict[block].shape[0]
- if not plus and "conv1x1" in block:
- plus = True
-
- nf = state_dict["model.0.weight"].shape[0]
- in_nc = state_dict["model.0.weight"].shape[1]
- out_nc = out_nc
- scale = 2 ** scale2x
-
- return in_nc, out_nc, nf, nb, plus, scale
+from modules.upscaler_utils import upscale_with_model
class UpscalerESRGAN(Upscaler):
@@ -143,12 +29,11 @@ class UpscalerESRGAN(Upscaler):
def do_upscale(self, img, selected_model):
try:
model = self.load_model(selected_model)
- except Exception as e:
- print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
+ except Exception:
+ errors.report(f"Unable to load ESRGAN model {selected_model}", exc_info=True)
return img
model.to(devices.device_esrgan)
- img = esrgan_upscale(model, img)
- return img
+ return esrgan_upscale(model, img)
def load_model(self, path: str):
if path.startswith("http"):
@@ -161,69 +46,17 @@ class UpscalerESRGAN(Upscaler):
else:
filename = path
- state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
-
- if "params_ema" in state_dict:
- state_dict = state_dict["params_ema"]
- elif "params" in state_dict:
- state_dict = state_dict["params"]
- num_conv = 16 if "realesr-animevideov3" in filename else 32
- model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
- model.load_state_dict(state_dict)
- model.eval()
- return model
-
- if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
- nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
- state_dict = resrgan2normal(state_dict, nb)
- elif "conv_first.weight" in state_dict:
- state_dict = mod2normal(state_dict)
- elif "model.0.weight" not in state_dict:
- raise Exception("The file is not a recognized ESRGAN model.")
-
- in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
-
- model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
- model.load_state_dict(state_dict)
- model.eval()
-
- return model
-
-
-def upscale_without_tiling(model, img):
- img = np.array(img)
- img = img[:, :, ::-1]
- img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
- img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(devices.device_esrgan)
- with torch.no_grad():
- output = model(img)
- output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
- output = 255. * np.moveaxis(output, 0, 2)
- output = output.astype(np.uint8)
- output = output[:, :, ::-1]
- return Image.fromarray(output, 'RGB')
+ return modelloader.load_spandrel_model(
+ filename,
+ device=('cpu' if devices.device_esrgan.type == 'mps' else None),
+ expected_architecture='ESRGAN',
+ )
def esrgan_upscale(model, img):
- if opts.ESRGAN_tile == 0:
- return upscale_without_tiling(model, img)
-
- grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
- newtiles = []
- scale_factor = 1
-
- for y, h, row in grid.tiles:
- newrow = []
- for tiledata in row:
- x, w, tile = tiledata
-
- output = upscale_without_tiling(model, tile)
- scale_factor = output.width // tile.width
-
- newrow.append([x * scale_factor, w * scale_factor, output])
- newtiles.append([y * scale_factor, h * scale_factor, newrow])
-
- newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
- output = images.combine_grid(newgrid)
- return output
+ return upscale_with_model(
+ model,
+ img,
+ tile_size=opts.ESRGAN_tile,
+ tile_overlap=opts.ESRGAN_tile_overlap,
+ )
diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py
deleted file mode 100644
index 2b9888ba..00000000
--- a/modules/esrgan_model_arch.py
+++ /dev/null
@@ -1,465 +0,0 @@
-# this file is adapted from https://github.com/victorca25/iNNfer
-
-from collections import OrderedDict
-import math
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-####################
-# RRDBNet Generator
-####################
-
-class RRDBNet(nn.Module):
- def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
- act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
- finalact=None, gaussian_noise=False, plus=False):
- super(RRDBNet, self).__init__()
- n_upscale = int(math.log(upscale, 2))
- if upscale == 3:
- n_upscale = 1
-
- self.resrgan_scale = 0
- if in_nc % 16 == 0:
- self.resrgan_scale = 1
- elif in_nc != 4 and in_nc % 4 == 0:
- self.resrgan_scale = 2
-
- fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
- rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
- norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
- gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
- LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
-
- if upsample_mode == 'upconv':
- upsample_block = upconv_block
- elif upsample_mode == 'pixelshuffle':
- upsample_block = pixelshuffle_block
- else:
- raise NotImplementedError(f'upsample mode [{upsample_mode}] is not found')
- if upscale == 3:
- upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
- else:
- upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
- HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
- HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
-
- outact = act(finalact) if finalact else None
-
- self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
- *upsampler, HR_conv0, HR_conv1, outact)
-
- def forward(self, x, outm=None):
- if self.resrgan_scale == 1:
- feat = pixel_unshuffle(x, scale=4)
- elif self.resrgan_scale == 2:
- feat = pixel_unshuffle(x, scale=2)
- else:
- feat = x
-
- return self.model(feat)
-
-
-class RRDB(nn.Module):
- """
- Residual in Residual Dense Block
- (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
- """
-
- def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
- norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
- spectral_norm=False, gaussian_noise=False, plus=False):
- super(RRDB, self).__init__()
- # This is for backwards compatibility with existing models
- if nr == 3:
- self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
- norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
- gaussian_noise=gaussian_noise, plus=plus)
- self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
- norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
- gaussian_noise=gaussian_noise, plus=plus)
- self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
- norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
- gaussian_noise=gaussian_noise, plus=plus)
- else:
- RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
- norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
- gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
- self.RDBs = nn.Sequential(*RDB_list)
-
- def forward(self, x):
- if hasattr(self, 'RDB1'):
- out = self.RDB1(x)
- out = self.RDB2(out)
- out = self.RDB3(out)
- else:
- out = self.RDBs(x)
- return out * 0.2 + x
-
-
-class ResidualDenseBlock_5C(nn.Module):
- """
- Residual Dense Block
- The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
- Modified options that can be used:
- - "Partial Convolution based Padding" arXiv:1811.11718
- - "Spectral normalization" arXiv:1802.05957
- - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
- {Rakotonirina} and A. {Rasoanaivo}
- """
-
- def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
- norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
- spectral_norm=False, gaussian_noise=False, plus=False):
- super(ResidualDenseBlock_5C, self).__init__()
-
- self.noise = GaussianNoise() if gaussian_noise else None
- self.conv1x1 = conv1x1(nf, gc) if plus else None
-
- self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
- norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
- spectral_norm=spectral_norm)
- self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
- norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
- spectral_norm=spectral_norm)
- self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
- norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
- spectral_norm=spectral_norm)
- self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
- norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
- spectral_norm=spectral_norm)
- if mode == 'CNA':
- last_act = None
- else:
- last_act = act_type
- self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
- norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
- spectral_norm=spectral_norm)
-
- def forward(self, x):
- x1 = self.conv1(x)
- x2 = self.conv2(torch.cat((x, x1), 1))
- if self.conv1x1:
- x2 = x2 + self.conv1x1(x)
- x3 = self.conv3(torch.cat((x, x1, x2), 1))
- x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
- if self.conv1x1:
- x4 = x4 + x2
- x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
- if self.noise:
- return self.noise(x5.mul(0.2) + x)
- else:
- return x5 * 0.2 + x
-
-
-####################
-# ESRGANplus
-####################
-
-class GaussianNoise(nn.Module):
- def __init__(self, sigma=0.1, is_relative_detach=False):
- super().__init__()
- self.sigma = sigma
- self.is_relative_detach = is_relative_detach
- self.noise = torch.tensor(0, dtype=torch.float)
-
- def forward(self, x):
- if self.training and self.sigma != 0:
- self.noise = self.noise.to(x.device)
- scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
- sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
- x = x + sampled_noise
- return x
-
-def conv1x1(in_planes, out_planes, stride=1):
- return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
-
-
-####################
-# SRVGGNetCompact
-####################
-
-class SRVGGNetCompact(nn.Module):
- """A compact VGG-style network structure for super-resolution.
- This class is copied from https://github.com/xinntao/Real-ESRGAN
- """
-
- def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
- super(SRVGGNetCompact, self).__init__()
- self.num_in_ch = num_in_ch
- self.num_out_ch = num_out_ch
- self.num_feat = num_feat
- self.num_conv = num_conv
- self.upscale = upscale
- self.act_type = act_type
-
- self.body = nn.ModuleList()
- # the first conv
- self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
- # the first activation
- if act_type == 'relu':
- activation = nn.ReLU(inplace=True)
- elif act_type == 'prelu':
- activation = nn.PReLU(num_parameters=num_feat)
- elif act_type == 'leakyrelu':
- activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
- self.body.append(activation)
-
- # the body structure
- for _ in range(num_conv):
- self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
- # activation
- if act_type == 'relu':
- activation = nn.ReLU(inplace=True)
- elif act_type == 'prelu':
- activation = nn.PReLU(num_parameters=num_feat)
- elif act_type == 'leakyrelu':
- activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
- self.body.append(activation)
-
- # the last conv
- self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
- # upsample
- self.upsampler = nn.PixelShuffle(upscale)
-
- def forward(self, x):
- out = x
- for i in range(0, len(self.body)):
- out = self.body[i](out)
-
- out = self.upsampler(out)
- # add the nearest upsampled image, so that the network learns the residual
- base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
- out += base
- return out
-
-
-####################
-# Upsampler
-####################
-
-class Upsample(nn.Module):
- r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
- The input data is assumed to be of the form
- `minibatch x channels x [optional depth] x [optional height] x width`.
- """
-
- def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
- super(Upsample, self).__init__()
- if isinstance(scale_factor, tuple):
- self.scale_factor = tuple(float(factor) for factor in scale_factor)
- else:
- self.scale_factor = float(scale_factor) if scale_factor else None
- self.mode = mode
- self.size = size
- self.align_corners = align_corners
-
- def forward(self, x):
- return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
-
- def extra_repr(self):
- if self.scale_factor is not None:
- info = f'scale_factor={self.scale_factor}'
- else:
- info = f'size={self.size}'
- info += f', mode={self.mode}'
- return info
-
-
-def pixel_unshuffle(x, scale):
- """ Pixel unshuffle.
- Args:
- x (Tensor): Input feature with shape (b, c, hh, hw).
- scale (int): Downsample ratio.
- Returns:
- Tensor: the pixel unshuffled feature.
- """
- b, c, hh, hw = x.size()
- out_channel = c * (scale**2)
- assert hh % scale == 0 and hw % scale == 0
- h = hh // scale
- w = hw // scale
- x_view = x.view(b, c, h, scale, w, scale)
- return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
-
-
-def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
- pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
- """
- Pixel shuffle layer
- (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
- Neural Network, CVPR17)
- """
- conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
- pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
- pixel_shuffle = nn.PixelShuffle(upscale_factor)
-
- n = norm(norm_type, out_nc) if norm_type else None
- a = act(act_type) if act_type else None
- return sequential(conv, pixel_shuffle, n, a)
-
-
-def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
- pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
- """ Upconv layer """
- upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
- upsample = Upsample(scale_factor=upscale_factor, mode=mode)
- conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
- pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
- return sequential(upsample, conv)
-
-
-
-
-
-
-
-
-####################
-# Basic blocks
-####################
-
-
-def make_layer(basic_block, num_basic_block, **kwarg):
- """Make layers by stacking the same blocks.
- Args:
- basic_block (nn.module): nn.module class for basic block. (block)
- num_basic_block (int): number of blocks. (n_layers)
- Returns:
- nn.Sequential: Stacked blocks in nn.Sequential.
- """
- layers = []
- for _ in range(num_basic_block):
- layers.append(basic_block(**kwarg))
- return nn.Sequential(*layers)
-
-
-def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
- """ activation helper """
- act_type = act_type.lower()
- if act_type == 'relu':
- layer = nn.ReLU(inplace)
- elif act_type in ('leakyrelu', 'lrelu'):
- layer = nn.LeakyReLU(neg_slope, inplace)
- elif act_type == 'prelu':
- layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
- elif act_type == 'tanh': # [-1, 1] range output
- layer = nn.Tanh()
- elif act_type == 'sigmoid': # [0, 1] range output
- layer = nn.Sigmoid()
- else:
- raise NotImplementedError(f'activation layer [{act_type}] is not found')
- return layer
-
-
-class Identity(nn.Module):
- def __init__(self, *kwargs):
- super(Identity, self).__init__()
-
- def forward(self, x, *kwargs):
- return x
-
-
-def norm(norm_type, nc):
- """ Return a normalization layer """
- norm_type = norm_type.lower()
- if norm_type == 'batch':
- layer = nn.BatchNorm2d(nc, affine=True)
- elif norm_type == 'instance':
- layer = nn.InstanceNorm2d(nc, affine=False)
- elif norm_type == 'none':
- def norm_layer(x): return Identity()
- else:
- raise NotImplementedError(f'normalization layer [{norm_type}] is not found')
- return layer
-
-
-def pad(pad_type, padding):
- """ padding layer helper """
- pad_type = pad_type.lower()
- if padding == 0:
- return None
- if pad_type == 'reflect':
- layer = nn.ReflectionPad2d(padding)
- elif pad_type == 'replicate':
- layer = nn.ReplicationPad2d(padding)
- elif pad_type == 'zero':
- layer = nn.ZeroPad2d(padding)
- else:
- raise NotImplementedError(f'padding layer [{pad_type}] is not implemented')
- return layer
-
-
-def get_valid_padding(kernel_size, dilation):
- kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
- padding = (kernel_size - 1) // 2
- return padding
-
-
-class ShortcutBlock(nn.Module):
- """ Elementwise sum the output of a submodule to its input """
- def __init__(self, submodule):
- super(ShortcutBlock, self).__init__()
- self.sub = submodule
-
- def forward(self, x):
- output = x + self.sub(x)
- return output
-
- def __repr__(self):
- return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
-
-
-def sequential(*args):
- """ Flatten Sequential. It unwraps nn.Sequential. """
- if len(args) == 1:
- if isinstance(args[0], OrderedDict):
- raise NotImplementedError('sequential does not support OrderedDict input.')
- return args[0] # No sequential is needed.
- modules = []
- for module in args:
- if isinstance(module, nn.Sequential):
- for submodule in module.children():
- modules.append(submodule)
- elif isinstance(module, nn.Module):
- modules.append(module)
- return nn.Sequential(*modules)
-
-
-def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
- pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
- spectral_norm=False):
- """ Conv layer with padding, normalization, activation """
- assert mode in ['CNA', 'NAC', 'CNAC'], f'Wrong conv mode [{mode}]'
- padding = get_valid_padding(kernel_size, dilation)
- p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
- padding = padding if pad_type == 'zero' else 0
-
- if convtype=='PartialConv2D':
- from torchvision.ops import PartialConv2d # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer
- c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
- dilation=dilation, bias=bias, groups=groups)
- elif convtype=='DeformConv2D':
- from torchvision.ops import DeformConv2d # not tested
- c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
- dilation=dilation, bias=bias, groups=groups)
- elif convtype=='Conv3D':
- c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
- dilation=dilation, bias=bias, groups=groups)
- else:
- c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
- dilation=dilation, bias=bias, groups=groups)
-
- if spectral_norm:
- c = nn.utils.spectral_norm(c)
-
- a = act(act_type) if act_type else None
- if 'CNA' in mode:
- n = norm(norm_type, out_nc) if norm_type else None
- return sequential(p, c, n, a)
- elif mode == 'NAC':
- if norm_type is None and act_type is not None:
- a = act(act_type, inplace=False)
- n = norm(norm_type, in_nc) if norm_type else None
- return sequential(n, a, p, c)
diff --git a/modules/face_restoration_utils.py b/modules/face_restoration_utils.py
new file mode 100644
index 00000000..1cbac236
--- /dev/null
+++ b/modules/face_restoration_utils.py
@@ -0,0 +1,180 @@
+from __future__ import annotations
+
+import logging
+import os
+from functools import cached_property
+from typing import TYPE_CHECKING, Callable
+
+import cv2
+import numpy as np
+import torch
+
+from modules import devices, errors, face_restoration, shared
+
+if TYPE_CHECKING:
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
+
+logger = logging.getLogger(__name__)
+
+
+def bgr_image_to_rgb_tensor(img: np.ndarray) -> torch.Tensor:
+ """Convert a BGR NumPy image in [0..1] range to a PyTorch RGB float32 tensor."""
+ assert img.shape[2] == 3, "image must be RGB"
+ if img.dtype == "float64":
+ img = img.astype("float32")
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ return torch.from_numpy(img.transpose(2, 0, 1)).float()
+
+
+def rgb_tensor_to_bgr_image(tensor: torch.Tensor, *, min_max=(0.0, 1.0)) -> np.ndarray:
+ """
+ Convert a PyTorch RGB tensor in range `min_max` to a BGR NumPy image in [0..1] range.
+ """
+ tensor = tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
+ tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])
+ assert tensor.dim() == 3, "tensor must be RGB"
+ img_np = tensor.numpy().transpose(1, 2, 0)
+ if img_np.shape[2] == 1: # gray image, no RGB/BGR required
+ return np.squeeze(img_np, axis=2)
+ return cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
+
+
+def create_face_helper(device) -> FaceRestoreHelper:
+ from facexlib.detection import retinaface
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
+ if hasattr(retinaface, 'device'):
+ retinaface.device = device
+ return FaceRestoreHelper(
+ upscale_factor=1,
+ face_size=512,
+ crop_ratio=(1, 1),
+ det_model='retinaface_resnet50',
+ save_ext='png',
+ use_parse=True,
+ device=device,
+ )
+
+
+def restore_with_face_helper(
+ np_image: np.ndarray,
+ face_helper: FaceRestoreHelper,
+ restore_face: Callable[[torch.Tensor], torch.Tensor],
+) -> np.ndarray:
+ """
+ Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image.
+
+ `restore_face` should take a cropped face image and return a restored face image.
+ """
+ from torchvision.transforms.functional import normalize
+ np_image = np_image[:, :, ::-1]
+ original_resolution = np_image.shape[0:2]
+
+ try:
+ logger.debug("Detecting faces...")
+ face_helper.clean_all()
+ face_helper.read_image(np_image)
+ face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
+ face_helper.align_warp_face()
+ logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces))
+ for cropped_face in face_helper.cropped_faces:
+ cropped_face_t = bgr_image_to_rgb_tensor(cropped_face / 255.0)
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
+
+ try:
+ with torch.no_grad():
+ cropped_face_t = restore_face(cropped_face_t)
+ devices.torch_gc()
+ except Exception:
+ errors.report('Failed face-restoration inference', exc_info=True)
+
+ restored_face = rgb_tensor_to_bgr_image(cropped_face_t, min_max=(-1, 1))
+ restored_face = (restored_face * 255.0).astype('uint8')
+ face_helper.add_restored_face(restored_face)
+
+ logger.debug("Merging restored faces into image")
+ face_helper.get_inverse_affine(None)
+ img = face_helper.paste_faces_to_input_image()
+ img = img[:, :, ::-1]
+ if original_resolution != img.shape[0:2]:
+ img = cv2.resize(
+ img,
+ (0, 0),
+ fx=original_resolution[1] / img.shape[1],
+ fy=original_resolution[0] / img.shape[0],
+ interpolation=cv2.INTER_LINEAR,
+ )
+ logger.debug("Face restoration complete")
+ finally:
+ face_helper.clean_all()
+ return img
+
+
+class CommonFaceRestoration(face_restoration.FaceRestoration):
+ net: torch.Module | None
+ model_url: str
+ model_download_name: str
+
+ def __init__(self, model_path: str):
+ super().__init__()
+ self.net = None
+ self.model_path = model_path
+ os.makedirs(model_path, exist_ok=True)
+
+ @cached_property
+ def face_helper(self) -> FaceRestoreHelper:
+ return create_face_helper(self.get_device())
+
+ def send_model_to(self, device):
+ if self.net:
+ logger.debug("Sending %s to %s", self.net, device)
+ self.net.to(device)
+ if self.face_helper:
+ logger.debug("Sending face helper to %s", device)
+ self.face_helper.face_det.to(device)
+ self.face_helper.face_parse.to(device)
+
+ def get_device(self):
+ raise NotImplementedError("get_device must be implemented by subclasses")
+
+ def load_net(self) -> torch.Module:
+ raise NotImplementedError("load_net must be implemented by subclasses")
+
+ def restore_with_helper(
+ self,
+ np_image: np.ndarray,
+ restore_face: Callable[[torch.Tensor], torch.Tensor],
+ ) -> np.ndarray:
+ try:
+ if self.net is None:
+ self.net = self.load_net()
+ except Exception:
+ logger.warning("Unable to load face-restoration model", exc_info=True)
+ return np_image
+
+ try:
+ self.send_model_to(self.get_device())
+ return restore_with_face_helper(np_image, self.face_helper, restore_face)
+ finally:
+ if shared.opts.face_restoration_unload:
+ self.send_model_to(devices.cpu)
+
+
+def patch_facexlib(dirname: str) -> None:
+ import facexlib.detection
+ import facexlib.parsing
+
+ det_facex_load_file_from_url = facexlib.detection.load_file_from_url
+ par_facex_load_file_from_url = facexlib.parsing.load_file_from_url
+
+ def update_kwargs(kwargs):
+ return dict(kwargs, save_dir=dirname, model_dir=None)
+
+ def facex_load_file_from_url(**kwargs):
+ return det_facex_load_file_from_url(**update_kwargs(kwargs))
+
+ def facex_load_file_from_url2(**kwargs):
+ return par_facex_load_file_from_url(**update_kwargs(kwargs))
+
+ facexlib.detection.load_file_from_url = facex_load_file_from_url
+ facexlib.parsing.load_file_from_url = facex_load_file_from_url2
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index 01d668ec..445b0409 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -1,125 +1,71 @@
+from __future__ import annotations
+
+import logging
import os
-import facexlib
-import gfpgan
+import torch
-import modules.face_restoration
-from modules import paths, shared, devices, modelloader, errors
+from modules import (
+ devices,
+ errors,
+ face_restoration,
+ face_restoration_utils,
+ modelloader,
+ shared,
+)
-model_dir = "GFPGAN"
-user_path = None
-model_path = os.path.join(paths.models_path, model_dir)
-model_file_path = None
+logger = logging.getLogger(__name__)
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
-have_gfpgan = False
-loaded_gfpgan_model = None
-
-
-def gfpgann():
- global loaded_gfpgan_model
- global model_path
- global model_file_path
- if loaded_gfpgan_model is not None:
- loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
- return loaded_gfpgan_model
-
- if gfpgan_constructor is None:
- return None
-
- models = modelloader.load_models(model_path, model_url, user_path, ext_filter=['.pth'])
-
- if len(models) == 1 and models[0].startswith("http"):
- model_file = models[0]
- elif len(models) != 0:
- gfp_models = []
- for item in models:
- if 'GFPGAN' in os.path.basename(item):
- gfp_models.append(item)
- latest_file = max(gfp_models, key=os.path.getctime)
- model_file = latest_file
- else:
- print("Unable to load gfpgan model!")
- return None
-
- if hasattr(facexlib.detection.retinaface, 'device'):
- facexlib.detection.retinaface.device = devices.device_gfpgan
- model_file_path = model_file
- model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
- loaded_gfpgan_model = model
-
- return model
-
-
-def send_model_to(model, device):
- model.gfpgan.to(device)
- model.face_helper.face_det.to(device)
- model.face_helper.face_parse.to(device)
+model_download_name = "GFPGANv1.4.pth"
+gfpgan_face_restorer: face_restoration.FaceRestoration | None = None
+
+
+class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
+ def name(self):
+ return "GFPGAN"
+
+ def get_device(self):
+ return devices.device_gfpgan
+
+ def load_net(self) -> torch.Module:
+ for model_path in modelloader.load_models(
+ model_path=self.model_path,
+ model_url=model_url,
+ command_path=self.model_path,
+ download_name=model_download_name,
+ ext_filter=['.pth'],
+ ):
+ if 'GFPGAN' in os.path.basename(model_path):
+ model = modelloader.load_spandrel_model(
+ model_path,
+ device=self.get_device(),
+ expected_architecture='GFPGAN',
+ ).model
+ model.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81
+ return model
+ raise ValueError("No GFPGAN model found")
+
+ def restore(self, np_image):
+ def restore_face(cropped_face_t):
+ assert self.net is not None
+ return self.net(cropped_face_t, return_rgb=False)[0]
+
+ return self.restore_with_helper(np_image, restore_face)
def gfpgan_fix_faces(np_image):
- model = gfpgann()
- if model is None:
- return np_image
-
- send_model_to(model, devices.device_gfpgan)
-
- np_image_bgr = np_image[:, :, ::-1]
- cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
- np_image = gfpgan_output_bgr[:, :, ::-1]
-
- model.face_helper.clean_all()
-
- if shared.opts.face_restoration_unload:
- send_model_to(model, devices.cpu)
-
+ if gfpgan_face_restorer:
+ return gfpgan_face_restorer.restore(np_image)
+ logger.warning("GFPGAN face restorer not set up")
return np_image
-gfpgan_constructor = None
+def setup_model(dirname: str) -> None:
+ global gfpgan_face_restorer
-
-def setup_model(dirname):
try:
- os.makedirs(model_path, exist_ok=True)
- from gfpgan import GFPGANer
- from facexlib import detection, parsing # noqa: F401
- global user_path
- global have_gfpgan
- global gfpgan_constructor
- global model_file_path
-
- facexlib_path = model_path
-
- if dirname is not None:
- facexlib_path = dirname
-
- load_file_from_url_orig = gfpgan.utils.load_file_from_url
- facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
- facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
-
- def my_load_file_from_url(**kwargs):
- return load_file_from_url_orig(**dict(kwargs, model_dir=model_file_path))
-
- def facex_load_file_from_url(**kwargs):
- return facex_load_file_from_url_orig(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
-
- def facex_load_file_from_url2(**kwargs):
- return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
-
- gfpgan.utils.load_file_from_url = my_load_file_from_url
- facexlib.detection.load_file_from_url = facex_load_file_from_url
- facexlib.parsing.load_file_from_url = facex_load_file_from_url2
- user_path = dirname
- have_gfpgan = True
- gfpgan_constructor = GFPGANer
-
- class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
- def name(self):
- return "GFPGAN"
-
- def restore(self, np_image):
- return gfpgan_fix_faces(np_image)
-
- shared.face_restorers.append(FaceRestorerGFPGAN())
+ face_restoration_utils.patch_facexlib(dirname)
+ gfpgan_face_restorer = FaceRestorerGFPGAN(model_path=dirname)
+ shared.face_restorers.append(gfpgan_face_restorer)
except Exception:
errors.report("Error setting up GFPGAN", exc_info=True)
diff --git a/modules/hat_model.py b/modules/hat_model.py
new file mode 100644
index 00000000..7f2abb41
--- /dev/null
+++ b/modules/hat_model.py
@@ -0,0 +1,43 @@
+import os
+import sys
+
+from modules import modelloader, devices
+from modules.shared import opts
+from modules.upscaler import Upscaler, UpscalerData
+from modules.upscaler_utils import upscale_with_model
+
+
+class UpscalerHAT(Upscaler):
+ def __init__(self, dirname):
+ self.name = "HAT"
+ self.scalers = []
+ self.user_path = dirname
+ super().__init__()
+ for file in self.find_models(ext_filter=[".pt", ".pth"]):
+ name = modelloader.friendly_name(file)
+ scale = 4 # TODO: scale might not be 4, but we can't know without loading the model
+ scaler_data = UpscalerData(name, file, upscaler=self, scale=scale)
+ self.scalers.append(scaler_data)
+
+ def do_upscale(self, img, selected_model):
+ try:
+ model = self.load_model(selected_model)
+ except Exception as e:
+ print(f"Unable to load HAT model {selected_model}: {e}", file=sys.stderr)
+ return img
+ model.to(devices.device_esrgan) # TODO: should probably be device_hat
+ return upscale_with_model(
+ model,
+ img,
+ tile_size=opts.ESRGAN_tile, # TODO: should probably be HAT_tile
+ tile_overlap=opts.ESRGAN_tile_overlap, # TODO: should probably be HAT_tile_overlap
+ )
+
+ def load_model(self, path: str):
+ if not os.path.isfile(path):
+ raise FileNotFoundError(f"Model file {path} not found")
+ return modelloader.load_spandrel_model(
+ path,
+ device=devices.device_esrgan, # TODO: should probably be device_hat
+ expected_architecture='HAT',
+ )
diff --git a/modules/images.py b/modules/images.py
index daf4eebe..87a7bf22 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -61,12 +61,17 @@ def image_grid(imgs, batch_size=1, rows=None):
return grid
-Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
+class Grid(namedtuple("_Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])):
+ @property
+ def tile_count(self) -> int:
+ """
+ The total number of tiles in the grid.
+ """
+ return sum(len(row[2]) for row in self.tiles)
-def split_grid(image, tile_w=512, tile_h=512, overlap=64):
- w = image.width
- h = image.height
+def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid:
+ w, h = image.size
non_overlap_width = tile_w - overlap
non_overlap_height = tile_h - overlap
@@ -791,3 +796,4 @@ def flatten(img, bgcolor):
img = background
return img.convert('RGB')
+
diff --git a/modules/img2img.py b/modules/img2img.py
index c583290a..e7e8e251 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -7,7 +7,7 @@ from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageErr
import gradio as gr
from modules import images as imgutil
-from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
+from modules.infotext import create_override_settings_dict, parse_generation_parameters
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
from modules.sd_models import get_closet_checkpoint_match
@@ -51,7 +51,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
if state.skipped:
state.skipped = False
- if state.interrupted:
+ if state.interrupted or state.stopping_generation:
break
try:
diff --git a/modules/generation_parameters_copypaste.py b/modules/infotext.py
index 4efe53e0..26e9b949 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/infotext.py
@@ -4,12 +4,15 @@ import io
import json
import os
import re
+import sys
import gradio as gr
from modules.paths import data_path
-from modules import shared, ui_tempdir, script_callbacks, processing
+from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions
from PIL import Image
+sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name
+
re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
@@ -28,6 +31,19 @@ class ParamBinding:
self.paste_field_names = paste_field_names or []
+class PasteField(tuple):
+ def __new__(cls, component, target, *, api=None):
+ return super().__new__(cls, (component, target))
+
+ def __init__(self, component, target, *, api=None):
+ super().__init__()
+
+ self.api = api
+ self.component = component
+ self.label = target if isinstance(target, str) else None
+ self.function = target if callable(target) else None
+
+
paste_fields: dict[str, dict] = {}
registered_param_bindings: list[ParamBinding] = []
@@ -84,6 +100,12 @@ def image_from_url_text(filedata):
def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
+
+ if fields:
+ for i in range(len(fields)):
+ if not isinstance(fields[i], PasteField):
+ fields[i] = PasteField(*fields[i])
+
paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
# backwards compatibility for existing extensions
@@ -314,6 +336,14 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
if "VAE Decoder" not in res:
res["VAE Decoder"] = "Full"
+ if "FP8 weight" not in res:
+ res["FP8 weight"] = "Disable"
+
+ if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable":
+ res["Cache FP16 weight for LoRA"] = False
+
+ infotext_versions.backcompat(res)
+
skip = set(shared.opts.infotext_skip_pasting)
res = {k: v for k, v in res.items() if k not in skip}
@@ -365,6 +395,48 @@ def create_override_settings_dict(text_pairs):
return res
+def get_override_settings(params, *, skip_fields=None):
+ """Returns a list of settings overrides from the infotext parameters dictionary.
+
+ This function checks the `params` dictionary for any keys that correspond to settings in `shared.opts` and returns
+ a list of tuples containing the parameter name, setting name, and new value cast to correct type.
+
+ It checks for conditions before adding an override:
+ - ignores settings that match the current value
+ - ignores parameter keys present in skip_fields argument.
+
+ Example input:
+ {"Clip skip": "2"}
+
+ Example output:
+ [("Clip skip", "CLIP_stop_at_last_layers", 2)]
+ """
+
+ res = []
+
+ mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
+ for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
+ if param_name in (skip_fields or {}):
+ continue
+
+ v = params.get(param_name, None)
+ if v is None:
+ continue
+
+ if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
+ continue
+
+ v = shared.opts.cast_value(setting_name, v)
+ current_value = getattr(shared.opts, setting_name, None)
+
+ if v == current_value:
+ continue
+
+ res.append((param_name, setting_name, v))
+
+ return res
+
+
def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
def paste_func(prompt):
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
@@ -406,29 +478,9 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
already_handled_fields = {key: 1 for _, key in paste_fields}
def paste_settings(params):
- vals = {}
-
- mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
- for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
- if param_name in already_handled_fields:
- continue
-
- v = params.get(param_name, None)
- if v is None:
- continue
-
- if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
- continue
-
- v = shared.opts.cast_value(setting_name, v)
- current_value = getattr(shared.opts, setting_name, None)
-
- if v == current_value:
- continue
-
- vals[param_name] = v
+ vals = get_override_settings(params, skip_fields=already_handled_fields)
- vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
+ vals_pairs = [f"{infotext_text}: {value}" for infotext_text, setting_name, value in vals]
return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs))
diff --git a/modules/infotext_versions.py b/modules/infotext_versions.py
new file mode 100644
index 00000000..a5afeebf
--- /dev/null
+++ b/modules/infotext_versions.py
@@ -0,0 +1,39 @@
+from modules import shared
+from packaging import version
+import re
+
+
+v160 = version.parse("1.6.0")
+v170_tsnr = version.parse("v1.7.0-225")
+
+
+def parse_version(text):
+ if text is None:
+ return None
+
+ m = re.match(r'([^-]+-[^-]+)-.*', text)
+ if m:
+ text = m.group(1)
+
+ try:
+ return version.parse(text)
+ except Exception:
+ return None
+
+
+def backcompat(d):
+ """Checks infotext Version field, and enables backwards compatibility options according to it."""
+
+ if not shared.opts.auto_backcompat:
+ return
+
+ ver = parse_version(d.get("Version"))
+ if ver is None:
+ return
+
+ if ver < v160:
+ d["Old prompt editing timelines"] = True
+
+ if ver < v170_tsnr:
+ d["Downcast alphas_cumprod"] = True
+
diff --git a/modules/initialize.py b/modules/initialize.py
index ac95fc6f..4a3cd98c 100644
--- a/modules/initialize.py
+++ b/modules/initialize.py
@@ -54,9 +54,6 @@ def initialize():
initialize_util.configure_sigint_handler()
initialize_util.configure_opts_onchange()
- from modules import modelloader
- modelloader.cleanup_models()
-
from modules import sd_models
sd_models.setup_model()
startup_timer.record("setup SD model")
diff --git a/modules/initialize_util.py b/modules/initialize_util.py
index 2e9b6d89..b6767138 100644
--- a/modules/initialize_util.py
+++ b/modules/initialize_util.py
@@ -177,6 +177,8 @@ def configure_opts_onchange():
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
+ shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
+ shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False)
startup_timer.record("opts onchange")
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 3045560d..35a627ca 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -10,7 +10,7 @@ import torch.hub
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
-from modules import devices, paths, shared, lowvram, modelloader, errors
+from modules import devices, paths, shared, lowvram, modelloader, errors, torch_utils
blip_image_eval_size = 384
clip_model_name = 'ViT-L/14'
@@ -131,7 +131,7 @@ class InterrogateModels:
self.clip_model = self.clip_model.to(devices.device_interrogate)
- self.dtype = next(self.clip_model.parameters()).dtype
+ self.dtype = torch_utils.get_param(self.clip_model).dtype
def send_clip_to_ram(self):
if not shared.opts.interrogate_keep_models_in_memory:
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index 29506f24..c2cbd8ce 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -314,8 +314,8 @@ def requirements_met(requirements_file):
def prepare_environment():
- torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
- torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
+ torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121")
+ torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.2 torchvision==0.16.2 --extra-index-url {torch_index_url}")
if args.use_ipex:
if platform.system() == "Windows":
# The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main
@@ -338,20 +338,18 @@ def prepare_environment():
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
- xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
+ xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.23.post1')
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
- codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
- codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
try:
@@ -408,15 +406,10 @@ def prepare_environment():
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
- git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
startup_timer.record("clone repositores")
- if not is_installed("lpips"):
- run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
- startup_timer.record("install CodeFormer requirements")
-
if not os.path.isfile(requirements_file):
requirements_file = os.path.join(script_path, requirements_file)
diff --git a/modules/modelloader.py b/modules/modelloader.py
index 098bcb79..a7194137 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -1,13 +1,20 @@
from __future__ import annotations
-import os
-import shutil
import importlib
+import logging
+import os
+from typing import TYPE_CHECKING
from urllib.parse import urlparse
+import torch
+
from modules import shared
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
-from modules.paths import script_path, models_path
+
+if TYPE_CHECKING:
+ import spandrel
+
+logger = logging.getLogger(__name__)
def load_file_from_url(
@@ -90,54 +97,6 @@ def friendly_name(file: str):
return model_name
-def cleanup_models():
- # This code could probably be more efficient if we used a tuple list or something to store the src/destinations
- # and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
- # somehow auto-register and just do these things...
- root_path = script_path
- src_path = models_path
- dest_path = os.path.join(models_path, "Stable-diffusion")
- move_files(src_path, dest_path, ".ckpt")
- move_files(src_path, dest_path, ".safetensors")
- src_path = os.path.join(root_path, "ESRGAN")
- dest_path = os.path.join(models_path, "ESRGAN")
- move_files(src_path, dest_path)
- src_path = os.path.join(models_path, "BSRGAN")
- dest_path = os.path.join(models_path, "ESRGAN")
- move_files(src_path, dest_path, ".pth")
- src_path = os.path.join(root_path, "gfpgan")
- dest_path = os.path.join(models_path, "GFPGAN")
- move_files(src_path, dest_path)
- src_path = os.path.join(root_path, "SwinIR")
- dest_path = os.path.join(models_path, "SwinIR")
- move_files(src_path, dest_path)
- src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
- dest_path = os.path.join(models_path, "LDSR")
- move_files(src_path, dest_path)
-
-
-def move_files(src_path: str, dest_path: str, ext_filter: str = None):
- try:
- os.makedirs(dest_path, exist_ok=True)
- if os.path.exists(src_path):
- for file in os.listdir(src_path):
- fullpath = os.path.join(src_path, file)
- if os.path.isfile(fullpath):
- if ext_filter is not None:
- if ext_filter not in file:
- continue
- print(f"Moving {file} from {src_path} to {dest_path}.")
- try:
- shutil.move(fullpath, dest_path)
- except Exception:
- pass
- if len(os.listdir(src_path)) == 0:
- print(f"Removing empty folder: {src_path}")
- shutil.rmtree(src_path, True)
- except Exception:
- pass
-
-
def load_upscalers():
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
# so we'll try to import any _model.py files before looking in __subclasses__
@@ -177,3 +136,26 @@ def load_upscalers():
# Special case for UpscalerNone keeps it at the beginning of the list.
key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
)
+
+
+def load_spandrel_model(
+ path: str,
+ *,
+ device: str | torch.device | None,
+ half: bool = False,
+ dtype: str | torch.dtype | None = None,
+ expected_architecture: str | None = None,
+) -> spandrel.ModelDescriptor:
+ import spandrel
+ model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path)
+ if expected_architecture and model_descriptor.architecture != expected_architecture:
+ logger.warning(
+ f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})",
+ )
+ if half:
+ model_descriptor.model.half()
+ if dtype:
+ model_descriptor.model.to(dtype=dtype)
+ model_descriptor.model.eval()
+ logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype)
+ return model_descriptor
diff --git a/modules/paths.py b/modules/paths.py
index 187b9496..03064651 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -38,7 +38,6 @@ mute_sdxl_imports()
path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion', []),
(os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
- (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
]
diff --git a/modules/paths_internal.py b/modules/paths_internal.py
index 89131a54..b86ecd7f 100644
--- a/modules/paths_internal.py
+++ b/modules/paths_internal.py
@@ -28,5 +28,6 @@ models_path = os.path.join(data_path, "models")
extensions_dir = os.path.join(data_path, "extensions")
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
config_states_dir = os.path.join(script_path, "config_states")
+default_output_dir = os.path.join(data_path, "output")
roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf')
diff --git a/modules/postprocessing.py b/modules/postprocessing.py
index 0c59fad4..facea899 100644
--- a/modules/postprocessing.py
+++ b/modules/postprocessing.py
@@ -2,7 +2,7 @@ import os
from PIL import Image
-from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste
+from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common
from modules.shared import opts
@@ -86,7 +86,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
basename = ''
forced_filename = None
- infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])
+ infotext = ", ".join([k if k == v else f'{k}: {infotext.quote(v)}' for k, v in pp.info.items() if v is not None])
if opts.enable_pnginfo:
pp.image.info = existing_pnginfo
diff --git a/modules/processing.py b/modules/processing.py
index 6f01c95f..f55b85ed 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -16,7 +16,7 @@ from skimage import exposure
from typing import Any
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
from modules.rng import slerp # noqa: F401
from modules.sd_hijack import model_hijack
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
@@ -62,18 +62,22 @@ def apply_color_correction(correction, original_image):
return image.convert('RGB')
-def apply_overlay(image, paste_loc, index, overlays):
- if overlays is None or index >= len(overlays):
- return image
+def uncrop(image, dest_size, paste_loc):
+ x, y, w, h = paste_loc
+ base_image = Image.new('RGBA', dest_size)
+ image = images.resize_image(1, image, w, h)
+ base_image.paste(image, (x, y))
+ image = base_image
+
+ return image
- overlay = overlays[index]
+
+def apply_overlay(image, paste_loc, overlay):
+ if overlay is None:
+ return image
if paste_loc is not None:
- x, y, w, h = paste_loc
- base_image = Image.new('RGBA', (overlay.width, overlay.height))
- image = images.resize_image(1, image, w, h)
- base_image.paste(image, (x, y))
- image = base_image
+ image = uncrop(image, (overlay.width, overlay.height), paste_loc)
image = image.convert('RGBA')
image.alpha_composite(overlay)
@@ -81,9 +85,12 @@ def apply_overlay(image, paste_loc, index, overlays):
return image
-def create_binary_mask(image):
+def create_binary_mask(image, round=True):
if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):
- image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
+ if round:
+ image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
+ else:
+ image = image.split()[-1].convert("L")
else:
image = image.convert('L')
return image
@@ -106,6 +113,21 @@ def txt2img_image_conditioning(sd_model, x, width, height):
return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
else:
+ sd = sd_model.model.state_dict()
+ diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
+ if diffusion_model_input is not None:
+ if diffusion_model_input.shape[1] == 9:
+ # The "masked-image" in this case will just be all 0.5 since the entire image is masked.
+ image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
+ image_conditioning = images_tensor_to_samples(image_conditioning,
+ approximation_indexes.get(opts.sd_vae_encode_method))
+
+ # Add the fake full 1s mask to the first dimension.
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
+ image_conditioning = image_conditioning.to(x.dtype)
+
+ return image_conditioning
+
# Dummy zero conditioning if we're not using inpainting or unclip models.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
@@ -308,7 +330,7 @@ class StableDiffusionProcessing:
c_adm = torch.cat((c_adm, noise_level_emb), 1)
return c_adm
- def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
+ def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
self.is_using_inpainting_conditioning = True
# Handle the different mask inputs
@@ -320,8 +342,10 @@ class StableDiffusionProcessing:
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
- # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
- conditioning_mask = torch.round(conditioning_mask)
+ if round_image_mask:
+ # Caller is requesting a discretized mask as input, so we round to either 1.0 or 0.0
+ conditioning_mask = torch.round(conditioning_mask)
+
else:
conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
@@ -345,7 +369,7 @@ class StableDiffusionProcessing:
return image_conditioning
- def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
+ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
source_image = devices.cond_cast_float(source_image)
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
@@ -357,11 +381,17 @@ class StableDiffusionProcessing:
return self.edit_image_conditioning(source_image)
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
- return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
+ return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask, round_image_mask=round_image_mask)
if self.sampler.conditioning_key == "crossattn-adm":
return self.unclip_image_conditioning(source_image)
+ sd = self.sampler.model_wrap.inner_model.model.state_dict()
+ diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
+ if diffusion_model_input is not None:
+ if diffusion_model_input.shape[1] == 9:
+ return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
+
# Dummy zero conditioning if we're not using inpainting or depth model.
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
@@ -422,6 +452,8 @@ class StableDiffusionProcessing:
opts.sdxl_crop_top,
self.width,
self.height,
+ opts.fp8_storage,
+ opts.cache_fp16_weight,
)
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
@@ -596,20 +628,33 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
sample = decode_first_stage(model, batch[i:i + 1])[0]
if check_for_nans:
+
try:
devices.test_for_nans(sample, "vae")
except devices.NansException as e:
- if devices.dtype_vae == torch.float32 or not shared.opts.auto_vae_precision:
+ if shared.opts.auto_vae_precision_bfloat16:
+ autofix_dtype = torch.bfloat16
+ autofix_dtype_text = "bfloat16"
+ autofix_dtype_setting = "Automatically convert VAE to bfloat16"
+ autofix_dtype_comment = ""
+ elif shared.opts.auto_vae_precision:
+ autofix_dtype = torch.float32
+ autofix_dtype_text = "32-bit float"
+ autofix_dtype_setting = "Automatically revert VAE to 32-bit floats"
+ autofix_dtype_comment = "\nTo always start with 32-bit VAE, use --no-half-vae commandline flag."
+ else:
+ raise e
+
+ if devices.dtype_vae == autofix_dtype:
raise e
errors.print_error_explanation(
"A tensor with all NaNs was produced in VAE.\n"
- "Web UI will now convert VAE into 32-bit float and retry.\n"
- "To disable this behavior, disable the 'Automatically revert VAE to 32-bit floats' setting.\n"
- "To always start with 32-bit VAE, use --no-half-vae commandline flag."
+ f"Web UI will now convert VAE into {autofix_dtype_text} and retry.\n"
+ f"To disable this behavior, disable the '{autofix_dtype_setting}' setting.{autofix_dtype_comment}"
)
- devices.dtype_vae = torch.float32
+ devices.dtype_vae = autofix_dtype
model.first_stage_model.to(devices.dtype_vae)
batch = batch.to(devices.dtype_vae)
@@ -679,6 +724,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Size": f"{p.width}x{p.height}",
"Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
"Model": p.sd_model_name if opts.add_model_name_to_info else None,
+ "FP8 weight": opts.fp8_storage if devices.fp8 else None,
+ "Cache FP16 weight for LoRA": opts.cache_fp16_weight if devices.fp8 else None,
"VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None,
"VAE": p.sd_vae_name if opts.add_vae_name_to_info else None,
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
@@ -699,7 +746,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"User": p.user if opts.add_user_name_to_info else None,
}
- generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
+ generation_params_text = ", ".join([k if k == v else f'{k}: {infotext.quote(v)}' for k, v in generation_params.items() if v is not None])
prompt_text = p.main_prompt if use_main_prompt else all_prompts[index]
negative_prompt_text = f"\nNegative prompt: {p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]}" if all_negative_prompts[index] else ""
@@ -818,7 +865,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if state.skipped:
state.skipped = False
- if state.interrupted:
+ if state.interrupted or state.stopping_generation:
break
sd_models.reload_model_weights() # model can be changed for example by refiner
@@ -864,9 +911,42 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
+ def rescale_zero_terminal_snr_abar(alphas_cumprod):
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= (alphas_bar_sqrt_T)
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+ alphas_bar[-1] = 4.8973451890853435e-08
+ return alphas_bar
+
+ if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'):
+ p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device)
+
+ if opts.use_downcasted_alpha_bar:
+ p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
+ p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device)
+ if opts.sd_noise_schedule == "Zero Terminal SNR":
+ p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
+ p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device)
+
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
+ if p.scripts is not None:
+ ps = scripts.PostSampleArgs(samples_ddim)
+ p.scripts.post_sample(p, ps)
+ samples_ddim = ps.samples
+
if getattr(samples_ddim, 'already_decoded', False):
x_samples_ddim = samples_ddim
else:
@@ -922,13 +1002,31 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
pp = scripts.PostprocessImageArgs(image)
p.scripts.postprocess_image(p, pp)
image = pp.image
+
+ mask_for_overlay = getattr(p, "mask_for_overlay", None)
+ overlay_image = p.overlay_images[i] if getattr(p, "overlay_images", None) is not None and i < len(p.overlay_images) else None
+
+ if p.scripts is not None:
+ ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image)
+ p.scripts.postprocess_maskoverlay(p, ppmo)
+ mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image
+
if p.color_corrections is not None and i < len(p.color_corrections):
if save_samples and opts.save_images_before_color_correction:
- image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
+ image_without_cc = apply_overlay(image, p.paste_to, overlay_image)
images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)
- image = apply_overlay(image, p.paste_to, i, p.overlay_images)
+ # If the intention is to show the output from the model
+ # that is being composited over the original image,
+ # we need to keep the original image around
+ # and use it in the composite step.
+ original_denoised_image = image.copy()
+
+ if p.paste_to is not None:
+ original_denoised_image = uncrop(original_denoised_image, (overlay_image.width, overlay_image.height), p.paste_to)
+
+ image = apply_overlay(image, p.paste_to, overlay_image)
if save_samples:
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
@@ -938,16 +1036,17 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.enable_pnginfo:
image.info["parameters"] = text
output_images.append(image)
- if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
+
+ if mask_for_overlay is not None:
if opts.return_mask or opts.save_mask:
- image_mask = p.mask_for_overlay.convert('RGB')
+ image_mask = mask_for_overlay.convert('RGB')
if save_samples and opts.save_mask:
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
if opts.return_mask:
output_images.append(image_mask)
if opts.return_mask_composite or opts.save_mask_composite:
- image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
+ image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
if save_samples and opts.save_mask_composite:
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
if opts.return_mask_composite:
@@ -1025,6 +1124,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
hr_sampler_name: str = None
hr_prompt: str = ''
hr_negative_prompt: str = ''
+ force_task_id: str = None
cached_hr_uc = [None, None]
cached_hr_c = [None, None]
@@ -1097,7 +1197,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
- if self.hr_checkpoint_name:
+ if self.hr_checkpoint_name and self.hr_checkpoint_name != 'Use same checkpoint':
self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
if self.hr_checkpoint_info is None:
@@ -1351,12 +1451,14 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
mask_blur_x: int = 4
mask_blur_y: int = 4
mask_blur: int = None
+ mask_round: bool = True
inpainting_fill: int = 0
inpaint_full_res: bool = True
inpaint_full_res_padding: int = 0
inpainting_mask_invert: int = 0
initial_noise_multiplier: float = None
latent_mask: Image = None
+ force_task_id: str = None
image_mask: Any = field(default=None, init=False)
@@ -1396,7 +1498,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if image_mask is not None:
# image_mask is passed in as RGBA by Gradio to support alpha masks,
# but we still want to support binary masks.
- image_mask = create_binary_mask(image_mask)
+ image_mask = create_binary_mask(image_mask, round=self.mask_round)
if self.inpainting_mask_invert:
image_mask = ImageOps.invert(image_mask)
@@ -1442,7 +1544,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
# Save init image
if opts.save_init_img:
self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
- images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)
+ images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False, existing_info=img.info)
image = images.flatten(img, opts.img2img_background_color)
@@ -1503,7 +1605,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
latmask = latmask[0]
- latmask = np.around(latmask)
+ if self.mask_round:
+ latmask = np.around(latmask)
latmask = np.tile(latmask[None], (4, 1, 1))
self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
@@ -1515,7 +1618,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
- self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask)
+ self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.mask_round)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
x = self.rng.next()
@@ -1527,7 +1630,14 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
if self.mask is not None:
- samples = samples * self.nmask + self.init_latent * self.mask
+ blended_samples = samples * self.nmask + self.init_latent * self.mask
+
+ if self.scripts is not None:
+ mba = scripts.MaskBlendArgs(samples, self.nmask, self.init_latent, self.mask, blended_samples)
+ self.scripts.on_mask_blend(self, mba)
+ blended_samples = mba.blended_latent
+
+ samples = blended_samples
del x
devices.torch_gc()
diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py
index 29ccb78f..e9941413 100644
--- a/modules/processing_scripts/refiner.py
+++ b/modules/processing_scripts/refiner.py
@@ -1,6 +1,7 @@
import gradio as gr
from modules import scripts, sd_models
+from modules.infotext import PasteField
from modules.ui_common import create_refresh_button
from modules.ui_components import InputAccordion
@@ -31,9 +32,9 @@ class ScriptRefiner(scripts.ScriptBuiltinUI):
return None if info is None else info.title
self.infotext_fields = [
- (enable_refiner, lambda d: 'Refiner' in d),
- (refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner'))),
- (refiner_switch_at, 'Refiner switch at'),
+ PasteField(enable_refiner, lambda d: 'Refiner' in d),
+ PasteField(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner')), api="refiner_checkpoint"),
+ PasteField(refiner_switch_at, 'Refiner switch at', api="refiner_switch_at"),
]
return enable_refiner, refiner_checkpoint, refiner_switch_at
diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py
index dc9c2da5..60293278 100644
--- a/modules/processing_scripts/seed.py
+++ b/modules/processing_scripts/seed.py
@@ -3,6 +3,7 @@ import json
import gradio as gr
from modules import scripts, ui, errors
+from modules.infotext import PasteField
from modules.shared import cmd_opts
from modules.ui_components import ToolButton
@@ -51,12 +52,12 @@ class ScriptSeed(scripts.ScriptBuiltinUI):
seed_checkbox.change(lambda x: gr.update(visible=x), show_progress=False, inputs=[seed_checkbox], outputs=[seed_extras])
self.infotext_fields = [
- (self.seed, "Seed"),
- (seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
- (subseed, "Variation seed"),
- (subseed_strength, "Variation seed strength"),
- (seed_resize_from_w, "Seed resize from-1"),
- (seed_resize_from_h, "Seed resize from-2"),
+ PasteField(self.seed, "Seed", api="seed"),
+ PasteField(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
+ PasteField(subseed, "Variation seed", api="subseed"),
+ PasteField(subseed_strength, "Variation seed strength", api="subseed_strength"),
+ PasteField(seed_resize_from_w, "Seed resize from-1", api="seed_resize_from_h"),
+ PasteField(seed_resize_from_h, "Seed resize from-2", api="seed_resize_from_w"),
]
self.on_after_component(lambda x: connect_reuse_seed(self.seed, reuse_seed, x.component, False), elem_id=f'generation_info_{self.tabname}')
diff --git a/modules/progress.py b/modules/progress.py
index 69921de7..85255e82 100644
--- a/modules/progress.py
+++ b/modules/progress.py
@@ -8,10 +8,13 @@ from pydantic import BaseModel, Field
from modules.shared import opts
import modules.shared as shared
-
+from collections import OrderedDict
+import string
+import random
+from typing import List
current_task = None
-pending_tasks = {}
+pending_tasks = OrderedDict()
finished_tasks = []
recorded_results = []
recorded_results_limit = 2
@@ -34,6 +37,11 @@ def finish_task(id_task):
if len(finished_tasks) > 16:
finished_tasks.pop(0)
+def create_task_id(task_type):
+ N = 7
+ res = ''.join(random.choices(string.ascii_uppercase +
+ string.digits, k=N))
+ return f"task({task_type}-{res})"
def record_results(id_task, res):
recorded_results.append((id_task, res))
@@ -44,6 +52,9 @@ def record_results(id_task, res):
def add_task_to_queue(id_job):
pending_tasks[id_job] = time.time()
+class PendingTasksResponse(BaseModel):
+ size: int = Field(title="Pending task size")
+ tasks: List[str] = Field(title="Pending task ids")
class ProgressRequest(BaseModel):
id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
@@ -63,9 +74,16 @@ class ProgressResponse(BaseModel):
def setup_progress_api(app):
+ app.add_api_route("/internal/pending-tasks", get_pending_tasks, methods=["GET"])
return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
+def get_pending_tasks():
+ pending_tasks_ids = list(pending_tasks)
+ pending_len = len(pending_tasks_ids)
+ return PendingTasksResponse(size=pending_len, tasks=pending_tasks_ids)
+
+
def progressapi(req: ProgressRequest):
active = req.id_task == current_task
queued = req.id_task in pending_tasks
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index 02841c30..4d35b695 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -1,12 +1,9 @@
import os
-import numpy as np
-from PIL import Image
-from realesrgan import RealESRGANer
-
-from modules.upscaler import Upscaler, UpscalerData
-from modules.shared import cmd_opts, opts
from modules import modelloader, errors
+from modules.shared import cmd_opts, opts
+from modules.upscaler import Upscaler, UpscalerData
+from modules.upscaler_utils import upscale_with_model
class UpscalerRealESRGAN(Upscaler):
@@ -14,29 +11,20 @@ class UpscalerRealESRGAN(Upscaler):
self.name = "RealESRGAN"
self.user_path = path
super().__init__()
- try:
- from basicsr.archs.rrdbnet_arch import RRDBNet # noqa: F401
- from realesrgan import RealESRGANer # noqa: F401
- from realesrgan.archs.srvgg_arch import SRVGGNetCompact # noqa: F401
- self.enable = True
- self.scalers = []
- scalers = self.load_models(path)
+ self.enable = True
+ self.scalers = []
+ scalers = get_realesrgan_models(self)
- local_model_paths = self.find_models(ext_filter=[".pth"])
- for scaler in scalers:
- if scaler.local_data_path.startswith("http"):
- filename = modelloader.friendly_name(scaler.local_data_path)
- local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")]
- if local_model_candidates:
- scaler.local_data_path = local_model_candidates[0]
+ local_model_paths = self.find_models(ext_filter=[".pth"])
+ for scaler in scalers:
+ if scaler.local_data_path.startswith("http"):
+ filename = modelloader.friendly_name(scaler.local_data_path)
+ local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")]
+ if local_model_candidates:
+ scaler.local_data_path = local_model_candidates[0]
- if scaler.name in opts.realesrgan_enabled_models:
- self.scalers.append(scaler)
-
- except Exception:
- errors.report("Error importing Real-ESRGAN", exc_info=True)
- self.enable = False
- self.scalers = []
+ if scaler.name in opts.realesrgan_enabled_models:
+ self.scalers.append(scaler)
def do_upscale(self, img, path):
if not self.enable:
@@ -48,20 +36,19 @@ class UpscalerRealESRGAN(Upscaler):
errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
return img
- upsampler = RealESRGANer(
- scale=info.scale,
- model_path=info.local_data_path,
- model=info.model(),
- half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
- tile=opts.ESRGAN_tile,
- tile_pad=opts.ESRGAN_tile_overlap,
+ model_descriptor = modelloader.load_spandrel_model(
+ info.local_data_path,
device=self.device,
+ half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
+ expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel
+ )
+ return upscale_with_model(
+ model_descriptor,
+ img,
+ tile_size=opts.ESRGAN_tile,
+ tile_overlap=opts.ESRGAN_tile_overlap,
+ # TODO: `outscale`?
)
-
- upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
-
- image = Image.fromarray(upsampled)
- return image
def load_model(self, path):
for scaler in self.scalers:
@@ -76,58 +63,43 @@ class UpscalerRealESRGAN(Upscaler):
return scaler
raise ValueError(f"Unable to find model info: {path}")
- def load_models(self, _):
- return get_realesrgan_models(self)
-
-def get_realesrgan_models(scaler):
- try:
- from basicsr.archs.rrdbnet_arch import RRDBNet
- from realesrgan.archs.srvgg_arch import SRVGGNetCompact
- models = [
- UpscalerData(
- name="R-ESRGAN General 4xV3",
- path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
- scale=4,
- upscaler=scaler,
- model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
- ),
- UpscalerData(
- name="R-ESRGAN General WDN 4xV3",
- path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
- scale=4,
- upscaler=scaler,
- model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
- ),
- UpscalerData(
- name="R-ESRGAN AnimeVideo",
- path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
- scale=4,
- upscaler=scaler,
- model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
- ),
- UpscalerData(
- name="R-ESRGAN 4x+",
- path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
- scale=4,
- upscaler=scaler,
- model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
- ),
- UpscalerData(
- name="R-ESRGAN 4x+ Anime6B",
- path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
- scale=4,
- upscaler=scaler,
- model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
- ),
- UpscalerData(
- name="R-ESRGAN 2x+",
- path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
- scale=2,
- upscaler=scaler,
- model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
- ),
- ]
- return models
- except Exception:
- errors.report("Error making Real-ESRGAN models list", exc_info=True)
+def get_realesrgan_models(scaler: UpscalerRealESRGAN):
+ return [
+ UpscalerData(
+ name="R-ESRGAN General 4xV3",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
+ scale=4,
+ upscaler=scaler,
+ ),
+ UpscalerData(
+ name="R-ESRGAN General WDN 4xV3",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
+ scale=4,
+ upscaler=scaler,
+ ),
+ UpscalerData(
+ name="R-ESRGAN AnimeVideo",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
+ scale=4,
+ upscaler=scaler,
+ ),
+ UpscalerData(
+ name="R-ESRGAN 4x+",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
+ scale=4,
+ upscaler=scaler,
+ ),
+ UpscalerData(
+ name="R-ESRGAN 4x+ Anime6B",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
+ scale=4,
+ upscaler=scaler,
+ ),
+ UpscalerData(
+ name="R-ESRGAN 2x+",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
+ scale=2,
+ upscaler=scaler,
+ ),
+ ]
diff --git a/modules/scripts.py b/modules/scripts.py
index 7f9454eb..017aed5a 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -11,11 +11,31 @@ from modules import shared, paths, script_callbacks, extensions, script_loading,
AlwaysVisible = object()
+class MaskBlendArgs:
+ def __init__(self, current_latent, nmask, init_latent, mask, blended_latent, denoiser=None, sigma=None):
+ self.current_latent = current_latent
+ self.nmask = nmask
+ self.init_latent = init_latent
+ self.mask = mask
+ self.blended_latent = blended_latent
+
+ self.denoiser = denoiser
+ self.is_final_blend = denoiser is None
+ self.sigma = sigma
+
+class PostSampleArgs:
+ def __init__(self, samples):
+ self.samples = samples
class PostprocessImageArgs:
def __init__(self, image):
self.image = image
+class PostProcessMaskOverlayArgs:
+ def __init__(self, index, mask_for_overlay, overlay_image):
+ self.index = index
+ self.mask_for_overlay = mask_for_overlay
+ self.overlay_image = overlay_image
class PostprocessBatchListArgs:
def __init__(self, images):
@@ -206,6 +226,25 @@ class Script:
pass
+ def on_mask_blend(self, p, mba: MaskBlendArgs, *args):
+ """
+ Called in inpainting mode when the original content is blended with the inpainted content.
+ This is called at every step in the denoising process and once at the end.
+ If is_final_blend is true, this is called for the final blending stage.
+ Otherwise, denoiser and sigma are defined and may be used to inform the procedure.
+ """
+
+ pass
+
+ def post_sample(self, p, ps: PostSampleArgs, *args):
+ """
+ Called after the samples have been generated,
+ but before they have been decoded by the VAE, if applicable.
+ Check getattr(samples, 'already_decoded', False) to test if the images are decoded.
+ """
+
+ pass
+
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
"""
Called for every image after it has been generated.
@@ -213,6 +252,13 @@ class Script:
pass
+ def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs, *args):
+ """
+ Called for every image after it has been generated.
+ """
+
+ pass
+
def postprocess(self, p, processed, *args):
"""
This function is called after processing ends for AlwaysVisible scripts.
@@ -520,7 +566,12 @@ class ScriptRunner:
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
for script_data in auto_processing_scripts + scripts_data:
- script = script_data.script_class()
+ try:
+ script = script_data.script_class()
+ except Exception:
+ errors.report(f"Error # failed to initialize Script {script_data.module}: ", exc_info=True)
+ continue
+
script.filename = script_data.path
script.is_txt2img = not is_img2img
script.is_img2img = is_img2img
@@ -645,6 +696,8 @@ class ScriptRunner:
self.setup_ui_for_section(None, self.selectable_scripts)
def select_script(script_index):
+ if script_index is None:
+ script_index = 0
selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
@@ -688,7 +741,7 @@ class ScriptRunner:
def run(self, p, *args):
script_index = args[0]
- if script_index == 0:
+ if script_index == 0 or script_index is None:
return None
script = self.selectable_scripts[script_index-1]
@@ -767,6 +820,22 @@ class ScriptRunner:
except Exception:
errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
+ def post_sample(self, p, ps: PostSampleArgs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.post_sample(p, ps, *script_args)
+ except Exception:
+ errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
+
+ def on_mask_blend(self, p, mba: MaskBlendArgs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.on_mask_blend(p, mba, *script_args)
+ except Exception:
+ errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
+
def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts:
try:
@@ -775,6 +844,14 @@ class ScriptRunner:
except Exception:
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
+ def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.postprocess_maskoverlay(p, ppmo, *script_args)
+ except Exception:
+ errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
+
def before_component(self, component, **kwargs):
for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
try:
diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py
index 8863107a..273a7edd 100644
--- a/modules/sd_disable_initialization.py
+++ b/modules/sd_disable_initialization.py
@@ -215,7 +215,7 @@ class LoadStateDictOnMeta(ReplaceHelper):
would be on the meta device.
"""
- if state_dict == sd:
+ if state_dict is sd:
state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
original(module, state_dict, strict=strict)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 9355f1e1..50bc209e 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -348,10 +348,28 @@ class SkipWritingToConfig:
SkipWritingToConfig.skip = self.previous
+def check_fp8(model):
+ if model is None:
+ return None
+ if devices.get_optimal_device_name() == "mps":
+ enable_fp8 = False
+ elif shared.opts.fp8_storage == "Enable":
+ enable_fp8 = True
+ elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL":
+ enable_fp8 = True
+ else:
+ enable_fp8 = False
+ return enable_fp8
+
+
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
+ if devices.fp8:
+ # prevent model to load state dict in fp8
+ model.half()
+
if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
@@ -383,6 +401,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if shared.cmd_opts.no_half:
model.float()
+ model.alphas_cumprod_original = model.alphas_cumprod
devices.dtype_unet = torch.float32
timer.record("apply float()")
else:
@@ -396,7 +415,11 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if shared.cmd_opts.upcast_sampling and depth_model:
model.depth_model = None
+ alphas_cumprod = model.alphas_cumprod
+ model.alphas_cumprod = None
model.half()
+ model.alphas_cumprod = alphas_cumprod
+ model.alphas_cumprod_original = alphas_cumprod
model.first_stage_model = vae
if depth_model:
model.depth_model = depth_model
@@ -404,6 +427,28 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
devices.dtype_unet = torch.float16
timer.record("apply half()")
+ for module in model.modules():
+ if hasattr(module, 'fp16_weight'):
+ del module.fp16_weight
+ if hasattr(module, 'fp16_bias'):
+ del module.fp16_bias
+
+ if check_fp8(model):
+ devices.fp8 = True
+ first_stage = model.first_stage_model
+ model.first_stage_model = None
+ for module in model.modules():
+ if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
+ if shared.opts.cache_fp16_weight:
+ module.fp16_weight = module.weight.data.clone().cpu().half()
+ if module.bias is not None:
+ module.fp16_bias = module.bias.data.clone().cpu().half()
+ module.to(torch.float8_e4m3fn)
+ model.first_stage_model = first_stage
+ timer.record("apply fp8")
+ else:
+ devices.fp8 = False
+
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae)
@@ -651,6 +696,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
else:
weight_dtype_conversion = {
'first_stage_model': None,
+ 'alphas_cumprod': None,
'': torch.float16,
}
@@ -746,7 +792,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
return None
-def reload_model_weights(sd_model=None, info=None):
+def reload_model_weights(sd_model=None, info=None, forced_reload=False):
checkpoint_info = info or select_checkpoint()
timer = Timer()
@@ -758,11 +804,14 @@ def reload_model_weights(sd_model=None, info=None):
current_checkpoint_info = None
else:
current_checkpoint_info = sd_model.sd_checkpoint_info
- if sd_model.sd_model_checkpoint == checkpoint_info.filename:
+ if check_fp8(sd_model) != devices.fp8:
+ # load from state dict again to prevent extra numerical errors
+ forced_reload = True
+ elif sd_model.sd_model_checkpoint == checkpoint_info.filename and not forced_reload:
return sd_model
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
- if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
+ if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
return sd_model
if sd_model is not None:
diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py
index deab2f6e..b38137eb 100644
--- a/modules/sd_models_config.py
+++ b/modules/sd_models_config.py
@@ -15,6 +15,7 @@ config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
+config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml")
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
@@ -71,7 +72,10 @@ def guess_model_config_from_state_dict(sd, filename):
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
- return config_sdxl
+ if diffusion_model_input.shape[1] == 9:
+ return config_sdxl_inpainting
+ else:
+ return config_sdxl
if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
return config_sdxl_refiner
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py
index 01123321..0de17af3 100644
--- a/modules/sd_models_xl.py
+++ b/modules/sd_models_xl.py
@@ -6,6 +6,7 @@ import sgm.models.diffusion
import sgm.modules.diffusionmodules.denoiser_scaling
import sgm.modules.diffusionmodules.discretizer
from modules import devices, shared, prompt_parser
+from modules import torch_utils
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
@@ -34,6 +35,12 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
+ sd = self.model.state_dict()
+ diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
+ if diffusion_model_input is not None:
+ if diffusion_model_input.shape[1] == 9:
+ x = torch.cat([x] + cond['c_concat'], dim=1)
+
return self.model(x, t, cond)
@@ -84,7 +91,7 @@ sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt
def extend_sdxl(model):
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
- dtype = next(model.model.diffusion_model.parameters()).dtype
+ dtype = torch_utils.get_param(model.model.diffusion_model).dtype
model.model.diffusion_model.dtype = dtype
model.model.conditioning_key = 'crossattn'
model.cond_stage_key = 'txt'
@@ -93,7 +100,7 @@ def extend_sdxl(model):
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
- model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
+ model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32)
model.conditioner.wrapped = torch.nn.Module()
diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py
index b8101d38..eb9d5daf 100644
--- a/modules/sd_samplers_cfg_denoiser.py
+++ b/modules/sd_samplers_cfg_denoiser.py
@@ -56,6 +56,9 @@ class CFGDenoiser(torch.nn.Module):
self.sampler = sampler
self.model_wrap = None
self.p = None
+
+ # NOTE: masking before denoising can cause the original latents to be oversmoothed
+ # as the original latents do not have noise
self.mask_before_denoising = False
@property
@@ -105,8 +108,21 @@ class CFGDenoiser(torch.nn.Module):
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
+ # If we use masks, blending between the denoised and original latent images occurs here.
+ def apply_blend(current_latent):
+ blended_latent = current_latent * self.nmask + self.init_latent * self.mask
+
+ if self.p.scripts is not None:
+ from modules import scripts
+ mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma)
+ self.p.scripts.on_mask_blend(self.p, mba)
+ blended_latent = mba.blended_latent
+
+ return blended_latent
+
+ # Blend in the original latents (before)
if self.mask_before_denoising and self.mask is not None:
- x = self.init_latent * self.mask + self.nmask * x
+ x = apply_blend(x)
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
@@ -207,8 +223,9 @@ class CFGDenoiser(torch.nn.Module):
else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
+ # Blend in the original latents (after)
if not self.mask_before_denoising and self.mask is not None:
- denoised = self.init_latent * self.mask + self.nmask * denoised
+ denoised = apply_blend(denoised)
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py
index b17a8f93..777dd8d0 100644
--- a/modules/sd_samplers_timesteps.py
+++ b/modules/sd_samplers_timesteps.py
@@ -36,7 +36,7 @@ class CompVisTimestepsVDenoiser(torch.nn.Module):
self.inner_model = model
def predict_eps_from_z_and_v(self, x_t, t, v):
- return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t
+ return torch.sqrt(self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * v + torch.sqrt(1 - self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * x_t
def forward(self, input, timesteps, **kwargs):
model_output = self.inner_model.apply_model(input, timesteps, **kwargs)
@@ -80,6 +80,7 @@ class CompVisSampler(sd_samplers_common.Sampler):
self.eta_default = 0.0
self.model_wrap_cfg = CFGDenoiserTimesteps(self)
+ self.model_wrap = self.model_wrap_cfg.inner_model
def get_timesteps(self, p, steps):
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
diff --git a/modules/shared_items.py b/modules/shared_items.py
index 991971ad..e1392472 100644
--- a/modules/shared_items.py
+++ b/modules/shared_items.py
@@ -67,14 +67,14 @@ def reload_hypernetworks():
def get_infotext_names():
- from modules import generation_parameters_copypaste, shared
+ from modules import infotext, shared
res = {}
for info in shared.opts.data_labels.values():
if info.infotext:
res[info.infotext] = 1
- for tab_data in generation_parameters_copypaste.paste_fields.values():
+ for tab_data in infotext.paste_fields.values():
for _, name in tab_data.get("fields") or []:
if isinstance(name, str):
res[name] = 1
diff --git a/modules/shared_options.py b/modules/shared_options.py
index 86e7636c..cca3f7be 100644
--- a/modules/shared_options.py
+++ b/modules/shared_options.py
@@ -1,7 +1,8 @@
+import os
import gradio as gr
-from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes
-from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
+from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes, util
+from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir, default_output_dir # noqa: F401
from modules.shared_cmd_options import cmd_opts
from modules.options import options_section, OptionInfo, OptionHTML, categories
@@ -74,14 +75,14 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
options_templates.update(options_section(('saving-paths', "Paths for saving", "saving"), {
"outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
- "outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
- "outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
- "outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs),
+ "outdir_txt2img_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'txt2img-images')), 'Output directory for txt2img images', component_args=hide_dirs),
+ "outdir_img2img_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'img2img-images')), 'Output directory for img2img images', component_args=hide_dirs),
+ "outdir_extras_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'extras-images')), 'Output directory for images from extras tab', component_args=hide_dirs),
"outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs),
- "outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
- "outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
- "outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
- "outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
+ "outdir_txt2img_grids": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'txt2img-grids')), 'Output directory for txt2img grids', component_args=hide_dirs),
+ "outdir_img2img_grids": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'img2img-grids')), 'Output directory for img2img grids', component_args=hide_dirs),
+ "outdir_save": OptionInfo(util.truncate_path(os.path.join(data_path, 'log', 'images')), "Directory for saving images using the Save button", component_args=hide_dirs),
+ "outdir_init_images": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'init-images')), "Directory for saving init images when using img2img", component_args=hide_dirs),
}))
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory", "saving"), {
@@ -176,6 +177,7 @@ For img2img, VAE is used to process user's input image before the sampling, and
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list, infotext='VAE').info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
"sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"),
+ "auto_vae_precision_bfloat16": OptionInfo(False, "Automatically convert VAE to bfloat16").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image; if enabled, overrides the option below"),
"auto_vae_precision": OptionInfo(True, "Automatically revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
"sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Encoder').info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"),
@@ -206,9 +208,12 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd"
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
+ "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."),
+ "cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."),
}))
options_templates.update(options_section(('compatibility', "Compatibility", "sd"), {
+ "auto_backcompat": OptionInfo(True, "Automatic backward compatibility").info("automatically enable options for backwards compatibility when importing generation parameters from infotext that has program version."),
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
@@ -216,6 +221,7 @@ options_templates.update(options_section(('compatibility', "Compatibility", "sd"
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
"use_old_scheduling": OptionInfo(False, "Use old prompt editing timelines.", infotext="Old prompt editing timelines").info("For [red:green:N]; old: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"),
+ "use_downcasted_alpha_bar": OptionInfo(False, "Downcast model alphas_cumprod to fp16 before sampling. For reproducing old seeds.", infotext="Downcast alphas_cumprod")
}))
options_templates.update(options_section(('interrogate', "Interrogate"), {
@@ -256,6 +262,7 @@ options_templates.update(options_section(('ui_prompt_editing', "Prompt editing",
"keyedit_precision_extra": OptionInfo(0.05, "Precision for <extra networks:0.9> when editing the prompt with Ctrl+up/down", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Word delimiters when editing the prompt with Ctrl+up/down"),
"keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}),
+ "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
}))
@@ -280,6 +287,7 @@ options_templates.update(options_section(('ui_alternatives', "UI alternatives",
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(),
"txt2img_settings_accordion": OptionInfo(False, "Settings in txt2img hidden under Accordion").needs_reload_ui(),
"img2img_settings_accordion": OptionInfo(False, "Settings in img2img hidden under Accordion").needs_reload_ui(),
+ "interrupt_after_current": OptionInfo(True, "Don't Interrupt in the middle").info("when using Interrupt button, if generating more than one image, stop after the generation of an image has finished, instead of immediately"),
}))
options_templates.update(options_section(('ui', "User interface", "ui"), {
@@ -332,6 +340,7 @@ options_templates.update(options_section(('ui', "Live previews", "ui"), {
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
"live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
"live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"),
+ "js_live_preview_in_modal_lightbox": OptionInfo(False, "Show Live preview in full page image viewer"),
}))
options_templates.update(options_section(('sampler-params', "Sampler parameters", "sd"), {
@@ -354,6 +363,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}, infotext='UniPC skip type'),
'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, infotext='UniPC order').info("must be < sampling steps"),
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'),
+ 'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise Schedule").info("for use with zero terminal SNR trained models")
}))
options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), {
diff --git a/modules/shared_state.py b/modules/shared_state.py
index a68789cc..33996691 100644
--- a/modules/shared_state.py
+++ b/modules/shared_state.py
@@ -12,6 +12,7 @@ log = logging.getLogger(__name__)
class State:
skipped = False
interrupted = False
+ stopping_generation = False
job = ""
job_no = 0
job_count = 0
@@ -79,6 +80,10 @@ class State:
self.interrupted = True
log.info("Received interrupt request")
+ def stop_generating(self):
+ self.stopping_generation = True
+ log.info("Received stop generating request")
+
def nextjob(self):
if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1:
self.do_set_current_image()
@@ -91,6 +96,7 @@ class State:
obj = {
"skipped": self.skipped,
"interrupted": self.interrupted,
+ "stopping_generation": self.stopping_generation,
"job": self.job,
"job_count": self.job_count,
"job_timestamp": self.job_timestamp,
@@ -114,6 +120,7 @@ class State:
self.id_live_preview = 0
self.skipped = False
self.interrupted = False
+ self.stopping_generation = False
self.textinfo = None
self.job = job
devices.torch_gc()
diff --git a/modules/styles.py b/modules/styles.py
index 7fb6c2e1..026c4300 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -30,38 +30,29 @@ def apply_styles_to_prompt(prompt, styles):
return prompt
-def unwrap_style_text_from_prompt(style_text, prompt):
- """
- Checks the prompt to see if the style text is wrapped around it. If so,
- returns True plus the prompt text without the style text. Otherwise, returns
- False with the original prompt.
+def extract_style_text_from_prompt(style_text, prompt):
+ """This function extracts the text from a given prompt based on a provided style text. It checks if the style text contains the placeholder {prompt} or if it appears at the end of the prompt. If a match is found, it returns True along with the extracted text. Otherwise, it returns False and the original prompt.
- Note that the "cleaned" version of the style text is only used for matching
- purposes here. It isn't returned; the original style text is not modified.
+ extract_style_text_from_prompt("masterpiece", "1girl, art by greg, masterpiece") outputs (True, "1girl, art by greg")
+ extract_style_text_from_prompt("masterpiece, {prompt}", "masterpiece, 1girl, art by greg") outputs (True, "1girl, art by greg")
+ extract_style_text_from_prompt("masterpiece, {prompt}", "exquisite, 1girl, art by greg") outputs (False, "exquisite, 1girl, art by greg")
"""
- stripped_prompt = prompt
- stripped_style_text = style_text
+
+ stripped_prompt = prompt.strip()
+ stripped_style_text = style_text.strip()
+
if "{prompt}" in stripped_style_text:
- # Work out whether the prompt is wrapped in the style text. If so, we
- # return True and the "inner" prompt text that isn't part of the style.
- try:
- left, right = stripped_style_text.split("{prompt}", 2)
- except ValueError as e:
- # If the style text has multple "{prompt}"s, we can't split it into
- # two parts. This is an error, but we can't do anything about it.
- print(f"Unable to compare style text to prompt:\n{style_text}")
- print(f"Error: {e}")
- return False, prompt
+ left, right = stripped_style_text.split("{prompt}", 2)
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
- prompt = stripped_prompt[len(left) : len(stripped_prompt) - len(right)]
+ prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
return True, prompt
else:
- # Work out whether the given prompt ends with the style text. If so, we
- # return True and the prompt text up to where the style text starts.
if stripped_prompt.endswith(stripped_style_text):
- prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)]
- if prompt.endswith(", "):
+ prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]
+
+ if prompt.endswith(', '):
prompt = prompt[:-2]
+
return True, prompt
return False, prompt
@@ -76,15 +67,11 @@ def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
if not style.prompt and not style.negative_prompt:
return False, prompt, negative_prompt
- match_positive, extracted_positive = unwrap_style_text_from_prompt(
- style.prompt, prompt
- )
+ match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)
if not match_positive:
return False, prompt, negative_prompt
- match_negative, extracted_negative = unwrap_style_text_from_prompt(
- style.negative_prompt, negative_prompt
- )
+ match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)
if not match_negative:
return False, prompt, negative_prompt
@@ -98,10 +85,8 @@ class StyleDatabase:
self.path = path
folder, file = os.path.split(self.path)
- self.default_file = file.split("*")[0] + ".csv"
- if self.default_file == ".csv":
- self.default_file = "styles.csv"
- self.default_path = os.path.join(folder, self.default_file)
+ filename, _, ext = file.partition('*')
+ self.default_path = os.path.join(folder, filename + ext)
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
@@ -155,10 +140,8 @@ class StyleDatabase:
row["name"], prompt, negative_prompt, path
)
- def get_style_paths(self) -> list():
- """
- Returns a list of all distinct paths, including the default path, of
- files that styles are loaded from."""
+ def get_style_paths(self) -> set:
+ """Returns a set of all distinct paths of files that styles are loaded from."""
# Update any styles without a path to the default path
for style in list(self.styles.values()):
if not style.path:
@@ -172,9 +155,9 @@ class StyleDatabase:
style_paths.add(style.path)
# Remove any paths for styles that are just list dividers
- style_paths.remove("do_not_save")
+ style_paths.discard("do_not_save")
- return list(style_paths)
+ return style_paths
def get_style_prompts(self, styles):
return [self.styles.get(x, self.no_style).prompt for x in styles]
@@ -196,20 +179,7 @@ class StyleDatabase:
# The path argument is deprecated, but kept for backwards compatibility
_ = path
- # Update any styles without a path to the default path
- for style in list(self.styles.values()):
- if not style.path:
- self.styles[style.name] = style._replace(path=self.default_path)
-
- # Create a list of all distinct paths, including the default path
- style_paths = set()
- style_paths.add(self.default_path)
- for _, style in self.styles.items():
- if style.path:
- style_paths.add(style.path)
-
- # Remove any paths for styles that are just list dividers
- style_paths.remove("do_not_save")
+ style_paths = self.get_style_paths()
csv_names = [os.path.split(path)[1].lower() for path in style_paths]
diff --git a/modules/sysinfo.py b/modules/sysinfo.py
index b669edd0..5abf616b 100644
--- a/modules/sysinfo.py
+++ b/modules/sysinfo.py
@@ -26,11 +26,9 @@ environment_whitelist = {
"OPENCLIP_PACKAGE",
"STABLE_DIFFUSION_REPO",
"K_DIFFUSION_REPO",
- "CODEFORMER_REPO",
"BLIP_REPO",
"STABLE_DIFFUSION_COMMIT_HASH",
"K_DIFFUSION_COMMIT_HASH",
- "CODEFORMER_COMMIT_HASH",
"BLIP_COMMIT_HASH",
"COMMANDLINE_ARGS",
"IGNORE_CMD_ARGS_ERRORS",
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 04dda585..c6bcab15 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -11,7 +11,6 @@ import safetensors.torch
import numpy as np
from PIL import Image, PngImagePlugin
-from torch.utils.tensorboard import SummaryWriter
from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
import modules.textual_inversion.dataset
@@ -344,6 +343,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
})
def tensorboard_setup(log_directory):
+ from torch.utils.tensorboard import SummaryWriter
os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
return SummaryWriter(
log_dir=os.path.join(log_directory, "tensorboard"),
@@ -448,8 +448,12 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
old_parallel_processing_allowed = shared.parallel_processing_allowed
+ tensorboard_writer = None
if shared.opts.training_enable_tensorboard:
- tensorboard_writer = tensorboard_setup(log_directory)
+ try:
+ tensorboard_writer = tensorboard_setup(log_directory)
+ except ImportError:
+ errors.report("Error initializing tensorboard", exc_info=True)
pin_memory = shared.opts.pin_memory
@@ -622,7 +626,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
- if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
+ if tensorboard_writer and shared.opts.training_tensorboard_save_images:
tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
diff --git a/modules/torch_utils.py b/modules/torch_utils.py
new file mode 100644
index 00000000..e5b52393
--- /dev/null
+++ b/modules/torch_utils.py
@@ -0,0 +1,17 @@
+from __future__ import annotations
+
+import torch.nn
+
+
+def get_param(model) -> torch.nn.Parameter:
+ """
+ Find the first parameter in a model or module.
+ """
+ if hasattr(model, "model") and hasattr(model.model, "parameters"):
+ # Unpeel a model descriptor to get at the actual Torch module.
+ model = model.model
+
+ for param in model.parameters():
+ return param
+
+ raise ValueError(f"No parameters found in model {model!r}")
diff --git a/modules/txt2img.py b/modules/txt2img.py
index e4e18ceb..3a481915 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -2,7 +2,7 @@ from contextlib import closing
import modules.scripts
from modules import processing
-from modules.generation_parameters_copypaste import create_override_settings_dict
+from modules.infotext import create_override_settings_dict
from modules.shared import opts
import modules.shared as shared
from modules.ui import plaintext_to_html
diff --git a/modules/ui.py b/modules/ui.py
index d80486dd..378529c7 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -21,14 +21,14 @@ from modules.ui_gradio_extensions import reload_javascript
from modules.shared import opts, cmd_opts
-import modules.generation_parameters_copypaste as parameters_copypaste
+import modules.infotext as parameters_copypaste
import modules.hypernetworks.ui as hypernetworks_ui
import modules.textual_inversion.ui as textual_inversion_ui
import modules.textual_inversion.textual_inversion as textual_inversion
import modules.shared as shared
from modules import prompt_parser
from modules.sd_hijack import model_hijack
-from modules.generation_parameters_copypaste import image_from_url_text
+from modules.infotext import image_from_url_text, PasteField
create_setting_component = ui_settings.create_setting_component
@@ -177,7 +177,6 @@ def update_negative_prompt_token_counter(text, steps):
return update_token_counter(text, steps, is_positive=False)
-
def setup_progressbar(*args, **kwargs):
pass
@@ -436,28 +435,28 @@ def create_ui():
)
txt2img_paste_fields = [
- (toprow.prompt, "Prompt"),
- (toprow.negative_prompt, "Negative prompt"),
- (steps, "Steps"),
- (sampler_name, "Sampler"),
- (cfg_scale, "CFG scale"),
- (width, "Size-1"),
- (height, "Size-2"),
- (batch_size, "Batch size"),
- (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
- (denoising_strength, "Denoising strength"),
- (enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)),
- (hr_scale, "Hires upscale"),
- (hr_upscaler, "Hires upscaler"),
- (hr_second_pass_steps, "Hires steps"),
- (hr_resize_x, "Hires resize-1"),
- (hr_resize_y, "Hires resize-2"),
- (hr_checkpoint_name, "Hires checkpoint"),
- (hr_sampler_name, "Hires sampler"),
- (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
- (hr_prompt, "Hires prompt"),
- (hr_negative_prompt, "Hires negative prompt"),
- (hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
+ PasteField(toprow.prompt, "Prompt", api="prompt"),
+ PasteField(toprow.negative_prompt, "Negative prompt", api="negative_prompt"),
+ PasteField(steps, "Steps", api="steps"),
+ PasteField(sampler_name, "Sampler", api="sampler_name"),
+ PasteField(cfg_scale, "CFG scale", api="cfg_scale"),
+ PasteField(width, "Size-1", api="width"),
+ PasteField(height, "Size-2", api="height"),
+ PasteField(batch_size, "Batch size", api="batch_size"),
+ PasteField(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update(), api="styles"),
+ PasteField(denoising_strength, "Denoising strength", api="denoising_strength"),
+ PasteField(enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d), api="enable_hr"),
+ PasteField(hr_scale, "Hires upscale", api="hr_scale"),
+ PasteField(hr_upscaler, "Hires upscaler", api="hr_upscaler"),
+ PasteField(hr_second_pass_steps, "Hires steps", api="hr_second_pass_steps"),
+ PasteField(hr_resize_x, "Hires resize-1", api="hr_resize_x"),
+ PasteField(hr_resize_y, "Hires resize-2", api="hr_resize_y"),
+ PasteField(hr_checkpoint_name, "Hires checkpoint", api="hr_checkpoint_name"),
+ PasteField(hr_sampler_name, "Hires sampler", api="hr_sampler_name"),
+ PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
+ PasteField(hr_prompt, "Hires prompt", api="hr_prompt"),
+ PasteField(hr_negative_prompt, "Hires negative prompt", api="hr_negative_prompt"),
+ PasteField(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
*scripts.scripts_txt2img.infotext_fields
]
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
@@ -1086,6 +1085,7 @@ def create_ui():
)
loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)
+ ui_settings_from_file = loadsave.ui_settings.copy()
settings = ui_settings.UiSettings()
settings.create_ui(loadsave, dummy_component)
@@ -1146,7 +1146,8 @@ def create_ui():
modelmerger_ui.setup_ui(dummy_component=dummy_component, sd_model_checkpoint_component=settings.component_dict['sd_model_checkpoint'])
- loadsave.dump_defaults()
+ if ui_settings_from_file != loadsave.ui_settings:
+ loadsave.dump_defaults()
demo.ui_loadsave = loadsave
return demo
diff --git a/modules/ui_common.py b/modules/ui_common.py
index 032ec4af..fd32676f 100644
--- a/modules/ui_common.py
+++ b/modules/ui_common.py
@@ -8,10 +8,10 @@ import gradio as gr
import subprocess as sp
from modules import call_queue, shared
-from modules.generation_parameters_copypaste import image_from_url_text
+from modules.infotext import image_from_url_text
import modules.images
from modules.ui_components import ToolButton
-import modules.generation_parameters_copypaste as parameters_copypaste
+import modules.infotext as parameters_copypaste
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index fe5d3ba3..790af135 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -10,7 +10,7 @@ import json
import html
from fastapi.exceptions import HTTPException
-from modules.generation_parameters_copypaste import image_from_url_text
+from modules.infotext import image_from_url_text
from modules.ui_components import ToolButton
extra_pages = []
@@ -223,7 +223,10 @@ class ExtraNetworksPage:
onclick = item.get("onclick", None)
if onclick is None:
- onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
+ if "negative_prompt" in item:
+ onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {item["negative_prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
+ else:
+ onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {'""'}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py
index 36a807fc..87aeb6f3 100644
--- a/modules/ui_extra_networks_user_metadata.py
+++ b/modules/ui_extra_networks_user_metadata.py
@@ -5,7 +5,7 @@ import os.path
import gradio as gr
-from modules import generation_parameters_copypaste, images, sysinfo, errors, ui_extra_networks
+from modules import infotext, images, sysinfo, errors, ui_extra_networks
class UserMetadataEditor:
@@ -181,7 +181,7 @@ class UserMetadataEditor:
index = len(gallery) - 1 if index >= len(gallery) else index
img_info = gallery[index if index >= 0 else 0]
- image = generation_parameters_copypaste.image_from_url_text(img_info)
+ image = infotext.image_from_url_text(img_info)
geninfo, items = images.read_info_from_image(image)
images.save_image_with_geninfo(image, geninfo, item["local_preview"])
diff --git a/modules/ui_gradio_extensions.py b/modules/ui_gradio_extensions.py
index 0d368f8b..a86c368e 100644
--- a/modules/ui_gradio_extensions.py
+++ b/modules/ui_gradio_extensions.py
@@ -1,17 +1,12 @@
import os
import gradio as gr
-from modules import localization, shared, scripts
-from modules.paths import script_path, data_path, cwd
+from modules import localization, shared, scripts, util
+from modules.paths import script_path, data_path
def webpath(fn):
- if fn.startswith(cwd):
- web_path = os.path.relpath(fn, cwd)
- else:
- web_path = os.path.abspath(fn)
-
- return f'file={web_path}?{os.path.getmtime(fn)}'
+ return f'file={util.truncate_path(fn)}?{os.path.getmtime(fn)}'
def javascript_html():
diff --git a/modules/ui_loadsave.py b/modules/ui_loadsave.py
index 7826786c..693ff75c 100644
--- a/modules/ui_loadsave.py
+++ b/modules/ui_loadsave.py
@@ -144,7 +144,7 @@ class UiLoadsave:
json.dump(current_ui_settings, file, indent=4, ensure_ascii=False)
def dump_defaults(self):
- """saves default values to a file unless tjhe file is present and there was an error loading default values at start"""
+ """saves default values to a file unless the file is present and there was an error loading default values at start"""
if self.error_loading and os.path.exists(self.filename):
return
diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py
index 13d888e4..b74a1532 100644
--- a/modules/ui_postprocessing.py
+++ b/modules/ui_postprocessing.py
@@ -1,6 +1,6 @@
import gradio as gr
from modules import scripts, shared, ui_common, postprocessing, call_queue, ui_toprow
-import modules.generation_parameters_copypaste as parameters_copypaste
+import modules.infotext as parameters_copypaste
def create_ui():
diff --git a/modules/ui_toprow.py b/modules/ui_toprow.py
index 88838f97..1abc9117 100644
--- a/modules/ui_toprow.py
+++ b/modules/ui_toprow.py
@@ -79,11 +79,11 @@ class Toprow:
def create_prompts(self):
with gr.Column(elem_id=f"{self.id_part}_prompt_container", elem_classes=["prompt-container-compact"] if self.is_compact else [], scale=6):
with gr.Row(elem_id=f"{self.id_part}_prompt_row", elem_classes=["prompt-row"]):
- self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
+ self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Esc to interrupt)", elem_classes=["prompt"])
self.prompt_img = gr.File(label="", elem_id=f"{self.id_part}_prompt_image", file_count="single", type="binary", visible=False)
with gr.Row(elem_id=f"{self.id_part}_neg_prompt_row", elem_classes=["prompt-row"]):
- self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
+ self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Esc to interrupt)", elem_classes=["prompt"])
self.prompt_img.change(
fn=modules.images.image_data,
@@ -106,8 +106,14 @@ class Toprow:
outputs=[],
)
+ def interrupt_function():
+ if shared.state.job_count > 1 and shared.opts.interrupt_after_current:
+ shared.state.stop_generating()
+ else:
+ shared.state.interrupt()
+
self.interrupt.click(
- fn=lambda: shared.state.interrupt(),
+ fn=interrupt_function,
inputs=[],
outputs=[],
)
diff --git a/modules/upscaler.py b/modules/upscaler.py
index b256e085..3aee69db 100644
--- a/modules/upscaler.py
+++ b/modules/upscaler.py
@@ -98,6 +98,9 @@ class UpscalerData:
self.scale = scale
self.model = model
+ def __repr__(self):
+ return f"<UpscalerData name={self.name} path={self.data_path} scale={self.scale}>"
+
class UpscalerNone(Upscaler):
name = "None"
diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py
new file mode 100644
index 00000000..f5cb92d5
--- /dev/null
+++ b/modules/upscaler_utils.py
@@ -0,0 +1,140 @@
+import logging
+from typing import Callable
+
+import numpy as np
+import torch
+import tqdm
+from PIL import Image
+
+from modules import images, shared, torch_utils
+
+logger = logging.getLogger(__name__)
+
+
+def upscale_without_tiling(model, img: Image.Image):
+ img = np.array(img)
+ img = img[:, :, ::-1]
+ img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
+ img = torch.from_numpy(img).float()
+
+ param = torch_utils.get_param(model)
+ img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype)
+
+ with torch.no_grad():
+ output = model(img)
+
+ output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
+ output = 255. * np.moveaxis(output, 0, 2)
+ output = output.astype(np.uint8)
+ output = output[:, :, ::-1]
+ return Image.fromarray(output, 'RGB')
+
+
+def upscale_with_model(
+ model: Callable[[torch.Tensor], torch.Tensor],
+ img: Image.Image,
+ *,
+ tile_size: int,
+ tile_overlap: int = 0,
+ desc="tiled upscale",
+) -> Image.Image:
+ if tile_size <= 0:
+ logger.debug("Upscaling %s without tiling", img)
+ output = upscale_without_tiling(model, img)
+ logger.debug("=> %s", output)
+ return output
+
+ grid = images.split_grid(img, tile_size, tile_size, tile_overlap)
+ newtiles = []
+
+ with tqdm.tqdm(total=grid.tile_count, desc=desc) as p:
+ for y, h, row in grid.tiles:
+ newrow = []
+ for x, w, tile in row:
+ logger.debug("Tile (%d, %d) %s...", x, y, tile)
+ output = upscale_without_tiling(model, tile)
+ scale_factor = output.width // tile.width
+ logger.debug("=> %s (scale factor %s)", output, scale_factor)
+ newrow.append([x * scale_factor, w * scale_factor, output])
+ p.update(1)
+ newtiles.append([y * scale_factor, h * scale_factor, newrow])
+
+ newgrid = images.Grid(
+ newtiles,
+ tile_w=grid.tile_w * scale_factor,
+ tile_h=grid.tile_h * scale_factor,
+ image_w=grid.image_w * scale_factor,
+ image_h=grid.image_h * scale_factor,
+ overlap=grid.overlap * scale_factor,
+ )
+ return images.combine_grid(newgrid)
+
+
+def tiled_upscale_2(
+ img,
+ model,
+ *,
+ tile_size: int,
+ tile_overlap: int,
+ scale: int,
+ device,
+ desc="Tiled upscale",
+):
+ # Alternative implementation of `upscale_with_model` originally used by
+ # SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and
+ # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
+ # Pillow space without weighting.
+ b, c, h, w = img.size()
+ tile_size = min(tile_size, h, w)
+
+ if tile_size <= 0:
+ logger.debug("Upscaling %s without tiling", img.shape)
+ return model(img)
+
+ stride = tile_size - tile_overlap
+ h_idx_list = list(range(0, h - tile_size, stride)) + [h - tile_size]
+ w_idx_list = list(range(0, w - tile_size, stride)) + [w - tile_size]
+ result = torch.zeros(
+ b,
+ c,
+ h * scale,
+ w * scale,
+ device=device,
+ ).type_as(img)
+ weights = torch.zeros_like(result)
+ logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape)
+ with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc) as pbar:
+ for h_idx in h_idx_list:
+ if shared.state.interrupted or shared.state.skipped:
+ break
+
+ for w_idx in w_idx_list:
+ if shared.state.interrupted or shared.state.skipped:
+ break
+
+ in_patch = img[
+ ...,
+ h_idx : h_idx + tile_size,
+ w_idx : w_idx + tile_size,
+ ]
+ out_patch = model(in_patch)
+
+ result[
+ ...,
+ h_idx * scale : (h_idx + tile_size) * scale,
+ w_idx * scale : (w_idx + tile_size) * scale,
+ ].add_(out_patch)
+
+ out_patch_mask = torch.ones_like(out_patch)
+
+ weights[
+ ...,
+ h_idx * scale : (h_idx + tile_size) * scale,
+ w_idx * scale : (w_idx + tile_size) * scale,
+ ].add_(out_patch_mask)
+
+ pbar.update(1)
+
+ output = result.div_(weights)
+
+ return output
diff --git a/modules/util.py b/modules/util.py
index 60afc067..4861bcb0 100644
--- a/modules/util.py
+++ b/modules/util.py
@@ -2,7 +2,7 @@ import os
import re
from modules import shared
-from modules.paths_internal import script_path
+from modules.paths_internal import script_path, cwd
def natural_sort_key(s, regex=re.compile('([0-9]+)')):
@@ -56,3 +56,13 @@ def ldm_print(*args, **kwargs):
return
print(*args, **kwargs)
+
+
+def truncate_path(target_path, base_path=cwd):
+ abs_target, abs_base = os.path.abspath(target_path), os.path.abspath(base_path)
+ try:
+ if os.path.commonpath([abs_target, abs_base]) == abs_base:
+ return os.path.relpath(abs_target, abs_base)
+ except ValueError:
+ pass
+ return abs_target
diff --git a/modules/xlmr.py b/modules/xlmr.py
index a407a3ca..319771b7 100644
--- a/modules/xlmr.py
+++ b/modules/xlmr.py
@@ -5,6 +5,9 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta
from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional
+from modules import torch_utils
+
+
class BertSeriesConfig(BertConfig):
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
@@ -62,7 +65,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
self.post_init()
def encode(self,c):
- device = next(self.parameters()).device
+ device = torch_utils.get_param(self).device
text = self.tokenizer(c,
truncation=True,
max_length=77,
diff --git a/modules/xlmr_m18.py b/modules/xlmr_m18.py
index a727e865..f6055504 100644
--- a/modules/xlmr_m18.py
+++ b/modules/xlmr_m18.py
@@ -4,6 +4,8 @@ import torch
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional
+from modules import torch_utils
+
class BertSeriesConfig(BertConfig):
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
@@ -68,7 +70,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
self.post_init()
def encode(self,c):
- device = next(self.parameters()).device
+ device = torch_utils.get_param(self).device
text = self.tokenizer(c,
truncation=True,
max_length=77,
diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py
index d933c790..f7687a66 100644
--- a/modules/xpu_specific.py
+++ b/modules/xpu_specific.py
@@ -27,6 +27,71 @@ def torch_xpu_gc():
has_xpu = check_for_xpu()
+
+# Arc GPU cannot allocate a single block larger than 4GB: https://github.com/intel/compute-runtime/issues/627
+# Here we implement a slicing algorithm to split large batch size into smaller chunks,
+# so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT.
+# The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G,
+# which is the best trade-off between VRAM usage and performance.
+ARC_SINGLE_ALLOCATION_LIMIT = {}
+orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention
+def torch_xpu_scaled_dot_product_attention(
+ query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs
+):
+ # cast to same dtype first
+ key = key.to(query.dtype)
+ value = value.to(query.dtype)
+
+ N = query.shape[:-2] # Batch size
+ L = query.size(-2) # Target sequence length
+ E = query.size(-1) # Embedding dimension of the query and key
+ S = key.size(-2) # Source sequence length
+ Ev = value.size(-1) # Embedding dimension of the value
+
+ total_batch_size = torch.numel(torch.empty(N))
+ device_id = query.device.index
+ if device_id not in ARC_SINGLE_ALLOCATION_LIMIT:
+ ARC_SINGLE_ALLOCATION_LIMIT[device_id] = min(torch.xpu.get_device_properties(device_id).total_memory // 8, 4 * 1024 * 1024 * 1024)
+ batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT[device_id] // (L * S * query.element_size()))
+
+ if total_batch_size <= batch_size_limit:
+ return orig_sdp_attn_func(
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ *args, **kwargs
+ )
+
+ query = torch.reshape(query, (-1, L, E))
+ key = torch.reshape(key, (-1, S, E))
+ value = torch.reshape(value, (-1, S, Ev))
+ if attn_mask is not None:
+ attn_mask = attn_mask.view(-1, L, S)
+ chunk_count = (total_batch_size + batch_size_limit - 1) // batch_size_limit
+ outputs = []
+ for i in range(chunk_count):
+ attn_mask_chunk = (
+ None
+ if attn_mask is None
+ else attn_mask[i * batch_size_limit : (i + 1) * batch_size_limit, :, :]
+ )
+ chunk_output = orig_sdp_attn_func(
+ query[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
+ key[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
+ value[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],
+ attn_mask_chunk,
+ dropout_p,
+ is_causal,
+ *args, **kwargs
+ )
+ outputs.append(chunk_output)
+ result = torch.cat(outputs, dim=0)
+ return torch.reshape(result, (*N, L, Ev))
+
+
if has_xpu:
# W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device
CondFunc('torch.Generator',
@@ -48,3 +113,12 @@ if has_xpu:
CondFunc('torch.nn.modules.conv.Conv2d.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
+ CondFunc('torch.bmm',
+ lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out),
+ lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype)
+ CondFunc('torch.cat',
+ lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out),
+ lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors))
+ CondFunc('torch.nn.functional.scaled_dot_product_attention',
+ lambda orig_func, *args, **kwargs: torch_xpu_scaled_dot_product_attention(*args, **kwargs),
+ lambda orig_func, query, *args, **kwargs: query.is_xpu)