From 1b4d04737ac513cbd55958bb60a4f85166f3484b Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sat, 22 Oct 2022 20:13:16 -0300 Subject: Remove unused imports --- modules/api/api.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 5b0c934e..a5136b4b 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,11 +1,9 @@ from modules.api.processing import StableDiffusionProcessingAPI from modules.processing import StableDiffusionProcessingTxt2Img, process_images from modules.sd_samplers import all_samplers -from modules.extras import run_pnginfo import modules.shared as shared import uvicorn -from fastapi import Body, APIRouter, HTTPException -from fastapi.responses import JSONResponse +from fastapi import APIRouter, HTTPException from pydantic import BaseModel, Field, Json import json import io @@ -18,7 +16,6 @@ class TextToImageResponse(BaseModel): parameters: Json info: Json - class Api: def __init__(self, app, queue_lock): self.router = APIRouter() -- cgit v1.2.1 From b02926df1393df311db734af149fb9faf4389cbe Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sat, 22 Oct 2022 20:24:04 -0300 Subject: Moved moodels to their own file and extracted base64 conversion to its own function --- modules/api/api.py | 17 ++++++----------- modules/api/models.py | 8 ++++++++ 2 files changed, 14 insertions(+), 11 deletions(-) create mode 100644 modules/api/models.py (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index a5136b4b..c17d7580 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -4,17 +4,17 @@ from modules.sd_samplers import all_samplers import modules.shared as shared import uvicorn from fastapi import APIRouter, HTTPException -from pydantic import BaseModel, Field, Json import json import io import base64 +from modules.api.models import * sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) -class TextToImageResponse(BaseModel): - images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") - parameters: Json - info: Json +def img_to_base64(img): + buffer = io.BytesIO() + img.save(buffer, format="png") + return base64.b64encode(buffer.getvalue()) class Api: def __init__(self, app, queue_lock): @@ -41,15 +41,10 @@ class Api: with self.queue_lock: processed = process_images(p) - b64images = [] - for i in processed.images: - buffer = io.BytesIO() - i.save(buffer, format="png") - b64images.append(base64.b64encode(buffer.getvalue())) + b64images = list(map(img_to_base64, processed.images)) return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info)) - def img2imgapi(self): raise NotImplementedError diff --git a/modules/api/models.py b/modules/api/models.py new file mode 100644 index 00000000..a7d247d8 --- /dev/null +++ b/modules/api/models.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel, Field, Json + +class TextToImageResponse(BaseModel): + images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") + parameters: Json + info: Json + + \ No newline at end of file -- cgit v1.2.1 From 28e26c2bef217ae82eb9e980cceb3f67ef22e109 Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sat, 22 Oct 2022 23:13:32 -0300 Subject: Add "extra" single image operation - Separate extra modes into 3 endpoints so the user ddoesn't ahve to handle so many unused parameters. - Add response model for codumentation --- modules/api/api.py | 43 ++++++++++++++++++++++++++++++++++++++----- modules/api/models.py | 26 +++++++++++++++++++++++++- 2 files changed, 63 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index c17d7580..3b804373 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -8,20 +8,42 @@ import json import io import base64 from modules.api.models import * +from PIL import Image +from modules.extras import run_extras + +def upscaler_to_index(name: str): + try: + return [x.name.lower() for x in shared.sd_upscalers].index(name.lower()) + except: + raise HTTPException(status_code=400, detail="Upscaler not found") sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) -def img_to_base64(img): +def img_to_base64(img: str): buffer = io.BytesIO() img.save(buffer, format="png") return base64.b64encode(buffer.getvalue()) +def base64_to_bytes(base64Img: str): + if "," in base64Img: + base64Img = base64Img.split(",")[1] + return io.BytesIO(base64.b64decode(base64Img)) + +def base64_to_images(base64Imgs: list[str]): + imgs = [] + for img in base64Imgs: + img = Image.open(base64_to_bytes(img)) + imgs.append(img) + return imgs + + class Api: def __init__(self, app, queue_lock): self.router = APIRouter() self.app = app self.queue_lock = queue_lock - self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) + self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) + self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ): sampler_index = sampler_to_index(txt2imgreq.sampler_index) @@ -45,12 +67,23 @@ class Api: return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info)) - def img2imgapi(self): raise NotImplementedError - def extrasapi(self): - raise NotImplementedError + def extras_single_image_api(self, req: ExtrasSingleImageRequest): + upscaler1Index = upscaler_to_index(req.upscaler_1) + upscaler2Index = upscaler_to_index(req.upscaler_2) + + reqDict = vars(req) + reqDict.pop('upscaler_1') + reqDict.pop('upscaler_2') + + reqDict['image'] = base64_to_images([reqDict['image']])[0] + + with self.queue_lock: + result = run_extras(**reqDict, extras_upscaler_1=upscaler1Index, extras_upscaler_2=upscaler2Index, extras_mode=0, image_folder="", input_dir="", output_dir="") + + return ExtrasSingleImageResponse(image="data:image/png;base64,"+img_to_base64(result[0]), html_info_x=result[1], html_info=result[2]) def pnginfoapi(self): raise NotImplementedError diff --git a/modules/api/models.py b/modules/api/models.py index a7d247d8..dcf1ab54 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -1,8 +1,32 @@ from pydantic import BaseModel, Field, Json +from typing_extensions import Literal +from modules.shared import sd_upscalers class TextToImageResponse(BaseModel): images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") parameters: Json info: Json - \ No newline at end of file +class ExtrasBaseRequest(BaseModel): + resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.") + show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?") + gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.") + codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.") + codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.") + upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.") + upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.") + upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.") + upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the choosen size?") + upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}") + upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}") + extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.") + +class ExtraBaseResponse(BaseModel): + html_info_x: str + html_info: str + +class ExtrasSingleImageRequest(ExtrasBaseRequest): + image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") + +class ExtrasSingleImageResponse(ExtraBaseResponse): + image: str = Field(default=None, title="Image", description="The generated image in base64 format.") \ No newline at end of file -- cgit v1.2.1 From 0523704dade0508bf3ae0c8cb799b1ae332d449b Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sun, 23 Oct 2022 12:27:50 -0300 Subject: Update run_extras to use the temp filename In batch mode run_extras tries to preserve the original file name of the images. The problem is that this makes no sense since the user only gets a list of images in the UI, trying to manually save them shows that this images have random temp names. Also, trying to keep "orig_name" in the API is a hassle that adds complexity to the consuming UI since the client has to use (or emulate) an input (type=file) element in a form. Using the normal file name not only doesn't change the output and functionality in the original UI but also helps keep the API simple. --- modules/extras.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index 22c5a1c1..29ac312e 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -33,7 +33,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ for img in image_folder: image = Image.open(img) imageArr.append(image) - imageNameArr.append(os.path.splitext(img.orig_name)[0]) + imageNameArr.append(os.path.splitext(img.name)[0]) elif extras_mode == 2: assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' -- cgit v1.2.1 From 4ff852ffb50859f2eae75375cab94dd790a46886 Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sun, 23 Oct 2022 13:07:59 -0300 Subject: Add batch processing "extras" endpoint --- modules/api/api.py | 25 +++++++++++++++++++++++-- modules/api/models.py | 15 ++++++++++++++- 2 files changed, 37 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 3b804373..528134a8 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -10,6 +10,7 @@ import base64 from modules.api.models import * from PIL import Image from modules.extras import run_extras +from gradio import processing_utils def upscaler_to_index(name: str): try: @@ -44,6 +45,7 @@ class Api: self.queue_lock = queue_lock self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) + self.app.add_api_route("/sdapi/v1/extra-batch-image", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ): sampler_index = sampler_to_index(txt2imgreq.sampler_index) @@ -78,12 +80,31 @@ class Api: reqDict.pop('upscaler_1') reqDict.pop('upscaler_2') - reqDict['image'] = base64_to_images([reqDict['image']])[0] + reqDict['image'] = processing_utils.decode_base64_to_file(reqDict['image']) with self.queue_lock: result = run_extras(**reqDict, extras_upscaler_1=upscaler1Index, extras_upscaler_2=upscaler2Index, extras_mode=0, image_folder="", input_dir="", output_dir="") - return ExtrasSingleImageResponse(image="data:image/png;base64,"+img_to_base64(result[0]), html_info_x=result[1], html_info=result[2]) + return ExtrasSingleImageResponse(image=processing_utils.encode_pil_to_base64(result[0]), html_info_x=result[1], html_info=result[2]) + + def extras_batch_images_api(self, req: ExtrasBatchImagesRequest): + upscaler1Index = upscaler_to_index(req.upscaler_1) + upscaler2Index = upscaler_to_index(req.upscaler_2) + + reqDict = vars(req) + reqDict.pop('upscaler_1') + reqDict.pop('upscaler_2') + + reqDict['image_folder'] = list(map(processing_utils.decode_base64_to_file, reqDict['imageList'])) + reqDict.pop('imageList') + + with self.queue_lock: + result = run_extras(**reqDict, extras_upscaler_1=upscaler1Index, extras_upscaler_2=upscaler2Index, extras_mode=1, image="", input_dir="", output_dir="") + + return ExtrasBatchImagesResponse(images=list(map(processing_utils.encode_pil_to_base64, result[0])), html_info_x=result[1], html_info=result[2]) + + def extras_folder_processing_api(self): + raise NotImplementedError def pnginfoapi(self): raise NotImplementedError diff --git a/modules/api/models.py b/modules/api/models.py index dcf1ab54..bbd0ef53 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -29,4 +29,17 @@ class ExtrasSingleImageRequest(ExtrasBaseRequest): image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") class ExtrasSingleImageResponse(ExtraBaseResponse): - image: str = Field(default=None, title="Image", description="The generated image in base64 format.") \ No newline at end of file + image: str = Field(default=None, title="Image", description="The generated image in base64 format.") + +class SerializableImage(BaseModel): + path: str = Field(title="Path", description="The image's path ()") + +class ImageItem(BaseModel): + data: str = Field(title="image data") + name: str = Field(title="filename") + +class ExtrasBatchImagesRequest(ExtrasBaseRequest): + imageList: list[str] = Field(title="Images", description="List of images to work on. Must be Base64 strings") + +class ExtrasBatchImagesResponse(ExtraBaseResponse): + images: list[str] = Field(title="Images", description="The generated images in base64 format.") \ No newline at end of file -- cgit v1.2.1 From e0ca4dfbc10e0af8dfc4185e5e758f33fd2f0d81 Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sun, 23 Oct 2022 15:13:37 -0300 Subject: Update endpoints to use gradio's own utils functions --- modules/api/api.py | 75 +++++++++++++++++++++++++-------------------------- modules/api/models.py | 4 +-- 2 files changed, 38 insertions(+), 41 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 3f490ce2..3acb1f36 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -20,27 +20,27 @@ def upscaler_to_index(name: str): sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) -def img_to_base64(img: str): - buffer = io.BytesIO() - img.save(buffer, format="png") - return base64.b64encode(buffer.getvalue()) - -def base64_to_bytes(base64Img: str): - if "," in base64Img: - base64Img = base64Img.split(",")[1] - return io.BytesIO(base64.b64decode(base64Img)) - -def base64_to_images(base64Imgs: list[str]): - imgs = [] - for img in base64Imgs: - img = Image.open(base64_to_bytes(img)) - imgs.append(img) - return imgs +# def img_to_base64(img: str): +# buffer = io.BytesIO() +# img.save(buffer, format="png") +# return base64.b64encode(buffer.getvalue()) + +# def base64_to_bytes(base64Img: str): +# if "," in base64Img: +# base64Img = base64Img.split(",")[1] +# return io.BytesIO(base64.b64decode(base64Img)) + +# def base64_to_images(base64Imgs: list[str]): +# imgs = [] +# for img in base64Imgs: +# img = Image.open(base64_to_bytes(img)) +# imgs.append(img) +# return imgs class ImageToImageResponse(BaseModel): images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") - parameters: Json - info: Json + parameters: dict + info: str class Api: @@ -49,17 +49,17 @@ class Api: self.app = app self.queue_lock = queue_lock self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) - self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"]) + self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) self.app.add_api_route("/sdapi/v1/extra-batch-image", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) - def __base64_to_image(self, base64_string): - # if has a comma, deal with prefix - if "," in base64_string: - base64_string = base64_string.split(",")[1] - imgdata = base64.b64decode(base64_string) - # convert base64 to PIL image - return Image.open(io.BytesIO(imgdata)) + # def __base64_to_image(self, base64_string): + # # if has a comma, deal with prefix + # if "," in base64_string: + # base64_string = base64_string.split(",")[1] + # imgdata = base64.b64decode(base64_string) + # # convert base64 to PIL image + # return Image.open(io.BytesIO(imgdata)) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): sampler_index = sampler_to_index(txt2imgreq.sampler_index) @@ -79,11 +79,9 @@ class Api: with self.queue_lock: processed = process_images(p) - b64images = list(map(img_to_base64, processed.images)) - - return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info)) - + b64images = list(map(processing_utils.encode_pil_to_base64, processed.images)) + return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.info) def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): sampler_index = sampler_to_index(img2imgreq.sampler_index) @@ -98,7 +96,7 @@ class Api: mask = img2imgreq.mask if mask: - mask = self.__base64_to_image(mask) + mask = processing_utils.decode_base64_to_image(mask) populate = img2imgreq.copy(update={ # Override __init__ params @@ -113,7 +111,7 @@ class Api: imgs = [] for img in init_images: - img = self.__base64_to_image(img) + img = processing_utils.decode_base64_to_image(img) imgs = [img] * p.batch_size p.init_images = imgs @@ -121,13 +119,12 @@ class Api: with self.queue_lock: processed = process_images(p) - b64images = [] - for i in processed.images: - buffer = io.BytesIO() - i.save(buffer, format="png") - b64images.append(base64.b64encode(buffer.getvalue())) - - return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=json.dumps(processed.info)) + b64images = list(map(processing_utils.encode_pil_to_base64, processed.images)) + # for i in processed.images: + # buffer = io.BytesIO() + # i.save(buffer, format="png") + # b64images.append(base64.b64encode(buffer.getvalue())) + return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.info) def extras_single_image_api(self, req: ExtrasSingleImageRequest): upscaler1Index = upscaler_to_index(req.upscaler_1) diff --git a/modules/api/models.py b/modules/api/models.py index bbd0ef53..209f8af5 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -4,8 +4,8 @@ from modules.shared import sd_upscalers class TextToImageResponse(BaseModel): images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") - parameters: Json - info: Json + parameters: str + info: str class ExtrasBaseRequest(BaseModel): resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.") -- cgit v1.2.1 From 866b36d705a338d299aba385788729d60f7d48c8 Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sun, 23 Oct 2022 15:35:49 -0300 Subject: Move processing's models into models.py It didn't make sense to have two differente files for the same and "models" is a more descriptive name. --- modules/api/api.py | 57 ++++------------------- modules/api/models.py | 112 +++++++++++++++++++++++++++++++++++++++++++++- modules/api/processing.py | 106 ------------------------------------------- 3 files changed, 119 insertions(+), 156 deletions(-) delete mode 100644 modules/api/processing.py (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 3acb1f36..20e85e82 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,16 +1,11 @@ -from modules.api.processing import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI -from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images -from modules.sd_samplers import all_samplers -import modules.shared as shared import uvicorn +from gradio import processing_utils from fastapi import APIRouter, HTTPException -import json -import io -import base64 +import modules.shared as shared from modules.api.models import * -from PIL import Image +from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images +from modules.sd_samplers import all_samplers from modules.extras import run_extras -from gradio import processing_utils def upscaler_to_index(name: str): try: @@ -20,29 +15,6 @@ def upscaler_to_index(name: str): sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) -# def img_to_base64(img: str): -# buffer = io.BytesIO() -# img.save(buffer, format="png") -# return base64.b64encode(buffer.getvalue()) - -# def base64_to_bytes(base64Img: str): -# if "," in base64Img: -# base64Img = base64Img.split(",")[1] -# return io.BytesIO(base64.b64decode(base64Img)) - -# def base64_to_images(base64Imgs: list[str]): -# imgs = [] -# for img in base64Imgs: -# img = Image.open(base64_to_bytes(img)) -# imgs.append(img) -# return imgs - -class ImageToImageResponse(BaseModel): - images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") - parameters: dict - info: str - - class Api: def __init__(self, app, queue_lock): self.router = APIRouter() @@ -51,15 +23,7 @@ class Api: self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) - self.app.add_api_route("/sdapi/v1/extra-batch-image", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) - - # def __base64_to_image(self, base64_string): - # # if has a comma, deal with prefix - # if "," in base64_string: - # base64_string = base64_string.split(",")[1] - # imgdata = base64.b64decode(base64_string) - # # convert base64 to PIL image - # return Image.open(io.BytesIO(imgdata)) + self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): sampler_index = sampler_to_index(txt2imgreq.sampler_index) @@ -81,7 +45,7 @@ class Api: b64images = list(map(processing_utils.encode_pil_to_base64, processed.images)) - return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.info) + return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.info) def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): sampler_index = sampler_to_index(img2imgreq.sampler_index) @@ -120,10 +84,7 @@ class Api: processed = process_images(p) b64images = list(map(processing_utils.encode_pil_to_base64, processed.images)) - # for i in processed.images: - # buffer = io.BytesIO() - # i.save(buffer, format="png") - # b64images.append(base64.b64encode(buffer.getvalue())) + return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.info) def extras_single_image_api(self, req: ExtrasSingleImageRequest): @@ -134,12 +95,12 @@ class Api: reqDict.pop('upscaler_1') reqDict.pop('upscaler_2') - reqDict['image'] = processing_utils.decode_base64_to_file(reqDict['image']) + reqDict['image'] = processing_utils.decode_base64_to_image(reqDict['image']) with self.queue_lock: result = run_extras(**reqDict, extras_upscaler_1=upscaler1Index, extras_upscaler_2=upscaler2Index, extras_mode=0, image_folder="", input_dir="", output_dir="") - return ExtrasSingleImageResponse(image=processing_utils.encode_pil_to_base64(result[0]), html_info_x=result[1], html_info=result[2]) + return ExtrasSingleImageResponse(image=processing_utils.encode_pil_to_base64(result[0][0]), html_info_x=result[1], html_info=result[2]) def extras_batch_images_api(self, req: ExtrasBatchImagesRequest): upscaler1Index = upscaler_to_index(req.upscaler_1) diff --git a/modules/api/models.py b/modules/api/models.py index 209f8af5..362e6277 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -1,10 +1,118 @@ -from pydantic import BaseModel, Field, Json +import inspect +from pydantic import BaseModel, Field, Json, create_model +from typing import Any, Optional from typing_extensions import Literal +from inflection import underscore +from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img from modules.shared import sd_upscalers +API_NOT_ALLOWED = [ + "self", + "kwargs", + "sd_model", + "outpath_samples", + "outpath_grids", + "sampler_index", + "do_not_save_samples", + "do_not_save_grid", + "extra_generation_params", + "overlay_images", + "do_not_reload_embeddings", + "seed_enable_extras", + "prompt_for_display", + "sampler_noise_scheduler_override", + "ddim_discretize" +] + +class ModelDef(BaseModel): + """Assistance Class for Pydantic Dynamic Model Generation""" + + field: str + field_alias: str + field_type: Any + field_value: Any + + +class PydanticModelGenerator: + """ + Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about: + source_data is a snapshot of the default values produced by the class + params are the names of the actual keys required by __init__ + """ + + def __init__( + self, + model_name: str = None, + class_instance = None, + additional_fields = None, + ): + def field_type_generator(k, v): + # field_type = str if not overrides.get(k) else overrides[k]["type"] + # print(k, v.annotation, v.default) + field_type = v.annotation + + return Optional[field_type] + + def merge_class_params(class_): + all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_))) + parameters = {} + for classes in all_classes: + parameters = {**parameters, **inspect.signature(classes.__init__).parameters} + return parameters + + + self._model_name = model_name + self._class_data = merge_class_params(class_instance) + self._model_def = [ + ModelDef( + field=underscore(k), + field_alias=k, + field_type=field_type_generator(k, v), + field_value=v.default + ) + for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED + ] + + for fields in additional_fields: + self._model_def.append(ModelDef( + field=underscore(fields["key"]), + field_alias=fields["key"], + field_type=fields["type"], + field_value=fields["default"])) + + def generate_model(self): + """ + Creates a pydantic BaseModel + from the json and overrides provided at initialization + """ + fields = { + d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def + } + DynamicModel = create_model(self._model_name, **fields) + DynamicModel.__config__.allow_population_by_field_name = True + DynamicModel.__config__.allow_mutation = True + return DynamicModel + +StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( + "StableDiffusionProcessingTxt2Img", + StableDiffusionProcessingTxt2Img, + [{"key": "sampler_index", "type": str, "default": "Euler"}] +).generate_model() + +StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( + "StableDiffusionProcessingImg2Img", + StableDiffusionProcessingImg2Img, + [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}] +).generate_model() + class TextToImageResponse(BaseModel): images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") - parameters: str + parameters: dict + info: str + +class ImageToImageResponse(BaseModel): + images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") + parameters: dict info: str class ExtrasBaseRequest(BaseModel): diff --git a/modules/api/processing.py b/modules/api/processing.py deleted file mode 100644 index f551fa35..00000000 --- a/modules/api/processing.py +++ /dev/null @@ -1,106 +0,0 @@ -from array import array -from inflection import underscore -from typing import Any, Dict, Optional -from pydantic import BaseModel, Field, create_model -from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img -import inspect - - -API_NOT_ALLOWED = [ - "self", - "kwargs", - "sd_model", - "outpath_samples", - "outpath_grids", - "sampler_index", - "do_not_save_samples", - "do_not_save_grid", - "extra_generation_params", - "overlay_images", - "do_not_reload_embeddings", - "seed_enable_extras", - "prompt_for_display", - "sampler_noise_scheduler_override", - "ddim_discretize" -] - -class ModelDef(BaseModel): - """Assistance Class for Pydantic Dynamic Model Generation""" - - field: str - field_alias: str - field_type: Any - field_value: Any - - -class PydanticModelGenerator: - """ - Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about: - source_data is a snapshot of the default values produced by the class - params are the names of the actual keys required by __init__ - """ - - def __init__( - self, - model_name: str = None, - class_instance = None, - additional_fields = None, - ): - def field_type_generator(k, v): - # field_type = str if not overrides.get(k) else overrides[k]["type"] - # print(k, v.annotation, v.default) - field_type = v.annotation - - return Optional[field_type] - - def merge_class_params(class_): - all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_))) - parameters = {} - for classes in all_classes: - parameters = {**parameters, **inspect.signature(classes.__init__).parameters} - return parameters - - - self._model_name = model_name - self._class_data = merge_class_params(class_instance) - self._model_def = [ - ModelDef( - field=underscore(k), - field_alias=k, - field_type=field_type_generator(k, v), - field_value=v.default - ) - for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED - ] - - for fields in additional_fields: - self._model_def.append(ModelDef( - field=underscore(fields["key"]), - field_alias=fields["key"], - field_type=fields["type"], - field_value=fields["default"])) - - def generate_model(self): - """ - Creates a pydantic BaseModel - from the json and overrides provided at initialization - """ - fields = { - d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def - } - DynamicModel = create_model(self._model_name, **fields) - DynamicModel.__config__.allow_population_by_field_name = True - DynamicModel.__config__.allow_mutation = True - return DynamicModel - -StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( - "StableDiffusionProcessingTxt2Img", - StableDiffusionProcessingTxt2Img, - [{"key": "sampler_index", "type": str, "default": "Euler"}] -).generate_model() - -StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( - "StableDiffusionProcessingImg2Img", - StableDiffusionProcessingImg2Img, - [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}] -).generate_model() \ No newline at end of file -- cgit v1.2.1 From 1e625624ba6ab3dfc167f0a5226780bb9b50fb58 Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sun, 23 Oct 2022 16:01:16 -0300 Subject: Add folder processing endpoint Also minor refactor --- modules/api/api.py | 56 +++++++++++++++++++++++++++------------------------ modules/api/models.py | 6 +++++- 2 files changed, 35 insertions(+), 27 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 20e85e82..7b4fbe29 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,5 +1,5 @@ import uvicorn -from gradio import processing_utils +from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image from fastapi import APIRouter, HTTPException import modules.shared as shared from modules.api.models import * @@ -11,10 +11,18 @@ def upscaler_to_index(name: str): try: return [x.name.lower() for x in shared.sd_upscalers].index(name.lower()) except: - raise HTTPException(status_code=400, detail="Upscaler not found") + raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}") sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) +def setUpscalers(req: dict): + reqDict = vars(req) + reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1) + reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2) + reqDict.pop('upscaler_1') + reqDict.pop('upscaler_2') + return reqDict + class Api: def __init__(self, app, queue_lock): self.router = APIRouter() @@ -24,6 +32,7 @@ class Api: self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) + self.app.add_api_route("/sdapi/v1/extra-folder-images", self.extras_folder_processing_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): sampler_index = sampler_to_index(txt2imgreq.sampler_index) @@ -43,7 +52,7 @@ class Api: with self.queue_lock: processed = process_images(p) - b64images = list(map(processing_utils.encode_pil_to_base64, processed.images)) + b64images = list(map(encode_pil_to_base64, processed.images)) return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.info) @@ -60,7 +69,7 @@ class Api: mask = img2imgreq.mask if mask: - mask = processing_utils.decode_base64_to_image(mask) + mask = decode_base64_to_image(mask) populate = img2imgreq.copy(update={ # Override __init__ params @@ -75,7 +84,7 @@ class Api: imgs = [] for img in init_images: - img = processing_utils.decode_base64_to_image(img) + img = decode_base64_to_image(img) imgs = [img] * p.batch_size p.init_images = imgs @@ -83,43 +92,38 @@ class Api: with self.queue_lock: processed = process_images(p) - b64images = list(map(processing_utils.encode_pil_to_base64, processed.images)) + b64images = list(map(encode_pil_to_base64, processed.images)) return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.info) def extras_single_image_api(self, req: ExtrasSingleImageRequest): - upscaler1Index = upscaler_to_index(req.upscaler_1) - upscaler2Index = upscaler_to_index(req.upscaler_2) - - reqDict = vars(req) - reqDict.pop('upscaler_1') - reqDict.pop('upscaler_2') + reqDict = setUpscalers(req) - reqDict['image'] = processing_utils.decode_base64_to_image(reqDict['image']) + reqDict['image'] = decode_base64_to_image(reqDict['image']) with self.queue_lock: - result = run_extras(**reqDict, extras_upscaler_1=upscaler1Index, extras_upscaler_2=upscaler2Index, extras_mode=0, image_folder="", input_dir="", output_dir="") + result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", **reqDict) - return ExtrasSingleImageResponse(image=processing_utils.encode_pil_to_base64(result[0][0]), html_info_x=result[1], html_info=result[2]) + return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info_x=result[1], html_info=result[2]) def extras_batch_images_api(self, req: ExtrasBatchImagesRequest): - upscaler1Index = upscaler_to_index(req.upscaler_1) - upscaler2Index = upscaler_to_index(req.upscaler_2) + reqDict = setUpscalers(req) - reqDict = vars(req) - reqDict.pop('upscaler_1') - reqDict.pop('upscaler_2') - - reqDict['image_folder'] = list(map(processing_utils.decode_base64_to_file, reqDict['imageList'])) + reqDict['image_folder'] = list(map(decode_base64_to_file, reqDict['imageList'])) reqDict.pop('imageList') with self.queue_lock: - result = run_extras(**reqDict, extras_upscaler_1=upscaler1Index, extras_upscaler_2=upscaler2Index, extras_mode=1, image="", input_dir="", output_dir="") + result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", **reqDict) - return ExtrasBatchImagesResponse(images=list(map(processing_utils.encode_pil_to_base64, result[0])), html_info_x=result[1], html_info=result[2]) + return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info_x=result[1], html_info=result[2]) - def extras_folder_processing_api(self): - raise NotImplementedError + def extras_folder_processing_api(self, req:ExtrasFoldersRequest): + reqDict = setUpscalers(req) + + with self.queue_lock: + result = run_extras(extras_mode=2, image=None, image_folder=None, **reqDict) + + return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info_x=result[1], html_info=result[2]) def pnginfoapi(self): raise NotImplementedError diff --git a/modules/api/models.py b/modules/api/models.py index 362e6277..6f096807 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -150,4 +150,8 @@ class ExtrasBatchImagesRequest(ExtrasBaseRequest): imageList: list[str] = Field(title="Images", description="List of images to work on. Must be Base64 strings") class ExtrasBatchImagesResponse(ExtraBaseResponse): - images: list[str] = Field(title="Images", description="The generated images in base64 format.") \ No newline at end of file + images: list[str] = Field(title="Images", description="The generated images in base64 format.") + +class ExtrasFoldersRequest(ExtrasBaseRequest): + input_dir: str = Field(title="Input directory", description="Directory path from where to take the images") + output_dir: str = Field(title="Output directory", description="Directory path to put the processsed images into") -- cgit v1.2.1 From 90f02c75220d187e075203a4e3b450bfba392c4d Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sun, 23 Oct 2022 16:03:30 -0300 Subject: Remove unused field and class --- modules/api/api.py | 6 +++--- modules/api/models.py | 6 +----- 2 files changed, 4 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 7b4fbe29..799e3701 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -104,7 +104,7 @@ class Api: with self.queue_lock: result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", **reqDict) - return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info_x=result[1], html_info=result[2]) + return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1]) def extras_batch_images_api(self, req: ExtrasBatchImagesRequest): reqDict = setUpscalers(req) @@ -115,7 +115,7 @@ class Api: with self.queue_lock: result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", **reqDict) - return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info_x=result[1], html_info=result[2]) + return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) def extras_folder_processing_api(self, req:ExtrasFoldersRequest): reqDict = setUpscalers(req) @@ -123,7 +123,7 @@ class Api: with self.queue_lock: result = run_extras(extras_mode=2, image=None, image_folder=None, **reqDict) - return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info_x=result[1], html_info=result[2]) + return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) def pnginfoapi(self): raise NotImplementedError diff --git a/modules/api/models.py b/modules/api/models.py index 6f096807..e461d397 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -130,8 +130,7 @@ class ExtrasBaseRequest(BaseModel): extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.") class ExtraBaseResponse(BaseModel): - html_info_x: str - html_info: str + html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.") class ExtrasSingleImageRequest(ExtrasBaseRequest): image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") @@ -139,9 +138,6 @@ class ExtrasSingleImageRequest(ExtrasBaseRequest): class ExtrasSingleImageResponse(ExtraBaseResponse): image: str = Field(default=None, title="Image", description="The generated image in base64 format.") -class SerializableImage(BaseModel): - path: str = Field(title="Path", description="The image's path ()") - class ImageItem(BaseModel): data: str = Field(title="image data") name: str = Field(title="filename") -- cgit v1.2.1 From 994aaadf0861366b9e6f219e1a3c25a233fbb63c Mon Sep 17 00:00:00 2001 From: yfszzx Date: Mon, 24 Oct 2022 16:44:36 +0800 Subject: a strange bug --- modules/ui.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index a73b9ff0..8c6dc026 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -55,6 +55,7 @@ mimetypes.init() mimetypes.add_type('application/javascript', '.js') txt2img_paste_fields = [] img2img_paste_fields = [] +init_img_components = {} if not cmd_opts.share and not cmd_opts.listen: @@ -1174,6 +1175,9 @@ def create_ui(wrap_gradio_gpu_call): outputs=[init_img_with_mask], ) + global init_img_components + init_img_components = {"img2img":init_img, "inpaint":init_img_with_mask, "extras":extras_image} + with gr.Blocks(analytics_enabled=False) as pnginfo_interface: with gr.Row().style(equal_height=False): with gr.Column(variant='panel'): -- cgit v1.2.1 From 595dca85af9e26b5d76cd64659a5bdd9da4f2b89 Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Mon, 24 Oct 2022 08:32:18 -0300 Subject: Reverse run_extras change Update serialization on the batch images endpoint --- modules/api/api.py | 7 ++++++- modules/api/models.py | 8 ++++---- modules/extras.py | 2 +- 3 files changed, 11 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 799e3701..67b783de 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -109,7 +109,12 @@ class Api: def extras_batch_images_api(self, req: ExtrasBatchImagesRequest): reqDict = setUpscalers(req) - reqDict['image_folder'] = list(map(decode_base64_to_file, reqDict['imageList'])) + def prepareFiles(file): + file = decode_base64_to_file(file.data, file_path=file.name) + file.orig_name = file.name + return file + + reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList'])) reqDict.pop('imageList') with self.queue_lock: diff --git a/modules/api/models.py b/modules/api/models.py index e461d397..fca2f991 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -138,12 +138,12 @@ class ExtrasSingleImageRequest(ExtrasBaseRequest): class ExtrasSingleImageResponse(ExtraBaseResponse): image: str = Field(default=None, title="Image", description="The generated image in base64 format.") -class ImageItem(BaseModel): - data: str = Field(title="image data") - name: str = Field(title="filename") +class FileData(BaseModel): + data: str = Field(title="File data", description="Base64 representation of the file") + name: str = Field(title="File name") class ExtrasBatchImagesRequest(ExtrasBaseRequest): - imageList: list[str] = Field(title="Images", description="List of images to work on. Must be Base64 strings") + imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings") class ExtrasBatchImagesResponse(ExtraBaseResponse): images: list[str] = Field(title="Images", description="The generated images in base64 format.") diff --git a/modules/extras.py b/modules/extras.py index 29ac312e..22c5a1c1 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -33,7 +33,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ for img in image_folder: image = Image.open(img) imageArr.append(image) - imageNameArr.append(os.path.splitext(img.name)[0]) + imageNameArr.append(os.path.splitext(img.orig_name)[0]) elif extras_mode == 2: assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' -- cgit v1.2.1 From ff305acd51cc71c5eea8aee0f537a26a6d1ba2a1 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Tue, 25 Oct 2022 15:33:43 +0800 Subject: some rights for extensions --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 76cbb1bd..7b1fadf2 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -82,6 +82,7 @@ parser.add_argument("--api", action='store_true', help="use api=True to launch t parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui") parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) +parser.add_argument("--administrator", type=str, help="Administrator rights", default=None) cmd_opts = parser.parse_args() restricted_opts = [ -- cgit v1.2.1 From 9ba439b53313ef78984dd8e39f25b34501188ee2 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Tue, 25 Oct 2022 18:48:07 +0800 Subject: need some rights for extensions --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 7b1fadf2..b5975707 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -82,7 +82,7 @@ parser.add_argument("--api", action='store_true', help="use api=True to launch t parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui") parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) -parser.add_argument("--administrator", type=str, help="Administrator rights", default=None) +parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False) cmd_opts = parser.parse_args() restricted_opts = [ -- cgit v1.2.1 From f9549d1cbb3f1d7d1f0fb70375a06e31f9c5dd9d Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Tue, 25 Oct 2022 11:14:12 -0700 Subject: Added option to use unmasked conditioning image. --- modules/processing.py | 6 +++++- modules/shared.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index c61bbfbd..96f56b0d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -768,7 +768,11 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): # Create another latent image, this time with a masked version of the original input. conditioning_mask = conditioning_mask.to(image.device) - conditioning_image = image * (1.0 - conditioning_mask) + + conditioning_image = image + if shared.opts.inpainting_mask_image: + conditioning_image = conditioning_image * (1.0 - conditioning_mask) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) # Create the concatenated conditioning tensor to be fed to `c_concat` diff --git a/modules/shared.py b/modules/shared.py index 308fccce..1d0ff1a1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -320,6 +320,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), + "inpainting_mask_image": OptionInfo(True, "Mask original image for conditioning used by inpainting model."), })) -- cgit v1.2.1 From 605d27687f433c0fefb9025aace7dc94f0ebd454 Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Tue, 25 Oct 2022 12:20:54 -0700 Subject: Added conditioning image masking to xy_grid. Use `True` and `False` to select values. --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 96f56b0d..23ee5e02 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -770,7 +770,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): conditioning_mask = conditioning_mask.to(image.device) conditioning_image = image - if shared.opts.inpainting_mask_image: + if getattr(self, "inpainting_mask_image", shared.opts.inpainting_mask_image): conditioning_image = conditioning_image * (1.0 - conditioning_mask) conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) -- cgit v1.2.1 From 8b4f32779f28010fc8077e8fcfb85a3205b36bc2 Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Tue, 25 Oct 2022 13:15:08 -0700 Subject: Switch to a continous blend for cond. image. --- modules/processing.py | 9 ++++++--- modules/shared.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 23ee5e02..02292bdc 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -769,9 +769,12 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): # Create another latent image, this time with a masked version of the original input. conditioning_mask = conditioning_mask.to(image.device) - conditioning_image = image - if getattr(self, "inpainting_mask_image", shared.opts.inpainting_mask_image): - conditioning_image = conditioning_image * (1.0 - conditioning_mask) + # Smoothly interpolate between the masked and unmasked latent conditioning image. + conditioning_image = torch.lerp( + image, + image * (1.0 - conditioning_mask), + getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) + ) conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) diff --git a/modules/shared.py b/modules/shared.py index 1d0ff1a1..e0ffb824 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -320,7 +320,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), - "inpainting_mask_image": OptionInfo(True, "Mask original image for conditioning used by inpainting model."), + "inpainting_mask_weight": OptionInfo(1.0, "Blend betweeen an unmasked and masked conditioning image for inpainting models.", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), })) -- cgit v1.2.1 From b2e0d8ba789b345145436f6e960a3f0a896a6643 Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Wed, 26 Oct 2022 09:54:26 -0300 Subject: Remove folder endpoint --- modules/api/api.py | 9 --------- modules/api/models.py | 6 +----- 2 files changed, 1 insertion(+), 14 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index ca289d9f..49c213ea 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -32,7 +32,6 @@ class Api: self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) - self.app.add_api_route("/sdapi/v1/extra-folder-images", self.extras_folder_processing_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): sampler_index = sampler_to_index(txt2imgreq.sampler_index) @@ -126,14 +125,6 @@ class Api: return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) - def extras_folder_processing_api(self, req:ExtrasFoldersRequest): - reqDict = setUpscalers(req) - - with self.queue_lock: - result = run_extras(extras_mode=2, image=None, image_folder=None, **reqDict) - - return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) - def pnginfoapi(self): raise NotImplementedError diff --git a/modules/api/models.py b/modules/api/models.py index 00406368..dd122321 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -148,8 +148,4 @@ class ExtrasBatchImagesRequest(ExtrasBaseRequest): imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings") class ExtrasBatchImagesResponse(ExtraBaseResponse): - images: list[str] = Field(title="Images", description="The generated images in base64 format.") - -class ExtrasFoldersRequest(ExtrasBaseRequest): - input_dir: str = Field(title="Input directory", description="Directory path from where to take the images") - output_dir: str = Field(title="Output directory", description="Directory path to put the processsed images into") + images: list[str] = Field(title="Images", description="The generated images in base64 format.") \ No newline at end of file -- cgit v1.2.1 From fddb4883f4a408b3464076465e1b0949ebe0fc30 Mon Sep 17 00:00:00 2001 From: evshiron Date: Wed, 26 Oct 2022 22:33:45 +0800 Subject: prototype progress api --- modules/api/api.py | 89 +++++++++++++++++++++++++++++++++++++++++++++--------- modules/shared.py | 13 ++++++++ 2 files changed, 88 insertions(+), 14 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 6e9d6097..c038f674 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,8 +1,11 @@ +import time + from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.sd_samplers import all_samplers from modules.extras import run_pnginfo import modules.shared as shared +from modules import devices import uvicorn from fastapi import Body, APIRouter, HTTPException from fastapi.responses import JSONResponse @@ -25,6 +28,37 @@ class ImageToImageResponse(BaseModel): parameters: Json info: Json +class ProgressResponse(BaseModel): + progress: float + eta_relative: float + state: Json + +# copy from wrap_gradio_gpu_call of webui.py +# because queue lock will be acquired in api handlers +# and time start needs to be set +# the function has been modified into two parts + +def before_gpu_call(): + devices.torch_gc() + + shared.state.sampling_step = 0 + shared.state.job_count = -1 + shared.state.job_no = 0 + shared.state.job_timestamp = shared.state.get_job_timestamp() + shared.state.current_latent = None + shared.state.current_image = None + shared.state.current_image_sampling_step = 0 + shared.state.skipped = False + shared.state.interrupted = False + shared.state.textinfo = None + shared.state.time_start = time.time() + + +def after_gpu_call(): + shared.state.job = "" + shared.state.job_count = 0 + + devices.torch_gc() class Api: def __init__(self, app, queue_lock): @@ -33,6 +67,7 @@ class Api: self.queue_lock = queue_lock self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"]) + self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"]) def __base64_to_image(self, base64_string): # if has a comma, deal with prefix @@ -44,12 +79,12 @@ class Api: def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): sampler_index = sampler_to_index(txt2imgreq.sampler_index) - + if sampler_index is None: - raise HTTPException(status_code=404, detail="Sampler not found") - + raise HTTPException(status_code=404, detail="Sampler not found") + populate = txt2imgreq.copy(update={ # Override __init__ params - "sd_model": shared.sd_model, + "sd_model": shared.sd_model, "sampler_index": sampler_index[0], "do_not_save_samples": True, "do_not_save_grid": True @@ -57,9 +92,11 @@ class Api: ) p = StableDiffusionProcessingTxt2Img(**vars(populate)) # Override object param + before_gpu_call() with self.queue_lock: processed = process_images(p) - + after_gpu_call() + b64images = [] for i in processed.images: buffer = io.BytesIO() @@ -67,30 +104,30 @@ class Api: b64images.append(base64.b64encode(buffer.getvalue())) return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.js()) - - + + def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): sampler_index = sampler_to_index(img2imgreq.sampler_index) - + if sampler_index is None: - raise HTTPException(status_code=404, detail="Sampler not found") + raise HTTPException(status_code=404, detail="Sampler not found") init_images = img2imgreq.init_images if init_images is None: - raise HTTPException(status_code=404, detail="Init image not found") + raise HTTPException(status_code=404, detail="Init image not found") mask = img2imgreq.mask if mask: mask = self.__base64_to_image(mask) - + populate = img2imgreq.copy(update={ # Override __init__ params - "sd_model": shared.sd_model, + "sd_model": shared.sd_model, "sampler_index": sampler_index[0], "do_not_save_samples": True, - "do_not_save_grid": True, + "do_not_save_grid": True, "mask": mask } ) @@ -103,9 +140,11 @@ class Api: p.init_images = imgs # Override object param + before_gpu_call() with self.queue_lock: processed = process_images(p) - + after_gpu_call() + b64images = [] for i in processed.images: buffer = io.BytesIO() @@ -118,6 +157,28 @@ class Api: return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js()) + def progressapi(self): + # copy from check_progress_call of ui.py + + if shared.state.job_count == 0: + return ProgressResponse(progress=0, eta_relative=0, state=shared.state.js()) + + # avoid dividing zero + progress = 0.01 + + if shared.state.job_count > 0: + progress += shared.state.job_no / shared.state.job_count + if shared.state.sampling_steps > 0: + progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps + + time_since_start = time.time() - shared.state.time_start + eta = (time_since_start/progress) + eta_relative = eta-time_since_start + + progress = min(progress, 1) + + return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.js()) + def extrasapi(self): raise NotImplementedError diff --git a/modules/shared.py b/modules/shared.py index 1a9b8289..00f61898 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -146,6 +146,19 @@ class State: def get_job_timestamp(self): return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp? + def js(self): + obj = { + "skipped": self.skipped, + "interrupted": self.skipped, + "job": self.job, + "job_count": self.job_count, + "job_no": self.job_no, + "sampling_step": self.sampling_step, + "sampling_steps": self.sampling_steps, + } + + return json.dumps(obj) + state = State() -- cgit v1.2.1 From 3de036514138d7cdcba9729c975f1683a8e06b16 Mon Sep 17 00:00:00 2001 From: xmodar Date: Wed, 26 Oct 2022 23:56:11 +0300 Subject: Add id access to scripts list in the css --- modules/scripts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/scripts.py b/modules/scripts.py index 9323af3e..a7f36012 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -236,7 +236,7 @@ class ScriptRunner: with gr.Group(): create_script_ui(script, inputs, inputs_alwayson) - dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index") + dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index") dropdown.save_to_config = True inputs[0] = dropdown -- cgit v1.2.1 From 4a4647e0dfc812783db7fa993d486b031f098ef8 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Thu, 27 Oct 2022 13:36:11 +0800 Subject: create send to buttons in one module --- modules/generation_parameters_copypaste.py | 86 +++++++- modules/shared.py | 1 + modules/ui.py | 344 +++++++++-------------------- 3 files changed, 184 insertions(+), 247 deletions(-) (limited to 'modules') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index f73647da..2b80737a 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -3,13 +3,16 @@ import re import gradio as gr from modules.shared import script_path from modules import shared +import tempfile +from PIL import Image, PngImagePlugin re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)' re_param = re.compile(re_param_code) re_params = re.compile(r"^(?:" + re_param_code + "){3,}$") re_imagesize = re.compile(r"^(\d+)x(\d+)$") type_of_gr_update = type(gr.update()) - +paste_fields = {} +bind_list = [] def quote(text): if ',' not in str(text): @@ -20,6 +23,81 @@ def quote(text): text = text.replace('"', '\\"') return f'"{text}"' +def image_from_url_text(filedata): + if type(filedata) == dict and filedata["is_file"]: + filename = filedata["name"] + tempdir = os.path.normpath(tempfile.gettempdir()) + normfn = os.path.normpath(filename) + assert normfn.startswith(tempdir), 'trying to open image file not in temporary directory' + + return Image.open(filename) + + if type(filedata) == list: + if len(filedata) == 0: + return None + + filedata = filedata[0] + + if filedata.startswith("data:image/png;base64,"): + filedata = filedata[len("data:image/png;base64,"):] + + filedata = base64.decodebytes(filedata.encode('utf-8')) + image = Image.open(io.BytesIO(filedata)) + return image + +def add_paste_fields(tabname, init_img, fields): + paste_fields[tabname] = {"init_img":init_img, "fields": fields} + +def create_buttons(tabs_list): + buttons = {} + for tab in tabs_list: + buttons[tab] = gr.Button(f"Send to {tab}") + return buttons + +#if send_generate_info is a tab name, mean generate_info comes from the params fields of the tab +def bind_buttons(buttons, send_image, send_generate_info): + bind_list.append([buttons, send_image, send_generate_info]) + +def run_bind(): + for buttons, send_image, send_generate_info in bind_list: + for tab in buttons: + button = buttons[tab] + if send_image and paste_fields[tab]["init_img"]: + if type(send_image) == gr.Gallery: + button.click( + fn=lambda x: image_from_url_text(x), + _js="extract_image_from_gallery", + inputs=[send_image], + outputs=[paste_fields[tab]["init_img"]], + ) + else: + button.click( + fn=lambda x:x, + inputs=[send_image], + outputs=[paste_fields[tab]["init_img"]], + ) + + if send_generate_info and paste_fields[tab]["fields"] is not None: + paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Size-1', 'Size-2'] + if shared.opts.send_seed: + paste_field_names += ["Seed"] + if send_generate_info in paste_fields: + button.click( + fn=lambda *x:x, + inputs=[field for field,name in paste_fields[send_generate_info]["fields"] if name in paste_field_names], + outputs=[field for field,name in paste_fields[tab]["fields"] if name in paste_field_names], + ) + + else: + connect_paste(button, [(field, name) for field, name in paste_fields[tab]["fields"] if name in paste_field_names], send_generate_info) + + button.click( + fn=None, + _js=f"switch_to_{tab}", + inputs=None, + outputs=None, + ) + def parse_generation_parameters(x: str): """parses generation parameters string, the one you see in text field under the picture in UI: ``` @@ -67,8 +145,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model return res - -def connect_paste(button, paste_fields, input_comp, js=None): +def connect_paste(button, paste_fields, input_comp): def paste_func(prompt): if not prompt and not shared.cmd_opts.hide_ui_dir_config: filename = os.path.join(script_path, "params.txt") @@ -106,7 +183,8 @@ def connect_paste(button, paste_fields, input_comp, js=None): button.click( fn=paste_func, - _js=js, inputs=[input_comp], outputs=[x[0] for x in paste_fields], ) + + diff --git a/modules/shared.py b/modules/shared.py index f8b13b06..3ade2afa 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -279,6 +279,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "filter_nsfw": OptionInfo(False, "Filter NSFW content"), 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), + "send_seed": OptionInfo(False, "Send seed when sending prompt or image to other interface"), })) options_templates.update(options_section(('interrogate', "Interrogate Options"), { diff --git a/modules/ui.py b/modules/ui.py index 3e5b84d2..ccba14b6 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -35,7 +35,7 @@ if cmd_opts.deepdanbooru: from modules.deepbooru import get_deepbooru_tags import modules.codeformer_model -import modules.generation_parameters_copypaste +import modules.generation_parameters_copypaste as parameters_copypaste import modules.gfpgan_model import modules.hypernetworks.ui import modules.ldsr_model @@ -49,14 +49,11 @@ from modules.sd_hijack import model_hijack from modules.sd_samplers import samplers, samplers_for_img2img import modules.textual_inversion.ui import modules.hypernetworks.ui +from modules.generation_parameters_copypaste import image_from_url_text # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() mimetypes.add_type('application/javascript', '.js') -txt2img_paste_fields = [] -img2img_paste_fields = [] -init_img_components = {} - if not cmd_opts.share and not cmd_opts.listen: # fix gradio phoning home @@ -99,37 +96,11 @@ def plaintext_to_html(text): text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" return text - -def image_from_url_text(filedata): - if type(filedata) == dict and filedata["is_file"]: - filename = filedata["name"] - tempdir = os.path.normpath(tempfile.gettempdir()) - normfn = os.path.normpath(filename) - assert normfn.startswith(tempdir), 'trying to open image file not in temporary directory' - - return Image.open(filename) - - if type(filedata) == list: - if len(filedata) == 0: - return None - - filedata = filedata[0] - - if filedata.startswith("data:image/png;base64,"): - filedata = filedata[len("data:image/png;base64,"):] - - filedata = base64.decodebytes(filedata.encode('utf-8')) - image = Image.open(io.BytesIO(filedata)) - return image - - def send_gradio_gallery_to_image(x): if len(x) == 0: return None - return image_from_url_text(x[0]) - def save_files(js_data, images, do_make_zip, index): import csv filenames = [] @@ -193,7 +164,6 @@ def save_files(js_data, images, do_make_zip, index): return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}") - def save_pil_to_file(pil_image, dir=None): use_metadata = False metadata = PngImagePlugin.PngInfo() @@ -626,6 +596,83 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele ) return refresh_button +def create_output_panel(tabname, outdir): + def open_folder(f): + if not os.path.exists(f): + print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') + return + elif not os.path.isdir(f): + print(f""" +WARNING +An open_folder request was made with an argument that is not a folder. +This could be an error or a malicious attempt to run code on your computer. +Requested path was: {f} +""", file=sys.stderr) + return + + if not shared.cmd_opts.hide_ui_dir_config: + path = os.path.normpath(f) + if platform.system() == "Windows": + os.startfile(path) + elif platform.system() == "Darwin": + sp.Popen(["open", path]) + else: + sp.Popen(["xdg-open", path]) + + with gr.Column(variant='panel'): + with gr.Group(): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) + + generation_info = None + with gr.Column(): + with gr.Row(): + if tabname != "extras": + save = gr.Button('Save') + + buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) + button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder' + open_folder = gr.Button(folder_symbol, elem_id=button_id) + + open_folder.click( + fn=lambda: open_folder(opts.outdir_samples or outdir), + inputs=[], + outputs=[], + ) + + if tabname != "extras": + with gr.Row(): + do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False) + + with gr.Row(): + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False) + + with gr.Group(): + html_info = gr.HTML() + generation_info = gr.Textbox(visible=False) + + save.click( + fn=wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + do_make_zip, + html_info, + ], + outputs=[ + download_files, + html_info, + html_info, + html_info, + ] + ) + else: + html_info_x = gr.HTML() + html_info = gr.HTML() + parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) + return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info + + def create_ui(wrap_gradio_gpu_call): import modules.img2img @@ -676,31 +723,10 @@ def create_ui(wrap_gradio_gpu_call): with gr.Group(): custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False) - with gr.Column(variant='panel'): - - with gr.Group(): - txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False) - txt2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='txt2img_gallery').style(grid=4) - - with gr.Column(): - with gr.Row(): - save = gr.Button('Save') - send_to_img2img = gr.Button('Send to img2img') - send_to_inpaint = gr.Button('Send to inpaint') - send_to_extras = gr.Button('Send to extras') - button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder' - open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id) - - with gr.Row(): - do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False) - - with gr.Row(): - download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False) - - with gr.Group(): - html_info = gr.HTML() - generation_info = gr.Textbox(visible=False) + + txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples) + connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) @@ -755,24 +781,7 @@ def create_ui(wrap_gradio_gpu_call): fn=lambda x: gr_show(x), inputs=[enable_hr], outputs=[hr_options], - ) - - save.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]", - inputs=[ - generation_info, - txt2img_gallery, - do_make_zip, - html_info, - ], - outputs=[ - download_files, - html_info, - html_info, - html_info, - ] - ) + ) roll.click( fn=roll_artist, @@ -785,8 +794,7 @@ def create_ui(wrap_gradio_gpu_call): ] ) - global txt2img_paste_fields - txt2img_paste_fields = [ + parameters_copypaste.add_paste_fields("txt2img", None, [ (txt2img_prompt, "Prompt"), (txt2img_negative_prompt, "Negative prompt"), (steps, "Steps"), @@ -807,7 +815,7 @@ def create_ui(wrap_gradio_gpu_call): (firstphase_width, "First pass size-1"), (firstphase_height, "First pass size-2"), *modules.scripts.scripts_txt2img.infotext_fields - ] + ]) txt2img_preview_params = [ txt2img_prompt, @@ -894,30 +902,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Group(): custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True) - with gr.Column(variant='panel'): - - with gr.Group(): - img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) - img2img_gallery = gr.Gallery(label='Output', show_label=False, elem_id='img2img_gallery').style(grid=4) - - with gr.Column(): - with gr.Row(): - save = gr.Button('Save') - img2img_send_to_img2img = gr.Button('Send to img2img') - img2img_send_to_inpaint = gr.Button('Send to inpaint') - img2img_send_to_extras = gr.Button('Send to extras') - button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder' - open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id) - - with gr.Row(): - do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False) - - with gr.Row(): - download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False) - - with gr.Group(): - html_info = gr.HTML() - generation_info = gr.Textbox(visible=False) + img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples) connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) @@ -1004,24 +989,8 @@ def create_ui(wrap_gradio_gpu_call): fn=interrogate_deepbooru, inputs=[init_img], outputs=[img2img_prompt], - ) - - save.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]", - inputs=[ - generation_info, - img2img_gallery, - do_make_zip, - html_info, - ], - outputs=[ - download_files, - html_info, - html_info, - html_info, - ] ) + roll.click( fn=roll_artist, @@ -1056,7 +1025,8 @@ def create_ui(wrap_gradio_gpu_call): outputs=[prompt, negative_prompt, style1, style2], ) - global img2img_paste_fields + token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) + img2img_paste_fields = [ (img2img_prompt, "Prompt"), (img2img_negative_prompt, "Negative prompt"), @@ -1075,7 +1045,9 @@ def create_ui(wrap_gradio_gpu_call): (denoising_strength, "Denoising strength"), *modules.scripts.scripts_img2img.infotext_fields ] - token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) + parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) + parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) + with gr.Blocks(analytics_enabled=False) as extras_interface: with gr.Row().style(equal_height=False): @@ -1122,15 +1094,8 @@ def create_ui(wrap_gradio_gpu_call): submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') - with gr.Column(variant='panel'): - result_images = gr.Gallery(label="Result", show_label=False) - html_info_x = gr.HTML() - html_info = gr.HTML() - extras_send_to_img2img = gr.Button('Send to img2img') - extras_send_to_inpaint = gr.Button('Send to inpaint') - button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else '' - open_extras_folder = gr.Button('Open output directory', elem_id=button_id) + result_images, html_info_x, html_info = create_output_panel("extras", opts.outdir_extras_samples) submit.click( fn=wrap_gradio_gpu_call(modules.extras.run_extras), @@ -1160,23 +1125,8 @@ def create_ui(wrap_gradio_gpu_call): html_info, ] ) + parameters_copypaste.add_paste_fields("extras", extras_image, None) - extras_send_to_img2img.click( - fn=lambda x: image_from_url_text(x), - _js="extract_image_from_gallery_img2img", - inputs=[result_images], - outputs=[init_img], - ) - - extras_send_to_inpaint.click( - fn=lambda x: image_from_url_text(x), - _js="extract_image_from_gallery_inpaint", - inputs=[result_images], - outputs=[init_img_with_mask], - ) - - global init_img_components - init_img_components = {"img2img":init_img, "inpaint":init_img_with_mask, "extras":extras_image} with gr.Blocks(analytics_enabled=False) as pnginfo_interface: with gr.Row().style(equal_height=False): @@ -1187,11 +1137,10 @@ def create_ui(wrap_gradio_gpu_call): html = gr.HTML() generation_info = gr.Textbox(visible=False) html2 = gr.HTML() - with gr.Row(): - pnginfo_send_to_txt2img = gr.Button('Send to txt2img') - pnginfo_send_to_img2img = gr.Button('Send to img2img') - + buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) + parameters_copypaste.bind_buttons(buttons, image, generation_info) + image.change( fn=wrap_gradio_call(modules.extras.run_pnginfo), inputs=[image], @@ -1475,28 +1424,6 @@ def create_ui(wrap_gradio_gpu_call): script_callbacks.ui_settings_callback() opts.reorder() - def open_folder(f): - if not os.path.exists(f): - print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') - return - elif not os.path.isdir(f): - print(f""" -WARNING -An open_folder request was made with an argument that is not a folder. -This could be an error or a malicious attempt to run code on your computer. -Requested path was: {f} -""", file=sys.stderr) - return - - if not shared.cmd_opts.hide_ui_dir_config: - path = os.path.normpath(f) - if platform.system() == "Windows": - os.startfile(path) - elif platform.system() == "Darwin": - sp.Popen(["open", path]) - else: - sp.Popen(["xdg-open", path]) - def run_settings(*args): changed = 0 @@ -1641,6 +1568,8 @@ Requested path was: {f} if column is not None: column.__exit__() + parameters_copypaste.run_bind() + interfaces = [ (txt2img_interface, "txt2img", "txt2img"), (img2img_interface, "img2img", "img2img"), @@ -1731,85 +1660,14 @@ Requested path was: {f} component_dict['sd_model_checkpoint'], ] ) - paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Seed', 'Size-1', 'Size-2'] - txt2img_fields = [field for field,name in txt2img_paste_fields if name in paste_field_names] - img2img_fields = [field for field,name in img2img_paste_fields if name in paste_field_names] - send_to_img2img.click( - fn=lambda img, *args: (image_from_url_text(img),*args), - _js="(gallery, ...args) => [extract_image_from_gallery_img2img(gallery), ...args]", - inputs=[txt2img_gallery] + txt2img_fields, - outputs=[init_img] + img2img_fields, - ) - - send_to_inpaint.click( - fn=lambda x, *args: (image_from_url_text(x), *args), - _js="(gallery, ...args) => [extract_image_from_gallery_inpaint(gallery), ...args]", - inputs=[txt2img_gallery] + txt2img_fields, - outputs=[init_img_with_mask] + img2img_fields, - ) - - img2img_send_to_img2img.click( - fn=lambda x: image_from_url_text(x), - _js="extract_image_from_gallery_img2img", - inputs=[img2img_gallery], - outputs=[init_img], - ) - - img2img_send_to_inpaint.click( - fn=lambda x: image_from_url_text(x), - _js="extract_image_from_gallery_inpaint", - inputs=[img2img_gallery], - outputs=[init_img_with_mask], - ) - - send_to_extras.click( - fn=lambda x: image_from_url_text(x), - _js="extract_image_from_gallery_extras", - inputs=[txt2img_gallery], - outputs=[extras_image], - ) - - open_txt2img_folder.click( - fn=lambda: open_folder(opts.outdir_samples or opts.outdir_txt2img_samples), - inputs=[], - outputs=[], - ) - - open_img2img_folder.click( - fn=lambda: open_folder(opts.outdir_samples or opts.outdir_img2img_samples), - inputs=[], - outputs=[], - ) - - open_extras_folder.click( - fn=lambda: open_folder(opts.outdir_samples or opts.outdir_extras_samples), - inputs=[], - outputs=[], - ) - - img2img_send_to_extras.click( - fn=lambda x: image_from_url_text(x), - _js="extract_image_from_gallery_extras", - inputs=[img2img_gallery], - outputs=[extras_image], - ) + settings_map = { 'sd_hypernetwork': 'Hypernet', 'CLIP_stop_at_last_layers': 'Clip skip', 'sd_model_checkpoint': 'Model hash', } - - settings_paste_fields = [ - (component_dict[k], lambda d, k=k, v=v: apply_setting(k, d.get(v, None))) - for k, v in settings_map.items() - ] - - modules.generation_parameters_copypaste.connect_paste(txt2img_paste, txt2img_paste_fields + settings_paste_fields, txt2img_prompt) - modules.generation_parameters_copypaste.connect_paste(img2img_paste, img2img_paste_fields + settings_paste_fields, img2img_prompt) - - modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_txt2img, txt2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_txt2img') - modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_img2img, img2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_img2img_img2img') + ui_config_file = cmd_opts.ui_config_file ui_settings = {} -- cgit v1.2.1 From 462e6ba6675bd14c0f82e465423a0eedfff82372 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Thu, 27 Oct 2022 15:40:24 +0900 Subject: Disable unavailable or duplicate options --- modules/hypernetworks/ui.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 2c6c0470..c2d4b51c 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -8,7 +8,8 @@ import modules.textual_inversion.textual_inversion from modules import devices, sd_hijack, shared from modules.hypernetworks import hypernetwork -keys = list(hypernetwork.HypernetworkModule.activation_dict.keys()) +not_available = ["hardswish", "multiheadattention"] +keys = list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available) def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False): # Remove illegal characters from name. -- cgit v1.2.1 From e0cbf53f451f45ea73dafab654eb6596cbd67ec2 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Thu, 27 Oct 2022 18:00:51 +0800 Subject: create send to buttons by extensions --- modules/ui.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index ccba14b6..922a2163 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1568,7 +1568,8 @@ def create_ui(wrap_gradio_gpu_call): if column is not None: column.__exit__() - parameters_copypaste.run_bind() + + interfaces = [ (txt2img_interface, "txt2img", "txt2img"), @@ -1581,7 +1582,7 @@ def create_ui(wrap_gradio_gpu_call): interfaces += script_callbacks.ui_tabs_callback() - interfaces += [(settings_interface, "Settings", "settings")] + interfaces += [(settings_interface, "Settings", "settings")] css = "" @@ -1667,7 +1668,8 @@ def create_ui(wrap_gradio_gpu_call): 'CLIP_stop_at_last_layers': 'Clip skip', 'sd_model_checkpoint': 'Model hash', } - + + parameters_copypaste.run_bind() ui_config_file = cmd_opts.ui_config_file ui_settings = {} -- cgit v1.2.1 From 0995e879cea8ce871489ea8e393bb0eba6edc09c Mon Sep 17 00:00:00 2001 From: Florian Horn Date: Thu, 27 Oct 2022 16:20:01 +0200 Subject: added save button and shortcut (s) to Modal View --- modules/ui.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 0a63e357..1332e265 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -630,7 +630,7 @@ def create_ui(wrap_gradio_gpu_call): import modules.img2img import modules.txt2img - + with gr.Blocks(analytics_enabled=False) as txt2img_interface: txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) @@ -683,7 +683,8 @@ def create_ui(wrap_gradio_gpu_call): with gr.Column(): with gr.Row(): - save = gr.Button('Save') + saveButtonId = 'save_txt2img' + save = gr.Button('Save', elem_id=saveButtonId) send_to_img2img = gr.Button('Send to img2img') send_to_inpaint = gr.Button('Send to inpaint') send_to_extras = gr.Button('Send to extras') @@ -901,7 +902,8 @@ def create_ui(wrap_gradio_gpu_call): with gr.Column(): with gr.Row(): - save = gr.Button('Save') + saveButtonId = 'save_img2img' + save = gr.Button('Save', elem_id=saveButtonId) img2img_send_to_img2img = gr.Button('Send to img2img') img2img_send_to_inpaint = gr.Button('Send to inpaint') img2img_send_to_extras = gr.Button('Send to extras') -- cgit v1.2.1 From 268159cfe3231743c554a1a9bf15d090c758f920 Mon Sep 17 00:00:00 2001 From: Florian Horn Date: Thu, 27 Oct 2022 16:32:10 +0200 Subject: fixed indentation --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 1332e265..d49b10b2 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -630,7 +630,7 @@ def create_ui(wrap_gradio_gpu_call): import modules.img2img import modules.txt2img - + with gr.Blocks(analytics_enabled=False) as txt2img_interface: txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) -- cgit v1.2.1 From a0a7024c679056dd66beb1832e52041b10143130 Mon Sep 17 00:00:00 2001 From: FlameLaw <116745066+FlameLaw@users.noreply.github.com> Date: Fri, 28 Oct 2022 02:13:48 +0900 Subject: Fix random dataset shuffle on TI --- modules/textual_inversion/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 5b1c5002..8bb00d27 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -86,12 +86,12 @@ class PersonalizedBase(Dataset): assert len(self.dataset) > 0, "No images have been found in the dataset." self.length = len(self.dataset) * repeats // batch_size - self.initial_indexes = np.arange(len(self.dataset)) + self.dataset_length = len(self.dataset) self.indexes = None self.shuffle() def shuffle(self): - self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()] + self.indexes = np.random.permutation(self.dataset_length) def create_text(self, filename_text): text = random.choice(self.lines) -- cgit v1.2.1 From 26a3fd2fe9314330336fb0e28d1e9ca7d2abe10e Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Thu, 27 Oct 2022 11:27:59 -0700 Subject: Highres fix works with unmasked latent. Also refactor the mask creation to make it more accesible. --- modules/processing.py | 134 ++++++++++++++++++++++++++++---------------------- 1 file changed, 76 insertions(+), 58 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index f72185ac..548eec29 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -129,6 +129,73 @@ class StableDiffusionProcessing(): self.all_seeds = None self.all_subseeds = None + def txt2img_image_conditioning(self, x, width=None, height=None): + if self.sampler.conditioning_key not in {'hybrid', 'concat'}: + # Dummy zero conditioning if we're not using inpainting model. + # Still takes up a bit of memory, but no encoder call. + # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. + return torch.zeros( + x.shape[0], 5, 1, 1, + dtype=x.dtype, + device=x.device + ) + + height = height or self.height + width = width or self.width + + # The "masked-image" in this case will just be all zeros since the entire image is masked. + image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) + image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning)) + + # Add the fake full 1s mask to the first dimension. + image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) + image_conditioning = image_conditioning.to(x.dtype) + + return image_conditioning + + def img2img_image_conditioning(self, source_image, latent_image, image_mask = None): + if self.sampler.conditioning_key not in {'hybrid', 'concat'}: + # Dummy zero conditioning if we're not using inpainting model. + return torch.zeros( + latent_image.shape[0], 5, 1, 1, + dtype=latent_image.dtype, + device=latent_image.device + ) + + # Handle the different mask inputs + if image_mask is not None: + if torch.is_tensor(image_mask): + conditioning_mask = image_mask + else: + conditioning_mask = np.array(image_mask.convert("L")) + conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 + conditioning_mask = torch.from_numpy(conditioning_mask[None, None]) + + # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0 + conditioning_mask = torch.round(conditioning_mask) + else: + conditioning_mask = torch.ones(1, 1, *source_image.shape[-2:]) + + # Create another latent image, this time with a masked version of the original input. + # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter. + conditioning_mask = conditioning_mask.to(source_image.device) + conditioning_image = torch.lerp( + source_image, + source_image * (1.0 - conditioning_mask), + getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) + ) + + # Encode the new masked image using first stage of network. + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) + + # Create the concatenated conditioning tensor to be fed to `c_concat` + conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:]) + conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1) + image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1) + image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype) + + return image_conditioning + def init(self, all_prompts, all_seeds, all_subseeds): pass @@ -571,37 +638,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f - def create_dummy_mask(self, x, width=None, height=None): - if self.sampler.conditioning_key in {'hybrid', 'concat'}: - height = height or self.height - width = width or self.width - - # The "masked-image" in this case will just be all zeros since the entire image is masked. - image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) - image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning)) - - # Add the fake full 1s mask to the first dimension. - image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) - image_conditioning = image_conditioning.to(x.dtype) - - else: - # Dummy zero conditioning if we're not using inpainting model. - # Still takes up a bit of memory, but no encoder call. - # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. - image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device) - - return image_conditioning - def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) if not self.enable_hr: x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x)) + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) return samples x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, self.firstphase_width, self.firstphase_height)) + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height)) samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] @@ -638,7 +684,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): x = None devices.torch_gc() - samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=self.create_dummy_mask(samples)) + image_conditioning = self.img2img_image_conditioning( + decoded_samples, + samples, + decoded_samples.new_ones(decoded_samples.shape[0], 1, decoded_samples.shape[2], decoded_samples.shape[3]) + ) + samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning) return samples @@ -770,40 +821,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask - if self.sampler.conditioning_key in {'hybrid', 'concat'}: - if self.image_mask is not None: - conditioning_mask = np.array(self.image_mask.convert("L")) - conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 - conditioning_mask = torch.from_numpy(conditioning_mask[None, None]) - - # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0 - conditioning_mask = torch.round(conditioning_mask) - else: - conditioning_mask = torch.ones(1, 1, *image.shape[-2:]) - - # Create another latent image, this time with a masked version of the original input. - conditioning_mask = conditioning_mask.to(image.device) - - # Smoothly interpolate between the masked and unmasked latent conditioning image. - conditioning_image = torch.lerp( - image, - image * (1.0 - conditioning_mask), - getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) - ) - - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) - - # Create the concatenated conditioning tensor to be fed to `c_concat` - conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:]) - conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1) - self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1) - self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype) - else: - self.image_conditioning = torch.zeros( - self.init_latent.shape[0], 5, 1, 1, - dtype=self.init_latent.dtype, - device=self.init_latent.device - ) + self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask) def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): -- cgit v1.2.1 From a38496c1deef12f56f74f8abce2034bef8bdaccb Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Thu, 27 Oct 2022 11:31:31 -0700 Subject: Moved mask weight config to SD section --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index d47378e8..9c2fa0d4 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -267,6 +267,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), + "inpainting_mask_weight": OptionInfo(1.0, "Strength of img2img conditioning mask for inpainting models.", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), @@ -320,7 +321,6 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), - "inpainting_mask_weight": OptionInfo(1.0, "Blend betweeen an unmasked and masked conditioning image for inpainting models.", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), })) -- cgit v1.2.1 From b68c7c437eda2840a304539dd2acd0b0894e920c Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Thu, 27 Oct 2022 11:45:35 -0700 Subject: Updated name and hover text. --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 9c2fa0d4..7c428d90 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -267,7 +267,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), - "inpainting_mask_weight": OptionInfo(1.0, "Strength of img2img conditioning mask for inpainting models.", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), -- cgit v1.2.1 From b50ff4f4e4d4d6bf31e222832d3fe4cfde4703c9 Mon Sep 17 00:00:00 2001 From: Josh Watzman Date: Thu, 27 Oct 2022 21:59:16 +0100 Subject: Reduce peak memory usage when changing models A few tweaks to reduce peak memory usage, the biggest being that if we aren't using the checkpoint cache, we shouldn't duplicate the model state dict just to immediately throw it away. On my machine with 16GB of RAM, this change means I can typically change models, whereas before it would typically OOM. --- modules/sd_models.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index e697bb72..203e99a8 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -170,7 +170,9 @@ def load_model_weights(model, checkpoint_info): print(f"Global Step: {pl_sd['global_step']}") sd = get_state_dict_from_checkpoint(pl_sd) - missing, extra = model.load_state_dict(sd, strict=False) + del pl_sd + model.load_state_dict(sd, strict=False) + del sd if shared.cmd_opts.opt_channelslast: model.to(memory_format=torch.channels_last) @@ -194,9 +196,10 @@ def load_model_weights(model, checkpoint_info): model.first_stage_model.to(devices.dtype_vae) - checkpoints_loaded[checkpoint_info] = model.state_dict().copy() - while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: - checkpoints_loaded.popitem(last=False) # LRU + if shared.opts.sd_checkpoint_cache > 0: + checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + checkpoints_loaded.popitem(last=False) # LRU else: print(f"Loading weights [{sd_model_hash}] from cache") checkpoints_loaded.move_to_end(checkpoint_info) -- cgit v1.2.1 From 5d5dc64064d8ca399a76fe44dbb62bdef6c4b7c4 Mon Sep 17 00:00:00 2001 From: Antonio Date: Fri, 28 Oct 2022 05:49:39 +0200 Subject: Natural sorting for dropdown checkpoint list Example: Before After 11.ckpt 11.ckpt ab.ckpt ab.ckpt ade_pablo_step_1000.ckpt ade_pablo_step_500.ckpt ade_pablo_step_500.ckpt ade_pablo_step_1000.ckpt ade_step_1000.ckpt ade_step_500.ckpt ade_step_1500.ckpt ade_step_1000.ckpt ade_step_2000.ckpt ade_step_1500.ckpt ade_step_2500.ckpt ade_step_2000.ckpt ade_step_3000.ckpt ade_step_2500.ckpt ade_step_500.ckpt ade_step_3000.ckpt atp_step_5500.ckpt atp_step_5500.ckpt model1.ckpt model1.ckpt model10.ckpt model10.ckpt model1000.ckpt model33.ckpt model33.ckpt model50.ckpt model400.ckpt model400.ckpt model50.ckpt model1000.ckpt moo44.ckpt moo44.ckpt v1-4-pruned-emaonly.ckpt v1-4-pruned-emaonly.ckpt v1-5-pruned-emaonly.ckpt v1-5-pruned-emaonly.ckpt v1-5-pruned.ckpt v1-5-pruned.ckpt v1-5-vae.ckpt v1-5-vae.ckpt --- modules/sd_models.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index e697bb72..64d5ee0d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -3,6 +3,7 @@ import os.path import sys from collections import namedtuple import torch +import re from omegaconf import OmegaConf from ldm.util import instantiate_from_config @@ -35,8 +36,10 @@ def setup_model(): list_models() -def checkpoint_tiles(): - return sorted([x.title for x in checkpoints_list.values()]) +def checkpoint_tiles(): + convert = lambda name: int(name) if name.isdigit() else name.lower() + alphanumeric_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)] + return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key) def list_models(): -- cgit v1.2.1 From b2a8b263b2f09bd772f75502c5a83656580f34ec Mon Sep 17 00:00:00 2001 From: benkyoujouzu Date: Thu, 27 Oct 2022 13:00:47 +0800 Subject: Add missing support for linear activation in hypernetwork --- modules/hypernetworks/hypernetwork.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 8113b35b..87cf3cf3 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -25,6 +25,7 @@ from statistics import stdev, mean class HypernetworkModule(torch.nn.Module): multiplier = 1.0 activation_dict = { + "linear": torch.nn.Identity, "relu": torch.nn.ReLU, "leakyrelu": torch.nn.LeakyReLU, "elu": torch.nn.ELU, -- cgit v1.2.1 From 9e465c8aa5616df4c6723bee007ffd3910404f12 Mon Sep 17 00:00:00 2001 From: timntorres Date: Thu, 27 Oct 2022 23:03:34 -0700 Subject: Add strength to textinfo. --- modules/processing.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 4efba946..93066522 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -329,6 +329,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name), + "Hypernetwork strength": (None if shared.loaded_hypernetwork is None else shared.opts.sd_hypernetwork_strength), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), -- cgit v1.2.1 From d4a069a23cb19104b4e58a33d0d1670fadaefb7a Mon Sep 17 00:00:00 2001 From: timntorres Date: Thu, 27 Oct 2022 23:16:27 -0700 Subject: Read hypernet strength from PNG info. --- modules/ui.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 0a63e357..62a2f4f3 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1812,6 +1812,7 @@ Requested path was: {f} settings_map = { 'sd_hypernetwork': 'Hypernet', + 'sd_hypernetwork_strength': 'Hypernetwork strength', 'CLIP_stop_at_last_layers': 'Clip skip', 'sd_model_checkpoint': 'Model hash', } -- cgit v1.2.1 From c0677b33161f04c3ed1a7a78f4c7288fb95787b7 Mon Sep 17 00:00:00 2001 From: timntorres Date: Thu, 27 Oct 2022 23:31:45 -0700 Subject: Explicitly state when Hypernet is none. --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 93066522..74a0cd64 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -328,7 +328,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Size": f"{p.width}x{p.height}", "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), - "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name), + "Hypernet": ("None" if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name), "Hypernetwork strength": (None if shared.loaded_hypernetwork is None else shared.opts.sd_hypernetwork_strength), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), -- cgit v1.2.1 From db5a354c489bfd1c95e0bbf9af12ab8b5d6fe170 Mon Sep 17 00:00:00 2001 From: timntorres Date: Fri, 28 Oct 2022 01:41:57 -0700 Subject: Always ignore "None.pt" in the hypernet directory. --- modules/hypernetworks/hypernetwork.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 8113b35b..cd920df5 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -208,13 +208,16 @@ def list_hypernetworks(path): res = {} for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): name = os.path.splitext(os.path.basename(filename))[0] - res[name] = filename + # Prevent a hypothetical "None.pt" from being listed. + if name != "None": + res[name] = filename return res def load_hypernetwork(filename): path = shared.hypernetworks.get(filename, None) - if path is not None: + # Prevent any file named "None.pt" from being loaded. + if path is not None and filename != "None": print(f"Loading hypernetwork {filename}") try: shared.loaded_hypernetwork = Hypernetwork() -- cgit v1.2.1 From 9ceef81f77ecce89f0c8f412c4d849210d852e82 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Fri, 28 Oct 2022 20:48:08 +0700 Subject: Fix log off by 1 --- modules/hypernetworks/hypernetwork.py | 12 +++++++----- modules/textual_inversion/learn_schedule.py | 2 +- modules/textual_inversion/textual_inversion.py | 24 ++++++++++++------------ 3 files changed, 20 insertions(+), 18 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 8113b35b..a0297997 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -428,7 +428,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log optimizer.step() - if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): + steps_done = hypernetwork.step + 1 + + if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): raise RuntimeError("Loss diverged.") if len(previous_mean_losses) > 1: @@ -438,9 +440,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})" pbar.set_description(dataset_loss_info) - if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: + if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: # Before saving, change name to match current checkpoint. - hypernetwork.name = f'{hypernetwork_name}-{hypernetwork.step}' + hypernetwork.name = f'{hypernetwork_name}-{steps_done}' last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt') hypernetwork.save(last_saved_file) @@ -449,8 +451,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log "learn_rate": scheduler.learn_rate }) - if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: - forced_filename = f'{hypernetwork_name}-{hypernetwork.step}' + if images_dir is not None and steps_done % create_image_every == 0: + forced_filename = f'{hypernetwork_name}-{steps_done}' last_saved_image = os.path.join(images_dir, forced_filename) optimizer.zero_grad() diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py index 2062726a..3a736065 100644 --- a/modules/textual_inversion/learn_schedule.py +++ b/modules/textual_inversion/learn_schedule.py @@ -52,7 +52,7 @@ class LearnRateScheduler: self.finished = False def apply(self, optimizer, step_number): - if step_number <= self.end_step: + if step_number < self.end_step: return try: diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index ff002d3e..17dfb223 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -184,9 +184,8 @@ def write_loss(log_directory, filename, step, epoch_len, values): if shared.opts.training_write_csv_every == 0: return - if step % shared.opts.training_write_csv_every != 0: + if (step + 1) % shared.opts.training_write_csv_every != 0: return - write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True with open(os.path.join(log_directory, filename), "a+", newline='') as fout: @@ -196,11 +195,11 @@ def write_loss(log_directory, filename, step, epoch_len, values): csv_writer.writeheader() epoch = step // epoch_len - epoch_step = step - epoch * epoch_len + epoch_step = step % epoch_len csv_writer.writerow({ "step": step + 1, - "epoch": epoch + 1, + "epoch": epoch, "epoch_step": epoch_step + 1, **values, }) @@ -282,15 +281,16 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc loss.backward() optimizer.step() + steps_done = embedding.step + 1 epoch_num = embedding.step // len(ds) - epoch_step = embedding.step - (epoch_num * len(ds)) + 1 + epoch_step = embedding.step % len(ds) - pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}") + pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}") - if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0: + if embedding_dir is not None and steps_done % save_embedding_every == 0: # Before saving, change name to match current checkpoint. - embedding.name = f'{embedding_name}-{embedding.step}' + embedding.name = f'{embedding_name}-{steps_done}' last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt') embedding.save(last_saved_file) embedding_yet_to_be_embedded = True @@ -300,8 +300,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc "learn_rate": scheduler.learn_rate }) - if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: - forced_filename = f'{embedding_name}-{embedding.step}' + if images_dir is not None and steps_done % create_image_every == 0: + forced_filename = f'{embedding_name}-{steps_done}' last_saved_image = os.path.join(images_dir, forced_filename) p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, @@ -334,7 +334,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: - last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png') + last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png') info = PngImagePlugin.PngInfo() data = torch.load(last_saved_file) @@ -350,7 +350,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc checkpoint = sd_models.select_checkpoint() footer_left = checkpoint.model_name footer_mid = '[{}]'.format(checkpoint.hash) - footer_right = '{}v {}s'.format(vectorSize, embedding.step) + footer_right = '{}v {}s'.format(vectorSize, steps_done) captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) captioned_image = insert_image_data_embed(captioned_image, data) -- cgit v1.2.1 From 26d08193848568b06105a1ee7b76f338ebf0f0ee Mon Sep 17 00:00:00 2001 From: Chris OBryan <13701027+cobryan05@users.noreply.github.com> Date: Fri, 28 Oct 2022 13:24:11 -0500 Subject: extras: Add option to run upscaling before face fixing Face restoration can look much better if ran after upscaling, as it allows the restoration to fix upscaling artifacts. This patch adds an option to choose which order to run upscaling/face fixing in. --- modules/extras.py | 145 +++++++++++++++++++++++++++++++++++------------------- modules/ui.py | 4 ++ 2 files changed, 99 insertions(+), 50 deletions(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index 22c5a1c1..79047f3a 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -7,6 +7,10 @@ from PIL import Image import torch import tqdm +from typing import Callable, List, Tuple +from functools import partial +from dataclasses import dataclass + from modules import processing, shared, images, devices, sd_models from modules.shared import opts import modules.gfpgan_model @@ -20,7 +24,7 @@ import gradio as gr cached_images = {} -def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility): +def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool ): devices.torch_gc() imageArr = [] @@ -56,68 +60,109 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ else: outpath = opts.outdir_samples or opts.outdir_extras_samples - - for image, image_name in zip(imageArr, imageNameArr): - if image is None: - return outputs, "Please select an input image.", '' - existing_pnginfo = image.info or {} - image = image.convert("RGB") - info = "" + # Extra operation definitions + def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]: + restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) + res = Image.fromarray(restored_img) + + if gfpgan_visibility < 1.0: + res = Image.blend(image, res, gfpgan_visibility) + + info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n" + return (res, info) - if gfpgan_visibility > 0: - restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) - res = Image.fromarray(restored_img) + def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]: + restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) + res = Image.fromarray(restored_img) - if gfpgan_visibility < 1.0: - res = Image.blend(image, res, gfpgan_visibility) + if codeformer_visibility < 1.0: + res = Image.blend(image, res, codeformer_visibility) - info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n" - image = res + info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" + return (res, info) - if codeformer_visibility > 0: - restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) - res = Image.fromarray(restored_img) - if codeformer_visibility < 1.0: - res = Image.blend(image, res, codeformer_visibility) + def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): + small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10)) + pixels = tuple(np.array(small).flatten().tolist()) + key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight, + resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop) + pixels - info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" - image = res + c = cached_images.get(key) + if c is None: + upscaler = shared.sd_upscalers[scaler_index] + c = upscaler.scaler.upscale(image, resize, upscaler.data_path) + if mode == 1 and crop: + cropped = Image.new("RGB", (resize_w, resize_h)) + cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2)) + c = cropped + cached_images[key] = c + return c + + def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]: + # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text + nonlocal upscaling_resize if resize_mode == 1: upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height) crop_info = " (crop)" if upscaling_crop else "" info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n" + return (image, info) + + @dataclass + class UpscaleParams: + upscaler_idx: int + blend_alpha: float + + def run_upscalers_blend( params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: + blended_result: Image.Image = None + for upscaler in params: + res = upscale(image, upscaler.upscaler_idx, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop) + info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" + if blended_result is None: + blended_result = res + else: + blended_result = Image.blend(blended_result, res, upscaler.blend_alpha) + return (blended_result, info) + + # Build a list of operations to run + facefix_ops: List[Callable] = [] + if gfpgan_visibility > 0: + facefix_ops.append(run_gfpgan) + if codeformer_visibility > 0: + facefix_ops.append(run_codeformer) + + upscale_ops: List[Callable] = [] + if resize_mode == 1: + upscale_ops.append(run_prepare_crop) + + if upscaling_resize != 0: + step_params: List[UpscaleParams] = [] + step_params.append( UpscaleParams( upscaler_idx=extras_upscaler_1, blend_alpha=1.0 )) + if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: + step_params.append( UpscaleParams( upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility ) ) + + upscale_ops.append( partial(run_upscalers_blend, step_params) ) + + + extras_ops: List[Callable] = [] + if upscale_first: + extras_ops = upscale_ops + facefix_ops + else: + extras_ops = facefix_ops + upscale_ops + + + for image, image_name in zip(imageArr, imageNameArr): + if image is None: + return outputs, "Please select an input image.", '' + existing_pnginfo = image.info or {} - if upscaling_resize != 1.0: - def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): - small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10)) - pixels = tuple(np.array(small).flatten().tolist()) - key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight, - resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop) + pixels - - c = cached_images.get(key) - if c is None: - upscaler = shared.sd_upscalers[scaler_index] - c = upscaler.scaler.upscale(image, resize, upscaler.data_path) - if mode == 1 and crop: - cropped = Image.new("RGB", (resize_w, resize_h)) - cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2)) - c = cropped - cached_images[key] = c - - return c - - info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n" - res = upscale(image, extras_upscaler_1, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop) - - if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: - res2 = upscale(image, extras_upscaler_2, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop) - info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n" - res = Image.blend(res, res2, extras_upscaler_2_visibility) - - image = res + image = image.convert("RGB") + info = "" + # Run each operation on each image + for op in extras_ops: + image, info = op(image, info) while len(cached_images) > 2: del cached_images[next(iter(cached_images.keys()))] diff --git a/modules/ui.py b/modules/ui.py index 0a63e357..16b6ac49 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1119,6 +1119,9 @@ def create_ui(wrap_gradio_gpu_call): codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer) codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer) + with gr.Group(): + upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False) + submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') with gr.Column(variant='panel'): @@ -1152,6 +1155,7 @@ def create_ui(wrap_gradio_gpu_call): extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, + upscale_before_face_fix, ], outputs=[ result_images, -- cgit v1.2.1 From bde4731f1d3ddf30c46f86c9f6e71e6c0644089d Mon Sep 17 00:00:00 2001 From: Chris OBryan <13701027+cobryan05@users.noreply.github.com> Date: Fri, 28 Oct 2022 14:30:04 -0500 Subject: extras: Rework image cache Bit of a refactor to the image cache to make it easier to extend. Also takes into account the entire image instead of just a cropped portion. --- modules/extras.py | 52 ++++++++++++++++++++++++++++++++-------------------- 1 file changed, 32 insertions(+), 20 deletions(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index 79047f3a..cffe0381 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -7,7 +7,7 @@ from PIL import Image import torch import tqdm -from typing import Callable, List, Tuple +from typing import Callable, Dict, List, Tuple from functools import partial from dataclasses import dataclass @@ -21,7 +21,18 @@ import piexif.helper import gradio as gr -cached_images = {} +@dataclass(frozen=True) +class CacheKey: + image_hash: int + info_hash: int + args_hash: int + +@dataclass +class CacheEntry: + image: Image.Image + info: str + +cached_images: Dict[CacheKey, CacheEntry] = {} def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool ): @@ -84,22 +95,13 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): - small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10)) - pixels = tuple(np.array(small).flatten().tolist()) - key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight, - resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop) + pixels - - c = cached_images.get(key) - if c is None: - upscaler = shared.sd_upscalers[scaler_index] - c = upscaler.scaler.upscale(image, resize, upscaler.data_path) - if mode == 1 and crop: - cropped = Image.new("RGB", (resize_w, resize_h)) - cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2)) - c = cropped - cached_images[key] = c - return c - + upscaler = shared.sd_upscalers[scaler_index] + res = upscaler.scaler.upscale(image, resize, upscaler.data_path) + if mode == 1 and crop: + cropped = Image.new("RGB", (resize_w, resize_h)) + cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2)) + res = cropped + return res def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]: # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text @@ -118,8 +120,18 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ def run_upscalers_blend( params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: blended_result: Image.Image = None for upscaler in params: - res = upscale(image, upscaler.upscaler_idx, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop) - info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" + upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop) + cache_key = CacheKey( image_hash = hash(np.array(image.getdata()).tobytes()), + info_hash = hash(info), + args_hash = hash(upscale_args) ) + cached_entry = cached_images.get(cache_key) + if cached_entry is None: + res = upscale(image, *upscale_args) + info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" + cached_images[cache_key] = CacheEntry(image=res, info=info) + else: + res, info = cached_entry.image, cached_entry.info + if blended_result is None: blended_result = res else: -- cgit v1.2.1 From 1f1b327959b546b5e6f995905a1699c5fe4a0c35 Mon Sep 17 00:00:00 2001 From: Chris OBryan <13701027+cobryan05@users.noreply.github.com> Date: Fri, 28 Oct 2022 16:11:16 -0500 Subject: extras: Make image cache LRU This changes the extras image cache into a Least-Recently-Used cache. This allows more experimentation with different upscalers without missing the cache. Max cache size is increased to 5 and is cleared on source image update. --- modules/extras.py | 67 +++++++++++++++++++++++++++++++------------------------ modules/ui.py | 5 +++++ 2 files changed, 43 insertions(+), 29 deletions(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index cffe0381..72cc6d1d 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -1,3 +1,4 @@ +from __future__ import annotations import math import os @@ -7,7 +8,7 @@ from PIL import Image import torch import tqdm -from typing import Callable, Dict, List, Tuple +from typing import Callable, List, OrderedDict, Tuple from functools import partial from dataclasses import dataclass @@ -21,18 +22,34 @@ import piexif.helper import gradio as gr -@dataclass(frozen=True) -class CacheKey: - image_hash: int - info_hash: int - args_hash: int +class LruCache(OrderedDict): + @dataclass(frozen=True) + class Key: + image_hash: int + info_hash: int + args_hash: int -@dataclass -class CacheEntry: - image: Image.Image - info: str + @dataclass + class Value: + image: Image.Image + info: str + + def __init__(self, max_size:int = 5, *args, **kwargs): + super().__init__(*args, **kwargs) + self._max_size = max_size + + def get(self, key: LruCache.Key) -> LruCache.Value: + ret = super().get(key) + if ret is not None: + self.move_to_end(key) # Move to end of eviction list + return ret + + def put(self, key: LruCache.Key, value: LruCache.Value) -> None: + self[key] = value + while len(self) > self._max_size: + self.popitem(last=False) -cached_images: Dict[CacheKey, CacheEntry] = {} +cached_images: LruCache = LruCache(max_size = 5) def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool ): @@ -121,14 +138,14 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ blended_result: Image.Image = None for upscaler in params: upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop) - cache_key = CacheKey( image_hash = hash(np.array(image.getdata()).tobytes()), + cache_key = LruCache.Key( image_hash = hash(np.array(image.getdata()).tobytes()), info_hash = hash(info), - args_hash = hash(upscale_args) ) + args_hash = hash(upscale_args + (upscaler.blend_alpha,)) ) cached_entry = cached_images.get(cache_key) if cached_entry is None: res = upscale(image, *upscale_args) info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" - cached_images[cache_key] = CacheEntry(image=res, info=info) + cached_images.put(cache_key, LruCache.Value(image=res, info=info)) else: res, info = cached_entry.image, cached_entry.info @@ -140,14 +157,11 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ # Build a list of operations to run facefix_ops: List[Callable] = [] - if gfpgan_visibility > 0: - facefix_ops.append(run_gfpgan) - if codeformer_visibility > 0: - facefix_ops.append(run_codeformer) + facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else [] + facefix_ops += [run_codeformer] if codeformer_visibility > 0 else [] upscale_ops: List[Callable] = [] - if resize_mode == 1: - upscale_ops.append(run_prepare_crop) + upscale_ops += [run_prepare_crop] if resize_mode == 1 else [] if upscaling_resize != 0: step_params: List[UpscaleParams] = [] @@ -157,12 +171,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ upscale_ops.append( partial(run_upscalers_blend, step_params) ) - - extras_ops: List[Callable] = [] - if upscale_first: - extras_ops = upscale_ops + facefix_ops - else: - extras_ops = facefix_ops + upscale_ops + extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops) for image, image_name in zip(imageArr, imageNameArr): @@ -176,9 +185,6 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ for op in extras_ops: image, info = op(image, info) - while len(cached_images) > 2: - del cached_images[next(iter(cached_images.keys()))] - if opts.use_original_name_batch and image_name != None: basename = os.path.splitext(os.path.basename(image_name))[0] else: @@ -198,6 +204,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ return outputs, plaintext_to_html(info), '' +def clear_cache(): + cached_images.clear() + def run_pnginfo(image): if image is None: diff --git a/modules/ui.py b/modules/ui.py index 16b6ac49..b7c36c55 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1178,6 +1178,11 @@ def create_ui(wrap_gradio_gpu_call): outputs=[init_img_with_mask], ) + extras_image.change( + fn=modules.extras.clear_cache, + inputs=[], outputs=[] + ) + with gr.Blocks(analytics_enabled=False) as pnginfo_interface: with gr.Row().style(equal_height=False): with gr.Column(variant='panel'): -- cgit v1.2.1 From 5732c0282d529ef2e0591c76e16959e97240dad8 Mon Sep 17 00:00:00 2001 From: Chris OBryan <13701027+cobryan05@users.noreply.github.com> Date: Fri, 28 Oct 2022 16:36:25 -0500 Subject: extras-tweaks: autoformat changed lines --- modules/extras.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index 72cc6d1d..50026a25 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -34,14 +34,14 @@ class LruCache(OrderedDict): image: Image.Image info: str - def __init__(self, max_size:int = 5, *args, **kwargs): + def __init__(self, max_size: int = 5, *args, **kwargs): super().__init__(*args, **kwargs) self._max_size = max_size def get(self, key: LruCache.Key) -> LruCache.Value: ret = super().get(key) if ret is not None: - self.move_to_end(key) # Move to end of eviction list + self.move_to_end(key) # Move to end of eviction list return ret def put(self, key: LruCache.Key, value: LruCache.Value) -> None: @@ -49,10 +49,11 @@ class LruCache(OrderedDict): while len(self) > self._max_size: self.popitem(last=False) -cached_images: LruCache = LruCache(max_size = 5) +cached_images: LruCache = LruCache(max_size=5) -def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool ): + +def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool): devices.torch_gc() imageArr = [] @@ -88,8 +89,8 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ else: outpath = opts.outdir_samples or opts.outdir_extras_samples - # Extra operation definitions + def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]: restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) res = Image.fromarray(restored_img) @@ -110,7 +111,6 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" return (res, info) - def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): upscaler = shared.sd_upscalers[scaler_index] res = upscaler.scaler.upscale(image, resize, upscaler.data_path) @@ -134,13 +134,14 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ upscaler_idx: int blend_alpha: float - def run_upscalers_blend( params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: + def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: blended_result: Image.Image = None for upscaler in params: - upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop) - cache_key = LruCache.Key( image_hash = hash(np.array(image.getdata()).tobytes()), - info_hash = hash(info), - args_hash = hash(upscale_args + (upscaler.blend_alpha,)) ) + upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, + upscaling_resize_w, upscaling_resize_h, upscaling_crop) + cache_key = LruCache.Key(image_hash=hash(np.array(image.getdata()).tobytes()), + info_hash=hash(info), + args_hash=hash(upscale_args + (upscaler.blend_alpha,))) cached_entry = cached_images.get(cache_key) if cached_entry is None: res = upscale(image, *upscale_args) @@ -165,15 +166,14 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ if upscaling_resize != 0: step_params: List[UpscaleParams] = [] - step_params.append( UpscaleParams( upscaler_idx=extras_upscaler_1, blend_alpha=1.0 )) + step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0)) if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: - step_params.append( UpscaleParams( upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility ) ) + step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility)) - upscale_ops.append( partial(run_upscalers_blend, step_params) ) + upscale_ops.append(partial(run_upscalers_blend, step_params)) extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops) - for image, image_name in zip(imageArr, imageNameArr): if image is None: return outputs, "Please select an input image.", '' -- cgit v1.2.1 From d8b366146748555a18b595af400c8cb222ea0ec9 Mon Sep 17 00:00:00 2001 From: Chris OBryan <13701027+cobryan05@users.noreply.github.com> Date: Fri, 28 Oct 2022 16:55:02 -0500 Subject: extras: upscaler blending should not be considered in cache key --- modules/extras.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index 50026a25..681d8d5a 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -141,7 +141,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ upscaling_resize_w, upscaling_resize_h, upscaling_crop) cache_key = LruCache.Key(image_hash=hash(np.array(image.getdata()).tobytes()), info_hash=hash(info), - args_hash=hash(upscale_args + (upscaler.blend_alpha,))) + args_hash=hash(upscale_args)) cached_entry = cached_images.get(cache_key) if cached_entry is None: res = upscale(image, *upscale_args) -- cgit v1.2.1 From 539c0f51e436beeb0ca2b8b8d52b24f4b59ad56a Mon Sep 17 00:00:00 2001 From: Yaiol <38218161+Yaiol@users.noreply.github.com> Date: Sat, 29 Oct 2022 01:07:01 +0200 Subject: Update images.py Filename tags [height] and [width] are wrongly referencing to process size instead of resulting image size. Making all upscale files named wrongly. --- modules/images.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 7870b5b7..a0728553 100644 --- a/modules/images.py +++ b/modules/images.py @@ -300,8 +300,8 @@ class FilenameGenerator: 'seed': lambda self: self.seed if self.seed is not None else '', 'steps': lambda self: self.p and self.p.steps, 'cfg': lambda self: self.p and self.p.cfg_scale, - 'width': lambda self: self.p and self.p.width, - 'height': lambda self: self.p and self.p.height, + 'width': lambda self: self.image.width, + 'height': lambda self: self.image.height, 'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False), 'sampler': lambda self: self.p and sanitize_filename_part(sd_samplers.samplers[self.p.sampler_index].name, replace_spaces=False), 'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash), @@ -315,10 +315,11 @@ class FilenameGenerator: } default_time_format = '%Y%m%d%H%M%S' - def __init__(self, p, seed, prompt): + def __init__(self, p, seed, prompt, image): self.p = p self.seed = seed self.prompt = prompt + self.image = image def prompt_no_style(self): if self.p is None or self.prompt is None: @@ -449,7 +450,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i txt_fullfn (`str` or None): If a text file is saved for this image, this will be its full path. Otherwise None. """ - namegen = FilenameGenerator(p, seed, prompt) + namegen = FilenameGenerator(p, seed, prompt, image) if save_to_dirs is None: save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt) -- cgit v1.2.1 From f361e804ebaa5af4a10711ece2522869fb64a4c6 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Sat, 29 Oct 2022 08:36:50 +0900 Subject: Re enable linear --- modules/hypernetworks/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index c2d4b51c..aad09ffc 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -9,7 +9,7 @@ from modules import devices, sd_hijack, shared from modules.hypernetworks import hypernetwork not_available = ["hardswish", "multiheadattention"] -keys = list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available) +keys = ["linear"] + list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available) def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False): # Remove illegal characters from name. -- cgit v1.2.1 From bce5adcd6de1ad608df8d813a92e1167b7a1d3f2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 29 Oct 2022 07:37:06 +0300 Subject: change default hypernet activation function to linear --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 0a63e357..2541970d 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1238,7 +1238,7 @@ def create_ui(wrap_gradio_gpu_call): new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") - new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=modules.hypernetworks.ui.keys) + new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork", choices=modules.hypernetworks.ui.keys) new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. relu-like - Kaiming, sigmoid-like - Xavier is recommended", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"]) new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout") -- cgit v1.2.1 From a1e5e0d7669def010ecf31d801d6f0667bcf8061 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 29 Oct 2022 08:11:03 +0300 Subject: skip filenames starting with . for img2img and extras batch modes --- modules/extras.py | 2 +- modules/img2img.py | 2 +- modules/shared.py | 5 +++++ 3 files changed, 7 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index 681d8d5a..4d51088b 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -72,7 +72,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ if input_dir == '': return outputs, "Please select an input directory.", '' - image_list = [file for file in [os.path.join(input_dir, x) for x in sorted(os.listdir(input_dir))] if os.path.isfile(file)] + image_list = shared.listfiles(input_dir) for img in image_list: try: image = Image.open(img) diff --git a/modules/img2img.py b/modules/img2img.py index 9c0cf23e..efda26e1 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -19,7 +19,7 @@ import modules.scripts def process_batch(p, input_dir, output_dir, args): processing.fix_seed(p) - images = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)] + images = shared.listfiles(input_dir) print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.") diff --git a/modules/shared.py b/modules/shared.py index 7c428d90..7e634423 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -450,3 +450,8 @@ total_tqdm = TotalTQDM() mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts) mem_mon.start() + + +def listfiles(dirname): + filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")] + return [file for file in filenames if os.path.isfile(file)] -- cgit v1.2.1 From 2d220afb24bd9812d5124814f670ec2a1ff5b0fe Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 29 Oct 2022 08:26:12 +0300 Subject: fix open folder button not working --- modules/ui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 922a2163..20cc10cf 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -631,9 +631,9 @@ Requested path was: {f} buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder' - open_folder = gr.Button(folder_symbol, elem_id=button_id) + open_folder_button = gr.Button(folder_symbol, elem_id=button_id) - open_folder.click( + open_folder_button.click( fn=lambda: open_folder(opts.outdir_samples or outdir), inputs=[], outputs=[], -- cgit v1.2.1 From a33d0a9a65189be35038be5765f2b4d1b19bab0f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 29 Oct 2022 08:28:48 +0300 Subject: remove weird spaces added to ui.py over time --- modules/ui.py | 51 +++++++++++++++++++++++++-------------------------- 1 file changed, 25 insertions(+), 26 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 20cc10cf..46657dd4 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -596,7 +596,7 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele ) return refresh_button -def create_output_panel(tabname, outdir): +def create_output_panel(tabname, outdir): def open_folder(f): if not os.path.exists(f): print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') @@ -618,11 +618,11 @@ Requested path was: {f} sp.Popen(["open", path]) else: sp.Popen(["xdg-open", path]) - - with gr.Column(variant='panel'): - with gr.Group(): + + with gr.Column(variant='panel'): + with gr.Group(): result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) - + generation_info = None with gr.Column(): with gr.Row(): @@ -639,7 +639,7 @@ Requested path was: {f} outputs=[], ) - if tabname != "extras": + if tabname != "extras": with gr.Row(): do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False) @@ -671,8 +671,7 @@ Requested path was: {f} html_info = gr.HTML() parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info - - + def create_ui(wrap_gradio_gpu_call): import modules.img2img @@ -723,10 +722,10 @@ def create_ui(wrap_gradio_gpu_call): with gr.Group(): custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False) - - txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples) - + + txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples) + connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) @@ -781,7 +780,7 @@ def create_ui(wrap_gradio_gpu_call): fn=lambda x: gr_show(x), inputs=[enable_hr], outputs=[hr_options], - ) + ) roll.click( fn=roll_artist, @@ -902,7 +901,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Group(): custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True) - img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples) + img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples) connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) @@ -990,7 +989,7 @@ def create_ui(wrap_gradio_gpu_call): inputs=[init_img], outputs=[img2img_prompt], ) - + roll.click( fn=roll_artist, @@ -1045,7 +1044,7 @@ def create_ui(wrap_gradio_gpu_call): (denoising_strength, "Denoising strength"), *modules.scripts.scripts_img2img.infotext_fields ] - parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) + parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) @@ -1077,9 +1076,9 @@ def create_ui(wrap_gradio_gpu_call): upscaling_resize_w = gr.Number(label="Width", value=512, precision=0) upscaling_resize_h = gr.Number(label="Height", value=512, precision=0) upscaling_crop = gr.Checkbox(label='Crop to fit', value=True) - + with gr.Group(): - extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") with gr.Group(): extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") @@ -1125,7 +1124,7 @@ def create_ui(wrap_gradio_gpu_call): html_info, ] ) - parameters_copypaste.add_paste_fields("extras", extras_image, None) + parameters_copypaste.add_paste_fields("extras", extras_image, None) with gr.Blocks(analytics_enabled=False) as pnginfo_interface: @@ -1139,14 +1138,14 @@ def create_ui(wrap_gradio_gpu_call): html2 = gr.HTML() with gr.Row(): buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) - parameters_copypaste.bind_buttons(buttons, image, generation_info) - + parameters_copypaste.bind_buttons(buttons, image, generation_info) + image.change( fn=wrap_gradio_call(modules.extras.run_pnginfo), inputs=[image], outputs=[html, generation_info, html2], ) - + with gr.Blocks() as modelmerger_interface: with gr.Row().style(equal_height=False): with gr.Column(variant='panel'): @@ -1569,7 +1568,7 @@ def create_ui(wrap_gradio_gpu_call): column.__exit__() - + interfaces = [ (txt2img_interface, "txt2img", "txt2img"), @@ -1582,7 +1581,7 @@ def create_ui(wrap_gradio_gpu_call): interfaces += script_callbacks.ui_tabs_callback() - interfaces += [(settings_interface, "Settings", "settings")] + interfaces += [(settings_interface, "Settings", "settings")] css = "" @@ -1661,7 +1660,7 @@ def create_ui(wrap_gradio_gpu_call): component_dict['sd_model_checkpoint'], ] ) - + settings_map = { 'sd_hypernetwork': 'Hypernet', @@ -1669,7 +1668,7 @@ def create_ui(wrap_gradio_gpu_call): 'sd_model_checkpoint': 'Model hash', } - parameters_copypaste.run_bind() + parameters_copypaste.run_bind() ui_config_file = cmd_opts.ui_config_file ui_settings = {} @@ -1749,7 +1748,7 @@ def load_javascript(raw_response): javascript = f'' scripts_list = modules.scripts.list_scripts("javascript", ".js") - + for basedir, filename, path in scripts_list: with open(path, "r", encoding="utf8") as jsfile: javascript += f"\n" -- cgit v1.2.1 From 3c207ca68483b3406faf519bde2743b578dac222 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 29 Oct 2022 08:42:34 +0300 Subject: add needed imports fr new code in copypaste.py --- modules/generation_parameters_copypaste.py | 9 +++++++++ modules/ui.py | 7 ------- 2 files changed, 9 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 2b80737a..224a17ea 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -1,3 +1,5 @@ +import base64 +import io import os import re import gradio as gr @@ -14,6 +16,7 @@ type_of_gr_update = type(gr.update()) paste_fields = {} bind_list = [] + def quote(text): if ',' not in str(text): return text @@ -23,6 +26,7 @@ def quote(text): text = text.replace('"', '\\"') return f'"{text}"' + def image_from_url_text(filedata): if type(filedata) == dict and filedata["is_file"]: filename = filedata["name"] @@ -45,19 +49,23 @@ def image_from_url_text(filedata): image = Image.open(io.BytesIO(filedata)) return image + def add_paste_fields(tabname, init_img, fields): paste_fields[tabname] = {"init_img":init_img, "fields": fields} + def create_buttons(tabs_list): buttons = {} for tab in tabs_list: buttons[tab] = gr.Button(f"Send to {tab}") return buttons + #if send_generate_info is a tab name, mean generate_info comes from the params fields of the tab def bind_buttons(buttons, send_image, send_generate_info): bind_list.append([buttons, send_image, send_generate_info]) + def run_bind(): for buttons, send_image, send_generate_info in bind_list: for tab in buttons: @@ -98,6 +106,7 @@ def run_bind(): outputs=None, ) + def parse_generation_parameters(x: str): """parses generation parameters string, the one you see in text field under the picture in UI: ``` diff --git a/modules/ui.py b/modules/ui.py index 46657dd4..280910d0 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1,6 +1,4 @@ -import base64 import html -import io import json import math import mimetypes @@ -18,13 +16,8 @@ import gradio as gr import gradio.routes import gradio.utils import numpy as np -import piexif -import torch from PIL import Image, PngImagePlugin -import gradio as gr -import gradio.utils -import gradio.routes from modules import sd_hijack, sd_models, localization, script_callbacks from modules.paths import script_path -- cgit v1.2.1 From 2922d8144f677ea3c37189a2b2e3c3d3ff6d3916 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 29 Oct 2022 09:01:04 +0300 Subject: make existing image browser extension not break --- modules/generation_parameters_copypaste.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 224a17ea..d590e9ee 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -51,7 +51,14 @@ def image_from_url_text(filedata): def add_paste_fields(tabname, init_img, fields): - paste_fields[tabname] = {"init_img":init_img, "fields": fields} + paste_fields[tabname] = {"init_img": init_img, "fields": fields} + + # backwards compatibility for existing extensions + import modules.ui + if tabname == 'txt2img': + modules.ui.txt2img_paste_fields = fields + elif tabname == 'img2img': + modules.ui.img2img_paste_fields = fields def create_buttons(tabs_list): @@ -61,7 +68,7 @@ def create_buttons(tabs_list): return buttons -#if send_generate_info is a tab name, mean generate_info comes from the params fields of the tab +#if send_generate_info is a tab name, mean generate_info comes from the params fields of the tab def bind_buttons(buttons, send_image, send_generate_info): bind_list.append([buttons, send_image, send_generate_info]) @@ -84,12 +91,12 @@ def run_bind(): inputs=[send_image], outputs=[paste_fields[tab]["init_img"]], ) - + if send_generate_info and paste_fields[tab]["fields"] is not None: paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Size-1', 'Size-2'] if shared.opts.send_seed: paste_field_names += ["Seed"] - if send_generate_info in paste_fields: + if send_generate_info in paste_fields: button.click( fn=lambda *x:x, inputs=[field for field,name in paste_fields[send_generate_info]["fields"] if name in paste_field_names], @@ -154,7 +161,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model return res -def connect_paste(button, paste_fields, input_comp): + +def connect_paste(button, paste_fields, input_comp, jsfunc=None): def paste_func(prompt): if not prompt and not shared.cmd_opts.hide_ui_dir_config: filename = os.path.join(script_path, "params.txt") @@ -192,6 +200,7 @@ def connect_paste(button, paste_fields, input_comp): button.click( fn=paste_func, + _js=jsfunc, inputs=[input_comp], outputs=[x[0] for x in paste_fields], ) -- cgit v1.2.1 From 28e6d4a54ea1fa1e34ad1ea0742ab2003ed7fa7f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 29 Oct 2022 09:13:36 +0300 Subject: add element ids for save buttons for #3798 --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 66b743f5..3c34eca0 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -620,7 +620,7 @@ Requested path was: {f} with gr.Column(): with gr.Row(): if tabname != "extras": - save = gr.Button('Save') + save = gr.Button('Save', elem_id=f'save_{tabname}') buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder' -- cgit v1.2.1 From beb6fc29798d82f1b08a34cf5dd79e4ab29d4cd0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 29 Oct 2022 09:57:22 +0300 Subject: move send seed option to UI section and make it false by default --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 5d1ceb85..fb84afd8 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -280,7 +280,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "filter_nsfw": OptionInfo(False, "Filter NSFW content"), 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), - "send_seed": OptionInfo(False, "Send seed when sending prompt or image to other interface"), })) options_templates.update(options_section(('interrogate', "Interrogate Options"), { @@ -306,6 +305,7 @@ options_templates.update(options_section(('ui', "User interface"), { "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "add_model_name_to_info": OptionInfo(False, "Add model name to generation information"), "disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."), + "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), "font": OptionInfo("", "Font for image grids that have text"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), -- cgit v1.2.1 From 2c4d20388425a5e40b93eef3722e42e8d375fbb4 Mon Sep 17 00:00:00 2001 From: timntorres Date: Sat, 29 Oct 2022 00:36:51 -0700 Subject: Revert "Explicitly state when Hypernet is none." --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 377c0978..04fdda7c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -395,7 +395,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Size": f"{p.width}x{p.height}", "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), - "Hypernet": ("None" if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name), + "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name), "Hypernetwork strength": (None if shared.loaded_hypernetwork is None else shared.opts.sd_hypernetwork_strength), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), -- cgit v1.2.1 From 35c45df28b303a05d56a13cb56d4046f08cf8c25 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 29 Oct 2022 10:56:19 +0300 Subject: =?UTF-8?q?fix=20broken=20=E2=86=99=20button,=20fix=20field=20past?= =?UTF-8?q?e=20ignoring=20most=20of=20useful=20fields=20for=20for=20#3768?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- modules/generation_parameters_copypaste.py | 36 ++++++++++++++++++------- modules/ui.py | 43 +++++++++++------------------- 2 files changed, 41 insertions(+), 38 deletions(-) (limited to 'modules') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index d590e9ee..bbaad42e 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -6,7 +6,7 @@ import gradio as gr from modules.shared import script_path from modules import shared import tempfile -from PIL import Image, PngImagePlugin +from PIL import Image re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)' re_param = re.compile(re_param_code) @@ -61,6 +61,24 @@ def add_paste_fields(tabname, init_img, fields): modules.ui.img2img_paste_fields = fields +def integrate_settings_paste_fields(component_dict): + from modules import ui + + settings_map = { + 'sd_hypernetwork': 'Hypernet', + 'CLIP_stop_at_last_layers': 'Clip skip', + 'sd_model_checkpoint': 'Model hash', + } + settings_paste_fields = [ + (component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None))) + for k, v in settings_map.items() + ] + + for tabname, info in paste_fields.items(): + if info["fields"] is not None: + info["fields"] += settings_paste_fields + + def create_buttons(tabs_list): buttons = {} for tab in tabs_list: @@ -87,24 +105,22 @@ def run_bind(): ) else: button.click( - fn=lambda x:x, + fn=lambda x: x, inputs=[send_image], outputs=[paste_fields[tab]["init_img"]], ) if send_generate_info and paste_fields[tab]["fields"] is not None: - paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Size-1', 'Size-2'] - if shared.opts.send_seed: - paste_field_names += ["Seed"] if send_generate_info in paste_fields: + paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Size-1', 'Size-2'] + (["Seed"] if shared.opts.send_seed else []) + button.click( - fn=lambda *x:x, - inputs=[field for field,name in paste_fields[send_generate_info]["fields"] if name in paste_field_names], - outputs=[field for field,name in paste_fields[tab]["fields"] if name in paste_field_names], + fn=lambda *x: x, + inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names], + outputs=[field for field, name in paste_fields[tab]["fields"] if name in paste_field_names], ) - else: - connect_paste(button, [(field, name) for field, name in paste_fields[tab]["fields"] if name in paste_field_names], send_generate_info) + connect_paste(button, paste_fields[tab]["fields"], send_generate_info) button.click( fn=None, diff --git a/modules/ui.py b/modules/ui.py index 3c34eca0..5055ca64 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -589,6 +589,7 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele ) return refresh_button + def create_output_panel(tabname, outdir): def open_folder(f): if not os.path.exists(f): @@ -716,6 +717,7 @@ def create_ui(wrap_gradio_gpu_call): custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False) txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples) + parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) @@ -784,7 +786,7 @@ def create_ui(wrap_gradio_gpu_call): ] ) - parameters_copypaste.add_paste_fields("txt2img", None, [ + txt2img_paste_fields = [ (txt2img_prompt, "Prompt"), (txt2img_negative_prompt, "Negative prompt"), (steps, "Steps"), @@ -805,7 +807,8 @@ def create_ui(wrap_gradio_gpu_call): (firstphase_width, "First pass size-1"), (firstphase_height, "First pass size-2"), *modules.scripts.scripts_txt2img.infotext_fields - ]) + ] + parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) txt2img_preview_params = [ txt2img_prompt, @@ -893,6 +896,7 @@ def create_ui(wrap_gradio_gpu_call): custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True) img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples) + parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) @@ -1038,7 +1042,6 @@ def create_ui(wrap_gradio_gpu_call): parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) - with gr.Blocks(analytics_enabled=False) as extras_interface: with gr.Row().style(equal_height=False): with gr.Column(variant='panel'): @@ -1050,12 +1053,8 @@ def create_ui(wrap_gradio_gpu_call): image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file") with gr.TabItem('Batch from Directory'): - extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, - placeholder="A directory on the same machine where the server is running." - ) - extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, - placeholder="Leave blank to save images to the default path." - ) + extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.") + extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.") show_extras_results = gr.Checkbox(label='Show result images', value=True) with gr.Tabs(elem_id="extras_resize_mode"): @@ -1087,7 +1086,6 @@ def create_ui(wrap_gradio_gpu_call): submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') - result_images, html_info_x, html_info = create_output_panel("extras", opts.outdir_extras_samples) submit.click( @@ -1121,7 +1119,6 @@ def create_ui(wrap_gradio_gpu_call): ) parameters_copypaste.add_paste_fields("extras", extras_image, None) - extras_image.change( fn=modules.extras.clear_cache, inputs=[], outputs=[] @@ -1587,9 +1584,6 @@ def create_ui(wrap_gradio_gpu_call): if column is not None: column.__exit__() - - - interfaces = [ (txt2img_interface, "txt2img", "txt2img"), (img2img_interface, "img2img", "img2img"), @@ -1599,10 +1593,6 @@ def create_ui(wrap_gradio_gpu_call): (train_interface, "Train", "ti"), ] - interfaces += script_callbacks.ui_tabs_callback() - - interfaces += [(settings_interface, "Settings", "settings")] - css = "" for cssfile in modules.scripts.list_files_with_name("style.css"): @@ -1619,6 +1609,9 @@ def create_ui(wrap_gradio_gpu_call): if not cmd_opts.no_progressbar_hiding: css += css_hide_progressbar + interfaces += script_callbacks.ui_tabs_callback() + interfaces += [(settings_interface, "Settings", "settings")] + with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: with gr.Row(elem_id="quicksettings"): for i, k, item in quicksettings_list: @@ -1627,6 +1620,9 @@ def create_ui(wrap_gradio_gpu_call): settings_interface.gradio_ref = demo + parameters_copypaste.integrate_settings_paste_fields(component_dict) + parameters_copypaste.run_bind() + with gr.Tabs(elem_id="tabs") as tabs: for interface, label, ifid in interfaces: with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): @@ -1681,15 +1677,6 @@ def create_ui(wrap_gradio_gpu_call): ] ) - - settings_map = { - 'sd_hypernetwork': 'Hypernet', - 'CLIP_stop_at_last_layers': 'Clip skip', - 'sd_model_checkpoint': 'Model hash', - } - - parameters_copypaste.run_bind() - ui_config_file = cmd_opts.ui_config_file ui_settings = {} settings_count = len(ui_settings) @@ -1708,7 +1695,7 @@ def create_ui(wrap_gradio_gpu_call): def apply_field(obj, field, condition=None, init_field=None): key = path + "/" + field - if getattr(obj,'custom_script_source',None) is not None: + if getattr(obj, 'custom_script_source', None) is not None: key = 'customscript/' + obj.custom_script_source + '/' + key if getattr(obj, 'do_not_save_to_config', False): -- cgit v1.2.1 From a5f3adbdd7d9b8245f7782216ac48913660e6bb5 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sat, 29 Oct 2022 15:37:24 +0700 Subject: Allow trailing comma in learning rate --- modules/textual_inversion/learn_schedule.py | 33 +++++++++++++++++------------ 1 file changed, 20 insertions(+), 13 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py index 3a736065..76e611b6 100644 --- a/modules/textual_inversion/learn_schedule.py +++ b/modules/textual_inversion/learn_schedule.py @@ -11,23 +11,30 @@ class LearnScheduleIterator: self.rates = [] self.it = 0 self.maxit = 0 - for i, pair in enumerate(pairs): - tmp = pair.split(':') - if len(tmp) == 2: - step = int(tmp[1]) - if step > cur_step: - self.rates.append((float(tmp[0]), min(step, max_steps))) - self.maxit += 1 - if step > max_steps: + try: + for i, pair in enumerate(pairs): + if not pair.strip(): + continue + tmp = pair.split(':') + if len(tmp) == 2: + step = int(tmp[1]) + if step > cur_step: + self.rates.append((float(tmp[0]), min(step, max_steps))) + self.maxit += 1 + if step > max_steps: + return + elif step == -1: + self.rates.append((float(tmp[0]), max_steps)) + self.maxit += 1 return - elif step == -1: + else: self.rates.append((float(tmp[0]), max_steps)) self.maxit += 1 return - else: - self.rates.append((float(tmp[0]), max_steps)) - self.maxit += 1 - return + assert self.rates + except (ValueError, AssertionError): + raise Exception("Invalid learning rate schedule") + def __iter__(self): return self -- cgit v1.2.1 From ef4c94e1cfe66299227aa95a28c2380d21cb1600 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sat, 29 Oct 2022 15:42:51 +0700 Subject: Improve lr schedule error message --- modules/textual_inversion/learn_schedule.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py index 76e611b6..dd0c0ad1 100644 --- a/modules/textual_inversion/learn_schedule.py +++ b/modules/textual_inversion/learn_schedule.py @@ -4,7 +4,7 @@ import tqdm class LearnScheduleIterator: def __init__(self, learn_rate, max_steps, cur_step=0): """ - specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000 + specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000 """ pairs = learn_rate.split(',') @@ -33,7 +33,7 @@ class LearnScheduleIterator: return assert self.rates except (ValueError, AssertionError): - raise Exception("Invalid learning rate schedule") + raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') def __iter__(self): -- cgit v1.2.1 From ab27c111d06ec920791c73eea25ad9a61671852e Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sat, 29 Oct 2022 18:09:17 +0700 Subject: Add input validations before loading dataset for training --- modules/hypernetworks/hypernetwork.py | 38 +++++++++++--------- modules/textual_inversion/textual_inversion.py | 48 +++++++++++++++++++------- 2 files changed, 58 insertions(+), 28 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 2e84583b..38f35c58 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -332,7 +332,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images - assert hypernetwork_name, 'hypernetwork not selected' + save_hypernetwork_every = save_hypernetwork_every or 0 + create_image_every = create_image_every or 0 + textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork") path = shared.hypernetworks.get(hypernetwork_name, None) shared.loaded_hypernetwork = Hypernetwork() @@ -358,39 +360,43 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log else: images_dir = None + hypernetwork = shared.loaded_hypernetwork + + ititial_step = hypernetwork.step or 0 + if ititial_step > steps: + shared.state.textinfo = f"Model has already been trained beyond specified max steps" + return hypernetwork, filename + + scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) + + # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) + if unload: shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu) - hypernetwork = shared.loaded_hypernetwork - weights = hypernetwork.weights() - for weight in weights: - weight.requires_grad = True - size = len(ds.indexes) loss_dict = defaultdict(lambda : deque(maxlen = 1024)) losses = torch.zeros((size,)) previous_mean_losses = [0] previous_mean_loss = 0 print("Mean loss of {} elements".format(size)) - - last_saved_file = "" - last_saved_image = "" - forced_filename = "" - - ititial_step = hypernetwork.step or 0 - if ititial_step > steps: - return hypernetwork, filename - - scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) + + weights = hypernetwork.weights() + for weight in weights: + weight.requires_grad = True # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc... optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) steps_without_grad = 0 + last_saved_file = "" + last_saved_image = "" + forced_filename = "" + pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) for i, entries in pbar: hypernetwork.step = i + ititial_step diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 17dfb223..44f06443 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -204,9 +204,30 @@ def write_loss(log_directory, filename, step, epoch_len, values): **values, }) +def validate_train_inputs(model_name, learn_rate, batch_size, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"): + assert model_name, f"{name} not selected" + assert learn_rate, "Learning rate is empty or 0" + assert isinstance(batch_size, int), "Batch size must be integer" + assert batch_size > 0, "Batch size must be positive" + assert data_root, "Dataset directory is empty" + assert os.path.isdir(data_root), "Dataset directory doesn't exist" + assert os.listdir(data_root), "Dataset directory is empty" + assert template_file, "Prompt template file is empty" + assert os.path.isfile(template_file), "Prompt template file doesn't exist" + assert steps, "Max steps is empty or 0" + assert isinstance(steps, int), "Max steps must be integer" + assert steps > 0 , "Max steps must be positive" + assert isinstance(save_model_every, int), "Save {name} must be integer" + assert save_model_every >= 0 , "Save {name} must be positive or 0" + assert isinstance(create_image_every, int), "Create image must be integer" + assert create_image_every >= 0 , "Create image must be positive or 0" + if save_model_every or create_image_every: + assert log_directory, "Log directory is empty" def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): - assert embedding_name, 'embedding not selected' + save_embedding_every = save_embedding_every or 0 + create_image_every = create_image_every or 0 + validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") shared.state.textinfo = "Initializing textual inversion training..." shared.state.job_count = steps @@ -232,17 +253,27 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc os.makedirs(images_embeds_dir, exist_ok=True) else: images_embeds_dir = None - + cond_model = shared.sd_model.cond_stage_model + hijack = sd_hijack.model_hijack + + embedding = hijack.embedding_db.word_embeddings[embedding_name] + + ititial_step = embedding.step or 0 + if ititial_step > steps: + shared.state.textinfo = f"Model has already been trained beyond specified max steps" + return embedding, filename + + scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) + + # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size) - hijack = sd_hijack.model_hijack - - embedding = hijack.embedding_db.word_embeddings[embedding_name] embedding.vec.requires_grad = True + optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) losses = torch.zeros((32,)) @@ -251,13 +282,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc forced_filename = "" embedding_yet_to_be_embedded = False - ititial_step = embedding.step or 0 - if ititial_step > steps: - return embedding, filename - - scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) - optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) - pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) for i, entries in pbar: embedding.step = i + ititial_step -- cgit v1.2.1 From 3ce2bfdf95bd5f26d0f6e250e67338ada91980d1 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sat, 29 Oct 2022 19:43:21 +0700 Subject: Add cleanup after training --- modules/hypernetworks/hypernetwork.py | 201 +++++++++++++------------ modules/textual_inversion/textual_inversion.py | 185 ++++++++++++----------- 2 files changed, 200 insertions(+), 186 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 38f35c58..170d5ea4 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -398,110 +398,112 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log forced_filename = "" pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) - for i, entries in pbar: - hypernetwork.step = i + ititial_step - if len(loss_dict) > 0: - previous_mean_losses = [i[-1] for i in loss_dict.values()] - previous_mean_loss = mean(previous_mean_losses) - - scheduler.apply(optimizer, hypernetwork.step) - if scheduler.finished: - break - - if shared.state.interrupted: - break - - with torch.autocast("cuda"): - c = stack_conds([entry.cond for entry in entries]).to(devices.device) - # c = torch.vstack([entry.cond for entry in entries]).to(devices.device) - x = torch.stack([entry.latent for entry in entries]).to(devices.device) - loss = shared.sd_model(x, c)[0] - del x - del c - - losses[hypernetwork.step % losses.shape[0]] = loss.item() - for entry in entries: - loss_dict[entry.filename].append(loss.item()) - - optimizer.zero_grad() - weights[0].grad = None - loss.backward() - if weights[0].grad is None: - steps_without_grad += 1 + try: + for i, entries in pbar: + hypernetwork.step = i + ititial_step + if len(loss_dict) > 0: + previous_mean_losses = [i[-1] for i in loss_dict.values()] + previous_mean_loss = mean(previous_mean_losses) + + scheduler.apply(optimizer, hypernetwork.step) + if scheduler.finished: + break + + if shared.state.interrupted: + break + + with torch.autocast("cuda"): + c = stack_conds([entry.cond for entry in entries]).to(devices.device) + # c = torch.vstack([entry.cond for entry in entries]).to(devices.device) + x = torch.stack([entry.latent for entry in entries]).to(devices.device) + loss = shared.sd_model(x, c)[0] + del x + del c + + losses[hypernetwork.step % losses.shape[0]] = loss.item() + for entry in entries: + loss_dict[entry.filename].append(loss.item()) + + optimizer.zero_grad() + weights[0].grad = None + loss.backward() + + if weights[0].grad is None: + steps_without_grad += 1 + else: + steps_without_grad = 0 + assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue' + + optimizer.step() + + steps_done = hypernetwork.step + 1 + + if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): + raise RuntimeError("Loss diverged.") + + if len(previous_mean_losses) > 1: + std = stdev(previous_mean_losses) else: - steps_without_grad = 0 - assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue' - - optimizer.step() - - steps_done = hypernetwork.step + 1 - - if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): - raise RuntimeError("Loss diverged.") - - if len(previous_mean_losses) > 1: - std = stdev(previous_mean_losses) - else: - std = 0 - dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})" - pbar.set_description(dataset_loss_info) - - if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: - # Before saving, change name to match current checkpoint. - hypernetwork.name = f'{hypernetwork_name}-{steps_done}' - last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt') - hypernetwork.save(last_saved_file) - - textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { - "loss": f"{previous_mean_loss:.7f}", - "learn_rate": scheduler.learn_rate - }) - - if images_dir is not None and steps_done % create_image_every == 0: - forced_filename = f'{hypernetwork_name}-{steps_done}' - last_saved_image = os.path.join(images_dir, forced_filename) - - optimizer.zero_grad() - shared.sd_model.cond_stage_model.to(devices.device) - shared.sd_model.first_stage_model.to(devices.device) - - p = processing.StableDiffusionProcessingTxt2Img( - sd_model=shared.sd_model, - do_not_save_grid=True, - do_not_save_samples=True, - ) + std = 0 + dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})" + pbar.set_description(dataset_loss_info) + + if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: + # Before saving, change name to match current checkpoint. + hypernetwork.name = f'{hypernetwork_name}-{steps_done}' + last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt') + hypernetwork.save(last_saved_file) + + textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { + "loss": f"{previous_mean_loss:.7f}", + "learn_rate": scheduler.learn_rate + }) + + if images_dir is not None and steps_done % create_image_every == 0: + forced_filename = f'{hypernetwork_name}-{steps_done}' + last_saved_image = os.path.join(images_dir, forced_filename) + + optimizer.zero_grad() + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) + + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + do_not_save_grid=True, + do_not_save_samples=True, + ) - if preview_from_txt2img: - p.prompt = preview_prompt - p.negative_prompt = preview_negative_prompt - p.steps = preview_steps - p.sampler_index = preview_sampler_index - p.cfg_scale = preview_cfg_scale - p.seed = preview_seed - p.width = preview_width - p.height = preview_height - else: - p.prompt = entries[0].cond_text - p.steps = 20 + if preview_from_txt2img: + p.prompt = preview_prompt + p.negative_prompt = preview_negative_prompt + p.steps = preview_steps + p.sampler_index = preview_sampler_index + p.cfg_scale = preview_cfg_scale + p.seed = preview_seed + p.width = preview_width + p.height = preview_height + else: + p.prompt = entries[0].cond_text + p.steps = 20 - preview_text = p.prompt + preview_text = p.prompt - processed = processing.process_images(p) - image = processed.images[0] if len(processed.images)>0 else None + processed = processing.process_images(p) + image = processed.images[0] if len(processed.images)>0 else None - if unload: - shared.sd_model.cond_stage_model.to(devices.cpu) - shared.sd_model.first_stage_model.to(devices.cpu) + if unload: + shared.sd_model.cond_stage_model.to(devices.cpu) + shared.sd_model.first_stage_model.to(devices.cpu) - if image is not None: - shared.state.current_image = image - last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) - last_saved_image += f", prompt: {preview_text}" + if image is not None: + shared.state.current_image = image + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) + last_saved_image += f", prompt: {preview_text}" - shared.state.job_no = hypernetwork.step + shared.state.job_no = hypernetwork.step - shared.state.textinfo = f""" + shared.state.textinfo = f"""

Loss: {previous_mean_loss:.7f}
Step: {hypernetwork.step}
@@ -510,7 +512,14 @@ Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

""" - + finally: + if weights: + for weight in weights: + weight.requires_grad = False + if unload: + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) + report_statistics(loss_dict) checkpoint = sd_models.select_checkpoint() diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 44f06443..fd7f0897 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -283,111 +283,113 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc embedding_yet_to_be_embedded = False pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) - for i, entries in pbar: - embedding.step = i + ititial_step - scheduler.apply(optimizer, embedding.step) - if scheduler.finished: - break - - if shared.state.interrupted: - break - - with torch.autocast("cuda"): - c = cond_model([entry.cond_text for entry in entries]) - x = torch.stack([entry.latent for entry in entries]).to(devices.device) - loss = shared.sd_model(x, c)[0] - del x - - losses[embedding.step % losses.shape[0]] = loss.item() - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - steps_done = embedding.step + 1 - - epoch_num = embedding.step // len(ds) - epoch_step = embedding.step % len(ds) - - pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}") - - if embedding_dir is not None and steps_done % save_embedding_every == 0: - # Before saving, change name to match current checkpoint. - embedding.name = f'{embedding_name}-{steps_done}' - last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt') - embedding.save(last_saved_file) - embedding_yet_to_be_embedded = True - - write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { - "loss": f"{losses.mean():.7f}", - "learn_rate": scheduler.learn_rate - }) - - if images_dir is not None and steps_done % create_image_every == 0: - forced_filename = f'{embedding_name}-{steps_done}' - last_saved_image = os.path.join(images_dir, forced_filename) - p = processing.StableDiffusionProcessingTxt2Img( - sd_model=shared.sd_model, - do_not_save_grid=True, - do_not_save_samples=True, - do_not_reload_embeddings=True, - ) - - if preview_from_txt2img: - p.prompt = preview_prompt - p.negative_prompt = preview_negative_prompt - p.steps = preview_steps - p.sampler_index = preview_sampler_index - p.cfg_scale = preview_cfg_scale - p.seed = preview_seed - p.width = preview_width - p.height = preview_height - else: - p.prompt = entries[0].cond_text - p.steps = 20 - p.width = training_width - p.height = training_height + try: + for i, entries in pbar: + embedding.step = i + ititial_step + + scheduler.apply(optimizer, embedding.step) + if scheduler.finished: + break + + if shared.state.interrupted: + break + + with torch.autocast("cuda"): + c = cond_model([entry.cond_text for entry in entries]) + x = torch.stack([entry.latent for entry in entries]).to(devices.device) + loss = shared.sd_model(x, c)[0] + del x + + losses[embedding.step % losses.shape[0]] = loss.item() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + steps_done = embedding.step + 1 + + epoch_num = embedding.step // len(ds) + epoch_step = embedding.step % len(ds) + + pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}") + + if embedding_dir is not None and steps_done % save_embedding_every == 0: + # Before saving, change name to match current checkpoint. + embedding.name = f'{embedding_name}-{steps_done}' + last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt') + embedding.save(last_saved_file) + embedding_yet_to_be_embedded = True + + write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { + "loss": f"{losses.mean():.7f}", + "learn_rate": scheduler.learn_rate + }) + + if images_dir is not None and steps_done % create_image_every == 0: + forced_filename = f'{embedding_name}-{steps_done}' + last_saved_image = os.path.join(images_dir, forced_filename) + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + do_not_save_grid=True, + do_not_save_samples=True, + do_not_reload_embeddings=True, + ) + + if preview_from_txt2img: + p.prompt = preview_prompt + p.negative_prompt = preview_negative_prompt + p.steps = preview_steps + p.sampler_index = preview_sampler_index + p.cfg_scale = preview_cfg_scale + p.seed = preview_seed + p.width = preview_width + p.height = preview_height + else: + p.prompt = entries[0].cond_text + p.steps = 20 + p.width = training_width + p.height = training_height - preview_text = p.prompt + preview_text = p.prompt - processed = processing.process_images(p) - image = processed.images[0] + processed = processing.process_images(p) + image = processed.images[0] - shared.state.current_image = image + shared.state.current_image = image - if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: + if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: - last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png') + last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png') - info = PngImagePlugin.PngInfo() - data = torch.load(last_saved_file) - info.add_text("sd-ti-embedding", embedding_to_b64(data)) + info = PngImagePlugin.PngInfo() + data = torch.load(last_saved_file) + info.add_text("sd-ti-embedding", embedding_to_b64(data)) - title = "<{}>".format(data.get('name', '???')) + title = "<{}>".format(data.get('name', '???')) - try: - vectorSize = list(data['string_to_param'].values())[0].shape[0] - except Exception as e: - vectorSize = '?' + try: + vectorSize = list(data['string_to_param'].values())[0].shape[0] + except Exception as e: + vectorSize = '?' - checkpoint = sd_models.select_checkpoint() - footer_left = checkpoint.model_name - footer_mid = '[{}]'.format(checkpoint.hash) - footer_right = '{}v {}s'.format(vectorSize, steps_done) + checkpoint = sd_models.select_checkpoint() + footer_left = checkpoint.model_name + footer_mid = '[{}]'.format(checkpoint.hash) + footer_right = '{}v {}s'.format(vectorSize, steps_done) - captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) - captioned_image = insert_image_data_embed(captioned_image, data) + captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) + captioned_image = insert_image_data_embed(captioned_image, data) - captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) - embedding_yet_to_be_embedded = False + captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) + embedding_yet_to_be_embedded = False - last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) - last_saved_image += f", prompt: {preview_text}" + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) + last_saved_image += f", prompt: {preview_text}" - shared.state.job_no = embedding.step + shared.state.job_no = embedding.step - shared.state.textinfo = f""" + shared.state.textinfo = f"""

Loss: {losses.mean():.7f}
Step: {embedding.step}
@@ -396,6 +398,9 @@ Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

""" + finally: + if embedding and embedding.vec is not None: + embedding.vec.requires_grad = False checkpoint = sd_models.select_checkpoint() -- cgit v1.2.1 From a27d19de2eff633b6a39f9f4a5c0f2d6abb81bb5 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sat, 29 Oct 2022 19:44:05 +0700 Subject: Additional assert on dataset --- modules/textual_inversion/dataset.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 8bb00d27..ad726577 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -42,6 +42,8 @@ class PersonalizedBase(Dataset): self.lines = lines assert data_root, 'dataset directory not specified' + assert os.path.isdir(data_root), "Dataset directory doesn't exist" + assert os.listdir(data_root), "Dataset directory is empty" cond_model = shared.sd_model.cond_stage_model -- cgit v1.2.1 From ab05a74ead9fabb45dd099990e34061c7eb02ca3 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 30 Oct 2022 00:32:02 +0700 Subject: Revert "Add cleanup after training" This reverts commit 3ce2bfdf95bd5f26d0f6e250e67338ada91980d1. --- modules/hypernetworks/hypernetwork.py | 201 ++++++++++++------------- modules/textual_inversion/textual_inversion.py | 185 +++++++++++------------ 2 files changed, 186 insertions(+), 200 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 170d5ea4..38f35c58 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -398,112 +398,110 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log forced_filename = "" pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) - - try: - for i, entries in pbar: - hypernetwork.step = i + ititial_step - if len(loss_dict) > 0: - previous_mean_losses = [i[-1] for i in loss_dict.values()] - previous_mean_loss = mean(previous_mean_losses) - - scheduler.apply(optimizer, hypernetwork.step) - if scheduler.finished: - break - - if shared.state.interrupted: - break - - with torch.autocast("cuda"): - c = stack_conds([entry.cond for entry in entries]).to(devices.device) - # c = torch.vstack([entry.cond for entry in entries]).to(devices.device) - x = torch.stack([entry.latent for entry in entries]).to(devices.device) - loss = shared.sd_model(x, c)[0] - del x - del c - - losses[hypernetwork.step % losses.shape[0]] = loss.item() - for entry in entries: - loss_dict[entry.filename].append(loss.item()) - - optimizer.zero_grad() - weights[0].grad = None - loss.backward() - - if weights[0].grad is None: - steps_without_grad += 1 - else: - steps_without_grad = 0 - assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue' - - optimizer.step() - - steps_done = hypernetwork.step + 1 - - if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): - raise RuntimeError("Loss diverged.") + for i, entries in pbar: + hypernetwork.step = i + ititial_step + if len(loss_dict) > 0: + previous_mean_losses = [i[-1] for i in loss_dict.values()] + previous_mean_loss = mean(previous_mean_losses) - if len(previous_mean_losses) > 1: - std = stdev(previous_mean_losses) + scheduler.apply(optimizer, hypernetwork.step) + if scheduler.finished: + break + + if shared.state.interrupted: + break + + with torch.autocast("cuda"): + c = stack_conds([entry.cond for entry in entries]).to(devices.device) + # c = torch.vstack([entry.cond for entry in entries]).to(devices.device) + x = torch.stack([entry.latent for entry in entries]).to(devices.device) + loss = shared.sd_model(x, c)[0] + del x + del c + + losses[hypernetwork.step % losses.shape[0]] = loss.item() + for entry in entries: + loss_dict[entry.filename].append(loss.item()) + + optimizer.zero_grad() + weights[0].grad = None + loss.backward() + + if weights[0].grad is None: + steps_without_grad += 1 else: - std = 0 - dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})" - pbar.set_description(dataset_loss_info) - - if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: - # Before saving, change name to match current checkpoint. - hypernetwork.name = f'{hypernetwork_name}-{steps_done}' - last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt') - hypernetwork.save(last_saved_file) - - textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { - "loss": f"{previous_mean_loss:.7f}", - "learn_rate": scheduler.learn_rate - }) - - if images_dir is not None and steps_done % create_image_every == 0: - forced_filename = f'{hypernetwork_name}-{steps_done}' - last_saved_image = os.path.join(images_dir, forced_filename) - - optimizer.zero_grad() - shared.sd_model.cond_stage_model.to(devices.device) - shared.sd_model.first_stage_model.to(devices.device) - - p = processing.StableDiffusionProcessingTxt2Img( - sd_model=shared.sd_model, - do_not_save_grid=True, - do_not_save_samples=True, - ) + steps_without_grad = 0 + assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue' - if preview_from_txt2img: - p.prompt = preview_prompt - p.negative_prompt = preview_negative_prompt - p.steps = preview_steps - p.sampler_index = preview_sampler_index - p.cfg_scale = preview_cfg_scale - p.seed = preview_seed - p.width = preview_width - p.height = preview_height - else: - p.prompt = entries[0].cond_text - p.steps = 20 + optimizer.step() - preview_text = p.prompt + steps_done = hypernetwork.step + 1 - processed = processing.process_images(p) - image = processed.images[0] if len(processed.images)>0 else None + if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): + raise RuntimeError("Loss diverged.") + + if len(previous_mean_losses) > 1: + std = stdev(previous_mean_losses) + else: + std = 0 + dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})" + pbar.set_description(dataset_loss_info) + + if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: + # Before saving, change name to match current checkpoint. + hypernetwork.name = f'{hypernetwork_name}-{steps_done}' + last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt') + hypernetwork.save(last_saved_file) + + textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { + "loss": f"{previous_mean_loss:.7f}", + "learn_rate": scheduler.learn_rate + }) + + if images_dir is not None and steps_done % create_image_every == 0: + forced_filename = f'{hypernetwork_name}-{steps_done}' + last_saved_image = os.path.join(images_dir, forced_filename) + + optimizer.zero_grad() + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) - if unload: - shared.sd_model.cond_stage_model.to(devices.cpu) - shared.sd_model.first_stage_model.to(devices.cpu) + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + do_not_save_grid=True, + do_not_save_samples=True, + ) - if image is not None: - shared.state.current_image = image - last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) - last_saved_image += f", prompt: {preview_text}" + if preview_from_txt2img: + p.prompt = preview_prompt + p.negative_prompt = preview_negative_prompt + p.steps = preview_steps + p.sampler_index = preview_sampler_index + p.cfg_scale = preview_cfg_scale + p.seed = preview_seed + p.width = preview_width + p.height = preview_height + else: + p.prompt = entries[0].cond_text + p.steps = 20 + + preview_text = p.prompt + + processed = processing.process_images(p) + image = processed.images[0] if len(processed.images)>0 else None + + if unload: + shared.sd_model.cond_stage_model.to(devices.cpu) + shared.sd_model.first_stage_model.to(devices.cpu) - shared.state.job_no = hypernetwork.step + if image is not None: + shared.state.current_image = image + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) + last_saved_image += f", prompt: {preview_text}" - shared.state.textinfo = f""" + shared.state.job_no = hypernetwork.step + + shared.state.textinfo = f"""

Loss: {previous_mean_loss:.7f}
Step: {hypernetwork.step}
@@ -512,14 +510,7 @@ Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

""" - finally: - if weights: - for weight in weights: - weight.requires_grad = False - if unload: - shared.sd_model.cond_stage_model.to(devices.device) - shared.sd_model.first_stage_model.to(devices.device) - + report_statistics(loss_dict) checkpoint = sd_models.select_checkpoint() diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index fd7f0897..44f06443 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -283,113 +283,111 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc embedding_yet_to_be_embedded = False pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) + for i, entries in pbar: + embedding.step = i + ititial_step - try: - for i, entries in pbar: - embedding.step = i + ititial_step - - scheduler.apply(optimizer, embedding.step) - if scheduler.finished: - break - - if shared.state.interrupted: - break - - with torch.autocast("cuda"): - c = cond_model([entry.cond_text for entry in entries]) - x = torch.stack([entry.latent for entry in entries]).to(devices.device) - loss = shared.sd_model(x, c)[0] - del x - - losses[embedding.step % losses.shape[0]] = loss.item() - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - steps_done = embedding.step + 1 - - epoch_num = embedding.step // len(ds) - epoch_step = embedding.step % len(ds) - - pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}") - - if embedding_dir is not None and steps_done % save_embedding_every == 0: - # Before saving, change name to match current checkpoint. - embedding.name = f'{embedding_name}-{steps_done}' - last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt') - embedding.save(last_saved_file) - embedding_yet_to_be_embedded = True - - write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { - "loss": f"{losses.mean():.7f}", - "learn_rate": scheduler.learn_rate - }) - - if images_dir is not None and steps_done % create_image_every == 0: - forced_filename = f'{embedding_name}-{steps_done}' - last_saved_image = os.path.join(images_dir, forced_filename) - p = processing.StableDiffusionProcessingTxt2Img( - sd_model=shared.sd_model, - do_not_save_grid=True, - do_not_save_samples=True, - do_not_reload_embeddings=True, - ) - - if preview_from_txt2img: - p.prompt = preview_prompt - p.negative_prompt = preview_negative_prompt - p.steps = preview_steps - p.sampler_index = preview_sampler_index - p.cfg_scale = preview_cfg_scale - p.seed = preview_seed - p.width = preview_width - p.height = preview_height - else: - p.prompt = entries[0].cond_text - p.steps = 20 - p.width = training_width - p.height = training_height + scheduler.apply(optimizer, embedding.step) + if scheduler.finished: + break + + if shared.state.interrupted: + break + + with torch.autocast("cuda"): + c = cond_model([entry.cond_text for entry in entries]) + x = torch.stack([entry.latent for entry in entries]).to(devices.device) + loss = shared.sd_model(x, c)[0] + del x + + losses[embedding.step % losses.shape[0]] = loss.item() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + steps_done = embedding.step + 1 + + epoch_num = embedding.step // len(ds) + epoch_step = embedding.step % len(ds) + + pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}") + + if embedding_dir is not None and steps_done % save_embedding_every == 0: + # Before saving, change name to match current checkpoint. + embedding.name = f'{embedding_name}-{steps_done}' + last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt') + embedding.save(last_saved_file) + embedding_yet_to_be_embedded = True + + write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { + "loss": f"{losses.mean():.7f}", + "learn_rate": scheduler.learn_rate + }) + + if images_dir is not None and steps_done % create_image_every == 0: + forced_filename = f'{embedding_name}-{steps_done}' + last_saved_image = os.path.join(images_dir, forced_filename) + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + do_not_save_grid=True, + do_not_save_samples=True, + do_not_reload_embeddings=True, + ) + + if preview_from_txt2img: + p.prompt = preview_prompt + p.negative_prompt = preview_negative_prompt + p.steps = preview_steps + p.sampler_index = preview_sampler_index + p.cfg_scale = preview_cfg_scale + p.seed = preview_seed + p.width = preview_width + p.height = preview_height + else: + p.prompt = entries[0].cond_text + p.steps = 20 + p.width = training_width + p.height = training_height - preview_text = p.prompt + preview_text = p.prompt - processed = processing.process_images(p) - image = processed.images[0] + processed = processing.process_images(p) + image = processed.images[0] - shared.state.current_image = image + shared.state.current_image = image - if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: + if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: - last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png') + last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png') - info = PngImagePlugin.PngInfo() - data = torch.load(last_saved_file) - info.add_text("sd-ti-embedding", embedding_to_b64(data)) + info = PngImagePlugin.PngInfo() + data = torch.load(last_saved_file) + info.add_text("sd-ti-embedding", embedding_to_b64(data)) - title = "<{}>".format(data.get('name', '???')) + title = "<{}>".format(data.get('name', '???')) - try: - vectorSize = list(data['string_to_param'].values())[0].shape[0] - except Exception as e: - vectorSize = '?' + try: + vectorSize = list(data['string_to_param'].values())[0].shape[0] + except Exception as e: + vectorSize = '?' - checkpoint = sd_models.select_checkpoint() - footer_left = checkpoint.model_name - footer_mid = '[{}]'.format(checkpoint.hash) - footer_right = '{}v {}s'.format(vectorSize, steps_done) + checkpoint = sd_models.select_checkpoint() + footer_left = checkpoint.model_name + footer_mid = '[{}]'.format(checkpoint.hash) + footer_right = '{}v {}s'.format(vectorSize, steps_done) - captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) - captioned_image = insert_image_data_embed(captioned_image, data) + captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) + captioned_image = insert_image_data_embed(captioned_image, data) - captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) - embedding_yet_to_be_embedded = False + captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) + embedding_yet_to_be_embedded = False - last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) - last_saved_image += f", prompt: {preview_text}" + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) + last_saved_image += f", prompt: {preview_text}" - shared.state.job_no = embedding.step + shared.state.job_no = embedding.step - shared.state.textinfo = f""" + shared.state.textinfo = f"""

Loss: {losses.mean():.7f}
Step: {embedding.step}
@@ -398,9 +396,6 @@ Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

""" - finally: - if embedding and embedding.vec is not None: - embedding.vec.requires_grad = False checkpoint = sd_models.select_checkpoint() -- cgit v1.2.1 From a07f054c86f33360ff620d6a3fffdee366ab2d99 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 30 Oct 2022 00:49:29 +0700 Subject: Add missing info on hypernetwork/embedding model log Mentioned here: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/1528#discussioncomment-3991513 Also group the saving into one --- modules/hypernetworks/hypernetwork.py | 31 +++++++++++++------- modules/textual_inversion/textual_inversion.py | 39 +++++++++++++++++--------- 2 files changed, 47 insertions(+), 23 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 38f35c58..86daf825 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -361,6 +361,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log images_dir = None hypernetwork = shared.loaded_hypernetwork + checkpoint = sd_models.select_checkpoint() ititial_step = hypernetwork.step or 0 if ititial_step > steps: @@ -449,9 +450,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: # Before saving, change name to match current checkpoint. - hypernetwork.name = f'{hypernetwork_name}-{steps_done}' - last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt') - hypernetwork.save(last_saved_file) + hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' + last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt') + save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file) textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { "loss": f"{previous_mean_loss:.7f}", @@ -512,13 +513,23 @@ Last saved image: {html.escape(last_saved_image)}
""" report_statistics(loss_dict) - checkpoint = sd_models.select_checkpoint() - hypernetwork.sd_checkpoint = checkpoint.hash - hypernetwork.sd_checkpoint_name = checkpoint.model_name - # Before saving for the last time, change name back to the base name (as opposed to the save_hypernetwork_every step-suffixed naming convention). - hypernetwork.name = hypernetwork_name - filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork.name}.pt') - hypernetwork.save(filename) + filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') + save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename) return hypernetwork, filename + +def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename): + old_hypernetwork_name = hypernetwork.name + old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None + old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None + try: + hypernetwork.sd_checkpoint = checkpoint.hash + hypernetwork.sd_checkpoint_name = checkpoint.model_name + hypernetwork.name = hypernetwork_name + hypernetwork.save(filename) + except: + hypernetwork.sd_checkpoint = old_sd_checkpoint + hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name + hypernetwork.name = old_hypernetwork_name + raise diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 44f06443..ee9917ce 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -119,7 +119,7 @@ class EmbeddingDatabase: vec = emb.detach().to(devices.device, dtype=torch.float32) embedding = Embedding(vec, name) embedding.step = data.get('step', None) - embedding.sd_checkpoint = data.get('hash', None) + embedding.sd_checkpoint = data.get('sd_checkpoint', None) embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) self.register_embedding(embedding, shared.sd_model) @@ -259,6 +259,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc hijack = sd_hijack.model_hijack embedding = hijack.embedding_db.word_embeddings[embedding_name] + checkpoint = sd_models.select_checkpoint() ititial_step = embedding.step or 0 if ititial_step > steps: @@ -314,9 +315,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc if embedding_dir is not None and steps_done % save_embedding_every == 0: # Before saving, change name to match current checkpoint. - embedding.name = f'{embedding_name}-{steps_done}' - last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt') - embedding.save(last_saved_file) + embedding_name_every = f'{embedding_name}-{steps_done}' + last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt') + save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True) embedding_yet_to_be_embedded = True write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { @@ -397,14 +398,26 @@ Last saved image: {html.escape(last_saved_image)}

""" - checkpoint = sd_models.select_checkpoint() - - embedding.sd_checkpoint = checkpoint.hash - embedding.sd_checkpoint_name = checkpoint.model_name - embedding.cached_checksum = None - # Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention). - embedding.name = embedding_name - filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding.name}.pt') - embedding.save(filename) + filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') + save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True) return embedding, filename + +def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True): + old_embedding_name = embedding.name + old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None + old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None + old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None + try: + embedding.sd_checkpoint = checkpoint.hash + embedding.sd_checkpoint_name = checkpoint.model_name + if remove_cached_checksum: + embedding.cached_checksum = None + embedding.name = embedding_name + embedding.save(filename) + except: + embedding.sd_checkpoint = old_sd_checkpoint + embedding.sd_checkpoint_name = old_sd_checkpoint_name + embedding.name = old_embedding_name + embedding.cached_checksum = old_cached_checksum + raise -- cgit v1.2.1 From 3d58510f214c645ce5cdb261aa47df6573b239e9 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 30 Oct 2022 00:54:59 +0700 Subject: Fix dataset still being loaded even when training will be skipped --- modules/hypernetworks/hypernetwork.py | 2 +- modules/textual_inversion/textual_inversion.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 86daf825..07acadc9 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -364,7 +364,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log checkpoint = sd_models.select_checkpoint() ititial_step = hypernetwork.step or 0 - if ititial_step > steps: + if ititial_step >= steps: shared.state.textinfo = f"Model has already been trained beyond specified max steps" return hypernetwork, filename diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index ee9917ce..e0babb46 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -262,7 +262,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc checkpoint = sd_models.select_checkpoint() ititial_step = embedding.step or 0 - if ititial_step > steps: + if ititial_step >= steps: shared.state.textinfo = f"Model has already been trained beyond specified max steps" return embedding, filename -- cgit v1.2.1 From 4609b83cd496013a05e77c42af031d89f07785a9 Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sat, 29 Oct 2022 16:09:19 -0300 Subject: Add PNG Info endpoint --- modules/api/api.py | 12 +++++++++--- modules/api/models.py | 9 ++++++++- 2 files changed, 17 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 49c213ea..8fcd068d 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -5,7 +5,7 @@ import modules.shared as shared from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.sd_samplers import all_samplers -from modules.extras import run_extras +from modules.extras import run_extras, run_pnginfo def upscaler_to_index(name: str): try: @@ -32,6 +32,7 @@ class Api: self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) + self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): sampler_index = sampler_to_index(txt2imgreq.sampler_index) @@ -125,8 +126,13 @@ class Api: return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) - def pnginfoapi(self): - raise NotImplementedError + def pnginfoapi(self, req:PNGInfoRequest): + if(not req.image.strip()): + return PNGInfoResponse(info="") + + result = run_pnginfo(decode_base64_to_image(req.image.strip())) + + return PNGInfoResponse(info=result[1]) def launch(self, server_name, port): self.app.include_router(self.router) diff --git a/modules/api/models.py b/modules/api/models.py index dd122321..58e8e58b 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -1,4 +1,5 @@ import inspect +from click import prompt from pydantic import BaseModel, Field, create_model from typing import Any, Optional from typing_extensions import Literal @@ -148,4 +149,10 @@ class ExtrasBatchImagesRequest(ExtrasBaseRequest): imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings") class ExtrasBatchImagesResponse(ExtraBaseResponse): - images: list[str] = Field(title="Images", description="The generated images in base64 format.") \ No newline at end of file + images: list[str] = Field(title="Images", description="The generated images in base64 format.") + +class PNGInfoRequest(BaseModel): + image: str = Field(title="Image", description="The base64 encoded PNG image") + +class PNGInfoResponse(BaseModel): + info: str = Field(title="Image info", description="A string with all the info the image had") \ No newline at end of file -- cgit v1.2.1 From 83a1f44ae26cb89492064bb8be0321b14a75efe4 Mon Sep 17 00:00:00 2001 From: Bruno Seoane Date: Sat, 29 Oct 2022 16:10:00 -0300 Subject: Fix space --- modules/api/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 8fcd068d..d0f488ca 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -126,7 +126,7 @@ class Api: return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) - def pnginfoapi(self, req:PNGInfoRequest): + def pnginfoapi(self, req: PNGInfoRequest): if(not req.image.strip()): return PNGInfoResponse(info="") -- cgit v1.2.1 From 9bb6b6509aff8c1e6546d5a798ef9e9922758dc4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 29 Oct 2022 22:20:02 +0300 Subject: add postprocess call for scripts --- modules/processing.py | 12 +++++++++--- modules/scripts.py | 24 +++++++++++++++++++++--- 2 files changed, 30 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 548eec29..50343846 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -478,7 +478,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: model_hijack.embedding_db.load_textual_inversion_embeddings() if p.scripts is not None: - p.scripts.run_alwayson_scripts(p) + p.scripts.process(p) infotexts = [] output_images = [] @@ -501,7 +501,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] - if (len(prompts) == 0): + if len(prompts) == 0: break with devices.autocast(): @@ -590,7 +590,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) devices.torch_gc() - return Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts) + + res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts) + + if p.scripts is not None: + p.scripts.postprocess(p, res) + + return res class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): diff --git a/modules/scripts.py b/modules/scripts.py index a7f36012..96e44bfd 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -64,7 +64,16 @@ class Script: def process(self, p, *args): """ This function is called before processing begins for AlwaysVisible scripts. - scripts. You can modify the processing object (p) here, inject hooks, etc. + You can modify the processing object (p) here, inject hooks, etc. + args contains all values returned by components from ui() + """ + + pass + + def postprocess(self, p, processed, *args): + """ + This function is called after processing ends for AlwaysVisible scripts. + args contains all values returned by components from ui() """ pass @@ -289,13 +298,22 @@ class ScriptRunner: return processed - def run_alwayson_scripts(self, p): + def process(self, p): for script in self.alwayson_scripts: try: script_args = p.script_args[script.args_from:script.args_to] script.process(p, *script_args) except Exception: - print(f"Error running alwayson script: {script.filename}", file=sys.stderr) + print(f"Error running process: {script.filename}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + def postprocess(self, p, processed): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.postprocess(p, processed, *script_args) + except Exception: + print(f"Error running postprocess: {script.filename}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) def reload_sources(self, cache): -- cgit v1.2.1 From f62db4d5c753bc32d2ae166606ce41f4c5fa5c43 Mon Sep 17 00:00:00 2001 From: evshiron Date: Sun, 30 Oct 2022 03:55:43 +0800 Subject: fix progress response model --- modules/api/api.py | 30 ------------------------------ modules/api/models.py | 8 ++++---- 2 files changed, 4 insertions(+), 34 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index e93cddcb..7e8522a2 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,33 +1,3 @@ -# import time - -# from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI -# from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images -# from modules.sd_samplers import all_samplers -# from modules.extras import run_pnginfo -# import modules.shared as shared -# from modules import devices -# import uvicorn -# from fastapi import Body, APIRouter, HTTPException -# from fastapi.responses import JSONResponse -# from pydantic import BaseModel, Field, Json -# from typing import List -# import json -# import io -# import base64 -# from PIL import Image - -# sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) - -# class TextToImageResponse(BaseModel): -# images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.") -# parameters: Json -# info: Json - -# class ImageToImageResponse(BaseModel): -# images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.") -# parameters: Json -# info: Json - import time import uvicorn from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image diff --git a/modules/api/models.py b/modules/api/models.py index 8d4abc39..e1762fb9 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -1,6 +1,6 @@ import inspect from click import prompt -from pydantic import BaseModel, Field, create_model +from pydantic import BaseModel, Field, Json, create_model from typing import Any, Optional from typing_extensions import Literal from inflection import underscore @@ -158,6 +158,6 @@ class PNGInfoResponse(BaseModel): info: str = Field(title="Image info", description="A string with all the info the image had") class ProgressResponse(BaseModel): - progress: float - eta_relative: float - state: dict + progress: float = Field(title="Progress", description="The progress with a range of 0 to 1") + eta_relative: float = Field(title="ETA in secs") + state: Json -- cgit v1.2.1 From e9c6c2a51f972fd7cd88ea740ade4ac3d8108b67 Mon Sep 17 00:00:00 2001 From: evshiron Date: Sun, 30 Oct 2022 04:02:56 +0800 Subject: add description for state field --- modules/api/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/api/models.py b/modules/api/models.py index e1762fb9..709ab5a6 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -160,4 +160,4 @@ class PNGInfoResponse(BaseModel): class ProgressResponse(BaseModel): progress: float = Field(title="Progress", description="The progress with a range of 0 to 1") eta_relative: float = Field(title="ETA in secs") - state: Json + state: Json = Field(title="State", description="The current state snapshot") -- cgit v1.2.1 From 88f46a5bec610cf03641f18becbe3deda541e982 Mon Sep 17 00:00:00 2001 From: evshiron Date: Sun, 30 Oct 2022 05:04:29 +0800 Subject: update progress response model --- modules/api/api.py | 6 +++--- modules/api/models.py | 4 ++-- modules/shared.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 7e8522a2..5912d289 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -61,7 +61,7 @@ class Api: self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse) - self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"]) + self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): sampler_index = sampler_to_index(txt2imgreq.sampler_index) @@ -171,7 +171,7 @@ class Api: # copy from check_progress_call of ui.py if shared.state.job_count == 0: - return ProgressResponse(progress=0, eta_relative=0, state=shared.state.js()) + return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict()) # avoid dividing zero progress = 0.01 @@ -187,7 +187,7 @@ class Api: progress = min(progress, 1) - return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.js()) + return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict()) def launch(self, server_name, port): self.app.include_router(self.router) diff --git a/modules/api/models.py b/modules/api/models.py index 709ab5a6..0ab85ec5 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -1,6 +1,6 @@ import inspect from click import prompt -from pydantic import BaseModel, Field, Json, create_model +from pydantic import BaseModel, Field, create_model from typing import Any, Optional from typing_extensions import Literal from inflection import underscore @@ -160,4 +160,4 @@ class PNGInfoResponse(BaseModel): class ProgressResponse(BaseModel): progress: float = Field(title="Progress", description="The progress with a range of 0 to 1") eta_relative: float = Field(title="ETA in secs") - state: Json = Field(title="State", description="The current state snapshot") + state: dict = Field(title="State", description="The current state snapshot") diff --git a/modules/shared.py b/modules/shared.py index 0f4c035d..f7b0990c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -147,7 +147,7 @@ class State: def get_job_timestamp(self): return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp? - def js(self): + def dict(self): obj = { "skipped": self.skipped, "interrupted": self.skipped, @@ -158,7 +158,7 @@ class State: "sampling_steps": self.sampling_steps, } - return json.dumps(obj) + return obj state = State() -- cgit v1.2.1 From 9f104b53c425e248595e5b6481336d2a339e015e Mon Sep 17 00:00:00 2001 From: evshiron Date: Sun, 30 Oct 2022 05:19:17 +0800 Subject: preview current image when opts.show_progress_every_n_steps is enabled --- modules/api/api.py | 8 ++++++-- modules/api/models.py | 1 + 2 files changed, 7 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 5912d289..e960bb7b 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,7 +1,7 @@ import time import uvicorn from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, Depends, HTTPException import modules.shared as shared from modules import devices from modules.api.models import * @@ -187,7 +187,11 @@ class Api: progress = min(progress, 1) - return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict()) + current_image = None + if shared.state.current_image: + current_image = encode_pil_to_base64(shared.state.current_image) + + return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image) def launch(self, server_name, port): self.app.include_router(self.router) diff --git a/modules/api/models.py b/modules/api/models.py index 0ab85ec5..c8bc719a 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -161,3 +161,4 @@ class ProgressResponse(BaseModel): progress: float = Field(title="Progress", description="The progress with a range of 0 to 1") eta_relative: float = Field(title="ETA in secs") state: dict = Field(title="State", description="The current state snapshot") + current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.") -- cgit v1.2.1 From 66d038f6a41507af2243ff1f6618a745a092c290 Mon Sep 17 00:00:00 2001 From: timntorres Date: Sat, 29 Oct 2022 15:00:08 -0700 Subject: Read hypernet strength from PNG info. --- modules/generation_parameters_copypaste.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index bbaad42e..59c6d7da 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -66,6 +66,7 @@ def integrate_settings_paste_fields(component_dict): settings_map = { 'sd_hypernetwork': 'Hypernet', + 'sd_hypernetwork_strength': 'Hypernetwork strength', 'CLIP_stop_at_last_layers': 'Clip skip', 'sd_model_checkpoint': 'Model hash', } -- cgit v1.2.1 From 9f4f894d74b57c3d02ebccaa59f9c22fca2b6c90 Mon Sep 17 00:00:00 2001 From: evshiron Date: Sun, 30 Oct 2022 06:03:32 +0800 Subject: allow skip current image in progress api --- modules/api/api.py | 4 ++-- modules/api/models.py | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index e960bb7b..5c5b210f 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -167,7 +167,7 @@ class Api: return PNGInfoResponse(info=result[1]) - def progressapi(self): + def progressapi(self, req: ProgressRequest = Depends()): # copy from check_progress_call of ui.py if shared.state.job_count == 0: @@ -188,7 +188,7 @@ class Api: progress = min(progress, 1) current_image = None - if shared.state.current_image: + if shared.state.current_image and not req.skip_current_image: current_image = encode_pil_to_base64(shared.state.current_image) return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image) diff --git a/modules/api/models.py b/modules/api/models.py index c8bc719a..9ee42a17 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -157,6 +157,9 @@ class PNGInfoRequest(BaseModel): class PNGInfoResponse(BaseModel): info: str = Field(title="Image info", description="A string with all the info the image had") +class ProgressRequest(BaseModel): + skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization") + class ProgressResponse(BaseModel): progress: float = Field(title="Progress", description="The progress with a range of 0 to 1") eta_relative: float = Field(title="ETA in secs") -- cgit v1.2.1 From 05a657dd357eaca6940c4775daa946bd33f1167d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 30 Oct 2022 07:36:56 +0300 Subject: fix broken hires fix --- modules/processing.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 50343846..947ce6fa 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -686,15 +686,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + image_conditioning = self.txt2img_image_conditioning(x) + # GC now before running the next img2img to prevent running out of memory x = None devices.torch_gc() - image_conditioning = self.img2img_image_conditioning( - decoded_samples, - samples, - decoded_samples.new_ones(decoded_samples.shape[0], 1, decoded_samples.shape[2], decoded_samples.shape[3]) - ) samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning) return samples -- cgit v1.2.1 From 61836bd544fc8f4ef62f311c9d5964fbdaeb3f4c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 30 Oct 2022 08:48:53 +0300 Subject: shorten Hypernetwork strength in infotext and omit it when it's the default value. --- modules/generation_parameters_copypaste.py | 2 +- modules/processing.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 59c6d7da..df70c728 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -66,7 +66,7 @@ def integrate_settings_paste_fields(component_dict): settings_map = { 'sd_hypernetwork': 'Hypernet', - 'sd_hypernetwork_strength': 'Hypernetwork strength', + 'sd_hypernetwork_strength': 'Hypernet strength', 'CLIP_stop_at_last_layers': 'Clip skip', 'sd_model_checkpoint': 'Model hash', } diff --git a/modules/processing.py b/modules/processing.py index ecaa78e2..b1df4918 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -396,7 +396,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name), - "Hypernetwork strength": (None if shared.loaded_hypernetwork is None else shared.opts.sd_hypernetwork_strength), + "Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), -- cgit v1.2.1 From 149784202cca8612b43629c601ee27cfda64e623 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 30 Oct 2022 09:10:22 +0300 Subject: rework #3722 to not introduce duplicate code --- modules/api/api.py | 43 +++++++++++++------------------------------ modules/shared.py | 22 +++++++++++++++++++--- 2 files changed, 32 insertions(+), 33 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 5c5b210f..6c06d449 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -9,31 +9,6 @@ from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusion from modules.sd_samplers import all_samplers from modules.extras import run_extras, run_pnginfo -# copy from wrap_gradio_gpu_call of webui.py -# because queue lock will be acquired in api handlers -# and time start needs to be set -# the function has been modified into two parts - -def before_gpu_call(): - devices.torch_gc() - - shared.state.sampling_step = 0 - shared.state.job_count = -1 - shared.state.job_no = 0 - shared.state.job_timestamp = shared.state.get_job_timestamp() - shared.state.current_latent = None - shared.state.current_image = None - shared.state.current_image_sampling_step = 0 - shared.state.skipped = False - shared.state.interrupted = False - shared.state.textinfo = None - shared.state.time_start = time.time() - -def after_gpu_call(): - shared.state.job = "" - shared.state.job_count = 0 - - devices.torch_gc() def upscaler_to_index(name: str): try: @@ -41,8 +16,10 @@ def upscaler_to_index(name: str): except: raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}") + sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) + def setUpscalers(req: dict): reqDict = vars(req) reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1) @@ -51,6 +28,7 @@ def setUpscalers(req: dict): reqDict.pop('upscaler_2') return reqDict + class Api: def __init__(self, app, queue_lock): self.router = APIRouter() @@ -78,10 +56,13 @@ class Api: ) p = StableDiffusionProcessingTxt2Img(**vars(populate)) # Override object param - before_gpu_call() + + shared.state.begin() + with self.queue_lock: processed = process_images(p) - after_gpu_call() + + shared.state.end() b64images = list(map(encode_pil_to_base64, processed.images)) @@ -119,11 +100,13 @@ class Api: imgs = [img] * p.batch_size p.init_images = imgs - # Override object param - before_gpu_call() + + shared.state.begin() + with self.queue_lock: processed = process_images(p) - after_gpu_call() + + shared.state.end() b64images = list(map(encode_pil_to_base64, processed.images)) diff --git a/modules/shared.py b/modules/shared.py index f7b0990c..e4f163c1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -144,9 +144,6 @@ class State: self.sampling_step = 0 self.current_image_sampling_step = 0 - def get_job_timestamp(self): - return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp? - def dict(self): obj = { "skipped": self.skipped, @@ -160,6 +157,25 @@ class State: return obj + def begin(self): + self.sampling_step = 0 + self.job_count = -1 + self.job_no = 0 + self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + self.current_latent = None + self.current_image = None + self.current_image_sampling_step = 0 + self.skipped = False + self.interrupted = False + self.textinfo = None + + devices.torch_gc() + + def end(self): + self.job = "" + self.job_count = 0 + + devices.torch_gc() state = State() -- cgit v1.2.1 From 910a097ae2ed78a62101951f1b87137f9e1baaea Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 31 Oct 2022 17:36:45 +0300 Subject: add initial version of the extensions tab fix broken Restart Gradio button --- modules/extensions.py | 83 +++++++++++++++ modules/generation_parameters_copypaste.py | 5 + modules/scripts.py | 21 +--- modules/shared.py | 10 +- modules/ui.py | 16 ++- modules/ui_extensions.py | 162 +++++++++++++++++++++++++++++ 6 files changed, 274 insertions(+), 23 deletions(-) create mode 100644 modules/extensions.py create mode 100644 modules/ui_extensions.py (limited to 'modules') diff --git a/modules/extensions.py b/modules/extensions.py new file mode 100644 index 00000000..8d6ae848 --- /dev/null +++ b/modules/extensions.py @@ -0,0 +1,83 @@ +import os +import sys +import traceback + +import git + +from modules import paths, shared + + +extensions = [] +extensions_dir = os.path.join(paths.script_path, "extensions") + + +def active(): + return [x for x in extensions if x.enabled] + + +class Extension: + def __init__(self, name, path, enabled=True): + self.name = name + self.path = path + self.enabled = enabled + self.status = '' + self.can_update = False + + repo = None + try: + if os.path.exists(os.path.join(path, ".git")): + repo = git.Repo(path) + except Exception: + print(f"Error reading github repository info from {path}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + if repo is None or repo.bare: + self.remote = None + else: + self.remote = next(repo.remote().urls, None) + self.status = 'unknown' + + def list_files(self, subdir, extension): + from modules import scripts + + dirpath = os.path.join(self.path, subdir) + if not os.path.isdir(dirpath): + return [] + + res = [] + for filename in sorted(os.listdir(dirpath)): + res.append(scripts.ScriptFile(dirpath, filename, os.path.join(dirpath, filename))) + + res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] + + return res + + def check_updates(self): + repo = git.Repo(self.path) + for fetch in repo.remote().fetch("--dry-run"): + if fetch.flags != fetch.HEAD_UPTODATE: + self.can_update = True + self.status = "behind" + return + + self.can_update = False + self.status = "latest" + + def pull(self): + repo = git.Repo(self.path) + repo.remotes.origin.pull() + + +def list_extensions(): + extensions.clear() + + if not os.path.isdir(extensions_dir): + return + + for dirname in sorted(os.listdir(extensions_dir)): + path = os.path.join(extensions_dir, dirname) + if not os.path.isdir(path): + continue + + extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions) + extensions.append(extension) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index df70c728..985ec95e 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -17,6 +17,11 @@ paste_fields = {} bind_list = [] +def reset(): + paste_fields.clear() + bind_list.clear() + + def quote(text): if ',' not in str(text): return text diff --git a/modules/scripts.py b/modules/scripts.py index 96e44bfd..533db45c 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -7,7 +7,7 @@ import modules.ui as ui import gradio as gr from modules.processing import StableDiffusionProcessing -from modules import shared, paths, script_callbacks +from modules import shared, paths, script_callbacks, extensions AlwaysVisible = object() @@ -107,17 +107,8 @@ def list_scripts(scriptdirname, extension): for filename in sorted(os.listdir(basedir)): scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename))) - extdir = os.path.join(paths.script_path, "extensions") - if os.path.exists(extdir): - for dirname in sorted(os.listdir(extdir)): - dirpath = os.path.join(extdir, dirname) - scriptdirpath = os.path.join(dirpath, scriptdirname) - - if not os.path.isdir(scriptdirpath): - continue - - for filename in sorted(os.listdir(scriptdirpath)): - scripts_list.append(ScriptFile(dirpath, filename, os.path.join(scriptdirpath, filename))) + for ext in extensions.active(): + scripts_list += ext.list_files(scriptdirname, extension) scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] @@ -127,11 +118,7 @@ def list_scripts(scriptdirname, extension): def list_files_with_name(filename): res = [] - dirs = [paths.script_path] - - extdir = os.path.join(paths.script_path, "extensions") - if os.path.exists(extdir): - dirs += [os.path.join(extdir, d) for d in sorted(os.listdir(extdir))] + dirs = [paths.script_path] + [ext.path for ext in extensions.active()] for dirpath in dirs: if not os.path.isdir(dirpath): diff --git a/modules/shared.py b/modules/shared.py index e4f163c1..cce87081 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -132,6 +132,7 @@ class State: current_image = None current_image_sampling_step = 0 textinfo = None + need_restart = False def skip(self): self.skipped = True @@ -354,6 +355,12 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), })) +options_templates.update(options_section((None, "Hidden options"), { + "disabled_extensions": OptionInfo([], "Disable those extensions"), +})) + +options_templates.update() + class Options: data = None @@ -365,8 +372,9 @@ class Options: def __setattr__(self, key, value): if self.data is not None: - if key in self.data: + if key in self.data or key in self.data_labels: self.data[key] = value + return return super(Options, self).__setattr__(key, value) diff --git a/modules/ui.py b/modules/ui.py index 5055ca64..2c15abb7 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -19,7 +19,7 @@ import numpy as np from PIL import Image, PngImagePlugin -from modules import sd_hijack, sd_models, localization, script_callbacks +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions from modules.paths import script_path from modules.shared import opts, cmd_opts, restricted_opts @@ -671,6 +671,7 @@ def create_ui(wrap_gradio_gpu_call): import modules.img2img import modules.txt2img + parameters_copypaste.reset() with gr.Blocks(analytics_enabled=False) as txt2img_interface: txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) @@ -1511,8 +1512,9 @@ def create_ui(wrap_gradio_gpu_call): column = None with gr.Row(elem_id="settings").style(equal_height=False): for i, (k, item) in enumerate(opts.data_labels.items()): + section_must_be_skipped = item.section[0] is None - if previous_section != item.section: + if previous_section != item.section and not section_must_be_skipped: if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None): if column is not None: column.__exit__() @@ -1531,6 +1533,8 @@ def create_ui(wrap_gradio_gpu_call): if k in quicksettings_names and not shared.cmd_opts.freeze_settings: quicksettings_list.append((i, k, item)) components.append(dummy_component) + elif section_must_be_skipped: + components.append(dummy_component) else: component = create_setting_component(k) component_dict[k] = component @@ -1572,9 +1576,10 @@ def create_ui(wrap_gradio_gpu_call): def request_restart(): shared.state.interrupt() - settings_interface.gradio_ref.do_restart = True + shared.state.need_restart = True restart_gradio.click( + fn=request_restart, inputs=[], outputs=[], @@ -1612,14 +1617,15 @@ def create_ui(wrap_gradio_gpu_call): interfaces += script_callbacks.ui_tabs_callback() interfaces += [(settings_interface, "Settings", "settings")] + extensions_interface = ui_extensions.create_ui() + interfaces += [(extensions_interface, "Extensions", "extensions")] + with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: with gr.Row(elem_id="quicksettings"): for i, k, item in quicksettings_list: component = create_setting_component(k, is_quicksettings=True) component_dict[k] = component - settings_interface.gradio_ref = demo - parameters_copypaste.integrate_settings_paste_fields(component_dict) parameters_copypaste.run_bind() diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py new file mode 100644 index 00000000..b7d747dc --- /dev/null +++ b/modules/ui_extensions.py @@ -0,0 +1,162 @@ +import json +import os.path +import shutil +import sys +import time +import traceback + +import git + +import gradio as gr +import html + +from modules import extensions, shared, paths + + +def apply_and_restart(disable_list, update_list): + disabled = json.loads(disable_list) + assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}" + + update = json.loads(update_list) + assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}" + + update = set(update) + + for ext in extensions.extensions: + if ext.name not in update: + continue + + try: + ext.pull() + except Exception: + print(f"Error pulling updates for {ext.name}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + shared.opts.disabled_extensions = disabled + shared.opts.save(shared.config_filename) + + shared.state.interrupt() + shared.state.need_restart = True + + +def check_updates(): + for ext in extensions.extensions: + if ext.remote is None: + continue + + try: + ext.check_updates() + except Exception: + print(f"Error checking updates for {ext.name}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + return extension_table() + + +def extension_table(): + code = f""" + + + + + + + + + + """ + + for ext in extensions.extensions: + if ext.can_update: + ext_status = f"""""" + else: + ext_status = ext.status + + code += f""" + + + + {ext_status} + + """ + + code += """ + +
ExtensionURLUpdate
{html.escape(ext.remote or '')}
+ """ + + return code + + +def install_extension_from_url(dirname, url): + assert url, 'No URL specified' + + if dirname is None or dirname == "": + *parts, last_part = url.split('/') + last_part = last_part.replace(".git", "") + + dirname = last_part + + target_dir = os.path.join(extensions.extensions_dir, dirname) + assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}' + + assert len([x for x in extensions.extensions if x.remote == url]) == 0, 'Extension with this URL is already installed' + + tmpdir = os.path.join(paths.script_path, "tmp", dirname) + + try: + shutil.rmtree(tmpdir, True) + + repo = git.Repo.clone_from(url, tmpdir) + repo.remote().fetch() + + os.rename(tmpdir, target_dir) + + extensions.list_extensions() + return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")] + finally: + shutil.rmtree(tmpdir, True) + + +def create_ui(): + import modules.ui + + with gr.Blocks(analytics_enabled=False) as ui: + with gr.Tabs(elem_id="tabs_extensions") as tabs: + with gr.TabItem("Installed"): + extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False) + extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False) + + with gr.Row(): + apply = gr.Button(value="Apply and restart UI", variant="primary") + check = gr.Button(value="Check for updates") + + extensions_table = gr.HTML(lambda: extension_table()) + + apply.click( + fn=apply_and_restart, + _js="extensions_apply", + inputs=[extensions_disabled_list, extensions_update_list], + outputs=[], + ) + + check.click( + fn=check_updates, + _js="extensions_check", + inputs=[], + outputs=[extensions_table], + ) + + with gr.TabItem("Install from URL"): + install_url = gr.Text(label="URL for extension's git repository") + install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto") + intall_button = gr.Button(value="Install", variant="primary") + intall_result = gr.HTML(elem_id="extension_install_result") + + intall_button.click( + fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]), + inputs=[install_dirname, install_url], + outputs=[extensions_table, intall_result], + ) + + return ui -- cgit v1.2.1 From dc7425a56e7a014cbfa3b3d44ad2321e519fe378 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 31 Oct 2022 18:33:44 +0300 Subject: disable access to extension stuff for non-local servers --- modules/shared.py | 5 ++++- modules/ui_extensions.py | 10 ++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index cce87081..a27c654e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -40,7 +40,7 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram") parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") -parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)") +parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site") parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None) parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us") parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer')) @@ -97,6 +97,9 @@ restricted_opts = { "outdir_save", } +if cmd_opts.share or cmd_opts.listen: + cmd_opts.disable_extension_access = True + devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer']) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index b7d747dc..e74b7d68 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -13,7 +13,13 @@ import html from modules import extensions, shared, paths +def check_access(): + assert not shared.cmd_opts.disable_extension_access, "extension access disabed because of commandline flags" + + def apply_and_restart(disable_list, update_list): + check_access() + disabled = json.loads(disable_list) assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}" @@ -40,6 +46,8 @@ def apply_and_restart(disable_list, update_list): def check_updates(): + check_access() + for ext in extensions.extensions: if ext.remote is None: continue @@ -89,6 +97,8 @@ def extension_table(): def install_extension_from_url(dirname, url): + check_access() + assert url, 'No URL specified' if dirname is None or dirname == "": -- cgit v1.2.1