aboutsummaryrefslogtreecommitdiff
path: root/scripts/xy_grid.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2022-09-17 14:57:10 +0300
committerGitHub <noreply@github.com>2022-09-17 14:57:10 +0300
commit0d7fdb179104e48983d07e0175021f0e4bdc2d55 (patch)
treea183247f90049207a5af64b2882c0f92136ee6fe /scripts/xy_grid.py
parentac61e4663c21ea0f51a4319162d3877e00554a2a (diff)
parent1ef79f926e6314b3ef9308b12ff7ad482afd790a (diff)
Merge branch 'master' into image_info_tab
Diffstat (limited to 'scripts/xy_grid.py')
-rw-r--r--scripts/xy_grid.py25
1 files changed, 20 insertions, 5 deletions
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
index eccfda87..6a157722 100644
--- a/scripts/xy_grid.py
+++ b/scripts/xy_grid.py
@@ -10,7 +10,9 @@ import gradio as gr
from modules import images
from modules.processing import process_images, Processed
from modules.shared import opts, cmd_opts, state
+import modules.shared as shared
import modules.sd_samplers
+import modules.sd_models
import re
@@ -41,6 +43,15 @@ def apply_sampler(p, x, xs):
p.sampler_index = sampler_index
+def apply_checkpoint(p, x, xs):
+ applicable = [info for info in modules.sd_models.checkpoints_list.values() if x in info.title]
+ assert len(applicable) > 0, f'Checkpoint {x} for found'
+
+ info = applicable[0]
+
+ modules.sd_models.reload_model_weights(shared.sd_model, info)
+
+
def format_value_add_label(p, opt, x):
if type(x) == float:
x = round(x, 8)
@@ -74,15 +85,16 @@ axis_options = [
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label),
AxisOption("Prompt S/R", str, apply_prompt, format_value),
AxisOption("Sampler", str, apply_sampler, format_value),
+ AxisOption("Checkpoint name", str, apply_checkpoint, format_value),
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
]
-def draw_xy_grid(p, xs, ys, x_label, y_label, cell, draw_legend):
+def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend):
res = []
- ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys]
- hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs]
+ ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
+ hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
first_pocessed = None
@@ -206,8 +218,8 @@ class Script(scripts.Script):
p,
xs=xs,
ys=ys,
- x_label=lambda x: x_opt.format_value(p, x_opt, x),
- y_label=lambda y: y_opt.format_value(p, y_opt, y),
+ x_labels=[x_opt.format_value(p, x_opt, x) for x in xs],
+ y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],
cell=cell,
draw_legend=draw_legend
)
@@ -215,4 +227,7 @@ class Script(scripts.Script):
if opts.grid_save:
images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p)
+ # restore checkpoint in case it was changed by axes
+ modules.sd_models.reload_model_weights(shared.sd_model)
+
return processed