aboutsummaryrefslogtreecommitdiff
path: root/modules/esrgan_model.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-09-08 15:49:47 +0300
committerAUTOMATIC <16777216c@gmail.com>2022-09-08 15:49:47 +0300
commitad02b249f5bf8e494c35a313f44515b7b1e6739d (patch)
tree118bf2514a0de97eaadfe2281c3eb5cdf487503c /modules/esrgan_model.py
parent62ce77e24568113f9a19836bf90741dba4166db5 (diff)
add a helpful message when user puts RealESRGAN model into ESRGAN directory.
Diffstat (limited to 'modules/esrgan_model.py')
-rw-r--r--modules/esrgan_model.py13
1 files changed, 8 insertions, 5 deletions
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index 2ed1d273..e86ad775 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -14,17 +14,20 @@ import modules.images
def load_model(filename):
# this code is adapted from https://github.com/xinntao/ESRGAN
- if torch.has_mps:
- map_l = 'cpu'
- else:
- map_l = None
- pretrained_net = torch.load(filename, map_location=map_l)
+ pretrained_net = torch.load(filename, map_location='cpu' if torch.has_mps else None)
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
if 'conv_first.weight' in pretrained_net:
crt_model.load_state_dict(pretrained_net)
return crt_model
+ if 'model.0.weight' not in pretrained_net:
+ is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
+ if is_realesrgan:
+ raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
+ else:
+ raise Exception("The file is not a ESRGAN model.")
+
crt_net = crt_model.state_dict()
load_net_clean = {}
for k, v in pretrained_net.items():