aboutsummaryrefslogtreecommitdiff
path: root/webui.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-08-26 11:16:57 +0300
committerAUTOMATIC <16777216c@gmail.com>2022-08-26 11:16:57 +0300
commit155dd2fc0c5dfde0fe9736d3170c023496fb4c39 (patch)
treeb96d8adbe071359bd21f948efeaf7c371ed592ea /webui.py
parent055dd10aae6341e91cfe27ec297099a3273c19bf (diff)
Renamed GFPGAN to extras
Added Real-ESRGAN to extras tab
Diffstat (limited to 'webui.py')
-rw-r--r--webui.py97
1 files changed, 76 insertions, 21 deletions
diff --git a/webui.py b/webui.py
index a6f3826d..03675032 100644
--- a/webui.py
+++ b/webui.py
@@ -76,6 +76,38 @@ samplers = [
SamplerData('PLMS', lambda model: PLMSSampler(model)),
]
+RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"])
+
+try:
+ from basicsr.archs.rrdbnet_arch import RRDBNet
+ from realesrgan import RealESRGANer
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
+
+ realesrgan_models = [
+ RealesrganModelInfo(
+ name="Real-ESRGAN 2x plus",
+ location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
+ netscale=2, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
+ ),
+ RealesrganModelInfo(
+ name="Real-ESRGAN 4x plus",
+ location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
+ netscale=4, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
+ ),
+ RealesrganModelInfo(
+ name="Real-ESRGAN 4x plus anime 6B",
+ location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
+ netscale=4, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
+ ),
+ ]
+ have_realesrgan = True
+except:
+ print("Error loading Real-ESRGAN:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ realesrgan_models = [RealesrganModelInfo('None', '', 0, None)]
+ have_realesrgan = False
+
class Options:
data = None
@@ -196,10 +228,6 @@ def torch_gc():
torch.cuda.ipc_collect()
-def sanitize_filename_part(text):
- return text.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]
-
-
def save_image(image, path, basename, seed, prompt, extension, info=None, short_filename=False):
prompt = sanitize_filename_part(prompt)
@@ -208,7 +236,7 @@ def save_image(image, path, basename, seed, prompt, extension, info=None, short_
else:
filename = f"{basename}-{seed}-{prompt[:128]}.{extension}"
- if extension == 'png' and opts.enable_pnginfo:
+ if extension == 'png' and opts.enable_pnginfo and info is not None:
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("parameters", info)
else:
@@ -217,6 +245,10 @@ def save_image(image, path, basename, seed, prompt, extension, info=None, short_
image.save(os.path.join(path, filename), quality=opts.jpeg_quality, pnginfo=pnginfo)
+def sanitize_filename_part(text):
+ return text.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]
+
+
def plaintext_to_html(text):
text = "".join([f"<p>{html.escape(x)}</p>\n" for x in text.split('\n')])
return text
@@ -835,7 +867,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_mat
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN,
do_not_save_grid=True,
- extra_generation_params = {"Denoising Strength": denoising_strength},
+ extra_generation_params={"Denoising Strength": denoising_strength},
)
if initial_seed is None:
@@ -870,7 +902,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_mat
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN,
- extra_generation_params = {"Denoising Strength": denoising_strength},
+ extra_generation_params={"Denoising Strength": denoising_strength},
)
del sampler
@@ -908,30 +940,56 @@ img2img_interface = gr.Interface(
)
-def run_GFPGAN(image, strength):
+def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_index):
image = image.convert("RGB")
- cropped_faces, restored_faces, restored_img = GFPGAN.enhance(np.array(image, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True)
- res = Image.fromarray(restored_img)
+ outpath = opts.outdir or "outputs/extras-samples"
+
+ if GFPGAN is not None and GFPGAN_strength > 0:
+ cropped_faces, restored_faces, restored_img = GFPGAN.enhance(np.array(image, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True)
+ res = Image.fromarray(restored_img)
+
+ if GFPGAN_strength < 1.0:
+ res = Image.blend(image, res, GFPGAN_strength)
+
+ image = res
+
+ if have_realesrgan and RealESRGAN_upscaling != 1.0:
+ info = realesrgan_models[RealESRGAN_model_index]
+
+ model = info.model()
+ upsampler = RealESRGANer(
+ scale=info.netscale,
+ model_path=info.location,
+ model=model,
+ half=True
+ )
+
+ upsampled = upsampler.enhance(np.array(image), outscale=RealESRGAN_upscaling)[0]
+
+ image = Image.fromarray(upsampled)
- if strength < 1.0:
- res = Image.blend(image, res, strength)
+ os.makedirs(outpath, exist_ok=True)
+ base_count = len(os.listdir(outpath))
+
+ save_image(image, outpath, f"{base_count:05}", None, '', opts.samples_format, short_filename=True)
- return res, 0, ''
+ return image, 0, ''
-gfpgan_interface = gr.Interface(
- run_GFPGAN,
+extras_interface = gr.Interface(
+ wrap_gradio_call(run_extras),
inputs=[
gr.Image(label="Source", source="upload", interactive=True, type="pil"),
- gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Effect strength", value=100),
+ gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN strength", value=1, interactive=GFPGAN is not None),
+ gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Real-ESRGAN upscaling", value=2, interactive=have_realesrgan),
+ gr.Radio(label='Real-ESRGAN model', choices=[x.name for x in realesrgan_models], value=realesrgan_models[0].name, type="index", interactive=have_realesrgan),
],
outputs=[
gr.Image(label="Result"),
gr.Number(label='Seed', visible=False),
gr.HTML(),
],
- description="Fix faces on images",
allow_flagging="never",
)
@@ -989,7 +1047,7 @@ settings_interface = gr.Interface(
interfaces = [
(txt2img_interface, "txt2img"),
(img2img_interface, "img2img"),
- (gfpgan_interface, "GFPGAN"),
+ (extras_interface, "Extras"),
(settings_interface, "Settings"),
]
@@ -1003,9 +1061,6 @@ text_inversion_embeddings = TextInversionEmbeddings()
if os.path.exists(cmd_opts.embeddings_dir):
text_inversion_embeddings.hijack(model)
-if GFPGAN is None:
- interfaces = [x for x in interfaces if x[0] != gfpgan_interface]
-
demo = gr.TabbedInterface(
interface_list=[x[0] for x in interfaces],
tab_names=[x[1] for x in interfaces],