From 6df316c881b533731faa77494ea01533e35f8dc7 Mon Sep 17 00:00:00 2001 From: wywywywy Date: Sat, 10 Dec 2022 13:54:29 +0000 Subject: LDSR cache / optimization / opt_channelslast --- extensions-builtin/LDSR/ldsr_model_arch.py | 40 +++++++++++++++++++-------- extensions-builtin/LDSR/scripts/ldsr_model.py | 1 + 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py index a87d1ef9..9ec4e67e 100644 --- a/extensions-builtin/LDSR/ldsr_model_arch.py +++ b/extensions-builtin/LDSR/ldsr_model_arch.py @@ -11,25 +11,41 @@ from omegaconf import OmegaConf from ldm.models.diffusion.ddim import DDIMSampler from ldm.util import instantiate_from_config, ismap +from modules import shared, sd_hijack warnings.filterwarnings("ignore", category=UserWarning) +cached_ldsr_model: torch.nn.Module = None + # Create LDSR Class class LDSR: def load_model_from_config(self, half_attention): - print(f"Loading model from {self.modelPath}") - pl_sd = torch.load(self.modelPath, map_location="cpu") - sd = pl_sd["state_dict"] - config = OmegaConf.load(self.yamlPath) - config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1" - model = instantiate_from_config(config.model) - model.load_state_dict(sd, strict=False) - model.cuda() - if half_attention: - model = model.half() - - model.eval() + global cached_ldsr_model + + if shared.opts.ldsr_cached and cached_ldsr_model is not None: + print(f"Loading model from cache") + model: torch.nn.Module = cached_ldsr_model + else: + print(f"Loading model from {self.modelPath}") + pl_sd = torch.load(self.modelPath, map_location="cpu") + sd = pl_sd["state_dict"] + config = OmegaConf.load(self.yamlPath) + config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1" + model: torch.nn.Module = instantiate_from_config(config.model) + model.load_state_dict(sd, strict=False) + model = model.to(shared.device) + if half_attention: + model = model.half() + if shared.cmd_opts.opt_channelslast: + model = model.to(memory_format=torch.channels_last) + + sd_hijack.model_hijack.hijack(model) # apply optimization + model.eval() + + if shared.opts.ldsr_cached: + cached_ldsr_model = model + return {"model": model} def __init__(self, model_path, yaml_path): diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py index 5c96037d..29d5f94e 100644 --- a/extensions-builtin/LDSR/scripts/ldsr_model.py +++ b/extensions-builtin/LDSR/scripts/ldsr_model.py @@ -59,6 +59,7 @@ def on_ui_settings(): import gradio as gr shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling"))) + shared.opts.add_option("ldsr_cached", shared.OptionInfo(False, "Cache LDSR model in memory", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling"))) script_callbacks.on_ui_settings(on_ui_settings) -- cgit v1.2.1 From 1581d5a1674fbbeaf047b79f3a138781d6322e6e Mon Sep 17 00:00:00 2001 From: wywywywy Date: Sat, 10 Dec 2022 14:07:27 +0000 Subject: Made device agnostic --- extensions-builtin/LDSR/ldsr_model_arch.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py index 9ec4e67e..8b048ae0 100644 --- a/extensions-builtin/LDSR/ldsr_model_arch.py +++ b/extensions-builtin/LDSR/ldsr_model_arch.py @@ -110,7 +110,8 @@ class LDSR: down_sample_method = 'Lanczos' gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available: + torch.cuda.empty_cache() im_og = image width_og, height_og = im_og.size @@ -147,7 +148,9 @@ class LDSR: del model gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available: + torch.cuda.empty_cache() + return a @@ -162,7 +165,7 @@ def get_cond(selected_path): c = rearrange(c, '1 c h w -> 1 h w c') c = 2. * c - 1. - c = c.to(torch.device("cuda")) + c = c.to(shared.device) example["LR_image"] = c example["image"] = c_up -- cgit v1.2.1