From 304222ef94d1c3c60fab466a96c448868f391bce Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 17 Sep 2022 13:49:36 +0300 Subject: X/Y plot support for switching checkpoints. --- scripts/xy_grid.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index eccfda87..680dd702 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -10,7 +10,9 @@ import gradio as gr from modules import images from modules.processing import process_images, Processed from modules.shared import opts, cmd_opts, state +import modules.shared as shared import modules.sd_samplers +import modules.sd_models import re @@ -41,6 +43,15 @@ def apply_sampler(p, x, xs): p.sampler_index = sampler_index +def apply_checkpoint(p, x, xs): + applicable = [info for info in modules.sd_models.checkpoints_list.values() if x in info.title] + assert len(applicable) > 0, f'Checkpoint {x} for found' + + info = applicable[0] + + modules.sd_models.reload_model_weights(shared.sd_model, info) + + def format_value_add_label(p, opt, x): if type(x) == float: x = round(x, 8) @@ -74,6 +85,7 @@ axis_options = [ AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label), AxisOption("Prompt S/R", str, apply_prompt, format_value), AxisOption("Sampler", str, apply_sampler, format_value), + AxisOption("Checkpoint name", str, apply_checkpoint, format_value), AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones ] @@ -215,4 +227,7 @@ class Script(scripts.Script): if opts.grid_save: images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p) + # restore checkpoint in case it was changed by axes + modules.sd_models.reload_model_weights(shared.sd_model) + return processed -- cgit v1.2.1 From 140f89315380dbcc541f6e18e3d355a06ea3e2f0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 17 Sep 2022 14:55:40 +0300 Subject: process all values for x/y plot right away to error out if any are bad before any processing begins --- scripts/xy_grid.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 680dd702..6a157722 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -90,11 +90,11 @@ axis_options = [ ] -def draw_xy_grid(p, xs, ys, x_label, y_label, cell, draw_legend): +def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend): res = [] - ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys] - hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs] + ver_texts = [[images.GridAnnotation(y)] for y in y_labels] + hor_texts = [[images.GridAnnotation(x)] for x in x_labels] first_pocessed = None @@ -218,8 +218,8 @@ class Script(scripts.Script): p, xs=xs, ys=ys, - x_label=lambda x: x_opt.format_value(p, x_opt, x), - y_label=lambda y: y_opt.format_value(p, y_opt, y), + x_labels=[x_opt.format_value(p, x_opt, x) for x in xs], + y_labels=[y_opt.format_value(p, y_opt, y) for y in ys], cell=cell, draw_legend=draw_legend ) -- cgit v1.2.1