aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/LDSR
diff options
context:
space:
mode:
authorwywywywy <wywywywy@gmail.com>2022-12-10 18:57:18 +0000
committerwywywywy <wywywywy@gmail.com>2022-12-10 18:57:18 +0000
commit8bcdd50461090a2dd238082b33f4c1423378ebbd (patch)
tree0c9f0f3ea522bb8c6914bc0af77ca570163481b2 /extensions-builtin/LDSR
parent685f9631b56ff8bd43bce24ff5ce0f9a0e9af490 (diff)
Add safetensors support to LDSR
Diffstat (limited to 'extensions-builtin/LDSR')
-rw-r--r--extensions-builtin/LDSR/ldsr_model_arch.py10
-rw-r--r--extensions-builtin/LDSR/scripts/ldsr_model.py8
2 files changed, 14 insertions, 4 deletions
diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py
index 8b048ae0..f5bd8ae4 100644
--- a/extensions-builtin/LDSR/ldsr_model_arch.py
+++ b/extensions-builtin/LDSR/ldsr_model_arch.py
@@ -1,3 +1,4 @@
+import os
import gc
import time
import warnings
@@ -8,6 +9,7 @@ import torchvision
from PIL import Image
from einops import rearrange, repeat
from omegaconf import OmegaConf
+import safetensors.torch
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, ismap
@@ -28,8 +30,12 @@ class LDSR:
model: torch.nn.Module = cached_ldsr_model
else:
print(f"Loading model from {self.modelPath}")
- pl_sd = torch.load(self.modelPath, map_location="cpu")
- sd = pl_sd["state_dict"]
+ _, extension = os.path.splitext(self.modelPath)
+ if extension.lower() == ".safetensors":
+ pl_sd = safetensors.torch.load_file(self.modelPath, device="cpu")
+ else:
+ pl_sd = torch.load(self.modelPath, map_location="cpu")
+ sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd
config = OmegaConf.load(self.yamlPath)
config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"
model: torch.nn.Module = instantiate_from_config(config.model)
diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py
index 29d5f94e..b8cff29b 100644
--- a/extensions-builtin/LDSR/scripts/ldsr_model.py
+++ b/extensions-builtin/LDSR/scripts/ldsr_model.py
@@ -25,6 +25,7 @@ class UpscalerLDSR(Upscaler):
yaml_path = os.path.join(self.model_path, "project.yaml")
old_model_path = os.path.join(self.model_path, "model.pth")
new_model_path = os.path.join(self.model_path, "model.ckpt")
+ safetensors_model_path = os.path.join(self.model_path, "model.safetensors")
if os.path.exists(yaml_path):
statinfo = os.stat(yaml_path)
if statinfo.st_size >= 10485760:
@@ -33,8 +34,11 @@ class UpscalerLDSR(Upscaler):
if os.path.exists(old_model_path):
print("Renaming model from model.pth to model.ckpt")
os.rename(old_model_path, new_model_path)
- model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
- file_name="model.ckpt", progress=True)
+ if os.path.exists(safetensors_model_path):
+ model = safetensors_model_path
+ else:
+ model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
+ file_name="model.ckpt", progress=True)
yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
file_name="project.yaml", progress=True)