aboutsummaryrefslogtreecommitdiff
path: root/webui.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-08-27 16:13:33 +0300
committerAUTOMATIC <16777216c@gmail.com>2022-08-27 16:13:33 +0300
commit4e0fdca2f4bd333d8eae5b8cf4a36caba61efc86 (patch)
treea584c2d1025766cb87f51d6b7941b41b9a0edb8d /webui.py
parent9597b265ec07e8ec6dab7487152459046585c1f9 (diff)
Implementation for SD upscale.
Diffstat (limited to 'webui.py')
-rw-r--r--webui.py178
1 files changed, 154 insertions, 24 deletions
diff --git a/webui.py b/webui.py
index a0fa23c4..13e5112a 100644
--- a/webui.py
+++ b/webui.py
@@ -86,11 +86,6 @@ try:
realesrgan_models = [
RealesrganModelInfo(
- name="Real-ESRGAN 2x plus",
- location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
- netscale=2, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
- ),
- RealesrganModelInfo(
name="Real-ESRGAN 4x plus",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
netscale=4, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
@@ -100,6 +95,11 @@ try:
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
netscale=4, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
),
+ RealesrganModelInfo(
+ name="Real-ESRGAN 2x plus",
+ location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
+ netscale=2, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
+ ),
]
have_realesrgan = True
except:
@@ -124,6 +124,7 @@ class Options:
"verify_input": (True, "Check input, and produce warning if it's too long"),
"enable_pnginfo": (True, "Save text information about generation parameters as chunks to png files"),
"prompt_matrix_add_to_start": (True, "In prompt matrix, add the variable combination of text to the start of the prompt, rather than the end"),
+ "sd_upscale_overlap": (64, "Overlap for tiles for SD upscale. The smaller it is, the less smooth transition from one tile to another", 0, 256, 16),
}
def __init__(self):
@@ -289,6 +290,73 @@ def image_grid(imgs, batch_size, force_n_rows=None):
return grid
+Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
+
+
+def split_grid(image, tile_w=512, tile_h=512, overlap=64):
+ w = image.width
+ h = image.height
+
+ now = tile_w - overlap # non-overlap width
+ noh = tile_h - overlap
+
+ cols = math.ceil((w - overlap) / now)
+ rows = math.ceil((h - overlap) / noh)
+
+ grid = Grid([], tile_w, tile_h, w, h, overlap)
+ for row in range(rows):
+ row_images = []
+
+ y = row * noh
+
+ if y + tile_h >= h:
+ y = h - tile_h
+
+ for col in range(cols):
+ x = col * now
+
+ if x+tile_w >= w:
+ x = w - tile_w
+
+ tile = image.crop((x, y, x + tile_w, y + tile_h))
+
+ row_images.append([x, tile_w, tile])
+
+ grid.tiles.append([y, tile_h, row_images])
+
+ return grid
+
+
+def combine_grid(grid):
+ def make_mask_image(r):
+ r = r * 255 / grid.overlap
+ r = r.astype(np.uint8)
+ return Image.fromarray(r, 'L')
+
+ mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
+ mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
+
+ combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
+ for y, h, row in grid.tiles:
+ combined_row = Image.new("RGB", (grid.image_w, h))
+ for x, w, tile in row:
+ if x == 0:
+ combined_row.paste(tile, (0, 0))
+ continue
+
+ combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
+ combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
+
+ if y == 0:
+ combined_image.paste(combined_row, (0, 0))
+ continue
+
+ combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
+ combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))
+
+ return combined_image
+
+
def draw_prompt_matrix(im, width, height, all_prompts):
def wrap(text, d, font, line_length):
lines = ['']
@@ -491,6 +559,7 @@ class StableDiffuionModelHijack:
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
+
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, embeddings):
super().__init__()
@@ -740,8 +809,6 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index,
save_image(grid, outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
grid_count += 1
-
-
torch_gc()
return output_images, seed, infotext()
@@ -847,7 +914,7 @@ txt2img_interface = gr.Interface(
)
-def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
+def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, sd_upscale: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
outpath = opts.outdir or "outputs/img2img-samples"
sampler = samplers_for_img2img[sampler_index].constructor(model)
@@ -894,7 +961,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
func_sample=sample,
prompt=prompt,
seed=seed,
- sampler_index=0,
+ sampler_index=sampler_index,
batch_size=1,
n_iter=1,
steps=ddim_steps,
@@ -923,6 +990,59 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
output_images = history
seed = initial_seed
+ elif sd_upscale:
+ initial_seed = None
+ initial_info = None
+
+ img = upscale_with_realesrgan(init_img, RealESRGAN_upscaling=2, RealESRGAN_model_index=0)
+
+ torch_gc()
+
+ grid = split_grid(img, tile_w=width, tile_h=height, overlap=opts.sd_upscale_overlap)
+
+
+ print(f"SD upscaling will process a total of {len(grid.tiles[0][2])}x{len(grid.tiles)} images.")
+
+ for y, h, row in grid.tiles:
+ for tiledata in row:
+ init_img = tiledata[2]
+
+ output_images, seed, info = process_images(
+ outpath=outpath,
+ func_init=init,
+ func_sample=sample,
+ prompt=prompt,
+ seed=seed,
+ sampler_index=sampler_index,
+ batch_size=1, # since process_images can't work with multiple different images we have to do this for now
+ n_iter=1,
+ steps=ddim_steps,
+ cfg_scale=cfg_scale,
+ width=width,
+ height=height,
+ prompt_matrix=prompt_matrix,
+ use_GFPGAN=use_GFPGAN,
+ do_not_save_grid=True,
+ extra_generation_params={"Denoising Strength": denoising_strength},
+ )
+
+ if initial_seed is None:
+ initial_seed = seed
+ initial_info = info
+
+ seed += 1
+
+ tiledata[2] = output_images[0]
+
+ combined_image = combine_grid(grid)
+
+ grid_count = len(os.listdir(outpath)) - 1
+ save_image(combined_image, outpath, f"grid-{grid_count:04}", initial_seed, prompt, opts.grid_format, info=initial_info, short_filename=not opts.grid_extended_filename)
+
+ output_images = [combined_image]
+ seed = initial_seed
+ info = initial_info
+
else:
output_images, seed, info = process_images(
outpath=outpath,
@@ -930,7 +1050,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
func_sample=sample,
prompt=prompt,
seed=seed,
- sampler_index=0,
+ sampler_index=sampler_index,
batch_size=batch_size,
n_iter=n_iter,
steps=ddim_steps,
@@ -960,6 +1080,7 @@ img2img_interface = gr.Interface(
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
gr.Checkbox(label='Loopback (use images from previous batch when creating next batch)', value=False),
+ gr.Checkbox(label='Stable Diffusion upscale', value=False),
gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1),
gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
@@ -978,7 +1099,26 @@ img2img_interface = gr.Interface(
)
+def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index):
+ info = realesrgan_models[RealESRGAN_model_index]
+
+ model = info.model()
+ upsampler = RealESRGANer(
+ scale=info.netscale,
+ model_path=info.location,
+ model=model,
+ half=True
+ )
+
+ upsampled = upsampler.enhance(np.array(image), outscale=RealESRGAN_upscaling)[0]
+
+ image = Image.fromarray(upsampled)
+ return image
+
+
def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_index):
+ torch_gc()
+
image = image.convert("RGB")
outpath = opts.outdir or "outputs/extras-samples"
@@ -993,19 +1133,7 @@ def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_in
image = res
if have_realesrgan and RealESRGAN_upscaling != 1.0:
- info = realesrgan_models[RealESRGAN_model_index]
-
- model = info.model()
- upsampler = RealESRGANer(
- scale=info.netscale,
- model_path=info.location,
- model=model,
- half=True
- )
-
- upsampled = upsampler.enhance(np.array(image), outscale=RealESRGAN_upscaling)[0]
-
- image = Image.fromarray(upsampled)
+ image = upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index)
os.makedirs(outpath, exist_ok=True)
base_count = len(os.listdir(outpath))
@@ -1058,7 +1186,9 @@ def create_setting_component(key):
if t == str:
item = gr.Textbox(label=label, value=fun, lines=1)
elif t == int:
- if len(labelinfo) == 4:
+ if len(labelinfo) == 5:
+ item = gr.Slider(minimum=labelinfo[2], maximum=labelinfo[3], step=labelinfo[4], label=label, value=fun)
+ elif len(labelinfo) == 4:
item = gr.Slider(minimum=labelinfo[2], maximum=labelinfo[3], step=1, label=label, value=fun)
else:
item = gr.Number(label=label, value=fun)