aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--javascript/ui.js25
-rw-r--r--modules/api/api.py2
-rw-r--r--modules/api/models.py3
-rw-r--r--modules/hashes.py84
-rw-r--r--modules/hypernetworks/hypernetwork.py44
-rw-r--r--modules/processing.py2
-rw-r--r--modules/sd_models.py119
-rw-r--r--modules/sd_samplers.py15
-rw-r--r--modules/shared.py17
-rw-r--r--modules/textual_inversion/textual_inversion.py6
-rw-r--r--modules/ui.py2
-rw-r--r--modules/ui_progress.py2
-rw-r--r--script.js4
-rw-r--r--webui.py2
15 files changed, 241 insertions, 87 deletions
diff --git a/.gitignore b/.gitignore
index 21fa26a7..0b1d17ca 100644
--- a/.gitignore
+++ b/.gitignore
@@ -32,3 +32,4 @@ notification.mp3
/extensions
/test/stdout.txt
/test/stderr.txt
+/cache.json
diff --git a/javascript/ui.js b/javascript/ui.js
index a41dd26f..1e04a8f4 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -143,14 +143,6 @@ function confirm_clear_prompt(prompt, negative_prompt) {
opts = {}
-function apply_settings(jsdata){
- console.log(jsdata)
-
- opts = JSON.parse(jsdata)
-
- return jsdata
-}
-
onUiUpdate(function(){
if(Object.keys(opts).length != 0) return;
@@ -160,7 +152,7 @@ onUiUpdate(function(){
textarea = json_elem.querySelector('textarea')
jsdata = textarea.value
opts = JSON.parse(jsdata)
-
+ executeCallbacks(optionsChangedCallbacks);
Object.defineProperty(textarea, 'value', {
set: function(newValue) {
@@ -171,6 +163,8 @@ onUiUpdate(function(){
if (oldValue != newValue) {
opts = JSON.parse(textarea.value)
}
+
+ executeCallbacks(optionsChangedCallbacks);
},
get: function() {
var valueProp = Object.getOwnPropertyDescriptor(HTMLTextAreaElement.prototype, 'value');
@@ -201,6 +195,19 @@ onUiUpdate(function(){
}
})
+
+onOptionsChanged(function(){
+ elem = gradioApp().getElementById('sd_checkpoint_hash')
+ sd_checkpoint_hash = opts.sd_checkpoint_hash || ""
+ shorthash = sd_checkpoint_hash.substr(0,10)
+
+ if(elem && elem.textContent != shorthash){
+ elem.textContent = shorthash
+ elem.title = sd_checkpoint_hash
+ elem.href = "https://google.com/search?q=" + sd_checkpoint_hash
+ }
+})
+
let txt2img_textarea, img2img_textarea = undefined;
let wait_time = 800
let token_timeout;
diff --git a/modules/api/api.py b/modules/api/api.py
index 5767ba90..9814bbc2 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -371,7 +371,7 @@ class Api:
return upscalers
def get_sd_models(self):
- return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
+ return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
def get_hypernetworks(self):
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
diff --git a/modules/api/models.py b/modules/api/models.py
index c78095ca..1eb1fcf1 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -224,7 +224,8 @@ class UpscalerItem(BaseModel):
class SDModelItem(BaseModel):
title: str = Field(title="Title")
model_name: str = Field(title="Model Name")
- hash: str = Field(title="Hash")
+ hash: Optional[str] = Field(title="Short hash")
+ sha256: Optional[str] = Field(title="sha256 hash")
filename: str = Field(title="Filename")
config: str = Field(title="Config file")
diff --git a/modules/hashes.py b/modules/hashes.py
new file mode 100644
index 00000000..14231771
--- /dev/null
+++ b/modules/hashes.py
@@ -0,0 +1,84 @@
+import hashlib
+import json
+import os.path
+
+import filelock
+
+
+cache_filename = "cache.json"
+cache_data = None
+
+
+def dump_cache():
+ with filelock.FileLock(cache_filename+".lock"):
+ with open(cache_filename, "w", encoding="utf8") as file:
+ json.dump(cache_data, file, indent=4)
+
+
+def cache(subsection):
+ global cache_data
+
+ if cache_data is None:
+ with filelock.FileLock(cache_filename+".lock"):
+ if not os.path.isfile(cache_filename):
+ cache_data = {}
+ else:
+ with open(cache_filename, "r", encoding="utf8") as file:
+ cache_data = json.load(file)
+
+ s = cache_data.get(subsection, {})
+ cache_data[subsection] = s
+
+ return s
+
+
+def calculate_sha256(filename):
+ hash_sha256 = hashlib.sha256()
+
+ with open(filename, "rb") as f:
+ for chunk in iter(lambda: f.read(4096), b""):
+ hash_sha256.update(chunk)
+
+ return hash_sha256.hexdigest()
+
+
+def sha256_from_cache(filename, title):
+ hashes = cache("hashes")
+ ondisk_mtime = os.path.getmtime(filename)
+
+ if title not in hashes:
+ return None
+
+ cached_sha256 = hashes[title].get("sha256", None)
+ cached_mtime = hashes[title].get("mtime", 0)
+
+ if ondisk_mtime > cached_mtime or cached_sha256 is None:
+ return None
+
+ return cached_sha256
+
+
+def sha256(filename, title):
+ hashes = cache("hashes")
+
+ sha256_value = sha256_from_cache(filename, title)
+ if sha256_value is not None:
+ return sha256_value
+
+ print(f"Calculating sha256 for {filename}: ", end='')
+ sha256_value = calculate_sha256(filename)
+ print(f"{sha256_value}")
+
+ hashes[title] = {
+ "mtime": os.path.getmtime(filename),
+ "sha256": sha256_value,
+ }
+
+ dump_cache()
+
+ return sha256_value
+
+
+
+
+
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 83cbb4f0..3aebefa8 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -12,7 +12,7 @@ import torch
import tqdm
from einops import rearrange, repeat
from ldm.util import default
-from modules import devices, processing, sd_models, shared, sd_samplers
+from modules import devices, processing, sd_models, shared, sd_samplers, hashes
from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
@@ -225,7 +225,7 @@ class Hypernetwork:
torch.save(state_dict, filename)
if shared.opts.save_optimizer_state and self.optimizer_state_dict:
- optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
+ optimizer_saved_dict['hash'] = self.shorthash()
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
torch.save(optimizer_saved_dict, filename + '.optim')
@@ -237,32 +237,33 @@ class Hypernetwork:
state_dict = torch.load(filename, map_location='cpu')
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
- print(self.layer_structure)
- optional_info = state_dict.get('optional_info', None)
- if optional_info is not None:
- print(f"INFO:\n {optional_info}\n")
- self.optional_info = optional_info
+ self.optional_info = state_dict.get('optional_info', None)
self.activation_func = state_dict.get('activation_func', None)
- print(f"Activation function is {self.activation_func}")
self.weight_init = state_dict.get('weight_initialization', 'Normal')
- print(f"Weight initialization is {self.weight_init}")
self.add_layer_norm = state_dict.get('is_layer_norm', False)
- print(f"Layer norm is set to {self.add_layer_norm}")
self.dropout_structure = state_dict.get('dropout_structure', None)
self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
- print(f"Dropout usage is set to {self.use_dropout}" )
self.activate_output = state_dict.get('activate_output', True)
- print(f"Activate last layer is set to {self.activate_output}")
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
# Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
if self.dropout_structure is None:
- print("Using previous dropout structure")
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
- print(f"Dropout structure is set to {self.dropout_structure}")
- optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
+ if shared.opts.print_hypernet_extra:
+ if self.optional_info is not None:
+ print(f" INFO:\n {self.optional_info}\n")
- if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
+ print(f" Layer structure: {self.layer_structure}")
+ print(f" Activation function: {self.activation_func}")
+ print(f" Weight initialization: {self.weight_init}")
+ print(f" Layer norm: {self.add_layer_norm}")
+ print(f" Dropout usage: {self.use_dropout}" )
+ print(f" Activate last layer: {self.activate_output}")
+ print(f" Dropout structure: {self.dropout_structure}")
+
+ optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {}
+
+ if self.shorthash() == optimizer_saved_dict.get('hash', None):
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
else:
self.optimizer_state_dict = None
@@ -289,6 +290,11 @@ class Hypernetwork:
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
self.eval()
+ def shorthash(self):
+ sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
+
+ return sha256[0:10]
+
def list_hypernetworks(path):
res = {}
@@ -296,7 +302,7 @@ def list_hypernetworks(path):
name = os.path.splitext(os.path.basename(filename))[0]
# Prevent a hypothetical "None.pt" from being listed.
if name != "None":
- res[name + f"({sd_models.model_hash(filename)})"] = filename
+ res[name] = filename
return res
@@ -509,7 +515,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
if shared.opts.save_training_settings_to_txt:
saved_params = dict(
- model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds),
+ model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds),
**{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
)
logging.save_settings_to_file(log_directory, {**saved_params, **locals()})
@@ -737,7 +743,7 @@ def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
try:
- hypernetwork.sd_checkpoint = checkpoint.hash
+ hypernetwork.sd_checkpoint = checkpoint.shorthash
hypernetwork.sd_checkpoint_name = checkpoint.model_name
hypernetwork.name = hypernetwork_name
hypernetwork.save(filename)
diff --git a/modules/processing.py b/modules/processing.py
index ae04cab7..849f6b19 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -437,7 +437,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
- "Hypernet hash": (None if shared.loaded_hypernetwork is None else sd_models.model_hash(shared.loaded_hypernetwork.filename)),
+ "Hypernet hash": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.shorthash()),
"Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
diff --git a/modules/sd_models.py b/modules/sd_models.py
index c466f273..e5a0bc63 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -14,17 +14,58 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
-from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors
+from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes
from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
-CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
checkpoints_list = {}
+checkpoint_alisases = {}
checkpoints_loaded = collections.OrderedDict()
+
+class CheckpointInfo:
+ def __init__(self, filename):
+ self.filename = filename
+ abspath = os.path.abspath(filename)
+
+ if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
+ name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
+ elif abspath.startswith(model_path):
+ name = abspath.replace(model_path, '')
+ else:
+ name = os.path.basename(filename)
+
+ if name.startswith("\\") or name.startswith("/"):
+ name = name[1:]
+
+ self.title = name
+ self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
+ self.hash = model_hash(filename)
+
+ self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + self.title)
+ self.shorthash = self.sha256[0:10] if self.sha256 else None
+
+ self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256] if self.shorthash else [])
+
+ def register(self):
+ checkpoints_list[self.title] = self
+ for id in self.ids:
+ checkpoint_alisases[id] = self
+
+ def calculate_shorthash(self):
+ self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.title)
+ self.shorthash = self.sha256[0:10]
+
+ if self.shorthash not in self.ids:
+ self.ids += [self.shorthash, self.sha256]
+ self.register()
+
+ return self.shorthash
+
+
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
@@ -43,10 +84,14 @@ def setup_model():
enable_midas_autodownload()
-def checkpoint_tiles():
- convert = lambda name: int(name) if name.isdigit() else name.lower()
- alphanumeric_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
- return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
+def checkpoint_tiles():
+ def convert(name):
+ return int(name) if name.isdigit() else name.lower()
+
+ def alphanumeric_key(key):
+ return [convert(c) for c in re.split('([0-9]+)', key)]
+
+ return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
def find_checkpoint_config(info):
@@ -62,48 +107,38 @@ def find_checkpoint_config(info):
def list_models():
checkpoints_list.clear()
+ checkpoint_alisases.clear()
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], ext_blacklist=[".vae.safetensors"])
- def modeltitle(path, shorthash):
- abspath = os.path.abspath(path)
-
- if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
- name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
- elif abspath.startswith(model_path):
- name = abspath.replace(model_path, '')
- else:
- name = os.path.basename(path)
-
- if name.startswith("\\") or name.startswith("/"):
- name = name[1:]
-
- shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
-
- return f'{name} [{shorthash}]', shortname
-
cmd_ckpt = shared.cmd_opts.ckpt
if os.path.exists(cmd_ckpt):
- h = model_hash(cmd_ckpt)
- title, short_model_name = modeltitle(cmd_ckpt, h)
- checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
- shared.opts.data['sd_model_checkpoint'] = title
+ checkpoint_info = CheckpointInfo(cmd_ckpt)
+ checkpoint_info.register()
+
+ shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
+
for filename in model_list:
- h = model_hash(filename)
- title, short_model_name = modeltitle(filename, h)
+ checkpoint_info = CheckpointInfo(filename)
+ checkpoint_info.register()
- checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
+def get_closet_checkpoint_match(search_string):
+ checkpoint_info = checkpoint_alisases.get(search_string, None)
+ if checkpoint_info is not None:
+ return checkpoint_info
+
+ found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))
+ if found:
+ return found[0]
-def get_closet_checkpoint_match(searchString):
- applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title))
- if len(applicable) > 0:
- return applicable[0]
return None
def model_hash(filename):
+ """old hash that only looks at a small part of the file and is prone to collisions"""
+
try:
with open(filename, "rb") as file:
import hashlib
@@ -119,7 +154,7 @@ def model_hash(filename):
def select_checkpoint():
model_checkpoint = shared.opts.sd_model_checkpoint
- checkpoint_info = checkpoints_list.get(model_checkpoint, None)
+ checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
if checkpoint_info is not None:
return checkpoint_info
@@ -189,9 +224,8 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
return sd
-def load_model_weights(model, checkpoint_info, vae_file="auto"):
- checkpoint_file = checkpoint_info.filename
- sd_model_hash = checkpoint_info.hash
+def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"):
+ sd_model_hash = checkpoint_info.calculate_shorthash()
cache_enabled = shared.opts.sd_checkpoint_cache > 0
@@ -201,9 +235,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
model.load_state_dict(checkpoints_loaded[checkpoint_info])
else:
# load from file
- print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
+ print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
- sd = read_state_dict(checkpoint_file)
+ sd = read_state_dict(checkpoint_info.filename)
model.load_state_dict(sd, strict=False)
del sd
@@ -235,14 +269,15 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
checkpoints_loaded.popitem(last=False) # LRU
model.sd_model_hash = sd_model_hash
- model.sd_model_checkpoint = checkpoint_file
+ model.sd_model_checkpoint = checkpoint_info.filename
model.sd_checkpoint_info = checkpoint_info
+ shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
model.logvar = model.logvar.to(devices.device) # fix for training
sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae()
- vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
+ vae_file = sd_vae.resolve_vae(checkpoint_info.filename, vae_file=vae_file)
sd_vae.load_vae(model, vae_file)
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 01221b89..7616fded 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -138,7 +138,7 @@ def samples_to_image_grid(samples, approximation=None):
def store_latent(decoded):
state.current_latent = decoded
- if opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
+ if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
if not shared.parallel_processing_allowed:
shared.state.current_image = sample_to_image(decoded)
@@ -243,7 +243,7 @@ class VanillaStableDiffusionSampler:
self.nmask = p.nmask if hasattr(p, 'nmask') else None
def adjust_steps_if_invalid(self, p, num_steps):
- if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
+ if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
valid_step = 999 / (1000 // num_steps)
if valid_step == floor(valid_step):
return int(valid_step) + 1
@@ -266,8 +266,7 @@ class VanillaStableDiffusionSampler:
if image_conditioning is not None:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
-
-
+
samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
return samples
@@ -352,6 +351,11 @@ class CFGDenoiser(torch.nn.Module):
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
+ if opts.live_preview_content == "Prompt":
+ store_latent(x_out[0:uncond.shape[0]])
+ elif opts.live_preview_content == "Negative prompt":
+ store_latent(x_out[-uncond.shape[0]:])
+
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
if self.mask is not None:
@@ -423,7 +427,8 @@ class KDiffusionSampler:
def callback_state(self, d):
step = d['i']
latent = d["denoised"]
- store_latent(latent)
+ if opts.live_preview_content == "Combined":
+ store_latent(latent)
self.last_latent = latent
if self.stop_at is not None and step > self.stop_at:
diff --git a/modules/shared.py b/modules/shared.py
index b90ded52..e0ec3136 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -176,7 +176,7 @@ class State:
self.interrupted = True
def nextjob(self):
- if opts.show_progress_every_n_steps == -1:
+ if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
self.do_set_current_image()
self.job_no += 1
@@ -224,7 +224,7 @@ class State:
if not parallel_processing_allowed:
return
- if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0:
+ if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable:
self.do_set_current_image()
def do_set_current_image(self):
@@ -361,6 +361,7 @@ options_templates.update(options_section(('system', "System"), {
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
+ "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
}))
options_templates.update(options_section(('training', "Training"), {
@@ -422,13 +423,11 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
options_templates.update(options_section(('ui', "User interface"), {
"show_progressbar": OptionInfo(True, "Show progressbar"),
- "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
- "show_progress_type": OptionInfo("Full", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
- "add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
+ "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
"disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
@@ -443,6 +442,13 @@ options_templates.update(options_section(('ui', "User interface"), {
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
}))
+options_templates.update(options_section(('ui', "Live previews"), {
+ "live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
+ "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
+ "show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
+ "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
+}))
+
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}),
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
@@ -457,6 +463,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
options_templates.update(options_section((None, "Hidden options"), {
"disabled_extensions": OptionInfo([], "Disable those extensions"),
+ "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
}))
options_templates.update()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 6939efcc..63935878 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -407,7 +407,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
if shared.opts.save_training_settings_to_txt:
- save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
+ save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
latent_sampling_method = ds.latent_sampling_method
@@ -584,7 +584,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name
- footer_mid = '[{}]'.format(checkpoint.hash)
+ footer_mid = '[{}]'.format(checkpoint.shorthash)
footer_right = '{}v {}s'.format(vectorSize, steps_done)
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
@@ -626,7 +626,7 @@ def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, r
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
try:
- embedding.sd_checkpoint = checkpoint.hash
+ embedding.sd_checkpoint = checkpoint.shorthash
embedding.sd_checkpoint_name = checkpoint.model_name
if remove_cached_checksum:
embedding.cached_checksum = None
diff --git a/modules/ui.py b/modules/ui.py
index e86a624b..2625ae32 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1841,4 +1841,6 @@ xformers: {xformers_version}
gradio: {gr.__version__}
 • 
commit: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{short_commit}</a>
+ • 
+checkpoint: <a id="sd_checkpoint_hash">N/A</a>
"""
diff --git a/modules/ui_progress.py b/modules/ui_progress.py
index 592fda55..7cd312e4 100644
--- a/modules/ui_progress.py
+++ b/modules/ui_progress.py
@@ -52,7 +52,7 @@ def check_progress_call(id_part):
image = gr.update(visible=False)
preview_visibility = gr.update(visible=False)
- if opts.show_progress_every_n_steps != 0:
+ if opts.live_previews_enable:
shared.state.set_current_image()
image = shared.state.current_image
diff --git a/script.js b/script.js
index 21960d91..3345e32b 100644
--- a/script.js
+++ b/script.js
@@ -14,6 +14,7 @@ function get_uiCurrentTabContent() {
uiUpdateCallbacks = []
uiTabChangeCallbacks = []
+optionsChangedCallbacks = []
let uiCurrentTab = null
function onUiUpdate(callback){
@@ -22,6 +23,9 @@ function onUiUpdate(callback){
function onUiTabChange(callback){
uiTabChangeCallbacks.push(callback)
}
+function onOptionsChanged(callback){
+ optionsChangedCallbacks.push(callback)
+}
function runCallback(x, m){
try {
diff --git a/webui.py b/webui.py
index 47d372c7..1fff80da 100644
--- a/webui.py
+++ b/webui.py
@@ -78,6 +78,8 @@ def initialize():
print("Stable diffusion model failed to load, exiting", file=sys.stderr)
exit(1)
+ shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
+
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)