aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py12
-rw-r--r--modules/api/models.py9
-rw-r--r--modules/generation_parameters_copypaste.py36
-rw-r--r--modules/processing.py12
-rw-r--r--modules/scripts.py24
-rw-r--r--modules/ui.py44
6 files changed, 88 insertions, 49 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 49c213ea..d0f488ca 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -5,7 +5,7 @@ import modules.shared as shared
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
-from modules.extras import run_extras
+from modules.extras import run_extras, run_pnginfo
def upscaler_to_index(name: str):
try:
@@ -32,6 +32,7 @@ class Api:
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
+ self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -125,8 +126,13 @@ class Api:
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
- def pnginfoapi(self):
- raise NotImplementedError
+ def pnginfoapi(self, req: PNGInfoRequest):
+ if(not req.image.strip()):
+ return PNGInfoResponse(info="")
+
+ result = run_pnginfo(decode_base64_to_image(req.image.strip()))
+
+ return PNGInfoResponse(info=result[1])
def launch(self, server_name, port):
self.app.include_router(self.router)
diff --git a/modules/api/models.py b/modules/api/models.py
index dd122321..58e8e58b 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -1,4 +1,5 @@
import inspect
+from click import prompt
from pydantic import BaseModel, Field, create_model
from typing import Any, Optional
from typing_extensions import Literal
@@ -148,4 +149,10 @@ class ExtrasBatchImagesRequest(ExtrasBaseRequest):
imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
class ExtrasBatchImagesResponse(ExtraBaseResponse):
- images: list[str] = Field(title="Images", description="The generated images in base64 format.") \ No newline at end of file
+ images: list[str] = Field(title="Images", description="The generated images in base64 format.")
+
+class PNGInfoRequest(BaseModel):
+ image: str = Field(title="Image", description="The base64 encoded PNG image")
+
+class PNGInfoResponse(BaseModel):
+ info: str = Field(title="Image info", description="A string with all the info the image had") \ No newline at end of file
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index d590e9ee..bbaad42e 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -6,7 +6,7 @@ import gradio as gr
from modules.shared import script_path
from modules import shared
import tempfile
-from PIL import Image, PngImagePlugin
+from PIL import Image
re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
@@ -61,6 +61,24 @@ def add_paste_fields(tabname, init_img, fields):
modules.ui.img2img_paste_fields = fields
+def integrate_settings_paste_fields(component_dict):
+ from modules import ui
+
+ settings_map = {
+ 'sd_hypernetwork': 'Hypernet',
+ 'CLIP_stop_at_last_layers': 'Clip skip',
+ 'sd_model_checkpoint': 'Model hash',
+ }
+ settings_paste_fields = [
+ (component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None)))
+ for k, v in settings_map.items()
+ ]
+
+ for tabname, info in paste_fields.items():
+ if info["fields"] is not None:
+ info["fields"] += settings_paste_fields
+
+
def create_buttons(tabs_list):
buttons = {}
for tab in tabs_list:
@@ -87,24 +105,22 @@ def run_bind():
)
else:
button.click(
- fn=lambda x:x,
+ fn=lambda x: x,
inputs=[send_image],
outputs=[paste_fields[tab]["init_img"]],
)
if send_generate_info and paste_fields[tab]["fields"] is not None:
- paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Size-1', 'Size-2']
- if shared.opts.send_seed:
- paste_field_names += ["Seed"]
if send_generate_info in paste_fields:
+ paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Size-1', 'Size-2'] + (["Seed"] if shared.opts.send_seed else [])
+
button.click(
- fn=lambda *x:x,
- inputs=[field for field,name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
- outputs=[field for field,name in paste_fields[tab]["fields"] if name in paste_field_names],
+ fn=lambda *x: x,
+ inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
+ outputs=[field for field, name in paste_fields[tab]["fields"] if name in paste_field_names],
)
-
else:
- connect_paste(button, [(field, name) for field, name in paste_fields[tab]["fields"] if name in paste_field_names], send_generate_info)
+ connect_paste(button, paste_fields[tab]["fields"], send_generate_info)
button.click(
fn=None,
diff --git a/modules/processing.py b/modules/processing.py
index 04fdda7c..dd2320fa 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -479,7 +479,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
model_hijack.embedding_db.load_textual_inversion_embeddings()
if p.scripts is not None:
- p.scripts.run_alwayson_scripts(p)
+ p.scripts.process(p)
infotexts = []
output_images = []
@@ -502,7 +502,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
- if (len(prompts) == 0):
+ if len(prompts) == 0:
break
with devices.autocast():
@@ -591,7 +591,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
devices.torch_gc()
- return Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
+
+ res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
+
+ if p.scripts is not None:
+ p.scripts.postprocess(p, res)
+
+ return res
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
diff --git a/modules/scripts.py b/modules/scripts.py
index a7f36012..96e44bfd 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -64,7 +64,16 @@ class Script:
def process(self, p, *args):
"""
This function is called before processing begins for AlwaysVisible scripts.
- scripts. You can modify the processing object (p) here, inject hooks, etc.
+ You can modify the processing object (p) here, inject hooks, etc.
+ args contains all values returned by components from ui()
+ """
+
+ pass
+
+ def postprocess(self, p, processed, *args):
+ """
+ This function is called after processing ends for AlwaysVisible scripts.
+ args contains all values returned by components from ui()
"""
pass
@@ -289,13 +298,22 @@ class ScriptRunner:
return processed
- def run_alwayson_scripts(self, p):
+ def process(self, p):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.process(p, *script_args)
except Exception:
- print(f"Error running alwayson script: {script.filename}", file=sys.stderr)
+ print(f"Error running process: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ def postprocess(self, p, processed):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.postprocess(p, processed, *script_args)
+ except Exception:
+ print(f"Error running postprocess: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def reload_sources(self, cache):
diff --git a/modules/ui.py b/modules/ui.py
index 1d7bf000..5055ca64 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -589,6 +589,7 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele
)
return refresh_button
+
def create_output_panel(tabname, outdir):
def open_folder(f):
if not os.path.exists(f):
@@ -716,6 +717,7 @@ def create_ui(wrap_gradio_gpu_call):
custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples)
+ parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
@@ -784,7 +786,7 @@ def create_ui(wrap_gradio_gpu_call):
]
)
- parameters_copypaste.add_paste_fields("txt2img", None, [
+ txt2img_paste_fields = [
(txt2img_prompt, "Prompt"),
(txt2img_negative_prompt, "Negative prompt"),
(steps, "Steps"),
@@ -805,7 +807,8 @@ def create_ui(wrap_gradio_gpu_call):
(firstphase_width, "First pass size-1"),
(firstphase_height, "First pass size-2"),
*modules.scripts.scripts_txt2img.infotext_fields
- ])
+ ]
+ parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
txt2img_preview_params = [
txt2img_prompt,
@@ -893,6 +896,7 @@ def create_ui(wrap_gradio_gpu_call):
custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples)
+ parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
@@ -1038,7 +1042,6 @@ def create_ui(wrap_gradio_gpu_call):
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields)
parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields)
-
with gr.Blocks(analytics_enabled=False) as extras_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
@@ -1050,12 +1053,8 @@ def create_ui(wrap_gradio_gpu_call):
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file")
with gr.TabItem('Batch from Directory'):
- extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs,
- placeholder="A directory on the same machine where the server is running."
- )
- extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs,
- placeholder="Leave blank to save images to the default path."
- )
+ extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.")
+ extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.")
show_extras_results = gr.Checkbox(label='Show result images', value=True)
with gr.Tabs(elem_id="extras_resize_mode"):
@@ -1087,7 +1086,6 @@ def create_ui(wrap_gradio_gpu_call):
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
-
result_images, html_info_x, html_info = create_output_panel("extras", opts.outdir_extras_samples)
submit.click(
@@ -1121,7 +1119,6 @@ def create_ui(wrap_gradio_gpu_call):
)
parameters_copypaste.add_paste_fields("extras", extras_image, None)
-
extras_image.change(
fn=modules.extras.clear_cache,
inputs=[], outputs=[]
@@ -1587,9 +1584,6 @@ def create_ui(wrap_gradio_gpu_call):
if column is not None:
column.__exit__()
-
-
-
interfaces = [
(txt2img_interface, "txt2img", "txt2img"),
(img2img_interface, "img2img", "img2img"),
@@ -1599,10 +1593,6 @@ def create_ui(wrap_gradio_gpu_call):
(train_interface, "Train", "ti"),
]
- interfaces += script_callbacks.ui_tabs_callback()
-
- interfaces += [(settings_interface, "Settings", "settings")]
-
css = ""
for cssfile in modules.scripts.list_files_with_name("style.css"):
@@ -1619,6 +1609,9 @@ def create_ui(wrap_gradio_gpu_call):
if not cmd_opts.no_progressbar_hiding:
css += css_hide_progressbar
+ interfaces += script_callbacks.ui_tabs_callback()
+ interfaces += [(settings_interface, "Settings", "settings")]
+
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings"):
for i, k, item in quicksettings_list:
@@ -1627,6 +1620,9 @@ def create_ui(wrap_gradio_gpu_call):
settings_interface.gradio_ref = demo
+ parameters_copypaste.integrate_settings_paste_fields(component_dict)
+ parameters_copypaste.run_bind()
+
with gr.Tabs(elem_id="tabs") as tabs:
for interface, label, ifid in interfaces:
with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
@@ -1681,16 +1677,6 @@ def create_ui(wrap_gradio_gpu_call):
]
)
-
- settings_map = {
- 'sd_hypernetwork': 'Hypernet',
- 'sd_hypernetwork_strength': 'Hypernetwork strength',
- 'CLIP_stop_at_last_layers': 'Clip skip',
- 'sd_model_checkpoint': 'Model hash',
- }
-
- parameters_copypaste.run_bind()
-
ui_config_file = cmd_opts.ui_config_file
ui_settings = {}
settings_count = len(ui_settings)
@@ -1709,7 +1695,7 @@ def create_ui(wrap_gradio_gpu_call):
def apply_field(obj, field, condition=None, init_field=None):
key = path + "/" + field
- if getattr(obj,'custom_script_source',None) is not None:
+ if getattr(obj, 'custom_script_source', None) is not None:
key = 'customscript/' + obj.custom_script_source + '/' + key
if getattr(obj, 'do_not_save_to_config', False):