aboutsummaryrefslogtreecommitdiff
path: root/webui.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-08-23 18:04:13 +0300
committerAUTOMATIC <16777216c@gmail.com>2022-08-23 18:04:13 +0300
commitcb118c4036176c3d534fc408fed1a95332a85f26 (patch)
treeca9c6fc0478982bd838dc7df8e4dfefd4d35c15e /webui.py
parent61bfa6c16b651b9c72fbe01f2a0899ff4ba1f027 (diff)
Prompt matrix now draws text like in demo.
Diffstat (limited to 'webui.py')
-rw-r--r--webui.py129
1 files changed, 84 insertions, 45 deletions
diff --git a/webui.py b/webui.py
index 6f8efa84..95dcc751 100644
--- a/webui.py
+++ b/webui.py
@@ -1,11 +1,10 @@
-import PIL
import argparse, os, sys, glob
import torch
import torch.nn as nn
import numpy as np
import gradio as gr
from omegaconf import OmegaConf
-from PIL import Image
+from PIL import Image, ImageFont, ImageDraw
from itertools import islice
from einops import rearrange, repeat
from torch import autocast
@@ -76,23 +75,6 @@ def load_model_from_config(config, ckpt, verbose=False):
return model
-def load_img_pil(img_pil):
- image = img_pil.convert("RGB")
- w, h = image.size
- print(f"loaded input image of size ({w}, {h})")
- w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
- image = image.resize((w, h), resample=PIL.Image.LANCZOS)
- print(f"cropped image to size ({w}, {h})")
- image = np.array(image).astype(np.float32) / 255.0
- image = image[None].transpose(0, 3, 1, 2)
- image = torch.from_numpy(image)
- return 2. * image - 1.
-
-
-def load_img(path):
- return load_img_pil(Image.open(path))
-
-
class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
@@ -179,6 +161,71 @@ def image_grid(imgs, batch_size, round_down=False):
return grid
+def draw_prompt_matrix(im, width, height, all_prompts):
+ def wrap(text, d, font, line_length):
+ lines = ['']
+ for word in text.split():
+ line = f'{lines[-1]} {word}'.strip()
+ if d.textlength(line, font=font) <= line_length:
+ lines[-1] = line
+ else:
+ lines.append(word)
+ return '\n'.join(lines)
+
+ def draw_texts(pos, x, y, texts, sizes):
+ for i, (text, size) in enumerate(zip(texts, sizes)):
+ active = pos & (1 << i) != 0
+
+ if not active:
+ text = '\u0336'.join(text) + '\u0336'
+
+ d.multiline_text((x, y + size[1] / 2), text, font=fnt, fill=color_active if active else color_inactive, anchor="mm", align="center")
+
+ y += size[1] + line_spacing
+
+ fontsize = (width + height) // 25
+ line_spacing = fontsize // 2
+ fnt = ImageFont.truetype("arial.ttf", fontsize)
+ color_active = (0, 0, 0)
+ color_inactive = (153, 153, 153)
+
+ pad_top = height // 4
+ pad_left = width * 3 // 4
+
+ cols = im.width // width
+ rows = im.height // height
+
+ prompts = all_prompts[1:]
+
+ result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white")
+ result.paste(im, (pad_left, pad_top))
+
+ d = ImageDraw.Draw(result)
+
+ boundary = math.ceil(len(prompts) / 2)
+ prompts_horiz = [wrap(x, d, fnt, width) for x in prompts[:boundary]]
+ prompts_vert = [wrap(x, d, fnt, pad_left) for x in prompts[boundary:]]
+
+ sizes_hor = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_horiz]]
+ sizes_ver = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_vert]]
+ hor_text_height = sum([x[1] + line_spacing for x in sizes_hor]) - line_spacing
+ ver_text_height = sum([x[1] + line_spacing for x in sizes_ver]) - line_spacing
+
+ for col in range(cols):
+ x = pad_left + width * col + width / 2
+ y = pad_top / 2 - hor_text_height / 2
+
+ draw_texts(col, x, y, prompts_horiz, sizes_hor)
+
+ for row in range(rows):
+ x = pad_left / 2
+ y = pad_top + height * row + height / 2 - ver_text_height / 2
+
+ draw_texts(row, x, y, prompts_vert, sizes_ver)
+
+ return result
+
+
def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int):
torch.cuda.empty_cache()
@@ -212,30 +259,23 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, pro
grid_count = len(os.listdir(outpath)) - 1
prompt_matrix_prompts = []
- comment = ""
+ prompt_matrix_parts = []
if prompt_matrix:
keep_same_seed = True
- comment = "Image prompts:\n\n"
- items = prompt.split("|")
- combination_count = 2 ** (len(items)-1)
+ prompt_matrix_parts = prompt.split("|")
+ combination_count = 2 ** (len(prompt_matrix_parts)-1)
for combination_num in range(combination_count):
- current = items[0]
- label = 'A'
+ current = prompt_matrix_parts[0]
- for n, text in enumerate(items[1:]):
+ for n, text in enumerate(prompt_matrix_parts[1:]):
if combination_num & (2**n) > 0:
current += ("" if text.strip().startswith(",") else ", ") + text
- label += chr(ord('B') + n)
-
- comment += " - " + label + "\n"
prompt_matrix_prompts.append(current)
n_iter = math.ceil(len(prompt_matrix_prompts) / batch_size)
- comment += "\nwhere:\n"
- for n, text in enumerate(items):
- comment += " " + chr(ord('A') + n) + " = " + items[n] + "\n"
+ print(f"Prompt matrix will create {len(prompt_matrix_prompts)} images using a total of {n_iter} batches.")
precision_scope = autocast if opt.precision == "autocast" else nullcontext
output_images = []
@@ -262,7 +302,7 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, pro
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
- if not opt.skip_save or not opt.skip_grid:
+ if prompt_matrix or not opt.skip_save or not opt.skip_grid:
for i, x_sample in enumerate(x_samples_ddim):
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = x_sample.astype(np.uint8)
@@ -279,24 +319,23 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, pro
output_images.append(image)
base_count += 1
- if not opt.skip_grid:
- # additionally, save as grid
+ if prompt_matrix or not opt.skip_grid:
grid = image_grid(output_images, batch_size, round_down=prompt_matrix)
+
+ if prompt_matrix:
+ grid = draw_prompt_matrix(grid, width, height, prompt_matrix_parts)
+ output_images.insert(0, grid)
+
grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid_count += 1
-
- if sampler is not None:
- del sampler
+ del sampler
info = f"""
{prompt}
Steps: {ddim_steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
""".strip()
- if len(comment) > 0:
- info += "\n\n" + comment
-
return output_images, seed, info
class Flagging(gr.FlaggingCallback):
@@ -350,7 +389,7 @@ dream_interface = gr.Interface(
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
gr.Slider(minimum=1, maximum=16, step=1, label='Batch count (how many batches of images to generate)', value=1),
- gr.Slider(minimum=1, maximum=4, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
+ gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
gr.Number(label='Seed', value=-1),
gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512),
@@ -389,7 +428,7 @@ def translation(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, ddim_e
grid_count = len(os.listdir(outpath)) - 1
image = init_img.convert("RGB")
- image = image.resize((width, height), resample=PIL.Image.Resampling.LANCZOS)
+ image = image.resize((width, height), resample=Image.Resampling.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
@@ -466,7 +505,7 @@ img2img_interface = gr.Interface(
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
gr.Slider(minimum=1, maximum=16, step=1, label='Batch count (how many batches of images to generate)', value=1),
- gr.Slider(minimum=1, maximum=4, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
+ gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising Strength', value=0.75),
gr.Number(label='Seed', value=-1),
@@ -494,7 +533,7 @@ def run_GFPGAN(image, strength):
res = Image.fromarray(restored_img)
if strength < 1.0:
- res = PIL.Image.blend(image, res, strength)
+ res = Image.blend(image, res, strength)
return res