aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorKohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>2023-08-14 13:34:51 +0800
committerKohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>2023-08-14 13:34:51 +0800
commite7c03ccdcefb2a80129703931ef1f8455708945b (patch)
tree6769bdb3a32bdf99cf25422206ff03457bdfa86c /modules
parentd9cc27cb29926c9cc5dce331da8fbaf996cf4973 (diff)
parent007ecfbb29771aa7cdcf0263ab1811bc75fa5446 (diff)
Merge branch 'dev' into extra-norm-module
Diffstat (limited to 'modules')
-rw-r--r--modules/launch_utils.py11
-rw-r--r--modules/mac_specific.py3
-rwxr-xr-xmodules/processing.py384
-rw-r--r--modules/processing_scripts/refiner.py18
-rw-r--r--modules/processing_scripts/seed.py2
-rw-r--r--modules/scripts.py17
-rw-r--r--modules/sd_hijack_optimizations.py13
-rw-r--r--modules/sd_samplers_common.py25
-rw-r--r--modules/sd_samplers_kdiffusion.py11
-rw-r--r--modules/sd_samplers_timesteps_impl.py4
-rw-r--r--modules/shared_options.py4
-rw-r--r--modules/sub_quadratic_attention.py4
-rw-r--r--modules/ui_extra_networks_checkpoints.py1
-rw-r--r--modules/ui_extra_networks_hypernets.py6
-rw-r--r--modules/ui_extra_networks_textual_inversion.py3
-rw-r--r--modules/ui_extra_networks_user_metadata.py4
16 files changed, 319 insertions, 191 deletions
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index 65eb684f..449a8755 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -173,9 +173,12 @@ def git_clone(url, dir, name, commithash=None):
if current_hash == commithash:
return
- run_git('fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}", autofix=False)
+ if run_git(dir, name, 'config --get remote.origin.url', None, f"Couldn't determine {name}'s origin URL", live=False).strip() != url:
+ run_git(dir, name, f'remote set-url origin "{url}"', None, f"Failed to set {name}'s origin URL", live=False)
- run_git('checkout', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
+ run_git(dir, name, 'fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}", autofix=False)
+
+ run_git(dir, name, f'checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
return
@@ -319,12 +322,12 @@ def prepare_environment():
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87")
- k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
+ k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
try:
- # the existance of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
+ # the existence of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
os.remove(os.path.join(script_path, "tmp", "restart"))
os.environ.setdefault('SD_WEBUI_RESTARTING', '1')
except OSError:
diff --git a/modules/mac_specific.py b/modules/mac_specific.py
index bce527cc..89256c5b 100644
--- a/modules/mac_specific.py
+++ b/modules/mac_specific.py
@@ -52,9 +52,6 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
if has_mps:
- # MPS fix for randn in torchsde
- CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
-
if platform.mac_ver()[0].startswith("13.2."):
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
diff --git a/modules/processing.py b/modules/processing.py
index 6ad105d7..74366655 100755
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -1,9 +1,11 @@
+from __future__ import annotations
import json
import logging
import math
import os
import sys
import hashlib
+from dataclasses import dataclass, field
import torch
import numpy as np
@@ -11,7 +13,7 @@ from PIL import Image, ImageOps
import random
import cv2
from skimage import exposure
-from typing import Any, Dict, List
+from typing import Any
import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
@@ -104,97 +106,160 @@ def txt2img_image_conditioning(sd_model, x, width, height):
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
+@dataclass(repr=False)
class StableDiffusionProcessing:
- """
- The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
- """
+ sd_model: object = None
+ outpath_samples: str = None
+ outpath_grids: str = None
+ prompt: str = ""
+ prompt_for_display: str = None
+ negative_prompt: str = ""
+ styles: list[str] = field(default_factory=list)
+ seed: int = -1
+ subseed: int = -1
+ subseed_strength: float = 0
+ seed_resize_from_h: int = -1
+ seed_resize_from_w: int = -1
+ seed_enable_extras: bool = True
+ sampler_name: str = None
+ batch_size: int = 1
+ n_iter: int = 1
+ steps: int = 50
+ cfg_scale: float = 7.0
+ width: int = 512
+ height: int = 512
+ restore_faces: bool = None
+ tiling: bool = None
+ do_not_save_samples: bool = False
+ do_not_save_grid: bool = False
+ extra_generation_params: dict[str, Any] = None
+ overlay_images: list = None
+ eta: float = None
+ do_not_reload_embeddings: bool = False
+ denoising_strength: float = 0
+ ddim_discretize: str = None
+ s_min_uncond: float = None
+ s_churn: float = None
+ s_tmax: float = None
+ s_tmin: float = None
+ s_noise: float = None
+ override_settings: dict[str, Any] = None
+ override_settings_restore_afterwards: bool = True
+ sampler_index: int = None
+ refiner_checkpoint: str = None
+ refiner_switch_at: float = None
+ token_merging_ratio = 0
+ token_merging_ratio_hr = 0
+ disable_extra_networks: bool = False
+
+ scripts_value: scripts.ScriptRunner = field(default=None, init=False)
+ script_args_value: list = field(default=None, init=False)
+ scripts_setup_complete: bool = field(default=False, init=False)
+
cached_uc = [None, None]
cached_c = [None, None]
- def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = None, tiling: bool = None, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = None, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
- if sampler_index is not None:
+ comments: dict = None
+ sampler: sd_samplers_common.Sampler | None = field(default=None, init=False)
+ is_using_inpainting_conditioning: bool = field(default=False, init=False)
+ paste_to: tuple | None = field(default=None, init=False)
+
+ is_hr_pass: bool = field(default=False, init=False)
+
+ c: tuple = field(default=None, init=False)
+ uc: tuple = field(default=None, init=False)
+
+ rng: rng.ImageRNG | None = field(default=None, init=False)
+ step_multiplier: int = field(default=1, init=False)
+ color_corrections: list = field(default=None, init=False)
+
+ all_prompts: list = field(default=None, init=False)
+ all_negative_prompts: list = field(default=None, init=False)
+ all_seeds: list = field(default=None, init=False)
+ all_subseeds: list = field(default=None, init=False)
+ iteration: int = field(default=0, init=False)
+ main_prompt: str = field(default=None, init=False)
+ main_negative_prompt: str = field(default=None, init=False)
+
+ prompts: list = field(default=None, init=False)
+ negative_prompts: list = field(default=None, init=False)
+ seeds: list = field(default=None, init=False)
+ subseeds: list = field(default=None, init=False)
+ extra_network_data: dict = field(default=None, init=False)
+
+ user: str = field(default=None, init=False)
+
+ sd_model_name: str = field(default=None, init=False)
+ sd_model_hash: str = field(default=None, init=False)
+ sd_vae_name: str = field(default=None, init=False)
+ sd_vae_hash: str = field(default=None, init=False)
+
+ def __post_init__(self):
+ if self.sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
- self.outpath_samples: str = outpath_samples
- self.outpath_grids: str = outpath_grids
- self.prompt: str = prompt
- self.prompt_for_display: str = None
- self.negative_prompt: str = (negative_prompt or "")
- self.styles: list = styles or []
- self.seed: int = seed
- self.subseed: int = subseed
- self.subseed_strength: float = subseed_strength
- self.seed_resize_from_h: int = seed_resize_from_h
- self.seed_resize_from_w: int = seed_resize_from_w
- self.sampler_name: str = sampler_name
- self.batch_size: int = batch_size
- self.n_iter: int = n_iter
- self.steps: int = steps
- self.cfg_scale: float = cfg_scale
- self.width: int = width
- self.height: int = height
- self.restore_faces: bool = restore_faces
- self.tiling: bool = tiling
- self.do_not_save_samples: bool = do_not_save_samples
- self.do_not_save_grid: bool = do_not_save_grid
- self.extra_generation_params: dict = extra_generation_params or {}
- self.overlay_images = overlay_images
- self.eta = eta
- self.do_not_reload_embeddings = do_not_reload_embeddings
- self.paste_to = None
- self.color_corrections = None
- self.denoising_strength: float = denoising_strength
+ self.comments = {}
+
self.sampler_noise_scheduler_override = None
- self.ddim_discretize = ddim_discretize or opts.ddim_discretize
- self.s_min_uncond = s_min_uncond or opts.s_min_uncond
- self.s_churn = s_churn or opts.s_churn
- self.s_tmin = s_tmin or opts.s_tmin
- self.s_tmax = (s_tmax if s_tmax is not None else opts.s_tmax) or float('inf')
- self.s_noise = s_noise if s_noise is not None else opts.s_noise
- self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
- self.override_settings_restore_afterwards = override_settings_restore_afterwards
- self.is_using_inpainting_conditioning = False
- self.disable_extra_networks = False
- self.token_merging_ratio = 0
- self.token_merging_ratio_hr = 0
-
- if not seed_enable_extras:
+ self.s_min_uncond = self.s_min_uncond if self.s_min_uncond is not None else opts.s_min_uncond
+ self.s_churn = self.s_churn if self.s_churn is not None else opts.s_churn
+ self.s_tmin = self.s_tmin if self.s_tmin is not None else opts.s_tmin
+ self.s_tmax = (self.s_tmax if self.s_tmax is not None else opts.s_tmax) or float('inf')
+ self.s_noise = self.s_noise if self.s_noise is not None else opts.s_noise
+
+ self.extra_generation_params = self.extra_generation_params or {}
+ self.override_settings = self.override_settings or {}
+ self.script_args = self.script_args or {}
+
+ self.refiner_checkpoint_info = None
+
+ if not self.seed_enable_extras:
self.subseed = -1
self.subseed_strength = 0
self.seed_resize_from_h = 0
self.seed_resize_from_w = 0
- self.scripts = None
- self.script_args = script_args
- self.all_prompts = None
- self.all_negative_prompts = None
- self.all_seeds = None
- self.all_subseeds = None
- self.iteration = 0
- self.is_hr_pass = False
- self.sampler = None
- self.main_prompt = None
- self.main_negative_prompt = None
-
- self.prompts = None
- self.negative_prompts = None
- self.extra_network_data = None
- self.seeds = None
- self.subseeds = None
-
- self.step_multiplier = 1
self.cached_uc = StableDiffusionProcessing.cached_uc
self.cached_c = StableDiffusionProcessing.cached_c
- self.uc = None
- self.c = None
- self.rng: rng.ImageRNG = None
-
- self.user = None
@property
def sd_model(self):
return shared.sd_model
+ @sd_model.setter
+ def sd_model(self, value):
+ pass
+
+ @property
+ def scripts(self):
+ return self.scripts_value
+
+ @scripts.setter
+ def scripts(self, value):
+ self.scripts_value = value
+
+ if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:
+ self.setup_scripts()
+
+ @property
+ def script_args(self):
+ return self.script_args_value
+
+ @script_args.setter
+ def script_args(self, value):
+ self.script_args_value = value
+
+ if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:
+ self.setup_scripts()
+
+ def setup_scripts(self):
+ self.scripts_setup_complete = True
+
+ self.scripts.setup_scrips(self)
+
+ def comment(self, text):
+ self.comments[text] = 1
+
def txt2img_image_conditioning(self, x, width=None, height=None):
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
@@ -398,7 +463,7 @@ class Processed:
self.subseed = subseed
self.subseed_strength = p.subseed_strength
self.info = info
- self.comments = comments
+ self.comments = "".join(f"{comment}\n" for comment in p.comments)
self.width = p.width
self.height = p.height
self.sampler_name = p.sampler_name
@@ -408,7 +473,10 @@ class Processed:
self.batch_size = p.batch_size
self.restore_faces = p.restore_faces
self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
- self.sd_model_hash = shared.sd_model.sd_model_hash
+ self.sd_model_name = p.sd_model_name
+ self.sd_model_hash = p.sd_model_hash
+ self.sd_vae_name = p.sd_vae_name
+ self.sd_vae_hash = p.sd_vae_hash
self.seed_resize_from_w = p.seed_resize_from_w
self.seed_resize_from_h = p.seed_resize_from_h
self.denoising_strength = getattr(p, 'denoising_strength', None)
@@ -459,7 +527,10 @@ class Processed:
"batch_size": self.batch_size,
"restore_faces": self.restore_faces,
"face_restoration_model": self.face_restoration_model,
+ "sd_model_name": self.sd_model_name,
"sd_model_hash": self.sd_model_hash,
+ "sd_vae_name": self.sd_vae_name,
+ "sd_vae_hash": self.sd_vae_hash,
"seed_resize_from_w": self.seed_resize_from_w,
"seed_resize_from_h": self.seed_resize_from_h,
"denoising_strength": self.denoising_strength,
@@ -578,10 +649,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
"Face restoration": opts.face_restoration_model if p.restore_faces else None,
"Size": f"{p.width}x{p.height}",
- "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
- "Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra),
- "VAE hash": p.loaded_vae_hash if opts.add_model_hash_to_info else None,
- "VAE": p.loaded_vae_name if opts.add_model_name_to_info else None,
+ "Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
+ "Model": p.sd_model_name if opts.add_model_name_to_info else None,
+ "VAE hash": p.sd_vae_hash if opts.add_model_hash_to_info else None,
+ "VAE": p.sd_vae_name if opts.add_model_name_to_info else None,
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
@@ -670,14 +741,19 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.tiling is None:
p.tiling = opts.tiling
- p.loaded_vae_name = sd_vae.get_loaded_vae_name()
- p.loaded_vae_hash = sd_vae.get_loaded_vae_hash()
+ if p.refiner_checkpoint not in (None, "", "None"):
+ p.refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(p.refiner_checkpoint)
+ if p.refiner_checkpoint_info is None:
+ raise Exception(f'Could not find checkpoint with name {p.refiner_checkpoint}')
+
+ p.sd_model_name = shared.sd_model.sd_checkpoint_info.name_for_extra
+ p.sd_model_hash = shared.sd_model.sd_model_hash
+ p.sd_vae_name = sd_vae.get_loaded_vae_name()
+ p.sd_vae_hash = sd_vae.get_loaded_vae_hash()
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
modules.sd_hijack.model_hijack.clear_comments()
- comments = {}
-
p.setup_prompts()
if type(seed) == list:
@@ -757,7 +833,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.setup_conds()
for comment in model_hijack.comments:
- comments[comment] = 1
+ p.comment(comment)
p.extra_generation_params.update(model_hijack.extra_generation_params)
@@ -886,7 +962,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
images_list=output_images,
seed=p.all_seeds[0],
info=infotexts[0],
- comments="".join(f"{comment}\n" for comment in comments),
subseed=p.all_subseeds[0],
index_of_first_image=index_of_first_image,
infotexts=infotexts,
@@ -910,49 +985,51 @@ def old_hires_fix_first_pass_dimensions(width, height):
return width, height
+@dataclass(repr=False)
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
- sampler = None
+ enable_hr: bool = False
+ denoising_strength: float = 0.75
+ firstphase_width: int = 0
+ firstphase_height: int = 0
+ hr_scale: float = 2.0
+ hr_upscaler: str = None
+ hr_second_pass_steps: int = 0
+ hr_resize_x: int = 0
+ hr_resize_y: int = 0
+ hr_checkpoint_name: str = None
+ hr_sampler_name: str = None
+ hr_prompt: str = ''
+ hr_negative_prompt: str = ''
+
cached_hr_uc = [None, None]
cached_hr_c = [None, None]
- def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_checkpoint_name: str = None, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
- super().__init__(**kwargs)
- self.enable_hr = enable_hr
- self.denoising_strength = denoising_strength
- self.hr_scale = hr_scale
- self.hr_upscaler = hr_upscaler
- self.hr_second_pass_steps = hr_second_pass_steps
- self.hr_resize_x = hr_resize_x
- self.hr_resize_y = hr_resize_y
- self.hr_upscale_to_x = hr_resize_x
- self.hr_upscale_to_y = hr_resize_y
- self.hr_checkpoint_name = hr_checkpoint_name
- self.hr_checkpoint_info = None
- self.hr_sampler_name = hr_sampler_name
- self.hr_prompt = hr_prompt
- self.hr_negative_prompt = hr_negative_prompt
- self.all_hr_prompts = None
- self.all_hr_negative_prompts = None
- self.latent_scale_mode = None
-
- if firstphase_width != 0 or firstphase_height != 0:
+ hr_checkpoint_info: dict = field(default=None, init=False)
+ hr_upscale_to_x: int = field(default=0, init=False)
+ hr_upscale_to_y: int = field(default=0, init=False)
+ truncate_x: int = field(default=0, init=False)
+ truncate_y: int = field(default=0, init=False)
+ applied_old_hires_behavior_to: tuple = field(default=None, init=False)
+ latent_scale_mode: dict = field(default=None, init=False)
+ hr_c: tuple | None = field(default=None, init=False)
+ hr_uc: tuple | None = field(default=None, init=False)
+ all_hr_prompts: list = field(default=None, init=False)
+ all_hr_negative_prompts: list = field(default=None, init=False)
+ hr_prompts: list = field(default=None, init=False)
+ hr_negative_prompts: list = field(default=None, init=False)
+ hr_extra_network_data: list = field(default=None, init=False)
+
+ def __post_init__(self):
+ super().__post_init__()
+
+ if self.firstphase_width != 0 or self.firstphase_height != 0:
self.hr_upscale_to_x = self.width
self.hr_upscale_to_y = self.height
- self.width = firstphase_width
- self.height = firstphase_height
-
- self.truncate_x = 0
- self.truncate_y = 0
- self.applied_old_hires_behavior_to = None
-
- self.hr_prompts = None
- self.hr_negative_prompts = None
- self.hr_extra_network_data = None
+ self.width = self.firstphase_width
+ self.height = self.firstphase_height
self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc
self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c
- self.hr_c = None
- self.hr_uc = None
def calculate_target_resolution(self):
if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
@@ -1146,6 +1223,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
+ self.sampler = None
+ devices.torch_gc()
+
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
self.is_hr_pass = False
@@ -1230,7 +1310,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return super().get_conds()
-
def parse_extra_network_prompts(self):
res = super().parse_extra_network_prompts()
@@ -1243,32 +1322,37 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return res
+@dataclass(repr=False)
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
- sampler = None
-
- def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = None, mask_blur_x: int = 4, mask_blur_y: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
- super().__init__(**kwargs)
-
- self.init_images = init_images
- self.resize_mode: int = resize_mode
- self.denoising_strength: float = denoising_strength
- self.image_cfg_scale: float = image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
- self.init_latent = None
- self.image_mask = mask
- self.latent_mask = None
- self.mask_for_overlay = None
- self.mask_blur_x = mask_blur_x
- self.mask_blur_y = mask_blur_y
- if mask_blur is not None:
- self.mask_blur = mask_blur
- self.inpainting_fill = inpainting_fill
- self.inpaint_full_res = inpaint_full_res
- self.inpaint_full_res_padding = inpaint_full_res_padding
- self.inpainting_mask_invert = inpainting_mask_invert
- self.initial_noise_multiplier = opts.initial_noise_multiplier if initial_noise_multiplier is None else initial_noise_multiplier
+ init_images: list = None
+ resize_mode: int = 0
+ denoising_strength: float = 0.75
+ image_cfg_scale: float = None
+ mask: Any = None
+ mask_blur_x: int = 4
+ mask_blur_y: int = 4
+ mask_blur: int = None
+ inpainting_fill: int = 0
+ inpaint_full_res: bool = True
+ inpaint_full_res_padding: int = 0
+ inpainting_mask_invert: int = 0
+ initial_noise_multiplier: float = None
+ latent_mask: Image = None
+
+ image_mask: Any = field(default=None, init=False)
+
+ nmask: torch.Tensor = field(default=None, init=False)
+ image_conditioning: torch.Tensor = field(default=None, init=False)
+ init_img_hash: str = field(default=None, init=False)
+ mask_for_overlay: Image = field(default=None, init=False)
+ init_latent: torch.Tensor = field(default=None, init=False)
+
+ def __post_init__(self):
+ super().__post_init__()
+
+ self.image_mask = self.mask
self.mask = None
- self.nmask = None
- self.image_conditioning = None
+ self.initial_noise_multiplier = opts.initial_noise_multiplier if self.initial_noise_multiplier is None else self.initial_noise_multiplier
@property
def mask_blur(self):
@@ -1278,15 +1362,13 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
@mask_blur.setter
def mask_blur(self, value):
- self.mask_blur_x = value
- self.mask_blur_y = value
-
- @mask_blur.deleter
- def mask_blur(self):
- del self.mask_blur_x
- del self.mask_blur_y
+ if isinstance(value, int):
+ self.mask_blur_x = value
+ self.mask_blur_y = value
def init(self, all_prompts, all_seeds, all_subseeds):
+ self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
+
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
crop_region = None
diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py
index 773ec5d0..3c5b37d2 100644
--- a/modules/processing_scripts/refiner.py
+++ b/modules/processing_scripts/refiner.py
@@ -38,18 +38,12 @@ class ScriptRefiner(scripts.Script):
return enable_refiner, refiner_checkpoint, refiner_switch_at
- def before_process(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at):
+ def setup(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at):
# the actual implementation is in sd_samplers_common.py, apply_refiner
- p.refiner_checkpoint_info = None
- p.refiner_switch_at = None
-
if not enable_refiner or refiner_checkpoint in (None, "", "None"):
- return
-
- refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(refiner_checkpoint)
- if refiner_checkpoint_info is None:
- raise Exception(f'Could not find checkpoint with name {refiner_checkpoint}')
-
- p.refiner_checkpoint_info = refiner_checkpoint_info
- p.refiner_switch_at = refiner_switch_at
+ p.refiner_checkpoint_info = None
+ p.refiner_switch_at = None
+ else:
+ p.refiner_checkpoint = refiner_checkpoint
+ p.refiner_switch_at = refiner_switch_at
diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py
index cc90775a..96b44dfb 100644
--- a/modules/processing_scripts/seed.py
+++ b/modules/processing_scripts/seed.py
@@ -58,7 +58,7 @@ class ScriptSeed(scripts.ScriptBuiltin):
return self.seed, subseed, subseed_strength
- def before_process(self, p, seed, subseed, subseed_strength):
+ def setup(self, p, seed, subseed, subseed_strength):
p.seed = seed
if subseed_strength > 0:
diff --git a/modules/scripts.py b/modules/scripts.py
index c6459b45..d4a9da94 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -106,9 +106,16 @@ class Script:
pass
+ def setup(self, p, *args):
+ """For AlwaysVisible scripts, this function is called when the processing object is set up, before any processing starts.
+ args contains all values returned by components from ui().
+ """
+ pass
+
+
def before_process(self, p, *args):
"""
- This function is called very early before processing begins for AlwaysVisible scripts.
+ This function is called very early during processing begins for AlwaysVisible scripts.
You can modify the processing object (p) here, inject hooks, etc.
args contains all values returned by components from ui()
"""
@@ -706,6 +713,14 @@ class ScriptRunner:
except Exception:
errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
+ def setup_scrips(self, p):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.setup(p, *script_args)
+ except Exception:
+ errors.report(f"Error running setup: {script.filename}", exc_info=True)
+
scripts_txt2img: ScriptRunner = None
scripts_img2img: ScriptRunner = None
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 0e810eec..7f9e328d 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import math
import psutil
+import platform
import torch
from torch import einsum
@@ -94,7 +95,10 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem):
class SdOptimizationSubQuad(SdOptimization):
name = "sub-quadratic"
cmd_opt = "opt_sub_quad_attention"
- priority = 10
+
+ @property
+ def priority(self):
+ return 1000 if shared.device.type == 'mps' else 10
def apply(self):
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
@@ -120,7 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization):
@property
def priority(self):
- return 1000 if not torch.cuda.is_available() else 10
+ return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10
def apply(self):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
@@ -427,7 +431,10 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
if chunk_threshold is None:
- chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
+ if q.device.type == 'mps':
+ chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token)
+ else:
+ chunk_threshold_bytes = int(get_available_vram() * 0.7)
elif chunk_threshold == 0:
chunk_threshold_bytes = None
else:
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index 40c7aae0..07fc4434 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -92,7 +92,15 @@ def images_tensor_to_samples(image, approximation=None, model=None):
model = shared.sd_model
image = image.to(shared.device, dtype=devices.dtype_vae)
image = image * 2 - 1
- x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
+ if len(image) > 1:
+ x_latent = torch.stack([
+ model.get_first_stage_encoding(
+ model.encode_first_stage(torch.unsqueeze(img, 0))
+ )[0]
+ for img in image
+ ])
+ else:
+ x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
return x_latent
@@ -145,7 +153,7 @@ def apply_refiner(cfg_denoiser):
refiner_switch_at = cfg_denoiser.p.refiner_switch_at
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
- if refiner_switch_at is not None and completed_ratio <= refiner_switch_at:
+ if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
return False
if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
@@ -276,19 +284,19 @@ class Sampler:
s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
s_noise = getattr(opts, 's_noise', p.s_noise)
- if s_churn != self.s_churn:
+ if 's_churn' in extra_params_kwargs and s_churn != self.s_churn:
extra_params_kwargs['s_churn'] = s_churn
p.s_churn = s_churn
p.extra_generation_params['Sigma churn'] = s_churn
- if s_tmin != self.s_tmin:
+ if 's_tmin' in extra_params_kwargs and s_tmin != self.s_tmin:
extra_params_kwargs['s_tmin'] = s_tmin
p.s_tmin = s_tmin
p.extra_generation_params['Sigma tmin'] = s_tmin
- if s_tmax != self.s_tmax:
+ if 's_tmax' in extra_params_kwargs and s_tmax != self.s_tmax:
extra_params_kwargs['s_tmax'] = s_tmax
p.s_tmax = s_tmax
p.extra_generation_params['Sigma tmax'] = s_tmax
- if s_noise != self.s_noise:
+ if 's_noise' in extra_params_kwargs and s_noise != self.s_noise:
extra_params_kwargs['s_noise'] = s_noise
p.s_noise = s_noise
p.extra_generation_params['Sigma noise'] = s_noise
@@ -305,5 +313,8 @@ class Sampler:
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ raise NotImplementedError()
-
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ raise NotImplementedError()
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 1f8e9c4b..0bacfe8d 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -22,6 +22,9 @@ samplers_k_diffusion = [
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}),
('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}),
+ ('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'discard_next_to_last_sigma': True, "brownian_noise": True}),
+ ('DPM++ 3M SDE Karras', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
+ ('DPM++ 3M SDE Exponential', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_exp'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
@@ -42,6 +45,12 @@ sampler_extra_params = {
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
+ 'sample_dpm_fast': ['s_noise'],
+ 'sample_dpm_2_ancestral': ['s_noise'],
+ 'sample_dpmpp_2s_ancestral': ['s_noise'],
+ 'sample_dpmpp_sde': ['s_noise'],
+ 'sample_dpmpp_2m_sde': ['s_noise'],
+ 'sample_dpmpp_3m_sde': ['s_noise'],
}
k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
@@ -67,6 +76,8 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
def __init__(self, funcname, sd_model, options=None):
super().__init__(funcname)
+ self.extra_params = sampler_extra_params.get(funcname, [])
+
self.options = options or {}
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
diff --git a/modules/sd_samplers_timesteps_impl.py b/modules/sd_samplers_timesteps_impl.py
index 48d7e649..d32e3521 100644
--- a/modules/sd_samplers_timesteps_impl.py
+++ b/modules/sd_samplers_timesteps_impl.py
@@ -11,7 +11,7 @@ from modules.models.diffusion.uni_pc import uni_pc
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
- alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64)
+ alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32)
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
@@ -42,7 +42,7 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
- alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64)
+ alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32)
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
extra_args = {} if extra_args is None else extra_args
diff --git a/modules/shared_options.py b/modules/shared_options.py
index 9ae51f18..7f6c3658 100644
--- a/modules/shared_options.py
+++ b/modules/shared_options.py
@@ -285,12 +285,12 @@ options_templates.update(options_section(('ui', "Live previews"), {
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(),
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unperdictable results"),
- "eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; applies to Euler a and other samplers that have a in them"),
+ "eta_ancestral": OptionInfo(1.0, "Eta for k-diffusion samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; currently only applies to ancestral samplers (i.e. Euler a) and SDE samplers"),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}, infotext='Sigma churn').info('amount of stochasticity; only applies to Euler, Heun, and DPM2'),
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}, infotext='Sigma tmin').info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'),
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
- 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling; only applies to Euler, Heun, and DPM2'),
+ 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling'),
'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule max sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule min sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py
index 497568eb..ae4ee4bb 100644
--- a/modules/sub_quadratic_attention.py
+++ b/modules/sub_quadratic_attention.py
@@ -58,7 +58,7 @@ def _summarize_chunk(
scale: float,
) -> AttnChunk:
attn_weights = torch.baddbmm(
- torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
+ torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
query,
key.transpose(1,2),
alpha=scale,
@@ -121,7 +121,7 @@ def _get_attention_scores_no_kv_chunking(
scale: float,
) -> Tensor:
attn_scores = torch.baddbmm(
- torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
+ torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
query,
key.transpose(1,2),
alpha=scale,
diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py
index 77885022..ebb5249f 100644
--- a/modules/ui_extra_networks_checkpoints.py
+++ b/modules/ui_extra_networks_checkpoints.py
@@ -19,6 +19,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
return {
"name": checkpoint.name_for_extra,
"filename": checkpoint.filename,
+ "shorthash": checkpoint.shorthash,
"preview": self.find_preview(path),
"description": self.find_description(path),
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py
index 514a4562..4cedf085 100644
--- a/modules/ui_extra_networks_hypernets.py
+++ b/modules/ui_extra_networks_hypernets.py
@@ -2,6 +2,7 @@ import os
from modules import shared, ui_extra_networks
from modules.ui_extra_networks import quote_js
+from modules.hashes import sha256_from_cache
class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
@@ -14,13 +15,16 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
def create_item(self, name, index=None, enable_filter=True):
full_path = shared.hypernetworks[name]
path, ext = os.path.splitext(full_path)
+ sha256 = sha256_from_cache(full_path, f'hypernet/{name}')
+ shorthash = sha256[0:10] if sha256 else None
return {
"name": name,
"filename": full_path,
+ "shorthash": shorthash,
"preview": self.find_preview(path),
"description": self.find_description(path),
- "search_term": self.search_terms_from_path(path),
+ "search_term": self.search_terms_from_path(path) + " " + (sha256 or ""),
"prompt": quote_js(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + quote_js(">"),
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
"sort_keys": {'default': index, **self.get_sort_keys(path + ext)},
diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py
index 73134698..55ef0ea7 100644
--- a/modules/ui_extra_networks_textual_inversion.py
+++ b/modules/ui_extra_networks_textual_inversion.py
@@ -19,9 +19,10 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
return {
"name": name,
"filename": embedding.filename,
+ "shorthash": embedding.shorthash,
"preview": self.find_preview(path),
"description": self.find_description(path),
- "search_term": self.search_terms_from_path(embedding.filename),
+ "search_term": self.search_terms_from_path(embedding.filename) + " " + (embedding.hash or ""),
"prompt": quote_js(embedding.name),
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
"sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)},
diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py
index cda471e4..b11622a1 100644
--- a/modules/ui_extra_networks_user_metadata.py
+++ b/modules/ui_extra_networks_user_metadata.py
@@ -93,11 +93,13 @@ class UserMetadataEditor:
item = self.page.items.get(name, {})
try:
filename = item["filename"]
+ shorthash = item.get("shorthash", None)
stats = os.stat(filename)
params = [
('Filename: ', os.path.basename(filename)),
('File size: ', sysinfo.pretty_bytes(stats.st_size)),
+ ('Hash: ', shorthash),
('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')),
]
@@ -115,7 +117,7 @@ class UserMetadataEditor:
errors.display(e, f"reading metadata info for {name}")
params = []
- table = '<table class="file-metadata">' + "".join(f"<tr><th>{name}</th><td>{value}</td></tr>" for name, value in params) + '</table>'
+ table = '<table class="file-metadata">' + "".join(f"<tr><th>{name}</th><td>{value}</td></tr>" for name, value in params if value is not None) + '</table>'
return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', '')