aboutsummaryrefslogtreecommitdiff
path: root/webui.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-08-23 00:34:49 +0300
committerAUTOMATIC <16777216c@gmail.com>2022-08-23 00:34:49 +0300
commit3395c29127e2dfc4467f04b40b2aec7ef3ec1196 (patch)
tree028664723fe8d0538c35cad9cc14ae8303ba85e7 /webui.py
parentb63d0726cd26aa8124e6d2e6f339b474a1563459 (diff)
added prompt matrix feature
all images in batches now have proper seeds, not just the first one added code to remove bad characters from filenames added code to flag output which writes it to csv and saves images renamed some fields in UI for clarity
Diffstat (limited to 'webui.py')
-rw-r--r--webui.py168
1 files changed, 127 insertions, 41 deletions
diff --git a/webui.py b/webui.py
index bb53e5ff..5b990a5f 100644
--- a/webui.py
+++ b/webui.py
@@ -8,12 +8,12 @@ from omegaconf import OmegaConf
from PIL import Image
from itertools import islice
from einops import rearrange, repeat
-from torchvision.utils import make_grid
from torch import autocast
from contextlib import contextmanager, nullcontext
import mimetypes
import random
import math
+import csv
import k_diffusion as K
from ldm.util import instantiate_from_config
@@ -28,6 +28,8 @@ mimetypes.add_type('application/javascript', '.js')
opt_C = 4
opt_f = 8
+invalid_filename_chars = '<>:"/\|?*'
+
parser = argparse.ArgumentParser()
parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default=None)
parser.add_argument("--skip_grid", action='store_true', help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",)
@@ -127,13 +129,14 @@ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cp
model = model.half().to(device)
-def image_grid(imgs, batch_size):
+def image_grid(imgs, batch_size, round_down=False):
if opt.n_rows > 0:
rows = opt.n_rows
elif opt.n_rows == 0:
rows = batch_size
else:
- rows = round(math.sqrt(len(imgs)))
+ rows = math.sqrt(len(imgs))
+ rows = int(rows) if round_down else round(rows)
cols = math.ceil(len(imgs) / rows)
@@ -146,7 +149,7 @@ def image_grid(imgs, batch_size):
return grid
-def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int):
+def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int):
torch.cuda.empty_cache()
outpath = opt.outdir or "outputs/txt2img-samples"
@@ -155,6 +158,7 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, ddi
seed = random.randrange(4294967294)
seed = int(seed)
+ keep_same_seed = False
is_PLMS = sampler_name == 'PLMS'
is_DDIM = sampler_name == 'DDIM'
@@ -177,59 +181,99 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, ddi
batch_size = n_samples
assert prompt is not None
- data = [batch_size * [prompt]]
+ prompts = batch_size * [prompt]
sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))
grid_count = len(os.listdir(outpath)) - 1
+ prompt_matrix_prompts = []
+ comment = ""
+ if prompt_matrix:
+ keep_same_seed = True
+ comment = "Image prompts:\n\n"
+
+ items = prompt.split("|")
+ combination_count = 2 ** (len(items)-1)
+ for combination_num in range(combination_count):
+ current = items[0]
+ label = 'A'
+
+ for n, text in enumerate(items[1:]):
+ if combination_num & (2**n) > 0:
+ current += ("" if text.strip().startswith(",") else ", ") + text
+ label += chr(ord('B') + n)
+
+ comment += " - " + label + "\n"
+
+ prompt_matrix_prompts.append(current)
+ n_iter = math.ceil(len(prompt_matrix_prompts) / batch_size)
+
+ comment += "\nwhere:\n"
+ for n, text in enumerate(items):
+ comment += " " + chr(ord('A') + n) + " = " + items[n] + "\n"
+
precision_scope = autocast if opt.precision == "autocast" else nullcontext
output_images = []
with torch.no_grad(), precision_scope("cuda"), model.ema_scope():
for n in range(n_iter):
- for batch_index, prompts in enumerate(data):
- uc = None
- if cfg_scale != 1.0:
- uc = model.get_learned_conditioning(batch_size * [""])
- if isinstance(prompts, tuple):
- prompts = list(prompts)
- c = model.get_learned_conditioning(prompts)
- shape = [opt_C, height // opt_f, width // opt_f]
-
- current_seed = seed + n * len(data) + batch_index
+ if prompt_matrix:
+ prompts = prompt_matrix_prompts[n*batch_size:(n+1)*batch_size]
+
+ uc = None
+ if cfg_scale != 1.0:
+ uc = model.get_learned_conditioning(len(prompts) * [""])
+ if isinstance(prompts, tuple):
+ prompts = list(prompts)
+ c = model.get_learned_conditioning(prompts)
+ shape = [opt_C, height // opt_f, width // opt_f]
+
+ batch_seed = seed if keep_same_seed else seed + n * len(prompts)
+
+ # we manually generate all input noises because each one should have a specific seed
+ xs = []
+ for i in range(len(prompts)):
+ current_seed = seed if keep_same_seed else batch_seed + i
torch.manual_seed(current_seed)
+ xs.append(torch.randn(shape, device=device))
+ x = torch.stack(xs)
- if is_Kdif:
- sigmas = model_wrap.get_sigmas(ddim_steps)
- x = torch.randn([n_samples, *shape], device=device) * sigmas[0] # for GPU draw
- model_wrap_cfg = CFGDenoiser(model_wrap)
- samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args={'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}, disable=False)
+ if is_Kdif:
+ sigmas = model_wrap.get_sigmas(ddim_steps)
+ x = x * sigmas[0]
+ model_wrap_cfg = CFGDenoiser(model_wrap)
+ samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args={'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}, disable=False)
- elif sampler is not None:
- samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=c, batch_size=n_samples, shape=shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, eta=ddim_eta, x_T=None)
+ elif sampler is not None:
+ samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=c, batch_size=len(prompts), shape=shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, eta=ddim_eta, x_T=x)
- x_samples_ddim = model.decode_first_stage(samples_ddim)
- x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
- if not opt.skip_save or not opt.skip_grid:
- for x_sample in x_samples_ddim:
- x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
- x_sample = x_sample.astype(np.uint8)
+ if not opt.skip_save or not opt.skip_grid:
+ for i, x_sample in enumerate(x_samples_ddim):
+ x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
+ x_sample = x_sample.astype(np.uint8)
+
+ if use_GFPGAN and GFPGAN is not None:
+ cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample, has_aligned=False, only_center_face=False, paste_back=True)
+ x_sample = restored_img
+
+ image = Image.fromarray(x_sample)
+ filename = f"{base_count:05}-{seed if keep_same_seed else batch_seed + i}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png"
+
+ image.save(os.path.join(sample_path, filename))
+
+ output_images.append(image)
+ base_count += 1
- if use_GFPGAN and GFPGAN is not None:
- cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample, has_aligned=False, only_center_face=False, paste_back=True)
- x_sample = restored_img
- image = Image.fromarray(x_sample)
- image.save(os.path.join(sample_path, f"{base_count:05}-{current_seed}_{prompt.replace(' ', '_')[:128]}.png"))
- output_images.append(image)
- base_count += 1
if not opt.skip_grid:
# additionally, save as grid
- grid = image_grid(output_images, batch_size)
+ grid = image_grid(output_images, batch_size, round_down=prompt_matrix)
grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid_count += 1
@@ -242,8 +286,49 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, ddi
Steps: {ddim_steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
""".strip()
+ if len(comment) > 0:
+ info += "\n\n" + comment
+
return output_images, seed, info
+class Flagging(gr.FlaggingCallback):
+
+ def setup(self, components, flagging_dir: str):
+ pass
+
+ def flag(self, flag_data, flag_option=None, flag_index=None, username=None) -> int:
+ os.makedirs("log/images", exist_ok=True)
+
+ # those must match the "dream" function
+ prompt, ddim_steps, sampler_name, use_GFPGAN, prompt_matrix, ddim_eta, n_iter, n_samples, cfg_scale, request_seed, height, width, images, seed, comment = flag_data
+
+ filenames = []
+
+ with open("log/log.csv", "a", encoding="utf8", newline='') as file:
+ import time
+ import base64
+
+ at_start = file.tell() == 0
+ writer = csv.writer(file)
+ if at_start:
+ writer.writerow(["prompt", "seed", "width", "height", "cfgs", "steps", "filename"])
+
+ filename_base = str(int(time.time() * 1000))
+ for i, filedata in enumerate(images):
+ filename = "log/images/"+filename_base + ("" if len(images) == 1 else "-"+str(i+1)) + ".png"
+
+ if filedata.startswith("data:image/png;base64,"):
+ filedata = filedata[len("data:image/png;base64,"):]
+
+ with open(filename, "wb") as imgfile:
+ imgfile.write(base64.decodebytes(filedata.encode('utf-8')))
+
+ filenames.append(filename)
+
+ writer.writerow([prompt, seed, width, height, cfg_scale, ddim_steps, filenames[0]])
+
+ print("Logged:", filenames[0])
+
dream_interface = gr.Interface(
dream,
@@ -252,10 +337,11 @@ dream_interface = gr.Interface(
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
gr.Radio(label='Sampling method', choices=["DDIM", "PLMS", "k-diffusion"], value="k-diffusion"),
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
+ gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
- gr.Slider(minimum=1, maximum=16, step=1, label='Sampling iterations', value=1),
- gr.Slider(minimum=1, maximum=4, step=1, label='Samples per iteration', value=1),
- gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale', value=7.0),
+ gr.Slider(minimum=1, maximum=16, step=1, label='Batch count (how many batches of images to generate)', value=1),
+ gr.Slider(minimum=1, maximum=4, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
+ gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly should the image follow the prompt)', value=7.0),
gr.Number(label='Seed', value=-1),
gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512),
gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512),
@@ -267,7 +353,7 @@ dream_interface = gr.Interface(
],
title="Stable Diffusion Text-to-Image K",
description="Generate images from text with Stable Diffusion (using K-LMS)",
- allow_flagging="never"
+ flagging_callback=Flagging()
)
@@ -346,8 +432,8 @@ def translation(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, ddim_e
x_sample = restored_img
image = Image.fromarray(x_sample)
+ image.save(os.path.join(sample_path, f"{base_count:05}-{current_seed}_{prompt.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png"))
- image.save(os.path.join(sample_path, f"{base_count:05}-{current_seed}_{prompt.replace(' ', '_')[:128]}.png"))
output_images.append(image)
base_count += 1