From e179b6098ac1b1ce9645fef5bd9fd0bc9b918f30 Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Wed, 25 Jan 2023 08:48:40 -0800 Subject: allow symlinks in the textual inversion embeddings folder --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 4e90f690..6cf00e65 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -194,7 +194,7 @@ class EmbeddingDatabase: if not os.path.isdir(embdir.path): return - for root, dirs, fns in os.walk(embdir.path): + for root, dirs, fns in os.walk(embdir.path, followlinks=True): for fn in fns: try: fullfn = os.path.join(root, fn) -- cgit v1.2.1 From d1d6ce29831d1b067801c3206f314258de88f683 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 25 Jan 2023 23:25:25 +0300 Subject: add edit_image_conditioning from my earlier edits in case there's an attempt to inegrate pix2pix properly this allows to use pix2pix model in img2img though it won't work well this way --- modules/processing.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 9e5a2f38..cb41288a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -185,7 +185,12 @@ class StableDiffusionProcessing: conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1. return conditioning - def inpainting_image_conditioning(self, source_image, latent_image, image_mask = None): + def edit_image_conditioning(self, source_image): + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image)) + + return conditioning_image + + def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None): self.is_using_inpainting_conditioning = True # Handle the different mask inputs @@ -228,6 +233,9 @@ class StableDiffusionProcessing: if isinstance(self.sd_model, LatentDepth2ImageDiffusion): return self.depth2img_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image) + if self.sd_model.cond_stage_key == "edit": + return self.edit_image_conditioning(source_image) + if self.sampler.conditioning_key in {'hybrid', 'concat'}: return self.inpainting_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image, latent_image, image_mask=image_mask) -- cgit v1.2.1 From 6cff4401824299a983c8e13424018efc347b4a2b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 25 Jan 2023 23:25:40 +0300 Subject: fix prompt editing break after first batch in img2img --- modules/sd_samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 6261d1f7..a7910b56 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -454,7 +454,7 @@ class KDiffusionSampler: def initialize(self, p): self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None - self.model_wrap.step = 0 + self.model_wrap_cfg.step = 0 self.eta = p.eta or opts.eta_ancestral k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) -- cgit v1.2.1 From 10421f93c3f7f7ce88cb40391b46d4e6664eff74 Mon Sep 17 00:00:00 2001 From: brkirch Date: Thu, 26 Jan 2023 00:34:38 -0500 Subject: Fix full previews, --no-half-vae --- modules/processing.py | 8 ++++---- modules/sd_hijack_utils.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index cb41288a..92894d67 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -172,7 +172,7 @@ class StableDiffusionProcessing: midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device) midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size) - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_unet) if devices.unet_needs_upcast else source_image)) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_vae) if devices.unet_needs_upcast else source_image)) conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image conditioning = torch.nn.functional.interpolate( self.sd_model.depth_model(midas_in), @@ -217,7 +217,7 @@ class StableDiffusionProcessing: ) # Encode the new masked image using first stage of network. - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_unet) if devices.unet_needs_upcast else conditioning_image)) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_vae) if devices.unet_needs_upcast else conditioning_image)) # Create the concatenated conditioning tensor to be fed to `c_concat` conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:]) @@ -417,7 +417,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see def decode_first_stage(model, x): with devices.autocast(disable=x.dtype == devices.dtype_vae): - x = model.decode_first_stage(x) + x = model.decode_first_stage(x.to(devices.dtype_vae) if devices.unet_needs_upcast else x) return x @@ -1001,7 +1001,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image = torch.from_numpy(batch_images) image = 2. * image - 1. - image = image.to(device=shared.device, dtype=devices.dtype_unet if devices.unet_needs_upcast else None) + image = image.to(device=shared.device, dtype=devices.dtype_vae if devices.unet_needs_upcast else None) self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image)) diff --git a/modules/sd_hijack_utils.py b/modules/sd_hijack_utils.py index f81b169a..f8684475 100644 --- a/modules/sd_hijack_utils.py +++ b/modules/sd_hijack_utils.py @@ -5,7 +5,7 @@ class CondFunc: self = super(CondFunc, cls).__new__(cls) if isinstance(orig_func, str): func_path = orig_func.split('.') - for i in range(len(func_path)-2, -1, -1): + for i in range(len(func_path)-1, -1, -1): try: resolved_obj = importlib.import_module('.'.join(func_path[:i])) break -- cgit v1.2.1 From 7a14c8ab45da8a681792a6331d48a88dd684a0a9 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 26 Jan 2023 23:29:27 +0300 Subject: add an option to enable sections from extras tab in txt2img/img2img fix some style inconsistenices --- modules/processing.py | 7 +++++- modules/scripts.py | 32 ++++++++++++++++++++++---- modules/scripts_auto_postprocessing.py | 42 ++++++++++++++++++++++++++++++++++ modules/scripts_postprocessing.py | 11 ++++++--- modules/shared.py | 15 ++++-------- modules/shared_items.py | 10 ++++++++ modules/ui_components.py | 8 +++++++ 7 files changed, 107 insertions(+), 18 deletions(-) create mode 100644 modules/scripts_auto_postprocessing.py create mode 100644 modules/shared_items.py (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 92894d67..262806a1 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -13,7 +13,7 @@ from skimage import exposure from typing import Any, Dict, List, Optional import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -658,6 +658,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: image = Image.fromarray(x_sample) + if p.scripts is not None: + pp = scripts.PostprocessImageArgs(image) + p.scripts.postprocess_image(p, pp) + image = pp.image + if p.color_corrections is not None and i < len(p.color_corrections): if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction: image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images) diff --git a/modules/scripts.py b/modules/scripts.py index 03907a63..6e9dc0c0 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -6,12 +6,16 @@ from collections import namedtuple import gradio as gr -from modules.processing import StableDiffusionProcessing from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing AlwaysVisible = object() +class PostprocessImageArgs: + def __init__(self, image): + self.image = image + + class Script: filename = None args_from = None @@ -65,7 +69,7 @@ class Script: args contains all values returned by components from ui() """ - raise NotImplementedError() + pass def process(self, p, *args): """ @@ -100,6 +104,13 @@ class Script: pass + def postprocess_image(self, p, pp: PostprocessImageArgs, *args): + """ + Called for every image after it has been generated. + """ + + pass + def postprocess(self, p, processed, *args): """ This function is called after processing ends for AlwaysVisible scripts. @@ -247,11 +258,15 @@ class ScriptRunner: self.infotext_fields = [] def initialize_scripts(self, is_img2img): + from modules import scripts_auto_postprocessing + self.scripts.clear() self.alwayson_scripts.clear() self.selectable_scripts.clear() - for script_class, path, basedir, script_module in scripts_data: + auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data() + + for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data: script = script_class() script.filename = path script.is_txt2img = not is_img2img @@ -332,7 +347,7 @@ class ScriptRunner: return inputs - def run(self, p: StableDiffusionProcessing, *args): + def run(self, p, *args): script_index = args[0] if script_index == 0: @@ -386,6 +401,15 @@ class ScriptRunner: print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) + def postprocess_image(self, p, pp: PostprocessImageArgs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.postprocess_image(p, pp, *script_args) + except Exception: + print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + def before_component(self, component, **kwargs): for script in self.scripts: try: diff --git a/modules/scripts_auto_postprocessing.py b/modules/scripts_auto_postprocessing.py new file mode 100644 index 00000000..30d6d658 --- /dev/null +++ b/modules/scripts_auto_postprocessing.py @@ -0,0 +1,42 @@ +from modules import scripts, scripts_postprocessing, shared + + +class ScriptPostprocessingForMainUI(scripts.Script): + def __init__(self, script_postproc): + self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc + self.postprocessing_controls = None + + def title(self): + return self.script.name + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def ui(self, is_img2img): + self.postprocessing_controls = self.script.ui() + return self.postprocessing_controls.values() + + def postprocess_image(self, p, script_pp, *args): + args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)} + + pp = scripts_postprocessing.PostprocessedImage(script_pp.image) + pp.info = {} + self.script.process(pp, **args_dict) + p.extra_generation_params.update(pp.info) + script_pp.image = pp.image + + +def create_auto_preprocessing_script_data(): + from modules import scripts + + res = [] + + for name in shared.opts.postprocessing_enable_in_main_ui: + script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None) + if script is None: + continue + + constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class()) + res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module)) + + return res diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py index 25de02d0..ce0ebb61 100644 --- a/modules/scripts_postprocessing.py +++ b/modules/scripts_postprocessing.py @@ -46,6 +46,8 @@ class ScriptPostprocessing: pass + + def wrap_call(func, filename, funcname, *args, default=None, **kwargs): try: res = func(*args, **kwargs) @@ -68,6 +70,9 @@ class ScriptPostprocessingRunner: script: ScriptPostprocessing = script_class() script.filename = path + if script.name == "Simple Upscale": + continue + self.scripts.append(script) def create_script_ui(self, script, inputs): @@ -87,12 +92,11 @@ class ScriptPostprocessingRunner: import modules.scripts self.initialize_scripts(modules.scripts.postprocessing_scripts_data) - scripts_order = [x.lower().strip() for x in shared.opts.postprocessing_scipts_order.split(",")] + scripts_order = shared.opts.postprocessing_operation_order def script_score(name): - name = name.lower() for i, possible_match in enumerate(scripts_order): - if possible_match in name: + if possible_match == name: return i return len(self.scripts) @@ -145,3 +149,4 @@ class ScriptPostprocessingRunner: def image_changed(self): for script in self.scripts_in_preferred_order(): script.image_changed() + diff --git a/modules/shared.py b/modules/shared.py index 6a0b96cb..cdeed55d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,8 +13,8 @@ import modules.interrogate import modules.memmon import modules.styles import modules.devices as devices -from modules import localization, sd_vae, extensions, script_loading, errors, ui_components -from modules.paths import models_path, script_path, sd_path +from modules import localization, sd_vae, extensions, script_loading, errors, ui_components, shared_items +from modules.paths import models_path, script_path demo = None @@ -264,12 +264,6 @@ interrogator = modules.interrogate.InterrogateModels("interrogate") face_restorers = [] - -def realesrgan_models_names(): - import modules.realesrgan_model - return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)] - - class OptionInfo: def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None): self.default = default @@ -360,7 +354,7 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo options_templates.update(options_section(('upscaling', "Upscaling"), { "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), - "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), + "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), })) @@ -483,7 +477,8 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" })) options_templates.update(options_section(('postprocessing', "Postprocessing"), { - 'postprocessing_scipts_order': OptionInfo("upscale, gfpgan, codeformer", "Postprocessing operation order"), + 'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), + 'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), 'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), })) diff --git a/modules/shared_items.py b/modules/shared_items.py new file mode 100644 index 00000000..b5d480c9 --- /dev/null +++ b/modules/shared_items.py @@ -0,0 +1,10 @@ + + +def realesrgan_models_names(): + import modules.realesrgan_model + return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)] + +def postprocessing_scripts(): + import modules.scripts + + return modules.scripts.scripts_postproc.scripts \ No newline at end of file diff --git a/modules/ui_components.py b/modules/ui_components.py index 9aec3097..284ca0cf 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -48,3 +48,11 @@ class FormColorPicker(gr.ColorPicker, gr.components.FormComponent): def get_block_name(self): return "colorpicker" + +class DropdownMulti(gr.Dropdown): + """Same as gr.Dropdown but always multiselect""" + def __init__(self, **kwargs): + super().__init__(multiselect=True, **kwargs) + + def get_block_name(self): + return "dropdown" -- cgit v1.2.1 From d2ac95fa7b2a8d0bcc5361ee16dba9cbb81ff8b2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 27 Jan 2023 11:28:12 +0300 Subject: remove the need to place configs near models --- modules/api/api.py | 5 +- modules/devices.py | 12 ++- modules/sd_hijack_inpainting.py | 9 -- modules/sd_models.py | 228 ++++++++++++++++++++-------------------- modules/sd_models_config.py | 65 ++++++++++++ modules/shared.py | 7 +- modules/shared_items.py | 15 ++- modules/timer.py | 35 ++++++ 8 files changed, 242 insertions(+), 134 deletions(-) create mode 100644 modules/sd_models_config.py create mode 100644 modules/timer.py (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 25c65e57..eb7b1da5 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -18,7 +18,8 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_ from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin,Image -from modules.sd_models import checkpoints_list, find_checkpoint_config +from modules.sd_models import checkpoints_list +from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices from typing import List @@ -387,7 +388,7 @@ class Api: ] def get_sd_models(self): - return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()] + return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()] def get_hypernetworks(self): return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] diff --git a/modules/devices.py b/modules/devices.py index 6b36622c..2d5f797a 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -34,14 +34,18 @@ def get_cuda_device_string(): return "cuda" -def get_optimal_device(): +def get_optimal_device_name(): if torch.cuda.is_available(): - return torch.device(get_cuda_device_string()) + return get_cuda_device_string() if has_mps(): - return torch.device("mps") + return "mps" + + return "cpu" - return cpu + +def get_optimal_device(): + return torch.device(get_optimal_device_name()) def get_device_for(task): diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index 31d2c898..478cd499 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -96,15 +96,6 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F return x_prev, pred_x0, e_t -def should_hijack_inpainting(checkpoint_info): - from modules import sd_models - - ckpt_basename = os.path.basename(checkpoint_info.filename).lower() - cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower() - - return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename - - def do_inpainting_hijack(): # p_sample_plms is needed because PLMS can't work with dicts as conditionings diff --git a/modules/sd_models.py b/modules/sd_models.py index 7072eb2e..fa208728 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -2,8 +2,6 @@ import collections import os.path import sys import gc -import time -from collections import namedtuple import torch import re import safetensors.torch @@ -14,10 +12,10 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes +from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config from modules.paths import models_path -from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting -from modules.sd_hijack_ip2p import should_hijack_ip2p +from modules.sd_hijack_inpainting import do_inpainting_hijack +from modules.timer import Timer model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) @@ -99,17 +97,6 @@ def checkpoint_tiles(): return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key) -def find_checkpoint_config(info): - if info is None: - return shared.cmd_opts.config - - config = os.path.splitext(info.filename)[0] + ".yaml" - if os.path.exists(config): - return config - - return shared.cmd_opts.config - - def list_models(): checkpoints_list.clear() checkpoint_alisases.clear() @@ -215,9 +202,7 @@ def get_state_dict_from_checkpoint(pl_sd): def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): _, extension = os.path.splitext(checkpoint_file) if extension.lower() == ".safetensors": - device = map_location or shared.weight_load_location - if device is None: - device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu" + device = map_location or shared.weight_load_location or devices.get_optimal_device_name() pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) else: pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) @@ -229,60 +214,74 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None return sd -def load_model_weights(model, checkpoint_info: CheckpointInfo): +def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): + sd_model_hash = checkpoint_info.calculate_shorthash() + timer.record("calculate hash") + + if checkpoint_info in checkpoints_loaded: + # use checkpoint cache + print(f"Loading weights [{sd_model_hash}] from cache") + return checkpoints_loaded[checkpoint_info] + + print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") + res = read_state_dict(checkpoint_info.filename) + timer.record("load weights from disk") + + return res + + +def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): title = checkpoint_info.title sd_model_hash = checkpoint_info.calculate_shorthash() + timer.record("calculate hash") + if checkpoint_info.title != title: shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title - cache_enabled = shared.opts.sd_checkpoint_cache > 0 + if state_dict is None: + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) - if cache_enabled and checkpoint_info in checkpoints_loaded: - # use checkpoint cache - print(f"Loading weights [{sd_model_hash}] from cache") - model.load_state_dict(checkpoints_loaded[checkpoint_info]) - else: - # load from file - print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") + model.load_state_dict(state_dict, strict=False) + del state_dict + timer.record("apply weights to model") - sd = read_state_dict(checkpoint_info.filename) - model.load_state_dict(sd, strict=False) - del sd - - if cache_enabled: - # cache newly loaded model - checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + if shared.opts.sd_checkpoint_cache > 0: + # cache newly loaded model + checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + + if shared.cmd_opts.opt_channelslast: + model.to(memory_format=torch.channels_last) + timer.record("apply channels_last") - if shared.cmd_opts.opt_channelslast: - model.to(memory_format=torch.channels_last) + if not shared.cmd_opts.no_half: + vae = model.first_stage_model + depth_model = getattr(model, 'depth_model', None) - if not shared.cmd_opts.no_half: - vae = model.first_stage_model - depth_model = getattr(model, 'depth_model', None) + # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 + if shared.cmd_opts.no_half_vae: + model.first_stage_model = None + # with --upcast-sampling, don't convert the depth model weights to float16 + if shared.cmd_opts.upcast_sampling and depth_model: + model.depth_model = None - # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 - if shared.cmd_opts.no_half_vae: - model.first_stage_model = None - # with --upcast-sampling, don't convert the depth model weights to float16 - if shared.cmd_opts.upcast_sampling and depth_model: - model.depth_model = None + model.half() + model.first_stage_model = vae + if depth_model: + model.depth_model = depth_model - model.half() - model.first_stage_model = vae - if depth_model: - model.depth_model = depth_model + timer.record("apply half()") - devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 - devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 - devices.dtype_unet = model.model.diffusion_model.dtype - devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 + devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 + devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 + devices.dtype_unet = model.model.diffusion_model.dtype + devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 - model.first_stage_model.to(devices.dtype_vae) + model.first_stage_model.to(devices.dtype_vae) + timer.record("apply dtype to VAE") # clean up cache if limit is reached - if cache_enabled: - while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model - checkpoints_loaded.popitem(last=False) # LRU + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + checkpoints_loaded.popitem(last=False) model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_info.filename @@ -295,6 +294,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo): sd_vae.clear_loaded_vae() vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename) sd_vae.load_vae(model, vae_file, vae_source) + timer.record("load VAE") def enable_midas_autodownload(): @@ -340,24 +340,20 @@ def enable_midas_autodownload(): midas.api.load_model = load_model_wrapper -class Timer: - def __init__(self): - self.start = time.time() +def repair_config(sd_config): - def elapsed(self): - end = time.time() - res = end - self.start - self.start = end - return res + if not hasattr(sd_config.model.params, "use_ema"): + sd_config.model.params.use_ema = False + if shared.cmd_opts.no_half: + sd_config.model.params.unet_config.params.use_fp16 = False + elif shared.cmd_opts.upcast_sampling: + sd_config.model.params.unet_config.params.use_fp16 = True -def load_model(checkpoint_info=None): + +def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() - checkpoint_config = find_checkpoint_config(checkpoint_info) - - if checkpoint_config != shared.cmd_opts.config: - print(f"Loading config from: {checkpoint_config}") if shared.sd_model: sd_hijack.model_hijack.undo_hijack(shared.sd_model) @@ -365,38 +361,27 @@ def load_model(checkpoint_info=None): gc.collect() devices.torch_gc() - sd_config = OmegaConf.load(checkpoint_config) - - if should_hijack_inpainting(checkpoint_info): - # Hardcoded config for now... - sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" - sd_config.model.params.conditioning_key = "hybrid" - sd_config.model.params.unet_config.params.in_channels = 9 - sd_config.model.params.finetune_keys = None - - if should_hijack_ip2p(checkpoint_info): - sd_config.model.target = "modules.models.diffusion.ddpm_edit.LatentDiffusion" - sd_config.model.params.conditioning_key = "hybrid" - sd_config.model.params.first_stage_key = "edited" - sd_config.model.params.cond_stage_key = "edit" - sd_config.model.params.image_size = 16 - sd_config.model.params.unet_config.params.in_channels = 8 - sd_config.model.params.unet_config.params.out_channels = 4 + do_inpainting_hijack() - if not hasattr(sd_config.model.params, "use_ema"): - sd_config.model.params.use_ema = False + timer = Timer() - do_inpainting_hijack() + if already_loaded_state_dict is not None: + state_dict = already_loaded_state_dict + else: + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) - if shared.cmd_opts.no_half: - sd_config.model.params.unet_config.params.use_fp16 = False - elif shared.cmd_opts.upcast_sampling: - sd_config.model.params.unet_config.params.use_fp16 = True + checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) - timer = Timer() + timer.record("find config") - sd_model = None + sd_config = OmegaConf.load(checkpoint_config) + repair_config(sd_config) + + timer.record("load config") + + print(f"Creating model from config: {checkpoint_config}") + sd_model = None try: with sd_disable_initialization.DisableInitialization(): sd_model = instantiate_from_config(sd_config.model) @@ -407,29 +392,35 @@ def load_model(checkpoint_info=None): print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) sd_model = instantiate_from_config(sd_config.model) - elapsed_create = timer.elapsed() + sd_model.used_config = checkpoint_config - load_model_weights(sd_model, checkpoint_info) + timer.record("create model") - elapsed_load_weights = timer.elapsed() + load_model_weights(sd_model, checkpoint_info, state_dict, timer) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) else: sd_model.to(shared.device) + timer.record("move model to device") + sd_hijack.model_hijack.hijack(sd_model) + timer.record("hijack") + sd_model.eval() shared.sd_model = sd_model sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model + timer.record("load textual inversion embeddings") + script_callbacks.model_loaded_callback(sd_model) - elapsed_the_rest = timer.elapsed() + timer.record("scripts callbacks") - print(f"Model loaded in {elapsed_create + elapsed_load_weights + elapsed_the_rest:.1f}s ({elapsed_create:.1f}s create model, {elapsed_load_weights:.1f}s load weights).") + print(f"Model loaded in {timer.summary()}.") return sd_model @@ -440,6 +431,7 @@ def reload_model_weights(sd_model=None, info=None): if not sd_model: sd_model = shared.sd_model + if sd_model is None: # previous model load failed current_checkpoint_info = None else: @@ -447,14 +439,6 @@ def reload_model_weights(sd_model=None, info=None): if sd_model.sd_model_checkpoint == checkpoint_info.filename: return - checkpoint_config = find_checkpoint_config(current_checkpoint_info) - - if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info) or should_hijack_ip2p(checkpoint_info) != should_hijack_ip2p(sd_model.sd_checkpoint_info): - del sd_model - checkpoints_loaded.clear() - load_model(checkpoint_info) - return shared.sd_model - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.send_everything_to_cpu() else: @@ -464,21 +448,35 @@ def reload_model_weights(sd_model=None, info=None): timer = Timer() + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) + + checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) + + timer.record("find config") + + if sd_model is None or checkpoint_config != sd_model.used_config: + del sd_model + checkpoints_loaded.clear() + load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"]) + return shared.sd_model + try: - load_model_weights(sd_model, checkpoint_info) + load_model_weights(sd_model, checkpoint_info, state_dict, timer) except Exception as e: print("Failed to load checkpoint, restoring previous") - load_model_weights(sd_model, current_checkpoint_info) + load_model_weights(sd_model, current_checkpoint_info, None, timer) raise finally: sd_hijack.model_hijack.hijack(sd_model) + timer.record("hijack") + script_callbacks.model_loaded_callback(sd_model) + timer.record("script callbacks") if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: sd_model.to(devices.device) + timer.record("move model to device") - elapsed = timer.elapsed() - - print(f"Weights loaded in {elapsed:.1f}s.") + print(f"Weights loaded in {timer.summary()}.") return sd_model diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py new file mode 100644 index 00000000..ea773a10 --- /dev/null +++ b/modules/sd_models_config.py @@ -0,0 +1,65 @@ +import re +import os + +from modules import shared, paths + +sd_configs_path = shared.sd_configs_path +sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion") + + +config_default = shared.sd_default_config +config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") +config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") +config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml") +config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") +config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") + +re_parametrization_v = re.compile(r'-v\b') + + +def guess_model_config_from_state_dict(sd, filename): + fn = os.path.basename(filename) + + sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None) + diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) + roberta_weight = sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) + + if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: + if re.search(re_parametrization_v, fn) or "v2-1_768" in fn: + return config_sd2v + else: + return config_sd2 + + if diffusion_model_input is not None: + if diffusion_model_input.shape[1] == 9: + return config_inpainting + if diffusion_model_input.shape[1] == 8: + return config_instruct_pix2pix + + if roberta_weight is not None: + return config_alt_diffusion + + return config_default + + +def find_checkpoint_config(state_dict, info): + if info is None: + return guess_model_config_from_state_dict(state_dict, "") + + config = find_checkpoint_config_near_filename(info) + if config is not None: + return config + + return guess_model_config_from_state_dict(state_dict, info.filename) + + +def find_checkpoint_config_near_filename(info): + if info is None: + return None + + config = os.path.splitext(info.filename)[0] + ".yaml" + if os.path.exists(config): + return config + + return None + diff --git a/modules/shared.py b/modules/shared.py index cdeed55d..14be993d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,13 +13,14 @@ import modules.interrogate import modules.memmon import modules.styles import modules.devices as devices -from modules import localization, sd_vae, extensions, script_loading, errors, ui_components, shared_items +from modules import localization, extensions, script_loading, errors, ui_components, shared_items from modules.paths import models_path, script_path demo = None -sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml") +sd_configs_path = os.path.join(script_path, "configs") +sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml") sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file @@ -391,7 +392,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), - "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list), + "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list), "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}), diff --git a/modules/shared_items.py b/modules/shared_items.py index b5d480c9..8b5ec96d 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -4,7 +4,20 @@ def realesrgan_models_names(): import modules.realesrgan_model return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)] + def postprocessing_scripts(): import modules.scripts - return modules.scripts.scripts_postproc.scripts \ No newline at end of file + return modules.scripts.scripts_postproc.scripts + + +def sd_vae_items(): + import modules.sd_vae + + return ["Automatic", "None"] + list(modules.sd_vae.vae_dict) + + +def refresh_vae_list(): + import modules.sd_vae + + return modules.sd_vae.refresh_vae_list diff --git a/modules/timer.py b/modules/timer.py new file mode 100644 index 00000000..57a4f17a --- /dev/null +++ b/modules/timer.py @@ -0,0 +1,35 @@ +import time + + +class Timer: + def __init__(self): + self.start = time.time() + self.records = {} + self.total = 0 + + def elapsed(self): + end = time.time() + res = end - self.start + self.start = end + return res + + def record(self, category, extra_time=0): + e = self.elapsed() + if category not in self.records: + self.records[category] = 0 + + self.records[category] += e + extra_time + self.total += e + extra_time + + def summary(self): + res = f"{self.total:.1f}s" + + additions = [x for x in self.records.items() if x[1] >= 0.1] + if not additions: + return res + + res += " (" + res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions]) + res += ")" + + return res -- cgit v1.2.1 From 6f31d2210c189f8db118e6f95add7ba2a64f0238 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 27 Jan 2023 11:54:19 +0300 Subject: support detecting midas model fix broken api for checkpoint list --- modules/api/models.py | 2 +- modules/sd_models.py | 10 +++++----- modules/sd_models_config.py | 7 +++++-- 3 files changed, 11 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/api/models.py b/modules/api/models.py index 805bd8f7..cba43d3b 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -228,7 +228,7 @@ class SDModelItem(BaseModel): hash: Optional[str] = Field(title="Short hash") sha256: Optional[str] = Field(title="sha256 hash") filename: str = Field(title="Filename") - config: str = Field(title="Config file") + config: Optional[str] = Field(title="Config file") class HypernetworkItem(BaseModel): name: str = Field(title="Name") diff --git a/modules/sd_models.py b/modules/sd_models.py index fa208728..37dad18d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -439,12 +439,12 @@ def reload_model_weights(sd_model=None, info=None): if sd_model.sd_model_checkpoint == checkpoint_info.filename: return - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: - lowvram.send_everything_to_cpu() - else: - sd_model.to(devices.cpu) + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.send_everything_to_cpu() + else: + sd_model.to(devices.cpu) - sd_hijack.model_hijack.undo_hijack(sd_model) + sd_hijack.model_hijack.undo_hijack(sd_model) timer = Timer() diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index ea773a10..4d1e92e1 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -10,6 +10,7 @@ sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", config_default = shared.sd_default_config config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") +config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml") config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") @@ -22,7 +23,9 @@ def guess_model_config_from_state_dict(sd, filename): sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None) diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) - roberta_weight = sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) + + if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: + return config_depth_model if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: if re.search(re_parametrization_v, fn) or "v2-1_768" in fn: @@ -36,7 +39,7 @@ def guess_model_config_from_state_dict(sd, filename): if diffusion_model_input.shape[1] == 8: return config_instruct_pix2pix - if roberta_weight is not None: + if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None: return config_alt_diffusion return config_default -- cgit v1.2.1 From 9beb794e0b0dc1a0f9e89d8e38bd789a8c608397 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 27 Jan 2023 13:08:00 +0300 Subject: clarify the option to disable NaN check. --- modules/devices.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules') diff --git a/modules/devices.py b/modules/devices.py index 2d5f797a..4687944e 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -143,6 +143,8 @@ def test_for_nans(x, where): else: message = "A tensor with all NaNs was produced." + message += " Use --disable-nan-check commandline argument to disable this check." + raise NansException(message) -- cgit v1.2.1