aboutsummaryrefslogtreecommitdiff
path: root/modules/realesrgan_model.py
blob: 458bf6780aa94052957b78c0cc820a97dae10a25 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import sys
import traceback
from collections import namedtuple

import numpy as np
from PIL import Image
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer

import modules.images
from modules.paths import models_path
from modules.shared import cmd_opts, opts

model_dir = "RealESRGAN"
model_path = os.path.join(models_path, model_dir)
cmd_dir = None
RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"])
realesrgan_models = []
have_realesrgan = False


def get_realesrgan_models():
    try:
        from basicsr.archs.rrdbnet_arch import RRDBNet
        from realesrgan.archs.srvgg_arch import SRVGGNetCompact
        models = [
            RealesrganModelInfo(
                name="Real-ESRGAN General x4x3",
                location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
                netscale=4,
                model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
            ),
            RealesrganModelInfo(
                name="Real-ESRGAN General WDN x4x3",
                location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
                netscale=4,
                model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
            ),
            RealesrganModelInfo(
                name="Real-ESRGAN AnimeVideo",
                location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
                netscale=4,
                model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
            ),
            RealesrganModelInfo(
                name="Real-ESRGAN 4x plus",
                location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
                netscale=4,
                model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
            ),
            RealesrganModelInfo(
                name="Real-ESRGAN 4x plus anime 6B",
                location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
                netscale=4,
                model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
            ),
            RealesrganModelInfo(
                name="Real-ESRGAN 2x plus",
                location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
                netscale=2,
                model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
            ),
        ]
        return models
    except Exception as e:
        print("Error making Real-ESRGAN models list:", file=sys.stderr)
        print(traceback.format_exc(), file=sys.stderr)


class UpscalerRealESRGAN(modules.images.Upscaler):
    def __init__(self, upscaling, model_index):
        self.upscaling = upscaling
        self.model_index = model_index
        self.name = realesrgan_models[model_index].name

    def do_upscale(self, img):
        return upscale_with_realesrgan(img, self.upscaling, self.model_index)


def setup_model(dirname):
    global model_path
    if not os.path.exists(model_path):
        os.makedirs(model_path)

    global realesrgan_models
    global have_realesrgan
    if model_path != dirname:
        model_path = dirname
    try:
        from basicsr.archs.rrdbnet_arch import RRDBNet
        from realesrgan import RealESRGANer
        from realesrgan.archs.srvgg_arch import SRVGGNetCompact

        realesrgan_models = get_realesrgan_models()
        have_realesrgan = True

        for i, model in enumerate(realesrgan_models):
            if model.name in opts.realesrgan_enabled_models:
                modules.shared.sd_upscalers.append(UpscalerRealESRGAN(model.netscale, i))

    except Exception:
        print("Error importing Real-ESRGAN:", file=sys.stderr)
        print(traceback.format_exc(), file=sys.stderr)

        realesrgan_models = [RealesrganModelInfo('None', '', 0, None)]
        have_realesrgan = False


def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index):
    if not have_realesrgan:
        return image

    info = realesrgan_models[RealESRGAN_model_index]

    model = info.model()
    model_file = load_file_from_url(url=info.location, model_dir=model_path, progress=True)
    if not os.path.exists(model_file):
        print("Unable to load RealESRGAN model: %s" % info.name)
        return image

    upsampler = RealESRGANer(
        scale=info.netscale,
        model_path=info.location,
        model=model,
        half=not cmd_opts.no_half,
        tile=opts.ESRGAN_tile,
        tile_pad=opts.ESRGAN_tile_overlap,
    )

    upsampled = upsampler.enhance(np.array(image), outscale=RealESRGAN_upscaling)[0]

    image = Image.fromarray(upsampled)
    return image