From 501993ebf210bf3b55173ec1910f0c84c7e75424 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 19:31:06 +0300 Subject: added a button to run hires fix on selected image in the gallery --- modules/processing.py | 46 ++++++++++++++---- modules/txt2img.py | 19 +++++++- modules/ui.py | 108 +++++++++++++++++++++++-------------------- modules/ui_common.py | 57 ++++++++++++++--------- modules/ui_postprocessing.py | 8 ++-- 5 files changed, 152 insertions(+), 86 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 213a2879..045c7d79 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -179,6 +179,7 @@ class StableDiffusionProcessing: token_merging_ratio = 0 token_merging_ratio_hr = 0 disable_extra_networks: bool = False + firstpass_image: Image = None scripts_value: scripts.ScriptRunner = field(default=None, init=False) script_args_value: list = field(default=None, init=False) @@ -1238,18 +1239,45 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) - x = self.rng.next() - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) - del x + if self.firstpass_image is not None and self.enable_hr: + # here we don't need to generate image, we just take self.firstpass_image and prepare it for hires fix - if not self.enable_hr: - return samples - devices.torch_gc() + if self.latent_scale_mode is None: + image = np.array(self.firstpass_image).astype(np.float32) / 255.0 * 2.0 - 1.0 + image = np.moveaxis(image, 2, 0) + + samples = None + decoded_samples = torch.asarray(np.expand_dims(image, 0)) + + else: + image = np.array(self.firstpass_image).astype(np.float32) / 255.0 + image = np.moveaxis(image, 2, 0) + image = torch.from_numpy(np.expand_dims(image, axis=0)) + image = image.to(shared.device, dtype=devices.dtype_vae) + + if opts.sd_vae_encode_method != 'Full': + self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method + + samples = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model) + decoded_samples = None + devices.torch_gc() - if self.latent_scale_mode is None: - decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32) else: - decoded_samples = None + # here we generate an image normally + + x = self.rng.next() + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) + del x + + if not self.enable_hr: + return samples + + devices.torch_gc() + + if self.latent_scale_mode is None: + decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32) + else: + decoded_samples = None with sd_models.SkipWritingToConfig(): sd_models.reload_model_weights(info=self.hr_checkpoint_info) diff --git a/modules/txt2img.py b/modules/txt2img.py index 49660e89..4a6fe72a 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -1,7 +1,7 @@ from contextlib import closing import modules.scripts -from modules import processing +from modules import processing, infotext_utils from modules.infotext_utils import create_override_settings_dict from modules.shared import opts import modules.shared as shared @@ -9,9 +9,23 @@ from modules.ui import plaintext_to_html import gradio as gr -def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args): +def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, *args): + assert len(gallery) > 0, 'No image to upscale' + + image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0] + image = infotext_utils.image_from_url_text(image_info) + + return txt2img(id_task, request, *args, firstpass_image=image) + + +def txt2img(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, firstpass_image=None): override_settings = create_override_settings_dict(override_settings_texts) + if firstpass_image is not None: + enable_hr = True + batch_size = 1 + n_iter = 1 + p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -38,6 +52,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step hr_prompt=hr_prompt, hr_negative_prompt=hr_negative_prompt, override_settings=override_settings, + firstpass_image=firstpass_image, ) p.scripts = modules.scripts.scripts_txt2img diff --git a/modules/ui.py b/modules/ui.py index 52b15646..3d548430 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -375,50 +375,60 @@ def create_ui(): show_progress=False, ) - txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow) + output_panel = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow) + + txt2img_inputs = [ + dummy_component, + toprow.prompt, + toprow.negative_prompt, + toprow.ui_styles.dropdown, + steps, + sampler_name, + batch_count, + batch_size, + cfg_scale, + height, + width, + enable_hr, + denoising_strength, + hr_scale, + hr_upscaler, + hr_second_pass_steps, + hr_resize_x, + hr_resize_y, + hr_checkpoint_name, + hr_sampler_name, + hr_prompt, + hr_negative_prompt, + override_settings, + ] + custom_inputs + + txt2img_outputs = [ + output_panel.gallery, + output_panel.infotext, + output_panel.html_info, + output_panel.html_log, + ] txt2img_args = dict( fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), _js="submit", - inputs=[ - dummy_component, - toprow.prompt, - toprow.negative_prompt, - toprow.ui_styles.dropdown, - steps, - sampler_name, - batch_count, - batch_size, - cfg_scale, - height, - width, - enable_hr, - denoising_strength, - hr_scale, - hr_upscaler, - hr_second_pass_steps, - hr_resize_x, - hr_resize_y, - hr_checkpoint_name, - hr_sampler_name, - hr_prompt, - hr_negative_prompt, - override_settings, - - ] + custom_inputs, - - outputs=[ - txt2img_gallery, - generation_info, - html_info, - html_log, - ], + inputs=txt2img_inputs, + outputs=txt2img_outputs, show_progress=False, ) toprow.prompt.submit(**txt2img_args) toprow.submit.click(**txt2img_args) + output_panel.button_upscale.click( + fn=wrap_gradio_gpu_call(modules.txt2img.txt2img_upscale, extra_outputs=[None, '', '']), + _js="submit_txt2img_upscale", + inputs=txt2img_inputs[0:1] + [output_panel.gallery, dummy_component] + txt2img_inputs[1:], + outputs=txt2img_outputs, + show_progress=False, + ) + res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False) toprow.restore_progress_button.click( @@ -426,10 +436,10 @@ def create_ui(): _js="restoreProgressTxt2img", inputs=[dummy_component], outputs=[ - txt2img_gallery, - generation_info, - html_info, - html_log, + output_panel.gallery, + output_panel.infotext, + output_panel.html_info, + output_panel.html_log, ], show_progress=False, ) @@ -479,7 +489,7 @@ def create_ui(): toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter]) extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img') - ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery) + ui_extra_networks.setup_ui(extra_networks_ui, output_panel.gallery) extra_tabs.__exit__() @@ -710,7 +720,7 @@ def create_ui(): outputs=[inpaint_controls, mask_alpha], ) - img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples, toprow) + output_panel = create_output_panel("img2img", opts.outdir_img2img_samples, toprow) img2img_args = dict( fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), @@ -755,10 +765,10 @@ def create_ui(): img2img_batch_png_info_dir, ] + custom_inputs, outputs=[ - img2img_gallery, - generation_info, - html_info, - html_log, + output_panel.gallery, + output_panel.infotext, + output_panel.html_info, + output_panel.html_log, ], show_progress=False, ) @@ -796,10 +806,10 @@ def create_ui(): _js="restoreProgressImg2img", inputs=[dummy_component], outputs=[ - img2img_gallery, - generation_info, - html_info, - html_log, + output_panel.gallery, + output_panel.infotext, + output_panel.html_info, + output_panel.html_log, ], show_progress=False, ) @@ -839,7 +849,7 @@ def create_ui(): )) extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img') - ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery) + ui_extra_networks.setup_ui(extra_networks_ui_img2img, output_panel.gallery) extra_tabs.__exit__() diff --git a/modules/ui_common.py b/modules/ui_common.py index f48ad426..ff84197c 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -1,3 +1,4 @@ +import dataclasses import json import html import os @@ -104,7 +105,17 @@ def save_files(js_data, images, do_make_zip, index): return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") +@dataclasses.dataclass +class OutputPanel: + gallery = None + infotext = None + html_info = None + html_log = None + button_upscale = None + + def create_output_panel(tabname, outdir, toprow=None): + res = OutputPanel() def open_folder(f): if not os.path.exists(f): @@ -136,9 +147,8 @@ Requested path was: {f} with gr.Column(variant='panel', elem_id=f"{tabname}_results_panel"): with gr.Group(elem_id=f"{tabname}_gallery_container"): - result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None) + res.gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None) - generation_info = None with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"): open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.") @@ -152,6 +162,9 @@ Requested path was: {f} 'extras': ToolButton('📐', elem_id=f'{tabname}_send_to_extras', tooltip="Send image and generation parameters to extras tab.") } + if tabname == 'txt2img': + res.button_upscale = ToolButton('✨', elem_id=f'{tabname}_upscale', tooltip="Create an upscaled version of the current image using hires fix settings.") + open_folder_button.click( fn=lambda: open_folder(shared.opts.outdir_samples or outdir), inputs=[], @@ -162,17 +175,17 @@ Requested path was: {f} download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') with gr.Group(): - html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") - html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log") + res.html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") + res.html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log") - generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') + res.infotext = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') if tabname == 'txt2img' or tabname == 'img2img': generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") generation_info_button.click( fn=update_generation_info, _js="function(x, y, z){ return [x, y, selected_gallery_index()] }", - inputs=[generation_info, html_info, html_info], - outputs=[html_info, html_info], + inputs=[res.infotext, res.html_info, res.html_info], + outputs=[res.html_info, res.html_info], show_progress=False, ) @@ -180,14 +193,14 @@ Requested path was: {f} fn=call_queue.wrap_gradio_call(save_files), _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", inputs=[ - generation_info, - result_gallery, - html_info, - html_info, + res.infotext, + res.gallery, + res.html_info, + res.html_info, ], outputs=[ download_files, - html_log, + res.html_log, ], show_progress=False, ) @@ -196,21 +209,21 @@ Requested path was: {f} fn=call_queue.wrap_gradio_call(save_files), _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", inputs=[ - generation_info, - result_gallery, - html_info, - html_info, + res.infotext, + res.gallery, + res.html_info, + res.html_info, ], outputs=[ download_files, - html_log, + res.html_log, ] ) else: - html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') - html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") - html_log = gr.HTML(elem_id=f'html_log_{tabname}') + res.infotext = gr.HTML(elem_id=f'html_info_x_{tabname}') + res.html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") + res.html_log = gr.HTML(elem_id=f'html_log_{tabname}') paste_field_names = [] if tabname == "txt2img": @@ -220,11 +233,11 @@ Requested path was: {f} for paste_tabname, paste_button in buttons.items(): parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( - paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery, + paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=res.gallery, paste_field_names=paste_field_names )) - return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log + return res def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index 1edb68c5..8f09e658 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -28,7 +28,7 @@ def create_ui(): toprow.create_inline_toprow_image() submit = toprow.submit - result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples) + output_panel = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples) tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index]) tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index]) @@ -48,9 +48,9 @@ def create_ui(): *script_inputs ], outputs=[ - result_images, - html_info_x, - html_log, + output_panel.gallery, + output_panel.infotext, + output_panel.html_log, ], show_progress=False, ) -- cgit v1.2.1