aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py16
-rw-r--r--modules/api/models.py5
-rw-r--r--modules/devices.py28
-rw-r--r--modules/hypernetworks/hypernetwork.py11
-rw-r--r--modules/processing.py45
-rw-r--r--modules/script_callbacks.py18
-rw-r--r--modules/scripts.py11
-rw-r--r--modules/sd_hijack.py6
-rw-r--r--modules/sd_hijack_clip.py363
-rw-r--r--modules/sd_hijack_clip_old.py81
-rw-r--r--modules/sd_samplers.py9
-rw-r--r--modules/shared.py5
-rw-r--r--modules/textual_inversion/logging.py24
-rw-r--r--modules/textual_inversion/textual_inversion.py33
-rw-r--r--modules/ui.py90
-rw-r--r--modules/ui_components.py8
-rw-r--r--modules/ui_extensions.py44
17 files changed, 545 insertions, 252 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 48a70a44..2103709b 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -11,10 +11,10 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials
from secrets import compare_digest
import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack
+from modules import sd_samplers, deepbooru, sd_hijack, images
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
-from modules.extras import run_extras, run_pnginfo
+from modules.extras import run_extras
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
@@ -233,9 +233,17 @@ class Api:
if(not req.image.strip()):
return PNGInfoResponse(info="")
- result = run_pnginfo(decode_base64_to_image(req.image.strip()))
+ image = decode_base64_to_image(req.image.strip())
+ if image is None:
+ return PNGInfoResponse(info="")
+
+ geninfo, items = images.read_info_from_image(image)
+ if geninfo is None:
+ geninfo = ""
+
+ items = {**{'parameters': geninfo}, **items}
- return PNGInfoResponse(info=result[1])
+ return PNGInfoResponse(info=geninfo, items=items)
def progressapi(self, req: ProgressRequest = Depends()):
# copy from check_progress_call of ui.py
diff --git a/modules/api/models.py b/modules/api/models.py
index 4a632c68..d8198a27 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -157,7 +157,8 @@ class PNGInfoRequest(BaseModel):
image: str = Field(title="Image", description="The base64 encoded PNG image")
class PNGInfoResponse(BaseModel):
- info: str = Field(title="Image info", description="A string with all the info the image had")
+ info: str = Field(title="Image info", description="A string with the parameters used to generate the image")
+ items: dict = Field(title="Items", description="An object containing all the info the image had")
class ProgressRequest(BaseModel):
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
@@ -258,4 +259,4 @@ class EmbeddingItem(BaseModel):
class EmbeddingsResponse(BaseModel):
loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
- skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)") \ No newline at end of file
+ skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
diff --git a/modules/devices.py b/modules/devices.py
index 800510b7..caeb0276 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -133,8 +133,26 @@ def numpy_fix(self, *args, **kwargs):
return orig_tensor_numpy(self, *args, **kwargs)
-# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
-if has_mps() and version.parse(torch.__version__) < version.parse("1.13"):
- torch.Tensor.to = tensor_to_fix
- torch.nn.functional.layer_norm = layer_norm_fix
- torch.Tensor.numpy = numpy_fix
+# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
+orig_cumsum = torch.cumsum
+orig_Tensor_cumsum = torch.Tensor.cumsum
+def cumsum_fix(input, cumsum_func, *args, **kwargs):
+ if input.device.type == 'mps':
+ output_dtype = kwargs.get('dtype', input.dtype)
+ if any(output_dtype == broken_dtype for broken_dtype in [torch.bool, torch.int8, torch.int16, torch.int64]):
+ return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
+ return cumsum_func(input, *args, **kwargs)
+
+
+if has_mps():
+ if version.parse(torch.__version__) < version.parse("1.13"):
+ # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
+ torch.Tensor.to = tensor_to_fix
+ torch.nn.functional.layer_norm = layer_norm_fix
+ torch.Tensor.numpy = numpy_fix
+ elif version.parse(torch.__version__) > version.parse("1.13.1"):
+ if not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.Tensor([1,1]).to(torch.device("mps")).cumsum(0, dtype=torch.int16)):
+ torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
+ torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
+ orig_narrow = torch.narrow
+ torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() )
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 6a9b1398..b0cfbe71 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -13,7 +13,7 @@ import tqdm
from einops import rearrange, repeat
from ldm.util import default
from modules import devices, processing, sd_models, shared, sd_samplers
-from modules.textual_inversion import textual_inversion
+from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
@@ -457,7 +457,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
pin_memory = shared.opts.pin_memory
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
-
+
+ if shared.opts.save_training_settings_to_txt:
+ saved_params = dict(
+ model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds),
+ **{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
+ )
+ logging.save_settings_to_file(log_directory, {**saved_params, **locals()})
+
latent_sampling_method = ds.latent_sampling_method
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
diff --git a/modules/processing.py b/modules/processing.py
index 7e853287..82157bc9 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -50,9 +50,9 @@ def apply_color_correction(correction, original_image):
correction,
channel_axis=2
), cv2.COLOR_LAB2RGB).astype("uint8"))
-
+
image = blendLayers(image, original_image, BlendType.LUMINOSITY)
-
+
return image
@@ -466,9 +466,15 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
try:
for k, v in p.override_settings.items():
setattr(opts, k, v)
- if k == 'sd_hypernetwork': shared.reload_hypernetworks() # make onchange call for changing hypernet
- if k == 'sd_model_checkpoint': sd_models.reload_model_weights() # make onchange call for changing SD model
- if k == 'sd_vae': sd_vae.reload_vae_weights() # make onchange call for changing VAE
+ if k == 'sd_hypernetwork':
+ shared.reload_hypernetworks() # make onchange call for changing hypernet
+
+ if k == 'sd_model_checkpoint':
+ sd_models.reload_model_weights() # make onchange call for changing SD model
+ p.sd_model = shared.sd_model
+
+ if k == 'sd_vae':
+ sd_vae.reload_vae_weights() # make onchange call for changing VAE
res = process_images_inner(p)
@@ -538,6 +544,29 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
infotexts = []
output_images = []
+ cached_uc = [None, None]
+ cached_c = [None, None]
+
+ def get_conds_with_caching(function, required_prompts, steps, cache):
+ """
+ Returns the result of calling function(shared.sd_model, required_prompts, steps)
+ using a cache to store the result if the same arguments have been used before.
+
+ cache is an array containing two elements. The first element is a tuple
+ representing the previously used arguments, or None if no arguments
+ have been used before. The second element is where the previously
+ computed result is stored.
+ """
+
+ if cache[0] is not None and (required_prompts, steps) == cache[0]:
+ return cache[1]
+
+ with devices.autocast():
+ cache[1] = function(shared.sd_model, required_prompts, steps)
+
+ cache[0] = (required_prompts, steps)
+ return cache[1]
+
with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
@@ -565,9 +594,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None:
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
- with devices.autocast():
- uc = prompt_parser.get_learned_conditioning(shared.sd_model, negative_prompts, p.steps)
- c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
+ uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, cached_uc)
+ c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, cached_c)
if len(model_hijack.comments) > 0:
for comment in model_hijack.comments:
@@ -683,7 +711,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = 0
self.truncate_y = 0
-
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
if self.hr_resize_x == 0 and self.hr_resize_y == 0:
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index de69fd9f..608c5300 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -71,6 +71,7 @@ callback_map = dict(
callbacks_before_component=[],
callbacks_after_component=[],
callbacks_image_grid=[],
+ callbacks_script_unloaded=[],
)
@@ -171,6 +172,14 @@ def image_grid_callback(params: ImageGridLoopParams):
report_exception(c, 'image_grid')
+def script_unloaded_callback():
+ for c in reversed(callback_map['callbacks_script_unloaded']):
+ try:
+ c.callback()
+ except Exception:
+ report_exception(c, 'script_unloaded')
+
+
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@@ -202,7 +211,7 @@ def on_app_started(callback):
def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is
- passed as an argument"""
+ passed as an argument; this function is also called when the script is reloaded. """
add_callback(callback_map['callbacks_model_loaded'], callback)
@@ -279,3 +288,10 @@ def on_image_grid(callback):
- params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
"""
add_callback(callback_map['callbacks_image_grid'], callback)
+
+
+def on_script_unloaded(callback):
+ """register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
+ the script did should be reverted here"""
+
+ add_callback(callback_map['callbacks_script_unloaded'], callback)
diff --git a/modules/scripts.py b/modules/scripts.py
index 722f8685..35164093 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -1,4 +1,5 @@
import os
+import re
import sys
import traceback
from collections import namedtuple
@@ -128,6 +129,15 @@ class Script:
"""unused"""
return ""
+ def elem_id(self, item_id):
+ """helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
+
+ need_tabname = self.show(True) == self.show(False)
+ tabname = ('img2img' if self.is_img2img else 'txt2txt') + "_" if need_tabname else ""
+ title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
+
+ return f'script_{tabname}{title}_{item_id}'
+
current_basedir = paths.script_path
@@ -280,7 +290,6 @@ class ScriptRunner:
script.group = group
dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
- dropdown.save_to_config = True
inputs[0] = dropdown
for script in self.selectable_scripts:
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index bd101e5b..cfdb09d6 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -147,10 +147,10 @@ class StableDiffusionModelHijack:
def clear_comments(self):
self.comments = []
- def tokenize(self, text):
- _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
+ def get_prompt_lengths(self, text):
+ _, token_count = self.clip.process_texts([text])
- return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
+ return token_count, self.clip.get_target_prompt_token_count(token_count)
class EmbeddingsWithFixes(torch.nn.Module):
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index ca92b142..5520c9b2 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -1,30 +1,89 @@
import math
+from collections import namedtuple
import torch
-from modules import prompt_parser, devices
+from modules import prompt_parser, devices, sd_hijack
from modules.shared import opts
-def get_target_prompt_token_count(token_count):
- return math.ceil(max(token_count, 1) / 75) * 75
+
+class PromptChunk:
+ """
+ This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
+ If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
+ Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
+ so just 75 tokens from prompt.
+ """
+
+ def __init__(self):
+ self.tokens = []
+ self.multipliers = []
+ self.fixes = []
+
+
+PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
+"""An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt
+chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
+are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
+ """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
+ have unlimited prompt length and assign weights to tokens in prompt.
+ """
+
def __init__(self, wrapped, hijack):
super().__init__()
+
self.wrapped = wrapped
- self.hijack = hijack
+ """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
+ depending on model."""
+
+ self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
+ self.chunk_length = 75
+
+ def empty_chunk(self):
+ """creates an empty PromptChunk and returns it"""
+
+ chunk = PromptChunk()
+ chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
+ chunk.multipliers = [1.0] * (self.chunk_length + 2)
+ return chunk
+
+ def get_target_prompt_token_count(self, token_count):
+ """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
+
+ return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
def tokenize(self, texts):
+ """Converts a batch of texts into a batch of token ids"""
+
raise NotImplementedError
def encode_with_transformers(self, tokens):
+ """
+ converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
+ All python lists with tokens are assumed to have same length, usually 77.
+ if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
+ model - can be 768 and 1024.
+ Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None).
+ """
+
raise NotImplementedError
def encode_embedding_init_text(self, init_text, nvpt):
+ """Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through
+ transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned."""
+
raise NotImplementedError
- def tokenize_line(self, line, used_custom_terms, hijack_comments):
+ def tokenize_line(self, line):
+ """
+ this transforms a single prompt into a list of PromptChunk objects - as many as needed to
+ represent the prompt.
+ Returns the list and the total number of tokens in the prompt.
+ """
+
if opts.enable_emphasis:
parsed = prompt_parser.parse_prompt_attention(line)
else:
@@ -32,205 +91,152 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
tokenized = self.tokenize([text for text, _ in parsed])
- fixes = []
- remade_tokens = []
- multipliers = []
+ chunks = []
+ chunk = PromptChunk()
+ token_count = 0
last_comma = -1
- for tokens, (text, weight) in zip(tokenized, parsed):
- i = 0
- while i < len(tokens):
- token = tokens[i]
+ def next_chunk():
+ """puts current chunk into the list of results and produces the next one - empty"""
+ nonlocal token_count
+ nonlocal last_comma
+ nonlocal chunk
- embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+ token_count += len(chunk.tokens)
+ to_add = self.chunk_length - len(chunk.tokens)
+ if to_add > 0:
+ chunk.tokens += [self.id_end] * to_add
+ chunk.multipliers += [1.0] * to_add
+
+ chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
+ chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
+
+ last_comma = -1
+ chunks.append(chunk)
+ chunk = PromptChunk()
+
+ for tokens, (text, weight) in zip(tokenized, parsed):
+ position = 0
+ while position < len(tokens):
+ token = tokens[position]
if token == self.comma_token:
- last_comma = len(remade_tokens)
- elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
- last_comma += 1
- reloc_tokens = remade_tokens[last_comma:]
- reloc_mults = multipliers[last_comma:]
+ last_comma = len(chunk.tokens)
+
+ # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
+ # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
+ elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
+ break_location = last_comma + 1
- remade_tokens = remade_tokens[:last_comma]
- length = len(remade_tokens)
+ reloc_tokens = chunk.tokens[break_location:]
+ reloc_mults = chunk.multipliers[break_location:]
- rem = int(math.ceil(length / 75)) * 75 - length
- remade_tokens += [self.id_end] * rem + reloc_tokens
- multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
+ chunk.tokens = chunk.tokens[:break_location]
+ chunk.multipliers = chunk.multipliers[:break_location]
+ next_chunk()
+ chunk.tokens = reloc_tokens
+ chunk.multipliers = reloc_mults
+
+ if len(chunk.tokens) == self.chunk_length:
+ next_chunk()
+
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position)
if embedding is None:
- remade_tokens.append(token)
- multipliers.append(weight)
- i += 1
- else:
- emb_len = int(embedding.vec.shape[0])
- iteration = len(remade_tokens) // 75
- if (len(remade_tokens) + emb_len) // 75 != iteration:
- rem = (75 * (iteration + 1) - len(remade_tokens))
- remade_tokens += [self.id_end] * rem
- multipliers += [1.0] * rem
- iteration += 1
- fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
- remade_tokens += [0] * emb_len
- multipliers += [weight] * emb_len
- used_custom_terms.append((embedding.name, embedding.checksum()))
- i += embedding_length_in_tokens
-
- token_count = len(remade_tokens)
- prompt_target_length = get_target_prompt_token_count(token_count)
- tokens_to_add = prompt_target_length - len(remade_tokens)
-
- remade_tokens = remade_tokens + [self.id_end] * tokens_to_add
- multipliers = multipliers + [1.0] * tokens_to_add
-
- return remade_tokens, fixes, multipliers, token_count
-
- def process_text(self, texts):
- used_custom_terms = []
- remade_batch_tokens = []
- hijack_comments = []
- hijack_fixes = []
+ chunk.tokens.append(token)
+ chunk.multipliers.append(weight)
+ position += 1
+ continue
+
+ emb_len = int(embedding.vec.shape[0])
+ if len(chunk.tokens) + emb_len > self.chunk_length:
+ next_chunk()
+
+ chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))
+
+ chunk.tokens += [0] * emb_len
+ chunk.multipliers += [weight] * emb_len
+ position += embedding_length_in_tokens
+
+ if len(chunk.tokens) > 0 or len(chunks) == 0:
+ next_chunk()
+
+ return chunks, token_count
+
+ def process_texts(self, texts):
+ """
+ Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
+ length, in tokens, of all texts.
+ """
+
token_count = 0
cache = {}
- batch_multipliers = []
+ batch_chunks = []
for line in texts:
if line in cache:
- remade_tokens, fixes, multipliers = cache[line]
+ chunks = cache[line]
else:
- remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
+ chunks, current_token_count = self.tokenize_line(line)
token_count = max(current_token_count, token_count)
- cache[line] = (remade_tokens, fixes, multipliers)
+ cache[line] = chunks
- remade_batch_tokens.append(remade_tokens)
- hijack_fixes.append(fixes)
- batch_multipliers.append(multipliers)
+ batch_chunks.append(chunks)
- return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
+ return batch_chunks, token_count
- def process_text_old(self, texts):
- id_start = self.id_start
- id_end = self.id_end
- maxlen = self.wrapped.max_length # you get to stay at 77
- used_custom_terms = []
- remade_batch_tokens = []
- hijack_comments = []
- hijack_fixes = []
- token_count = 0
+ def forward(self, texts):
+ """
+ Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
+ Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
+ be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
+ An example shape returned by this function can be: (2, 77, 768).
+ Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
+ is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
+ """
- cache = {}
- batch_tokens = self.tokenize(texts)
- batch_multipliers = []
- for tokens in batch_tokens:
- tuple_tokens = tuple(tokens)
+ if opts.use_old_emphasis_implementation:
+ import modules.sd_hijack_clip_old
+ return modules.sd_hijack_clip_old.forward_old(self, texts)
- if tuple_tokens in cache:
- remade_tokens, fixes, multipliers = cache[tuple_tokens]
- else:
- fixes = []
- remade_tokens = []
- multipliers = []
- mult = 1.0
-
- i = 0
- while i < len(tokens):
- token = tokens[i]
-
- embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
-
- mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
- if mult_change is not None:
- mult *= mult_change
- i += 1
- elif embedding is None:
- remade_tokens.append(token)
- multipliers.append(mult)
- i += 1
- else:
- emb_len = int(embedding.vec.shape[0])
- fixes.append((len(remade_tokens), embedding))
- remade_tokens += [0] * emb_len
- multipliers += [mult] * emb_len
- used_custom_terms.append((embedding.name, embedding.checksum()))
- i += embedding_length_in_tokens
-
- if len(remade_tokens) > maxlen - 2:
- vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
- ovf = remade_tokens[maxlen - 2:]
- overflowing_words = [vocab.get(int(x), "") for x in ovf]
- overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
- hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
-
- token_count = len(remade_tokens)
- remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
- remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
- cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
-
- multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
- multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
-
- remade_batch_tokens.append(remade_tokens)
- hijack_fixes.append(fixes)
- batch_multipliers.append(multipliers)
- return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
-
- def forward(self, text):
- use_old = opts.use_old_emphasis_implementation
- if use_old:
- batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
- else:
- batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
-
- self.hijack.comments += hijack_comments
-
- if len(used_custom_terms) > 0:
- self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
-
- if use_old:
- self.hijack.fixes = hijack_fixes
- return self.process_tokens(remade_batch_tokens, batch_multipliers)
-
- z = None
- i = 0
- while max(map(len, remade_batch_tokens)) != 0:
- rem_tokens = [x[75:] for x in remade_batch_tokens]
- rem_multipliers = [x[75:] for x in batch_multipliers]
-
- self.hijack.fixes = []
- for unfiltered in hijack_fixes:
- fixes = []
- for fix in unfiltered:
- if fix[0] == i:
- fixes.append(fix[1])
- self.hijack.fixes.append(fixes)
-
- tokens = []
- multipliers = []
- for j in range(len(remade_batch_tokens)):
- if len(remade_batch_tokens[j]) > 0:
- tokens.append(remade_batch_tokens[j][:75])
- multipliers.append(batch_multipliers[j][:75])
- else:
- tokens.append([self.id_end] * 75)
- multipliers.append([1.0] * 75)
-
- z1 = self.process_tokens(tokens, multipliers)
- z = z1 if z is None else torch.cat((z, z1), axis=-2)
-
- remade_batch_tokens = rem_tokens
- batch_multipliers = rem_multipliers
- i += 1
+ batch_chunks, token_count = self.process_texts(texts)
- return z
+ used_embeddings = {}
+ chunk_count = max([len(x) for x in batch_chunks])
- def process_tokens(self, remade_batch_tokens, batch_multipliers):
- if not opts.use_old_emphasis_implementation:
- remade_batch_tokens = [[self.id_start] + x[:75] + [self.id_end] for x in remade_batch_tokens]
- batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
+ zs = []
+ for i in range(chunk_count):
+ batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]
+
+ tokens = [x.tokens for x in batch_chunk]
+ multipliers = [x.multipliers for x in batch_chunk]
+ self.hijack.fixes = [x.fixes for x in batch_chunk]
+ for fixes in self.hijack.fixes:
+ for position, embedding in fixes:
+ used_embeddings[embedding.name] = embedding
+
+ z = self.process_tokens(tokens, multipliers)
+ zs.append(z)
+
+ if len(used_embeddings) > 0:
+ embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()])
+ self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
+
+ return torch.hstack(zs)
+
+ def process_tokens(self, remade_batch_tokens, batch_multipliers):
+ """
+ sends one single prompt chunk to be encoded by transformers neural network.
+ remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
+ there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
+ Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
+ corresponds to one token.
+ """
tokens = torch.asarray(remade_batch_tokens).to(devices.device)
+ # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
if self.id_end != self.id_pad:
for batch_pos in range(len(remade_batch_tokens)):
index = remade_batch_tokens[batch_pos].index(self.id_end)
@@ -239,8 +245,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
z = self.encode_with_transformers(tokens)
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
- batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
- batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(devices.device)
+ batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
diff --git a/modules/sd_hijack_clip_old.py b/modules/sd_hijack_clip_old.py
new file mode 100644
index 00000000..6d9fbbe6
--- /dev/null
+++ b/modules/sd_hijack_clip_old.py
@@ -0,0 +1,81 @@
+from modules import sd_hijack_clip
+from modules import shared
+
+
+def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
+ id_start = self.id_start
+ id_end = self.id_end
+ maxlen = self.wrapped.max_length # you get to stay at 77
+ used_custom_terms = []
+ remade_batch_tokens = []
+ hijack_comments = []
+ hijack_fixes = []
+ token_count = 0
+
+ cache = {}
+ batch_tokens = self.tokenize(texts)
+ batch_multipliers = []
+ for tokens in batch_tokens:
+ tuple_tokens = tuple(tokens)
+
+ if tuple_tokens in cache:
+ remade_tokens, fixes, multipliers = cache[tuple_tokens]
+ else:
+ fixes = []
+ remade_tokens = []
+ multipliers = []
+ mult = 1.0
+
+ i = 0
+ while i < len(tokens):
+ token = tokens[i]
+
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+
+ mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None
+ if mult_change is not None:
+ mult *= mult_change
+ i += 1
+ elif embedding is None:
+ remade_tokens.append(token)
+ multipliers.append(mult)
+ i += 1
+ else:
+ emb_len = int(embedding.vec.shape[0])
+ fixes.append((len(remade_tokens), embedding))
+ remade_tokens += [0] * emb_len
+ multipliers += [mult] * emb_len
+ used_custom_terms.append((embedding.name, embedding.checksum()))
+ i += embedding_length_in_tokens
+
+ if len(remade_tokens) > maxlen - 2:
+ vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
+ ovf = remade_tokens[maxlen - 2:]
+ overflowing_words = [vocab.get(int(x), "") for x in ovf]
+ overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
+ hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
+
+ token_count = len(remade_tokens)
+ remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
+ remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
+ cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
+
+ multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
+ multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
+
+ remade_batch_tokens.append(remade_tokens)
+ hijack_fixes.append(fixes)
+ batch_multipliers.append(multipliers)
+ return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
+
+
+def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
+ batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts)
+
+ self.hijack.comments += hijack_comments
+
+ if len(used_custom_terms) > 0:
+ self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
+
+ self.hijack.fixes = hijack_fixes
+ return self.process_tokens(remade_batch_tokens, batch_multipliers)
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 3851a77f..01221b89 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -463,6 +463,13 @@ class KDiffusionSampler:
return extra_params_kwargs
def get_sigmas(self, p, steps):
+ discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
+ if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:
+ discard_next_to_last_sigma = True
+ p.extra_generation_params["Discard penultimate sigma"] = True
+
+ steps += 1 if discard_next_to_last_sigma else 0
+
if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
@@ -472,7 +479,7 @@ class KDiffusionSampler:
else:
sigmas = self.model_wrap.get_sigmas(steps)
- if self.config is not None and self.config.options.get('discard_next_to_last_sigma', False):
+ if discard_next_to_last_sigma:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
return sigmas
diff --git a/modules/shared.py b/modules/shared.py
index d7a81db1..a6712dae 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -366,6 +366,7 @@ options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
+ "save_training_settings_to_txt": OptionInfo(True, "Save textual inversion and hypernet settings to a text file whenever training starts."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
@@ -433,7 +434,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
"dimensions_and_batch_together": OptionInfo(True, "Show Witdth/Height and Batch sliders in same row"),
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
- 'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/ing2img UI item order"),
+ 'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
}))
@@ -446,6 +447,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
+ 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"),
}))
options_templates.update(options_section((None, "Hidden options"), {
@@ -579,6 +581,7 @@ latent_upscale_modes = {
"Latent (bicubic)": {"mode": "bicubic", "antialias": False},
"Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True},
"Latent (nearest)": {"mode": "nearest", "antialias": False},
+ "Latent (nearest-exact)": {"mode": "nearest-exact", "antialias": False},
}
sd_upscalers = []
diff --git a/modules/textual_inversion/logging.py b/modules/textual_inversion/logging.py
new file mode 100644
index 00000000..8b1981d5
--- /dev/null
+++ b/modules/textual_inversion/logging.py
@@ -0,0 +1,24 @@
+import datetime
+import json
+import os
+
+saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file"}
+saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
+saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
+saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
+saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"}
+
+
+def save_settings_to_file(log_directory, all_params):
+ now = datetime.datetime.now()
+ params = {"datetime": now.strftime("%Y-%m-%d %H:%M:%S")}
+
+ keys = saved_params_all
+ if all_params.get('preview_from_txt2img'):
+ keys = keys | saved_params_previews
+
+ params.update({k: v for k, v in all_params.items() if k in keys})
+
+ filename = f'settings-{now.strftime("%Y-%m-%d-%H-%M-%S")}.json'
+ with open(os.path.join(log_directory, filename), "w") as file:
+ json.dump(params, file, indent=4)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 71e07bcc..45882ed6 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -1,6 +1,7 @@
import os
import sys
import traceback
+import inspect
import torch
import tqdm
@@ -17,6 +18,8 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
insert_image_data_embed, extract_image_data_embed,
caption_image_overlay)
+from modules.textual_inversion.logging import save_settings_to_file
+
class Embedding:
def __init__(self, vec, name, step=None):
@@ -76,7 +79,6 @@ class EmbeddingDatabase:
self.word_embeddings[embedding.name] = embedding
- # TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working
ids = model.cond_stage_model.tokenize([embedding.name])[0]
first_id = ids[0]
@@ -149,18 +151,19 @@ class EmbeddingDatabase:
else:
self.skipped_embeddings[name] = embedding
- for fn in os.listdir(self.embeddings_dir):
- try:
- fullfn = os.path.join(self.embeddings_dir, fn)
+ for root, dirs, fns in os.walk(self.embeddings_dir):
+ for fn in fns:
+ try:
+ fullfn = os.path.join(root, fn)
- if os.stat(fullfn).st_size == 0:
- continue
+ if os.stat(fullfn).st_size == 0:
+ continue
- process_file(fullfn, fn)
- except Exception:
- print(f"Error loading embedding {fn}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- continue
+ process_file(fullfn, fn)
+ except Exception:
+ print(f"Error loading embedding {fn}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ continue
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
if len(self.skipped_embeddings) > 0:
@@ -229,6 +232,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
**values,
})
+
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
assert model_name, f"{name} not selected"
assert learn_rate, "Learning rate is empty or 0"
@@ -292,13 +296,13 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
if initial_step >= steps:
shared.state.textinfo = "Model has already been trained beyond specified max steps"
return embedding, filename
+
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
-
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
None
if clip_grad:
- clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
+ clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
old_parallel_processing_allowed = shared.parallel_processing_allowed
@@ -307,6 +311,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
+ if shared.opts.save_training_settings_to_txt:
+ save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
+
latent_sampling_method = ds.latent_sampling_method
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
diff --git a/modules/ui.py b/modules/ui.py
index 04091e67..6c765262 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -20,7 +20,7 @@ from PIL import Image, PngImagePlugin
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru
-from modules.ui_components import FormRow, FormGroup, ToolButton
+from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path
from modules.shared import opts, cmd_opts, restricted_opts
@@ -256,6 +256,20 @@ def add_style(name: str, prompt: str, negative_prompt: str):
return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)]
+def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
+ from modules import processing, devices
+
+ if not enable:
+ return ""
+
+ p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y)
+
+ with devices.autocast():
+ p.init([""], [0], [0])
+
+ return f"resize to: <span class='resolution'>{p.hr_upscale_to_x}x{p.hr_upscale_to_y}</span>"
+
+
def apply_styles(prompt, prompt_neg, style1_name, style2_name):
prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name])
prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name])
@@ -368,7 +382,7 @@ def update_token_counter(text, steps):
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
prompts = [prompt_text for step, prompt_text in flat_prompts]
- tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1])
+ token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
style_class = ' class="red"' if (token_count > max_length) else ""
return f"<span {style_class}>{token_count}/{max_length}</span>"
@@ -435,11 +449,9 @@ def create_toprow(is_img2img):
with gr.Row():
with gr.Column(scale=1, elem_id="style_pos_col"):
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
- prompt_style.save_to_config = True
with gr.Column(scale=1, elem_id="style_neg_col"):
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
- prompt_style2.save_to_config = True
return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
@@ -550,6 +562,8 @@ Requested path was: {f}
os.startfile(path)
elif platform.system() == "Darwin":
sp.Popen(["open", path])
+ elif "microsoft-standard-WSL2" in platform.uname().release:
+ sp.Popen(["wsl-open", path])
else:
sp.Popen(["xdg-open", path])
@@ -560,7 +574,7 @@ Requested path was: {f}
generation_info = None
with gr.Column():
with gr.Row(elem_id=f"image_buttons_{tabname}"):
- open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder')
+ open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}')
if tabname != "extras":
save = gr.Button('Save', elem_id=f'save_{tabname}')
@@ -576,13 +590,13 @@ Requested path was: {f}
if tabname != "extras":
with gr.Row():
- download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
+ download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
with gr.Group():
- html_info = gr.HTML()
- html_log = gr.HTML()
+ html_info = gr.HTML(elem_id=f'html_info_{tabname}')
+ html_log = gr.HTML(elem_id=f'html_log_{tabname}')
- generation_info = gr.Textbox(visible=False)
+ generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
if tabname == 'txt2img' or tabname == 'img2img':
generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
generation_info_button.click(
@@ -624,9 +638,9 @@ Requested path was: {f}
)
else:
- html_info_x = gr.HTML()
- html_info = gr.HTML()
- html_log = gr.HTML()
+ html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
+ html_info = gr.HTML(elem_id=f'html_info_{tabname}')
+ html_log = gr.HTML(elem_id=f'html_log_{tabname}')
parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
@@ -636,7 +650,6 @@ def create_sampler_and_steps_selection(choices, tabname):
if opts.samplers_in_dropdown:
with FormRow(elem_id=f"sampler_selection_{tabname}"):
sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
- sampler_index.save_to_config = True
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
else:
with FormGroup(elem_id=f"sampler_selection_{tabname}"):
@@ -707,6 +720,7 @@ def create_ui():
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
+ hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False)
elif category == "hires_fix":
with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options:
@@ -730,6 +744,17 @@ def create_ui():
with FormGroup(elem_id="txt2img_script_container"):
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
+ hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
+ hr_resolution_preview_args = dict(
+ fn=calc_resolution_hires,
+ inputs=hr_resolution_preview_inputs,
+ outputs=[hr_final_resolution],
+ show_progress=False
+ )
+
+ for input in hr_resolution_preview_inputs:
+ input.change(**hr_resolution_preview_args)
+
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
@@ -791,6 +816,7 @@ def create_ui():
fn=lambda x: gr_show(x),
inputs=[enable_hr],
outputs=[hr_options],
+ show_progress = False,
)
txt2img_paste_fields = [
@@ -1696,7 +1722,9 @@ def create_ui():
if os.path.exists("html/footer.html"):
with open("html/footer.html", encoding="utf8") as file:
- gr.HTML(file.read(), elem_id="footer")
+ footer = file.read()
+ footer = footer.format(versions=versions_html())
+ gr.HTML(footer, elem_id="footer")
text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
settings_submit.click(
@@ -1790,7 +1818,7 @@ def create_ui():
if init_field is not None:
init_field(saved_value)
- if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number] and x.visible:
+ if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible:
apply_field(x, 'visible')
if type(x) == gr.Slider:
@@ -1811,11 +1839,8 @@ def create_ui():
if type(x) == gr.Number:
apply_field(x, 'value')
- # Since there are many dropdowns that shouldn't be saved,
- # we only mark dropdowns that should be saved.
- if type(x) == gr.Dropdown and getattr(x, 'save_to_config', False):
+ if type(x) == gr.Dropdown:
apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None))
- apply_field(x, 'visible')
visit(txt2img_interface, loadsave, "txt2img")
visit(img2img_interface, loadsave, "img2img")
@@ -1857,3 +1882,30 @@ def reload_javascript():
if not hasattr(shared, 'GradioTemplateResponseOriginal'):
shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse
+
+
+def versions_html():
+ import torch
+ import launch
+
+ python_version = ".".join([str(x) for x in sys.version_info[0:3]])
+ commit = launch.commit_hash()
+ short_commit = commit[0:8]
+
+ if shared.xformers_available:
+ import xformers
+ xformers_version = xformers.__version__
+ else:
+ xformers_version = "N/A"
+
+ return f"""
+python: <span title="{sys.version}">{python_version}</span>
+ • 
+torch: {torch.__version__}
+ • 
+xformers: {xformers_version}
+ • 
+gradio: {gr.__version__}
+ • 
+commit: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{short_commit}</a>
+"""
diff --git a/modules/ui_components.py b/modules/ui_components.py
index 91eb0e3d..cac001dc 100644
--- a/modules/ui_components.py
+++ b/modules/ui_components.py
@@ -23,3 +23,11 @@ class FormGroup(gr.Group, gr.components.FormComponent):
def get_block_name(self):
return "group"
+
+
+class FormHTML(gr.HTML, gr.components.FormComponent):
+ """Same as gr.HTML but fits inside gradio forms"""
+
+ def get_block_name(self):
+ return "html"
+
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index eec9586f..742e745e 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -162,15 +162,15 @@ def install_extension_from_url(dirname, url):
shutil.rmtree(tmpdir, True)
-def install_extension_from_index(url, hide_tags):
+def install_extension_from_index(url, hide_tags, sort_column):
ext_table, message = install_extension_from_url(None, url)
- code, _ = refresh_available_extensions_from_data(hide_tags)
+ code, _ = refresh_available_extensions_from_data(hide_tags, sort_column)
return code, ext_table, message
-def refresh_available_extensions(url, hide_tags):
+def refresh_available_extensions(url, hide_tags, sort_column):
global available_extensions
import urllib.request
@@ -179,18 +179,28 @@ def refresh_available_extensions(url, hide_tags):
available_extensions = json.loads(text)
- code, tags = refresh_available_extensions_from_data(hide_tags)
+ code, tags = refresh_available_extensions_from_data(hide_tags, sort_column)
return url, code, gr.CheckboxGroup.update(choices=tags), ''
-def refresh_available_extensions_for_tags(hide_tags):
- code, _ = refresh_available_extensions_from_data(hide_tags)
+def refresh_available_extensions_for_tags(hide_tags, sort_column):
+ code, _ = refresh_available_extensions_from_data(hide_tags, sort_column)
return code, ''
-def refresh_available_extensions_from_data(hide_tags):
+sort_ordering = [
+ # (reverse, order_by_function)
+ (True, lambda x: x.get('added', 'z')),
+ (False, lambda x: x.get('added', 'z')),
+ (False, lambda x: x.get('name', 'z')),
+ (True, lambda x: x.get('name', 'z')),
+ (False, lambda x: 'z'),
+]
+
+
+def refresh_available_extensions_from_data(hide_tags, sort_column):
extlist = available_extensions["extensions"]
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
@@ -210,8 +220,11 @@ def refresh_available_extensions_from_data(hide_tags):
<tbody>
"""
- for ext in extlist:
+ sort_reverse, sort_function = sort_ordering[sort_column if 0 <= sort_column < len(sort_ordering) else 0]
+
+ for ext in sorted(extlist, key=sort_function, reverse=sort_reverse):
name = ext.get("name", "noname")
+ added = ext.get('added', 'unknown')
url = ext.get("url", None)
description = ext.get("description", "")
extension_tags = ext.get("tags", [])
@@ -233,7 +246,7 @@ def refresh_available_extensions_from_data(hide_tags):
code += f"""
<tr>
<td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td>
- <td>{html.escape(description)}</td>
+ <td>{html.escape(description)}<p class="info"><span class="date_added">Added: {html.escape(added)}</span></p></td>
<td>{install_code}</td>
</tr>
@@ -291,25 +304,32 @@ def create_ui():
with gr.Row():
hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
+ sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index")
install_result = gr.HTML()
available_extensions_table = gr.HTML()
refresh_available_extensions_button.click(
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
- inputs=[available_extensions_index, hide_tags],
+ inputs=[available_extensions_index, hide_tags, sort_column],
outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result],
)
install_extension_button.click(
fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
- inputs=[extension_to_install, hide_tags],
+ inputs=[extension_to_install, hide_tags, sort_column],
outputs=[available_extensions_table, extensions_table, install_result],
)
hide_tags.change(
fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
- inputs=[hide_tags],
+ inputs=[hide_tags, sort_column],
+ outputs=[available_extensions_table, install_result]
+ )
+
+ sort_column.change(
+ fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
+ inputs=[hide_tags, sort_column],
outputs=[available_extensions_table, install_result]
)