aboutsummaryrefslogtreecommitdiff
path: root/modules/api
diff options
context:
space:
mode:
Diffstat (limited to 'modules/api')
-rw-r--r--modules/api/api.py22
1 files changed, 12 insertions, 10 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index efcedbba..89935a70 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -112,11 +112,13 @@ class Api:
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
populate = txt2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
- "sampler_name": validate_sampler_name(txt2imgreq.sampler_index),
+ "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
"do_not_save_samples": True,
"do_not_save_grid": True
}
)
+ if populate.sampler_name:
+ populate.sampler_index = None # prevent a warning later on
p = StableDiffusionProcessingTxt2Img(**vars(populate))
# Override object param
@@ -142,20 +144,20 @@ class Api:
populate = img2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
- "sampler_name": validate_sampler_name(img2imgreq.sampler_index),
+ "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
"do_not_save_samples": True,
"do_not_save_grid": True,
"mask": mask
}
)
- p = StableDiffusionProcessingImg2Img(**vars(populate))
+ if populate.sampler_name:
+ populate.sampler_index = None # prevent a warning later on
- imgs = []
- for img in init_images:
- img = decode_base64_to_image(img)
- imgs = [img] * p.batch_size
+ args = vars(populate)
+ args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
+ p = StableDiffusionProcessingImg2Img(**args)
- p.init_images = imgs
+ p.init_images = [decode_base64_to_image(x) for x in init_images]
shared.state.begin()
@@ -166,7 +168,7 @@ class Api:
b64images = list(map(encode_pil_to_base64, processed.images))
- if (not img2imgreq.include_init_images):
+ if not img2imgreq.include_init_images:
img2imgreq.init_images = None
img2imgreq.mask = None
@@ -310,7 +312,7 @@ class Api:
styleList = []
for k in shared.prompt_styles.styles:
style = shared.prompt_styles.styles[k]
- styleList.append({"name":style[0], "prompt": style[1], "negative_prompr": style[2]})
+ styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]})
return styleList