aboutsummaryrefslogtreecommitdiff
path: root/webui.py
diff options
context:
space:
mode:
Diffstat (limited to 'webui.py')
-rw-r--r--webui.py168
1 files changed, 127 insertions, 41 deletions
diff --git a/webui.py b/webui.py
index bb53e5ff..5b990a5f 100644
--- a/webui.py
+++ b/webui.py
@@ -8,12 +8,12 @@ from omegaconf import OmegaConf
from PIL import Image
from itertools import islice
from einops import rearrange, repeat
-from torchvision.utils import make_grid
from torch import autocast
from contextlib import contextmanager, nullcontext
import mimetypes
import random
import math
+import csv
import k_diffusion as K
from ldm.util import instantiate_from_config
@@ -28,6 +28,8 @@ mimetypes.add_type('application/javascript', '.js')
opt_C = 4
opt_f = 8
+invalid_filename_chars = '<>:"/\|?*'
+
parser = argparse.ArgumentParser()
parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default=None)
parser.add_argument("--skip_grid", action='store_true', help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",)
@@ -127,13 +129,14 @@ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cp
model = model.half().to(device)
-def image_grid(imgs, batch_size):
+def image_grid(imgs, batch_size, round_down=False):
if opt.n_rows > 0:
rows = opt.n_rows
elif opt.n_rows == 0:
rows = batch_size
else:
- rows = round(math.sqrt(len(imgs)))
+ rows = math.sqrt(len(imgs))
+ rows = int(rows) if round_down else round(rows)
cols = math.ceil(len(imgs) / rows)
@@ -146,7 +149,7 @@ def image_grid(imgs, batch_size):
return grid
-def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int):
+def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int):
torch.cuda.empty_cache()
outpath = opt.outdir or "outputs/txt2img-samples"
@@ -155,6 +158,7 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, ddi
seed = random.randrange(4294967294)
seed = int(seed)
+ keep_same_seed = False
is_PLMS = sampler_name == 'PLMS'
is_DDIM = sampler_name == 'DDIM'
@@ -177,59 +181,99 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, ddi
batch_size = n_samples
assert prompt is not None
- data = [batch_size * [prompt]]
+ prompts = batch_size * [prompt]
sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))
grid_count = len(os.listdir(outpath)) - 1
+ prompt_matrix_prompts = []
+ comment = ""
+ if prompt_matrix:
+ keep_same_seed = True
+ comment = "Image prompts:\n\n"
+
+ items = prompt.split("|")
+ combination_count = 2 ** (len(items)-1)
+ for combination_num in range(combination_count):
+ current = items[0]
+ label = 'A'
+
+ for n, text in enumerate(items[1:]):
+ if combination_num & (2**n) > 0:
+ current += ("" if text.strip().startswith(",") else ", ") + text
+ label += chr(ord('B') + n)
+
+ comment += " - " + label + "\n"
+
+ prompt_matrix_prompts.append(current)
+ n_iter = math.ceil(len(prompt_matrix_prompts) / batch_size)
+
+ comment += "\nwhere:\n"
+ for n, text in enumerate(items):
+ comment += " " + chr(ord('A') + n) + " = " + items[n] + "\n"
+
precision_scope = autocast if opt.precision == "autocast" else nullcontext
output_images = []
with torch.no_grad(), precision_scope("cuda"), model.ema_scope():
for n in range(n_iter):
- for batch_index, prompts in enumerate(data):
- uc = None
- if cfg_scale != 1.0:
- uc = model.get_learned_conditioning(batch_size * [""])
- if isinstance(prompts, tuple):
- prompts = list(prompts)
- c = model.get_learned_conditioning(prompts)
- shape = [opt_C, height // opt_f, width // opt_f]
-
- current_seed = seed + n * len(data) + batch_index
+ if prompt_matrix:
+ prompts = prompt_matrix_prompts[n*batch_size:(n+1)*batch_size]
+
+ uc = None
+ if cfg_scale != 1.0:
+ uc = model.get_learned_conditioning(len(prompts) * [""])
+ if isinstance(prompts, tuple):
+ prompts = list(prompts)
+ c = model.get_learned_conditioning(prompts)
+ shape = [opt_C, height // opt_f, width // opt_f]
+
+ batch_seed = seed if keep_same_seed else seed + n * len(prompts)
+
+ # we manually generate all input noises because each one should have a specific seed
+ xs = []
+ for i in range(len(prompts)):
+ current_seed = seed if keep_same_seed else batch_seed + i
torch.manual_seed(current_seed)
+ xs.append(torch.randn(shape, device=device))
+ x = torch.stack(xs)
- if is_Kdif:
- sigmas = model_wrap.get_sigmas(ddim_steps)
- x = torch.randn([n_samples, *shape], device=device) * sigmas[0] # for GPU draw
- model_wrap_cfg = CFGDenoiser(model_wrap)
- samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args={'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}, disable=False)
+ if is_Kdif:
+ sigmas = model_wrap.get_sigmas(ddim_steps)
+ x = x * sigmas[0]
+ model_wrap_cfg = CFGDenoiser(model_wrap)
+ samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args={'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}, disable=False)
- elif sampler is not None:
- samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=c, batch_size=n_samples, shape=shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, eta=ddim_eta, x_T=None)
+ elif sampler is not None:
+ samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=c, batch_size=len(prompts), shape=shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, eta=ddim_eta, x_T=x)
- x_samples_ddim = model.decode_first_stage(samples_ddim)
- x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
- if not opt.skip_save or not opt.skip_grid:
- for x_sample in x_samples_ddim:
- x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
- x_sample = x_sample.astype(np.uint8)
+ if not opt.skip_save or not opt.skip_grid:
+ for i, x_sample in enumerate(x_samples_ddim):
+ x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
+ x_sample = x_sample.astype(np.uint8)
+
+ if use_GFPGAN and GFPGAN is not None:
+ cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample, has_aligned=False, only_center_face=False, paste_back=True)
+ x_sample = restored_img
+
+ image = Image.fromarray(x_sample)
+ filename = f"{base_count:05}-{seed if keep_same_seed else batch_seed + i}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png"
+
+ image.save(os.path.join(sample_path, filename))
+
+ output_images.append(image)
+ base_count += 1
- if use_GFPGAN and GFPGAN is not None:
- cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample, has_aligned=False, only_center_face=False, paste_back=True)
- x_sample = restored_img
- image = Image.fromarray(x_sample)
- image.save(os.path.join(sample_path, f"{base_count:05}-{current_seed}_{prompt.replace(' ', '_')[:128]}.png"))
- output_images.append(image)
- base_count += 1
if not opt.skip_grid:
# additionally, save as grid
- grid = image_grid(output_images, batch_size)
+ grid = image_grid(output_images, batch_size, round_down=prompt_matrix)
grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid_count += 1
@@ -242,8 +286,49 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, ddi
Steps: {ddim_steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
""".strip()
+ if len(comment) > 0:
+ info += "\n\n" + comment
+
return output_images, seed, info
+class Flagging(gr.FlaggingCallback):
+
+ def setup(self, components, flagging_dir: str):
+ pass
+
+ def flag(self, flag_data, flag_option=None, flag_index=None, username=None) -> int:
+ os.makedirs("log/images", exist_ok=True)
+
+ # those must match the "dream" function
+ prompt, ddim_steps, sampler_name, use_GFPGAN, prompt_matrix, ddim_eta, n_iter, n_samples, cfg_scale, request_seed, height, width, images, seed, comment = flag_data
+
+ filenames = []
+
+ with open("log/log.csv", "a", encoding="utf8", newline='') as file:
+ import time
+ import base64
+
+ at_start = file.tell() == 0
+ writer = csv.writer(file)
+ if at_start:
+ writer.writerow(["prompt", "seed", "width", "height", "cfgs", "steps", "filename"])
+
+ filename_base = str(int(time.time() * 1000))
+ for i, filedata in enumerate(images):
+ filename = "log/images/"+filename_base + ("" if len(images) == 1 else "-"+str(i+1)) + ".png"
+
+ if filedata.startswith("data:image/png;base64,"):
+ filedata = filedata[len("data:image/png;base64,"):]
+
+ with open(filename, "wb") as imgfile:
+ imgfile.write(base64.decodebytes(filedata.encode('utf-8')))
+
+ filenames.append(filename)
+
+ writer.writerow([prompt, seed, width, height, cfg_scale, ddim_steps, filenames[0]])
+
+ print("Logged:", filenames[0])
+
dream_interface = gr.Interface(
dream,
@@ -252,10 +337,11 @@ dream_interface = gr.Interface(
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
gr.Radio(label='Sampling method', choices=["DDIM", "PLMS", "k-diffusion"], value="k-diffusion"),
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
+ gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
- gr.Slider(minimum=1, maximum=16, step=1, label='Sampling iterations', value=1),
- gr.Slider(minimum=1, maximum=4, step=1, label='Samples per iteration', value=1),
- gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale', value=7.0),
+ gr.Slider(minimum=1, maximum=16, step=1, label='Batch count (how many batches of images to generate)', value=1),
+ gr.Slider(minimum=1, maximum=4, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
+ gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly should the image follow the prompt)', value=7.0),
gr.Number(label='Seed', value=-1),
gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512),
gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512),
@@ -267,7 +353,7 @@ dream_interface = gr.Interface(
],
title="Stable Diffusion Text-to-Image K",
description="Generate images from text with Stable Diffusion (using K-LMS)",
- allow_flagging="never"
+ flagging_callback=Flagging()
)
@@ -346,8 +432,8 @@ def translation(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, ddim_e
x_sample = restored_img
image = Image.fromarray(x_sample)
+ image.save(os.path.join(sample_path, f"{base_count:05}-{current_seed}_{prompt.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png"))
- image.save(os.path.join(sample_path, f"{base_count:05}-{current_seed}_{prompt.replace(' ', '_')[:128]}.png"))
output_images.append(image)
base_count += 1