aboutsummaryrefslogtreecommitdiff
path: root/modules/ui.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/ui.py')
-rw-r--r--modules/ui.py40
1 files changed, 27 insertions, 13 deletions
diff --git a/modules/ui.py b/modules/ui.py
index a18b9007..6c765262 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -20,7 +20,7 @@ from PIL import Image, PngImagePlugin
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru
-from modules.ui_components import FormRow, FormGroup, ToolButton
+from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path
from modules.shared import opts, cmd_opts, restricted_opts
@@ -255,12 +255,20 @@ def add_style(name: str, prompt: str, negative_prompt: str):
return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)]
-def calc_resolution_hires(x, y, scale):
- #final res can only be a multiple of 8
- scaled_x = int(x * scale // 8) * 8
- scaled_y = int(y * scale // 8) * 8
-
- return str(scaled_x)+"x"+str(scaled_y)
+
+def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
+ from modules import processing, devices
+
+ if not enable:
+ return ""
+
+ p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y)
+
+ with devices.autocast():
+ p.init([""], [0], [0])
+
+ return f"resize to: <span class='resolution'>{p.hr_upscale_to_x}x{p.hr_upscale_to_y}</span>"
+
def apply_styles(prompt, prompt_neg, style1_name, style2_name):
prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name])
@@ -712,6 +720,7 @@ def create_ui():
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
+ hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False)
elif category == "hires_fix":
with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options:
@@ -724,9 +733,6 @@ def create_ui():
hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
-
- with FormRow(elem_id="txt2img_hires_fix_row3"):
- hr_final_resolution = gr.Textbox(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False)
elif category == "batch":
if not opts.dimensions_and_batch_together:
@@ -738,9 +744,16 @@ def create_ui():
with FormGroup(elem_id="txt2img_script_container"):
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
- hr_scale.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False)
- width.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False)
- height.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False)
+ hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
+ hr_resolution_preview_args = dict(
+ fn=calc_resolution_hires,
+ inputs=hr_resolution_preview_inputs,
+ outputs=[hr_final_resolution],
+ show_progress=False
+ )
+
+ for input in hr_resolution_preview_inputs:
+ input.change(**hr_resolution_preview_args)
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
@@ -803,6 +816,7 @@ def create_ui():
fn=lambda x: gr_show(x),
inputs=[enable_hr],
outputs=[hr_options],
+ show_progress = False,
)
txt2img_paste_fields = [