aboutsummaryrefslogtreecommitdiff
path: root/modules/codeformer_model.py
diff options
context:
space:
mode:
authord8ahazard <d8ahazard@gmail.com>2022-09-26 09:29:50 -0500
committerd8ahazard <d8ahazard@gmail.com>2022-09-26 09:29:50 -0500
commit740070ea9cdb254209f66417418f2a4af8b099d6 (patch)
tree52896a6159b706024af9520c855c10091162372c /modules/codeformer_model.py
parentbfb7f15d46048f27338eeac3a591a5943d03c5f1 (diff)
Re-implement universal model loading
Diffstat (limited to 'modules/codeformer_model.py')
-rw-r--r--modules/codeformer_model.py35
1 files changed, 25 insertions, 10 deletions
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index 8fbdea24..dc0a5eee 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -5,22 +5,28 @@ import traceback
import cv2
import torch
-from modules import shared, devices
-from modules.paths import script_path
+from modules import shared, devices, modelloader
+from modules.paths import script_path, models_path
import modules.shared
import modules.face_restoration
from importlib import reload
-# codeformer people made a choice to include modified basicsr librry to their projectwhich makes
-# it utterly impossiblr to use it alongside with other libraries that also use basicsr, like GFPGAN.
+# codeformer people made a choice to include modified basicsr library to their project, which makes
+# it utterly impossible to use it alongside other libraries that also use basicsr, like GFPGAN.
# I am making a choice to include some files from codeformer to work around this issue.
-
-pretrain_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
+model_dir = "Codeformer"
+model_path = os.path.join(models_path, model_dir)
+model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
have_codeformer = False
codeformer = None
-def setup_codeformer():
+
+def setup_model(dirname):
+ global model_path
+ if not os.path.exists(model_path):
+ os.makedirs(model_path)
+
path = modules.paths.paths.get("CodeFormer", None)
if path is None:
return
@@ -44,16 +50,22 @@ def setup_codeformer():
def name(self):
return "CodeFormer"
- def __init__(self):
+ def __init__(self, dirname):
self.net = None
self.face_helper = None
+ self.cmd_dir = dirname
def create_models(self):
if self.net is not None and self.face_helper is not None:
self.net.to(devices.device_codeformer)
return self.net, self.face_helper
-
+ model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir)
+ if len(model_paths) != 0:
+ ckpt_path = model_paths[0]
+ else:
+ print("Unable to load codeformer model.")
+ return None, None
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
ckpt_path = load_file_from_url(url=pretrain_model_url, model_dir=os.path.join(path, 'weights/CodeFormer'), progress=True)
checkpoint = torch.load(ckpt_path)['params_ema']
@@ -74,6 +86,9 @@ def setup_codeformer():
original_resolution = np_image.shape[0:2]
self.create_models()
+ if self.net is None or self.face_helper is None:
+ return np_image
+
self.face_helper.clean_all()
self.face_helper.read_image(np_image)
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
@@ -114,7 +129,7 @@ def setup_codeformer():
have_codeformer = True
global codeformer
- codeformer = FaceRestorerCodeFormer()
+ codeformer = FaceRestorerCodeFormer(dirname)
shared.face_restorers.append(codeformer)
except Exception: