aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-09-23 10:26:00 +0300
committerAUTOMATIC <16777216c@gmail.com>2022-09-23 10:26:00 +0300
commitd4205e66fa50e6bed4ace11bc2236e834b7c560f (patch)
treee7f272f117b0b07a653d5c6be4f79d8d0e60030b /modules
parentd6fd71f36f33763f3a8d1d98f815e1e6a979e13e (diff)
gfpgan: just download the damn model
Diffstat (limited to 'modules')
-rw-r--r--modules/gfpgan_model.py19
-rw-r--r--modules/shared.py3
2 files changed, 14 insertions, 8 deletions
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index 0af97123..44c5dc6c 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -1,6 +1,7 @@
import os
import sys
import traceback
+from glob import glob
from modules import shared, devices
from modules.shared import cmd_opts
@@ -11,14 +12,20 @@ import modules.face_restoration
def gfpgan_model_path():
from modules.shared import cmd_opts
+ filemask = 'GFPGAN*.pth'
+
+ if cmd_opts.gfpgan_model is not None:
+ return cmd_opts.gfpgan_model
+
places = [script_path, '.', os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models')]
- files = [cmd_opts.gfpgan_model] + [os.path.join(dirname, cmd_opts.gfpgan_model) for dirname in places]
- found = [x for x in files if os.path.exists(x)]
- if len(found) == 0:
- raise Exception("GFPGAN model not found in paths: " + ", ".join(files))
+ filename = None
+ for place in places:
+ filename = next(iter(glob(os.path.join(place, filemask))), None)
+ if filename is not None:
+ break
- return found[0]
+ return filename
loaded_gfpgan_model = None
@@ -34,7 +41,7 @@ def gfpgan():
if gfpgan_constructor is None:
return None
- model = gfpgan_constructor(model_path=gfpgan_model_path(), upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
+ model = gfpgan_constructor(model_path=gfpgan_model_path() or 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth', upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
model.gfpgan.to(shared.device)
loaded_gfpgan_model = model
diff --git a/modules/shared.py b/modules/shared.py
index f30bef02..39681ed0 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -2,7 +2,6 @@ import sys
import argparse
import json
import os
-from glob import glob
import gradio as gr
import tqdm
@@ -22,7 +21,7 @@ parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",)
parser.add_argument("--ckpt-dir", type=str, default=os.path.join(script_path, 'models'), help="path to directory with stable diffusion checkpoints",)
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
-parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=next(iter(glob('GFPGAN*.pth')), ''))
+parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")