aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.github/ISSUE_TEMPLATE/feature_request.yml2
-rw-r--r--modules/api/api.py48
-rw-r--r--modules/api/models.py4
-rw-r--r--modules/processing.py4
-rw-r--r--modules/sd_hijack_clip.py4
-rw-r--r--scripts/sd_upscale.py2
-rw-r--r--test/basic_features/img2img_test.py6
7 files changed, 57 insertions, 13 deletions
diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml
index 8ca6e21f..35a88740 100644
--- a/.github/ISSUE_TEMPLATE/feature_request.yml
+++ b/.github/ISSUE_TEMPLATE/feature_request.yml
@@ -1,7 +1,7 @@
name: Feature request
description: Suggest an idea for this project
title: "[Feature Request]: "
-labels: ["suggestion"]
+labels: ["enhancement"]
body:
- type: checkboxes
diff --git a/modules/api/api.py b/modules/api/api.py
index 2103709b..5b6125f8 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -11,7 +11,7 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials
from secrets import compare_digest
import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.extras import run_extras
@@ -28,8 +28,13 @@ def upscaler_to_index(name: str):
try:
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
except:
- raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
+ raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}")
+def script_name_to_index(name, scripts):
+ try:
+ return [script.title().lower() for script in scripts].index(name.lower())
+ except:
+ raise HTTPException(status_code=422, detail=f"Script '{name}' not found")
def validate_sampler_name(name):
config = sd_samplers.all_samplers_map.get(name, None)
@@ -143,7 +148,21 @@ class Api:
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
+ def get_script(self, script_name, script_runner):
+ if script_name is None:
+ return None, None
+
+ if not script_runner.scripts:
+ script_runner.initialize_scripts(False)
+ ui.create_ui()
+
+ script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
+ script = script_runner.selectable_scripts[script_idx]
+ return script, script_idx
+
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
+ script, script_idx = self.get_script(txt2imgreq.script_name, scripts.scripts_txt2img)
+
populate = txt2imgreq.copy(update={ # Override __init__ params
"sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
"do_not_save_samples": True,
@@ -153,14 +172,22 @@ class Api:
if populate.sampler_name:
populate.sampler_index = None # prevent a warning later on
+ args = vars(populate)
+ args.pop('script_name', None)
+
with self.queue_lock:
- p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **vars(populate))
+ p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
shared.state.begin()
- processed = process_images(p)
+ if script is not None:
+ p.outpath_grids = opts.outdir_txt2img_grids
+ p.outpath_samples = opts.outdir_txt2img_samples
+ p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
+ processed = scripts.scripts_txt2img.run(p, *p.script_args)
+ else:
+ processed = process_images(p)
shared.state.end()
-
b64images = list(map(encode_pil_to_base64, processed.images))
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
@@ -170,6 +197,8 @@ class Api:
if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found")
+ script, script_idx = self.get_script(img2imgreq.script_name, scripts.scripts_img2img)
+
mask = img2imgreq.mask
if mask:
mask = decode_base64_to_image(mask)
@@ -186,13 +215,20 @@ class Api:
args = vars(populate)
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
+ args.pop('script_name', None)
with self.queue_lock:
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
p.init_images = [decode_base64_to_image(x) for x in init_images]
shared.state.begin()
- processed = process_images(p)
+ if script is not None:
+ p.outpath_grids = opts.outdir_img2img_grids
+ p.outpath_samples = opts.outdir_img2img_samples
+ p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
+ processed = scripts.scripts_img2img.run(p, *p.script_args)
+ else:
+ processed = process_images(p)
shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images))
diff --git a/modules/api/models.py b/modules/api/models.py
index 5fa63774..ce43c858 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -100,13 +100,13 @@ class PydanticModelGenerator:
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingTxt2Img",
StableDiffusionProcessingTxt2Img,
- [{"key": "sampler_index", "type": str, "default": "Euler"}]
+ [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
).generate_model()
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img,
- [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
+ [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
).generate_model()
class TextToImageResponse(BaseModel):
diff --git a/modules/processing.py b/modules/processing.py
index 82157bc9..1d23b15f 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -98,7 +98,7 @@ class StableDiffusionProcessing():
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
"""
- def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None):
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
if sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
@@ -149,7 +149,7 @@ class StableDiffusionProcessing():
self.seed_resize_from_w = 0
self.scripts = None
- self.script_args = None
+ self.script_args = script_args
self.all_prompts = None
self.all_negative_prompts = None
self.all_seeds = None
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index 5520c9b2..852afc66 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -247,9 +247,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
original_mean = z.mean()
- z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
+ z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
- z *= original_mean / new_mean
+ z = z * (original_mean / new_mean)
return z
diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py
index 9b8ffd85..332d76d9 100644
--- a/scripts/sd_upscale.py
+++ b/scripts/sd_upscale.py
@@ -25,6 +25,8 @@ class Script(scripts.Script):
return [info, overlap, upscaler_index, scale_factor]
def run(self, p, _, overlap, upscaler_index, scale_factor):
+ if isinstance(upscaler_index, str):
+ upscaler_index = [x.name.lower() for x in shared.sd_upscalers].index(upscaler_index.lower())
processing.fix_seed(p)
upscaler = shared.sd_upscalers[upscaler_index]
diff --git a/test/basic_features/img2img_test.py b/test/basic_features/img2img_test.py
index 0a9c1e8a..bd520b13 100644
--- a/test/basic_features/img2img_test.py
+++ b/test/basic_features/img2img_test.py
@@ -50,6 +50,12 @@ class TestImg2ImgWorking(unittest.TestCase):
self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png"))
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
+ def test_img2img_sd_upscale_performed(self):
+ self.simple_img2img["script_name"] = "sd upscale"
+ self.simple_img2img["script_args"] = ["", 8, "Lanczos", 2.0]
+
+ self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
+
if __name__ == "__main__":
unittest.main()