aboutsummaryrefslogtreecommitdiff
path: root/modules/api/api.py
blob: 3f490ce26422f6cb53f2728aea0a4189532aea1f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from modules.api.processing import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
import modules.shared as shared
import uvicorn
from fastapi import APIRouter, HTTPException
import json
import io
import base64
from modules.api.models import *
from PIL import Image
from modules.extras import run_extras
from gradio import processing_utils

def upscaler_to_index(name: str):
    try:
        return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
    except:
        raise HTTPException(status_code=400, detail="Upscaler not found")

sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)

def img_to_base64(img: str):
    buffer = io.BytesIO()
    img.save(buffer, format="png")
    return base64.b64encode(buffer.getvalue())

def base64_to_bytes(base64Img: str):
    if "," in base64Img:
        base64Img = base64Img.split(",")[1]
    return io.BytesIO(base64.b64decode(base64Img))

def base64_to_images(base64Imgs: list[str]):
    imgs = []
    for img in base64Imgs:
        img = Image.open(base64_to_bytes(img))
        imgs.append(img)
    return imgs

class ImageToImageResponse(BaseModel):
    images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
    parameters: Json
    info: Json


class Api:
    def __init__(self, app, queue_lock):
        self.router = APIRouter()
        self.app = app
        self.queue_lock = queue_lock
        self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
        self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"])
        self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
        self.app.add_api_route("/sdapi/v1/extra-batch-image", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)

    def __base64_to_image(self, base64_string):
        # if has a comma, deal with prefix
        if "," in base64_string:
            base64_string = base64_string.split(",")[1]
        imgdata = base64.b64decode(base64_string)
        # convert base64 to PIL image
        return Image.open(io.BytesIO(imgdata))

    def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
        sampler_index = sampler_to_index(txt2imgreq.sampler_index)
        
        if sampler_index is None:
            raise HTTPException(status_code=404, detail="Sampler not found") 
        
        populate = txt2imgreq.copy(update={ # Override __init__ params
            "sd_model": shared.sd_model, 
            "sampler_index": sampler_index[0],
            "do_not_save_samples": True,
            "do_not_save_grid": True
            }
        )
        p = StableDiffusionProcessingTxt2Img(**vars(populate))
        # Override object param
        with self.queue_lock:
            processed = process_images(p)
        
        b64images = list(map(img_to_base64, processed.images))

        return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info))
        
        

    def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
        sampler_index = sampler_to_index(img2imgreq.sampler_index)
        
        if sampler_index is None:
            raise HTTPException(status_code=404, detail="Sampler not found") 


        init_images = img2imgreq.init_images
        if init_images is None:
            raise HTTPException(status_code=404, detail="Init image not found") 

        mask = img2imgreq.mask
        if mask:
            mask = self.__base64_to_image(mask)

        
        populate = img2imgreq.copy(update={ # Override __init__ params
            "sd_model": shared.sd_model, 
            "sampler_index": sampler_index[0],
            "do_not_save_samples": True,
            "do_not_save_grid": True, 
            "mask": mask
            }
        )
        p = StableDiffusionProcessingImg2Img(**vars(populate))

        imgs = []
        for img in init_images:
            img = self.__base64_to_image(img)
            imgs = [img] * p.batch_size

        p.init_images = imgs
        # Override object param
        with self.queue_lock:
            processed = process_images(p)
        
        b64images = []
        for i in processed.images:
            buffer = io.BytesIO()
            i.save(buffer, format="png")
            b64images.append(base64.b64encode(buffer.getvalue()))

        return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=json.dumps(processed.info))

    def extras_single_image_api(self, req: ExtrasSingleImageRequest):
        upscaler1Index = upscaler_to_index(req.upscaler_1)
        upscaler2Index = upscaler_to_index(req.upscaler_2)

        reqDict = vars(req)
        reqDict.pop('upscaler_1')
        reqDict.pop('upscaler_2')

        reqDict['image'] = processing_utils.decode_base64_to_file(reqDict['image'])

        with self.queue_lock:
            result = run_extras(**reqDict, extras_upscaler_1=upscaler1Index, extras_upscaler_2=upscaler2Index, extras_mode=0, image_folder="", input_dir="", output_dir="")

        return ExtrasSingleImageResponse(image=processing_utils.encode_pil_to_base64(result[0]), html_info_x=result[1], html_info=result[2])

    def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
        upscaler1Index = upscaler_to_index(req.upscaler_1)
        upscaler2Index = upscaler_to_index(req.upscaler_2)

        reqDict = vars(req)
        reqDict.pop('upscaler_1')
        reqDict.pop('upscaler_2')

        reqDict['image_folder'] = list(map(processing_utils.decode_base64_to_file, reqDict['imageList']))
        reqDict.pop('imageList')

        with self.queue_lock:
            result = run_extras(**reqDict, extras_upscaler_1=upscaler1Index, extras_upscaler_2=upscaler2Index, extras_mode=1, image="", input_dir="", output_dir="")

        return ExtrasBatchImagesResponse(images=list(map(processing_utils.encode_pil_to_base64, result[0])), html_info_x=result[1], html_info=result[2])
    
    def extras_folder_processing_api(self):
        raise NotImplementedError

    def pnginfoapi(self):
        raise NotImplementedError

    def launch(self, server_name, port):
        self.app.include_router(self.router)
        uvicorn.run(self.app, host=server_name, port=port)