aboutsummaryrefslogtreecommitdiff
path: root/modules/processing_scripts/refiner.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/processing_scripts/refiner.py')
-rw-r--r--modules/processing_scripts/refiner.py55
1 files changed, 55 insertions, 0 deletions
diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py
new file mode 100644
index 00000000..5a82991a
--- /dev/null
+++ b/modules/processing_scripts/refiner.py
@@ -0,0 +1,55 @@
+import gradio as gr
+
+from modules import scripts, sd_models
+from modules.ui_common import create_refresh_button
+from modules.ui_components import InputAccordion
+
+
+class ScriptRefiner(scripts.Script):
+ section = "accordions"
+ create_group = False
+
+ def __init__(self):
+ pass
+
+ def title(self):
+ return "Refiner"
+
+ def show(self, is_img2img):
+ return scripts.AlwaysVisible
+
+ def ui(self, is_img2img):
+ with InputAccordion(False, label="Refiner", elem_id=self.elem_id("enable")) as enable_refiner:
+ with gr.Row():
+ refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=sd_models.checkpoint_tiles(), value='', tooltip="switch to another model in the middle of generation")
+ create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, self.elem_id("checkpoint_refresh"))
+
+ refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the swtch to refiner model should happen; 1=never, 0.5=switch in the middle of generation")
+
+ def lookup_checkpoint(title):
+ info = sd_models.get_closet_checkpoint_match(title)
+ return None if info is None else info.title
+
+ self.infotext_fields = [
+ (enable_refiner, lambda d: 'Refiner' in d),
+ (refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner'))),
+ (refiner_switch_at, 'Refiner switch at'),
+ ]
+
+ return enable_refiner, refiner_checkpoint, refiner_switch_at
+
+ def before_process(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at):
+ # the actual implementation is in sd_samplers_common.py, apply_refiner
+
+ p.refiner_checkpoint_info = None
+ p.refiner_switch_at = None
+
+ if not enable_refiner or refiner_checkpoint in (None, "", "None"):
+ return
+
+ refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(refiner_checkpoint)
+ if refiner_checkpoint_info is None:
+ raise Exception(f'Could not find checkpoint with name {refiner_checkpoint}')
+
+ p.refiner_checkpoint_info = refiner_checkpoint_info
+ p.refiner_switch_at = refiner_switch_at