From 247f58a5e740a7bd3980815961425b778d77ec28 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 17 Sep 2022 12:05:04 +0300 Subject: add support for switching model checkpoints at runtime --- modules/shared.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 4f877036..3c3aa9b6 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,14 +13,15 @@ from modules.devices import get_optimal_device import modules.styles import modules.interrogate import modules.memmon +import modules.sd_models sd_model_file = os.path.join(script_path, 'model.ckpt') -if not os.path.exists(sd_model_file): - sd_model_file = "models/ldm/stable-diffusion-v1/model.ckpt" +default_sd_model_file = sd_model_file parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",) -parser.add_argument("--ckpt", type=str, default=os.path.join(sd_path, sd_model_file), help="path to checkpoint of model",) +parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",) +parser.add_argument("--ckpt-dir", type=str, default=os.path.join(script_path, 'models'), help="path to directory with stable diffusion checkpoints",) parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default='GFPGANv1.3.pth') parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") @@ -88,13 +89,17 @@ interrogator = modules.interrogate.InterrogateModels("interrogate") face_restorers = [] +modules.sd_models.list_models() + + class Options: class OptionInfo: - def __init__(self, default=None, label="", component=None, component_args=None): + def __init__(self, default=None, label="", component=None, component_args=None, onchange=None): self.default = default self.label = label self.component = component self.component_args = component_args + self.onchange = onchange data = None hide_dirs = {"visible": False} if cmd_opts.hide_ui_dir_config else None @@ -150,6 +155,7 @@ class Options: "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), "interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"), + "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Radio, lambda: {"choices": [x.title for x in modules.sd_models.checkpoints_list.values()]}), } def __init__(self): @@ -180,6 +186,10 @@ class Options: with open(filename, "r", encoding="utf8") as file: self.data = json.load(file) + def onchange(self, key, func): + item = self.data_labels.get(key) + item.onchange = func + opts = Options() if os.path.exists(config_filename): @@ -188,7 +198,6 @@ if os.path.exists(config_filename): sd_upscalers = [] sd_model = None -sd_model_hash = '' progress_print_out = sys.stdout -- cgit v1.2.1