aboutsummaryrefslogtreecommitdiff
path: root/webui.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-08-27 21:32:28 +0300
committerAUTOMATIC <16777216c@gmail.com>2022-08-27 21:32:28 +0300
commitc30aee2f4b4c3a46f6ae878b880d0b837609f63d (patch)
tree3f959080aea4321023b23d8996c4b3109e02bf82 /webui.py
parent4e0fdca2f4bd333d8eae5b8cf4a36caba61efc86 (diff)
fixed all lines PyCharm was nagging me about
fixed input verification not working properly with long textual inversion tokens in some cases (plus it will prevent incorrect outputs for forks that use the :::: prompt weighing method) changed process_images to object class with same fields as args it was previously accepting changed options system to make it possible to explicitly specify gradio objects with args
Diffstat (limited to 'webui.py')
-rw-r--r--webui.py561
1 files changed, 287 insertions, 274 deletions
diff --git a/webui.py b/webui.py
index 13e5112a..8de1bcf2 100644
--- a/webui.py
+++ b/webui.py
@@ -1,14 +1,13 @@
-import argparse, os, sys, glob
+import argparse
+import os
+import sys
from collections import namedtuple
-
import torch
import torch.nn as nn
import numpy as np
import gradio as gr
from omegaconf import OmegaConf
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
-from itertools import islice
-from einops import rearrange, repeat
from torch import autocast
import mimetypes
import random
@@ -22,14 +21,13 @@ import k_diffusion.sampling
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
-import ldm.modules.encoders.modules
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
from transformers import logging
logging.set_verbosity_error()
-except:
+except Exception:
pass
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
@@ -41,13 +39,13 @@ opt_C = 4
opt_f = 8
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
-invalid_filename_chars = '<>:"/\|?*\n'
+invalid_filename_chars = '<>:"/\\|?*\n'
config_filename = "config.json"
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="configs/stable-diffusion/v1-inference.yaml", help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",)
-parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) # i disagree with where you're putting it but since all guidefags are doing it this way, there you go
+parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
@@ -64,7 +62,7 @@ css_hide_progressbar = """
SamplerData = namedtuple('SamplerData', ['name', 'constructor'])
samplers = [
- *[SamplerData(x[0], lambda m, funcname=x[1]: KDiffusionSampler(m, funcname)) for x in [
+ *[SamplerData(x[0], lambda funcname=x[1]: KDiffusionSampler(funcname)) for x in [
('LMS', 'sample_lms'),
('Heun', 'sample_heun'),
('Euler', 'sample_euler'),
@@ -72,8 +70,8 @@ samplers = [
('DPM 2', 'sample_dpm_2'),
('DPM 2 Ancestral', 'sample_dpm_2_ancestral'),
] if hasattr(k_diffusion.sampling, x[1])],
- SamplerData('DDIM', lambda m: DDIMSampler(model)),
- SamplerData('PLMS', lambda m: PLMSSampler(model)),
+ SamplerData('DDIM', lambda: VanillaStableDiffusionSampler(DDIMSampler)),
+ SamplerData('PLMS', lambda: VanillaStableDiffusionSampler(PLMSSampler)),
]
samplers_for_img2img = [x for x in samplers if x.name != 'DDIM' and x.name != 'PLMS']
@@ -102,7 +100,7 @@ try:
),
]
have_realesrgan = True
-except:
+except Exception:
print("Error loading Real-ESRGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
@@ -111,24 +109,30 @@ except:
class Options:
+ class OptionInfo:
+ def __init__(self, default=None, label="", component=None, component_args=None):
+ self.default = default
+ self.label = label
+ self.component = component
+ self.component_args = component_args
+
data = None
data_labels = {
- "outdir": ("", "Output dictectory; if empty, defaults to 'outputs/*'"),
- "samples_save": (True, "Save indiviual samples"),
- "samples_format": ('png', 'File format for indiviual samples'),
- "grid_save": (True, "Save image grids"),
- "grid_format": ('png', 'File format for grids'),
- "grid_extended_filename": (False, "Add extended info (seed, prompt) to filename when saving grid"),
- "n_rows": (-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", -1, 16),
- "jpeg_quality": (80, "Quality for saved jpeg images", 1, 100),
- "verify_input": (True, "Check input, and produce warning if it's too long"),
- "enable_pnginfo": (True, "Save text information about generation parameters as chunks to png files"),
- "prompt_matrix_add_to_start": (True, "In prompt matrix, add the variable combination of text to the start of the prompt, rather than the end"),
- "sd_upscale_overlap": (64, "Overlap for tiles for SD upscale. The smaller it is, the less smooth transition from one tile to another", 0, 256, 16),
+ "outdir": OptionInfo("", "Output dictectory; if empty, defaults to 'outputs/*'"),
+ "samples_save": OptionInfo(True, "Save indiviual samples"),
+ "samples_format": OptionInfo('png', 'File format for indiviual samples'),
+ "grid_save": OptionInfo(True, "Save image grids"),
+ "grid_format": OptionInfo('png', 'File format for grids'),
+ "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
+ "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
+ "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
+ "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
+ "prompt_matrix_add_to_start": OptionInfo(True, "In prompt matrix, add the variable combination of text to the start of the prompt, rather than the end"),
+ "sd_upscale_overlap": OptionInfo(64, "Overlap for tiles for SD upscale. The smaller it is, the less smooth transition from one tile to another", gr.Slider, {"minimum": 0, "maximum": 256, "step": 16}),
}
def __init__(self):
- self.data = {k: v[0] for k, v in self.data_labels.items()}
+ self.data = {k: v.default for k, v in self.data_labels.items()}
def __setattr__(self, key, value):
if self.data is not None:
@@ -143,7 +147,7 @@ class Options:
return self.data[item]
if item in self.data_labels:
- return self.data_labels[item][0]
+ return self.data_labels[item].default
return super(Options, self).__getattribute__(item)
@@ -156,11 +160,6 @@ class Options:
self.data = json.load(file)
-def chunk(it, size):
- it = iter(it)
- return iter(lambda: tuple(islice(it, size)), ())
-
-
def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
@@ -181,36 +180,6 @@ def load_model_from_config(config, ckpt, verbose=False):
return model
-class CFGDenoiser(nn.Module):
- def __init__(self, model):
- super().__init__()
- self.inner_model = model
-
- def forward(self, x, sigma, uncond, cond, cond_scale):
- x_in = torch.cat([x] * 2)
- sigma_in = torch.cat([sigma] * 2)
- cond_in = torch.cat([uncond, cond])
- uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
- return uncond + (cond - uncond) * cond_scale
-
-
-class KDiffusionSampler:
- def __init__(self, m, funcname):
- self.model = m
- self.model_wrap = k_diffusion.external.CompVisDenoiser(m)
- self.funcname = funcname
- self.func = getattr(k_diffusion.sampling, self.funcname)
-
- def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T):
- sigmas = self.model_wrap.get_sigmas(S)
- x = x_T * sigmas[0]
- model_wrap_cfg = CFGDenoiser(self.model_wrap)
-
- samples_ddim = self.func(model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False)
-
- return samples_ddim, None
-
-
def create_random_tensors(shape, seeds):
xs = []
for seed in seeds:
@@ -256,7 +225,7 @@ def plaintext_to_html(text):
return text
-def load_GFPGAN():
+def load_gfpgan():
model_name = 'GFPGANv1.3'
model_path = os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models', model_name + '.pth')
if not os.path.isfile(model_path):
@@ -358,7 +327,7 @@ def combine_grid(grid):
def draw_prompt_matrix(im, width, height, all_prompts):
- def wrap(text, d, font, line_length):
+ def wrap(text, font, line_length):
lines = ['']
for word in text.split():
line = f'{lines[-1]} {word}'.strip()
@@ -368,16 +337,16 @@ def draw_prompt_matrix(im, width, height, all_prompts):
lines.append(word)
return '\n'.join(lines)
- def draw_texts(pos, x, y, texts, sizes):
+ def draw_texts(pos, draw_x, draw_y, texts, sizes):
for i, (text, size) in enumerate(zip(texts, sizes)):
active = pos & (1 << i) != 0
if not active:
text = '\u0336'.join(text) + '\u0336'
- d.multiline_text((x, y + size[1] / 2), text, font=fnt, fill=color_active if active else color_inactive, anchor="mm", align="center")
+ d.multiline_text((draw_x, draw_y + size[1] / 2), text, font=fnt, fill=color_active if active else color_inactive, anchor="mm", align="center")
- y += size[1] + line_spacing
+ draw_y += size[1] + line_spacing
fontsize = (width + height) // 25
line_spacing = fontsize // 2
@@ -399,8 +368,8 @@ def draw_prompt_matrix(im, width, height, all_prompts):
d = ImageDraw.Draw(result)
boundary = math.ceil(len(prompts) / 2)
- prompts_horiz = [wrap(x, d, fnt, width) for x in prompts[:boundary]]
- prompts_vert = [wrap(x, d, fnt, pad_left) for x in prompts[boundary:]]
+ prompts_horiz = [wrap(x, fnt, width) for x in prompts[:boundary]]
+ prompts_vert = [wrap(x, fnt, pad_left) for x in prompts[boundary:]]
sizes_hor = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_horiz]]
sizes_ver = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_vert]]
@@ -458,25 +427,6 @@ def resize_image(resize_mode, im, width, height):
return res
-def check_prompt_length(prompt, comments):
- """this function tests if prompt is too long, and if so, adds a message to comments"""
-
- tokenizer = model.cond_stage_model.tokenizer
- max_length = model.cond_stage_model.max_length
-
- info = model.cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length, return_overflowing_tokens=True, padding="max_length", return_tensors="pt")
- ovf = info['overflowing_tokens'][0]
- overflowing_count = ovf.shape[0]
- if overflowing_count == 0:
- return
-
- vocab = {v: k for k, v in tokenizer.get_vocab().items()}
- overflowing_words = [vocab.get(int(x), "") for x in ovf]
- overflowing_text = tokenizer.convert_tokens_to_string(''.join(overflowing_words))
-
- comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
-
-
def wrap_gradio_call(func):
def f(*p1, **p2):
t = time.perf_counter()
@@ -494,7 +444,7 @@ def wrap_gradio_call(func):
GFPGAN = None
if os.path.exists(cmd_opts.gfpgan_dir):
try:
- GFPGAN = load_GFPGAN()
+ GFPGAN = load_gfpgan()
print("Loaded GFPGAN")
except Exception:
print("Error loading GFPGAN:", file=sys.stderr)
@@ -506,11 +456,11 @@ class StableDiffuionModelHijack:
word_embeddings = {}
word_embeddings_checksums = {}
fixes = None
- used_custom_terms = []
+ comments = None
dir_mtime = None
- def load_textual_inversion_embeddings(self, dir, model):
- mt = os.path.getmtime(dir)
+ def load_textual_inversion_embeddings(self, dirname, model):
+ mt = os.path.getmtime(dirname)
if self.dir_mtime is not None and mt <= self.dir_mtime:
return
@@ -543,10 +493,10 @@ class StableDiffuionModelHijack:
self.ids_lookup[first_id] = []
self.ids_lookup[first_id].append((ids, name))
- for fn in os.listdir(dir):
+ for fn in os.listdir(dirname):
try:
- process_file(os.path.join(dir, fn), fn)
- except:
+ process_file(os.path.join(dirname, fn), fn)
+ except Exception:
print(f"Error loading emedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
continue
@@ -561,10 +511,10 @@ class StableDiffuionModelHijack:
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
- def __init__(self, wrapped, embeddings):
+ def __init__(self, wrapped, hijack):
super().__init__()
self.wrapped = wrapped
- self.embeddings = embeddings
+ self.hijack = hijack
self.tokenizer = wrapped.tokenizer
self.max_length = wrapped.max_length
self.token_mults = {}
@@ -586,12 +536,13 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.token_mults[ident] = mult
def forward(self, text):
- self.embeddings.fixes = []
- self.embeddings.used_custom_terms = []
+ self.hijack.fixes = []
+ self.hijack.comments = []
remade_batch_tokens = []
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
maxlen = self.wrapped.max_length - 2
+ used_custom_terms = []
cache = {}
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
@@ -611,7 +562,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens):
token = tokens[i]
- possible_matches = self.embeddings.ids_lookup.get(token, None)
+ possible_matches = self.hijack.ids_lookup.get(token, None)
mult_change = self.token_mults.get(token)
if mult_change is not None:
@@ -628,7 +579,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
multipliers.append(mult)
i += len(ids) - 1
found = True
- self.embeddings.used_custom_terms.append((word, self.embeddings.word_embeddings_checksums[word]))
+ used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
break
if not found:
@@ -637,6 +588,14 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
i += 1
+ if len(remade_tokens) > maxlen - 2:
+ vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
+ ovf = remade_tokens[maxlen - 2:]
+ overflowing_words = [vocab.get(int(x), "") for x in ovf]
+ overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
+
+ self.hijack.comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
+
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
@@ -645,9 +604,12 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
remade_batch_tokens.append(remade_tokens)
- self.embeddings.fixes.append(fixes)
+ self.hijack.fixes.append(fixes)
batch_multipliers.append(multipliers)
+ if len(used_custom_terms) > 0:
+ self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
+
tokens = torch.asarray(remade_batch_tokens).to(self.wrapped.device)
outputs = self.wrapped.transformer(input_ids=tokens)
z = outputs.last_hidden_state
@@ -679,71 +641,123 @@ class EmbeddingsWithFixes(nn.Module):
for offset, word in fixes:
tensor[offset] = self.embeddings.word_embeddings[word]
-
return inputs_embeds
-def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False, extra_generation_params=None):
+class StableDiffusionProcessing:
+ def __init__(self, outpath=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, prompt_matrix=False, use_GFPGAN=False, do_not_save_grid=False, extra_generation_params=None):
+ self.outpath: str = outpath
+ self.prompt: str = prompt
+ self.seed: int = seed
+ self.sampler_index: int = sampler_index
+ self.batch_size: int = batch_size
+ self.n_iter: int = n_iter
+ self.steps: int = steps
+ self.cfg_scale: float = cfg_scale
+ self.width: int = width
+ self.height: int = height
+ self.prompt_matrix: bool = prompt_matrix
+ self.use_GFPGAN: bool = use_GFPGAN
+ self.do_not_save_grid: bool = do_not_save_grid
+ self.extra_generation_params: dict = extra_generation_params
+
+ def init(self):
+ pass
+
+ def sample(self, x, conditioning, unconditional_conditioning):
+ raise NotImplementedError()
+
+
+class VanillaStableDiffusionSampler:
+ def __init__(self, constructor):
+ self.sampler = constructor(sd_model)
+
+ def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning):
+ samples_ddim, _ = self.sampler.sample(S=p.steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x)
+ return samples_ddim
+
+
+class CFGDenoiser(nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.inner_model = model
+
+ def forward(self, x, sigma, uncond, cond, cond_scale):
+ x_in = torch.cat([x] * 2)
+ sigma_in = torch.cat([sigma] * 2)
+ cond_in = torch.cat([uncond, cond])
+ uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
+ return uncond + (cond - uncond) * cond_scale
+
+
+class KDiffusionSampler:
+ def __init__(self, funcname):
+ self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model)
+ self.funcname = funcname
+ self.func = getattr(k_diffusion.sampling, self.funcname)
+ self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
+
+ def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning):
+ sigmas = self.model_wrap.get_sigmas(p.steps)
+ x = x * sigmas[0]
+
+ samples_ddim = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False)
+ return samples_ddim
+
+
+def process_images(p: StableDiffusionProcessing):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
- assert prompt is not None
+ prompt = p.prompt
+ model = sd_model
+
+ assert p.prompt is not None
torch_gc()
- if seed == -1:
- seed = random.randrange(4294967294)
- seed = int(seed)
+ seed = int(random.randrange(4294967294) if p.seed == -1 else p.seed)
- os.makedirs(outpath, exist_ok=True)
+ os.makedirs(p.outpath, exist_ok=True)
- sample_path = os.path.join(outpath, "samples")
+ sample_path = os.path.join(p.outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))
- grid_count = len(os.listdir(outpath)) - 1
+ grid_count = len(os.listdir(p.outpath)) - 1
comments = []
prompt_matrix_parts = []
- if prompt_matrix:
+ if p.prompt_matrix:
all_prompts = []
prompt_matrix_parts = prompt.split("|")
combination_count = 2 ** (len(prompt_matrix_parts) - 1)
for combination_num in range(combination_count):
- selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1<<n)]
+ selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1 << n)]
if opts.prompt_matrix_add_to_start:
selected_prompts = selected_prompts + [prompt_matrix_parts[0]]
else:
selected_prompts = [prompt_matrix_parts[0]] + selected_prompts
- all_prompts.append( ", ".join(selected_prompts))
+ all_prompts.append(", ".join(selected_prompts))
- n_iter = math.ceil(len(all_prompts) / batch_size)
+ p.n_iter = math.ceil(len(all_prompts) / p.batch_size)
all_seeds = len(all_prompts) * [seed]
- print(f"Prompt matrix will create {len(all_prompts)} images using a total of {n_iter} batches.")
+ print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.")
else:
-
- if opts.verify_input:
- try:
- check_prompt_length(prompt, comments)
- except:
- import traceback
- print("Error verifying input:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
- all_prompts = batch_size * n_iter * [prompt]
+ all_prompts = p.batch_size * p.n_iter * [prompt]
all_seeds = [seed + x for x in range(len(all_prompts))]
generation_params = {
- "Steps": steps,
- "Sampler": samplers[sampler_index].name,
- "CFG scale": cfg_scale,
+ "Steps": p.steps,
+ "Sampler": samplers[p.sampler_index].name,
+ "CFG scale": p.cfg_scale,
"Seed": seed,
- "GFPGAN": ("GFPGAN" if use_GFPGAN and GFPGAN is not None else None)
+ "GFPGAN": ("GFPGAN" if p.use_GFPGAN and GFPGAN is not None else None)
}
- if extra_generation_params is not None:
- generation_params.update(extra_generation_params)
+ if p.extra_generation_params is not None:
+ generation_params.update(p.extra_generation_params)
generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
@@ -755,32 +769,32 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index,
output_images = []
with torch.no_grad(), autocast("cuda"), model.ema_scope():
- init_data = func_init()
+ p.init()
- for n in range(n_iter):
- prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
- seeds = all_seeds[n * batch_size:(n + 1) * batch_size]
+ for n in range(p.n_iter):
+ prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+ seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
uc = model.get_learned_conditioning(len(prompts) * [""])
c = model.get_learned_conditioning(prompts)
- if len(model_hijack.used_custom_terms) > 0:
- comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in model_hijack.used_custom_terms]))
+ if len(model_hijack.comments) > 0:
+ comments += model_hijack.comments
# we manually generate all input noises because each one should have a specific seed
- x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds)
+ x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds)
- samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc)
+ samples_ddim = p.sample(x=x, conditioning=c, unconditional_conditioning=uc)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
- if prompt_matrix or opts.samples_save or opts.grid_save:
+ if p.prompt_matrix or opts.samples_save or opts.grid_save:
for i, x_sample in enumerate(x_samples_ddim):
- x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
- if use_GFPGAN and GFPGAN is not None:
+ if p.use_GFPGAN and GFPGAN is not None:
torch_gc()
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample, has_aligned=False, only_center_face=False, paste_back=True)
x_sample = restored_img
@@ -791,44 +805,44 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index,
output_images.append(image)
base_count += 1
- if (prompt_matrix or opts.grid_save) and not do_not_save_grid:
- if prompt_matrix:
- grid = image_grid(output_images, batch_size, force_n_rows=1 << ((len(prompt_matrix_parts)-1)//2))
+ if (p.prompt_matrix or opts.grid_save) and not p.do_not_save_grid:
+ if p.prompt_matrix:
+ grid = image_grid(output_images, p.batch_size, force_n_rows=1 << ((len(prompt_matrix_parts)-1)//2))
try:
- grid = draw_prompt_matrix(grid, width, height, prompt_matrix_parts)
- except:
+ grid = draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts)
+ except Exception:
import traceback
print("Error creating prompt_matrix text:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
output_images.insert(0, grid)
else:
- grid = image_grid(output_images, batch_size)
+ grid = image_grid(output_images, p.batch_size)
- save_image(grid, outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
+ save_image(grid, p.outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
grid_count += 1
torch_gc()
return output_images, seed, infotext()
-def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int):
- outpath = opts.outdir or "outputs/txt2img-samples"
+class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
+ sampler = None
- sampler = samplers[sampler_index].constructor(model)
+ def init(self):
+ self.sampler = samplers[self.sampler_index].constructor()
- def init():
- pass
-
- def sample(init_data, x, conditioning, unconditional_conditioning):
- samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x)
+ def sample(self, x, conditioning, unconditional_conditioning):
+ samples_ddim = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
return samples_ddim
- output_images, seed, info = process_images(
+
+def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int):
+ outpath = opts.outdir or "outputs/txt2img-samples"
+
+ p = StableDiffusionProcessingTxt2Img(
outpath=outpath,
- func_init=init,
- func_sample=sample,
prompt=prompt,
seed=seed,
sampler_index=sampler_index,
@@ -842,7 +856,7 @@ def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool,
use_GFPGAN=use_GFPGAN
)
- del sampler
+ output_images, seed, info = process_images(p)
return output_images, seed, plaintext_to_html(info)
@@ -858,7 +872,7 @@ class Flagging(gr.FlaggingCallback):
os.makedirs("log/images", exist_ok=True)
# those must match the "txt2img" function
- prompt, ddim_steps, sampler_name, use_GFPGAN, prompt_matrix, ddim_eta, n_iter, n_samples, cfg_scale, request_seed, height, width, images, seed, comment = flag_data
+ prompt, ddim_steps, sampler_name, use_gfpgan, prompt_matrix, ddim_eta, n_iter, n_samples, cfg_scale, request_seed, height, width, images, seed, comment = flag_data
filenames = []
@@ -896,7 +910,6 @@ txt2img_interface = gr.Interface(
gr.Radio(label='Sampling method', choices=[x.name for x in samplers], value=samplers[0].name, type="index"),
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
- gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1),
gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
@@ -914,73 +927,97 @@ txt2img_interface = gr.Interface(
)
-def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, sd_upscale: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
- outpath = opts.outdir or "outputs/img2img-samples"
+class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
+ sampler = None
- sampler = samplers_for_img2img[sampler_index].constructor(model)
+ def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, **kwargs):
+ super().__init__(**kwargs)
- assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
+ self.init_images = init_images
+ self.resize_mode: int = resize_mode
+ self.denoising_strength: float = denoising_strength
+ self.init_latent = None
+
+ def init(self):
+ self.sampler = samplers_for_img2img[self.sampler_index].constructor()
- def init():
- image = init_img.convert("RGB")
- image = resize_image(resize_mode, image, width, height)
- image = np.array(image).astype(np.float32) / 255.0
- image = image[None].transpose(0, 3, 1, 2)
- image = torch.from_numpy(image)
+ imgs = []
+ for img in self.init_images:
+ image = img.convert("RGB")
+ image = resize_image(self.resize_mode, image, self.width, self.height)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = np.moveaxis(image, 2, 0)
+ imgs.append(image)
- init_image = 2. * image - 1.
- init_image = init_image.to(device)
- init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
- init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
+ if len(imgs) == 1:
+ batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
+ elif len(imgs) <= self.batch_size:
+ self.batch_size = len(imgs)
+ batch_images = np.array(imgs)
+ else:
+ raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
- return init_latent,
+ image = torch.from_numpy(batch_images)
+ image = 2. * image - 1.
+ image = image.to(device)
- def sample(init_data, x, conditioning, unconditional_conditioning):
- t_enc = int(denoising_strength * ddim_steps)
+ self.init_latent = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image))
- x0, = init_data
+ def sample(self, x, conditioning, unconditional_conditioning):
+ t_enc = int(self.denoising_strength * self.steps)
- sigmas = sampler.model_wrap.get_sigmas(ddim_steps)
- noise = x * sigmas[ddim_steps - t_enc - 1]
+ sigmas = self.sampler.model_wrap.get_sigmas(self.steps)
+ noise = x * sigmas[self.steps - t_enc - 1]
- xi = x0 + noise
- sigma_sched = sigmas[ddim_steps - t_enc - 1:]
- model_wrap_cfg = CFGDenoiser(sampler.model_wrap)
- samples_ddim = sampler.func(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False)
+ xi = self.init_latent + noise
+ sigma_sched = sigmas[self.steps - t_enc - 1:]
+ samples_ddim = self.sampler.func(self.sampler.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': self.cfg_scale}, disable=False)
return samples_ddim
+
+def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, sd_upscale: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
+ outpath = opts.outdir or "outputs/img2img-samples"
+
+ assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
+
+ p = StableDiffusionProcessingImg2Img(
+ outpath=outpath,
+ prompt=prompt,
+ seed=seed,
+ sampler_index=sampler_index,
+ batch_size=batch_size,
+ n_iter=n_iter,
+ steps=ddim_steps,
+ cfg_scale=cfg_scale,
+ width=width,
+ height=height,
+ prompt_matrix=prompt_matrix,
+ use_GFPGAN=use_GFPGAN,
+ init_images=[init_img],
+ resize_mode=resize_mode,
+ denoising_strength=denoising_strength,
+ extra_generation_params={"Denoising Strength": denoising_strength}
+ )
+
if loopback:
output_images, info = None, None
history = []
initial_seed = None
for i in range(n_iter):
- output_images, seed, info = process_images(
- outpath=outpath,
- func_init=init,
- func_sample=sample,
- prompt=prompt,
- seed=seed,
- sampler_index=sampler_index,
- batch_size=1,
- n_iter=1,
- steps=ddim_steps,
- cfg_scale=cfg_scale,
- width=width,
- height=height,
- prompt_matrix=prompt_matrix,
- use_GFPGAN=use_GFPGAN,
- do_not_save_grid=True,
- extra_generation_params={"Denoising Strength": denoising_strength},
- )
+ p.n_iter = 1
+ p.batch_size = 1
+ p.do_not_save_grid = True
+
+ output_images, seed, info = process_images(p)
if initial_seed is None:
initial_seed = seed
- init_img = output_images[0]
- seed = seed + 1
- denoising_strength = max(denoising_strength * 0.95, 0.1)
- history.append(init_img)
+ p.init_img = output_images[0]
+ p.seed = seed + 1
+ p.denoising_strength = max(p.denoising_strength * 0.95, 0.1)
+ history.append(output_images[0])
grid_count = len(os.listdir(outpath)) - 1
grid = image_grid(history, batch_size, force_n_rows=1)
@@ -1000,39 +1037,36 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
grid = split_grid(img, tile_w=width, tile_h=height, overlap=opts.sd_upscale_overlap)
+ p.n_iter = 1
+ p.do_not_save_grid = True
+
+ work = []
+ work_results = []
+
+ for y, h, row in grid.tiles:
+ for tiledata in row:
+ work.append(tiledata[2])
+
+ batch_count = math.ceil(len(work) / p.batch_size)
+ print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} in a total of {batch_count} batches.")
- print(f"SD upscaling will process a total of {len(grid.tiles[0][2])}x{len(grid.tiles)} images.")
+ for i in range(batch_count):
+ p.init_images = work[i*p.batch_size:(i+1)*p.batch_size]
+ output_images, seed, info = process_images(p)
+
+ if initial_seed is None:
+ initial_seed = seed
+ initial_info = info
+
+ p.seed = seed + 1
+ work_results += output_images
+
+ image_index = 0
for y, h, row in grid.tiles:
for tiledata in row:
- init_img = tiledata[2]
-
- output_images, seed, info = process_images(
- outpath=outpath,
- func_init=init,
- func_sample=sample,
- prompt=prompt,
- seed=seed,
- sampler_index=sampler_index,
- batch_size=1, # since process_images can't work with multiple different images we have to do this for now
- n_iter=1,
- steps=ddim_steps,
- cfg_scale=cfg_scale,
- width=width,
- height=height,
- prompt_matrix=prompt_matrix,
- use_GFPGAN=use_GFPGAN,
- do_not_save_grid=True,
- extra_generation_params={"Denoising Strength": denoising_strength},
- )
-
- if initial_seed is None:
- initial_seed = seed
- initial_info = info
-
- seed += 1
-
- tiledata[2] = output_images[0]
+ tiledata[2] = work_results[image_index]
+ image_index += 1
combined_image = combine_grid(grid)
@@ -1044,25 +1078,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
info = initial_info
else:
- output_images, seed, info = process_images(
- outpath=outpath,
- func_init=init,
- func_sample=sample,
- prompt=prompt,
- seed=seed,
- sampler_index=sampler_index,
- batch_size=batch_size,
- n_iter=n_iter,
- steps=ddim_steps,
- cfg_scale=cfg_scale,
- width=width,
- height=height,
- prompt_matrix=prompt_matrix,
- use_GFPGAN=use_GFPGAN,
- extra_generation_params={"Denoising Strength": denoising_strength},
- )
-
- del sampler
+ output_images, seed, info = process_images(p)
return output_images, seed, plaintext_to_html(info)
@@ -1178,22 +1194,19 @@ def run_settings(*args):
def create_setting_component(key):
def fun():
- return opts.data[key] if key in opts.data else opts.data_labels[key][0]
+ return opts.data[key] if key in opts.data else opts.data_labels[key].default
+
+ info = opts.data_labels[key]
+ t = type(info.default)
- labelinfo = opts.data_labels[key]
- t = type(labelinfo[0])
- label = labelinfo[1]
- if t == str:
- item = gr.Textbox(label=label, value=fun, lines=1)
+ if info.component is not None:
+ item = info.component(label=info.label, value=fun, **(info.component_args or {}))
+ elif t == str:
+ item = gr.Textbox(label=info.label, value=fun, lines=1)
elif t == int:
- if len(labelinfo) == 5:
- item = gr.Slider(minimum=labelinfo[2], maximum=labelinfo[3], step=labelinfo[4], label=label, value=fun)
- elif len(labelinfo) == 4:
- item = gr.Slider(minimum=labelinfo[2], maximum=labelinfo[3], step=1, label=label, value=fun)
- else:
- item = gr.Number(label=label, value=fun)
+ item = gr.Number(label=info.label, value=fun)
elif t == bool:
- item = gr.Checkbox(label=label, value=fun)
+ item = gr.Checkbox(label=info.label, value=fun)
else:
raise Exception(f'bad options item type: {str(t)} for key {key}')
@@ -1219,14 +1232,14 @@ interfaces = [
(settings_interface, "Settings"),
]
-config = OmegaConf.load(cmd_opts.config)
-model = load_model_from_config(config, cmd_opts.ckpt)
+sd_config = OmegaConf.load(cmd_opts.config)
+sd_model = load_model_from_config(sd_config, cmd_opts.ckpt)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
-model = (model if cmd_opts.no_half else model.half()).to(device)
+sd_model = (sd_model if cmd_opts.no_half else sd_model.half()).to(device)
model_hijack = StableDiffuionModelHijack()
-model_hijack.hijack(model)
+model_hijack.hijack(sd_model)
demo = gr.TabbedInterface(
interface_list=[x[0] for x in interfaces],