aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py88
-rw-r--r--modules/api/models.py38
-rw-r--r--modules/generation_parameters_copypaste.py1
-rw-r--r--modules/hypernetworks/hypernetwork.py7
-rw-r--r--modules/processing.py20
-rw-r--r--modules/scripts.py24
-rw-r--r--modules/shared.py33
-rw-r--r--modules/textual_inversion/learn_schedule.py35
8 files changed, 186 insertions, 60 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 49c213ea..6c06d449 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,11 +1,14 @@
+import time
import uvicorn
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
-from fastapi import APIRouter, HTTPException
+from fastapi import APIRouter, Depends, HTTPException
import modules.shared as shared
+from modules import devices
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
-from modules.extras import run_extras
+from modules.extras import run_extras, run_pnginfo
+
def upscaler_to_index(name: str):
try:
@@ -13,8 +16,10 @@ def upscaler_to_index(name: str):
except:
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
+
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
+
def setUpscalers(req: dict):
reqDict = vars(req)
reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
@@ -23,6 +28,7 @@ def setUpscalers(req: dict):
reqDict.pop('upscaler_2')
return reqDict
+
class Api:
def __init__(self, app, queue_lock):
self.router = APIRouter()
@@ -32,15 +38,17 @@ class Api:
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
+ self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
+ self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
-
+
if sampler_index is None:
- raise HTTPException(status_code=404, detail="Sampler not found")
-
+ raise HTTPException(status_code=404, detail="Sampler not found")
+
populate = txt2imgreq.copy(update={ # Override __init__ params
- "sd_model": shared.sd_model,
+ "sd_model": shared.sd_model,
"sampler_index": sampler_index[0],
"do_not_save_samples": True,
"do_not_save_grid": True
@@ -48,34 +56,39 @@ class Api:
)
p = StableDiffusionProcessingTxt2Img(**vars(populate))
# Override object param
+
+ shared.state.begin()
+
with self.queue_lock:
processed = process_images(p)
-
+
+ shared.state.end()
+
b64images = list(map(encode_pil_to_base64, processed.images))
-
+
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
sampler_index = sampler_to_index(img2imgreq.sampler_index)
-
+
if sampler_index is None:
- raise HTTPException(status_code=404, detail="Sampler not found")
+ raise HTTPException(status_code=404, detail="Sampler not found")
init_images = img2imgreq.init_images
if init_images is None:
- raise HTTPException(status_code=404, detail="Init image not found")
+ raise HTTPException(status_code=404, detail="Init image not found")
mask = img2imgreq.mask
if mask:
mask = decode_base64_to_image(mask)
-
+
populate = img2imgreq.copy(update={ # Override __init__ params
- "sd_model": shared.sd_model,
+ "sd_model": shared.sd_model,
"sampler_index": sampler_index[0],
"do_not_save_samples": True,
- "do_not_save_grid": True,
+ "do_not_save_grid": True,
"mask": mask
}
)
@@ -87,16 +100,20 @@ class Api:
imgs = [img] * p.batch_size
p.init_images = imgs
- # Override object param
+
+ shared.state.begin()
+
with self.queue_lock:
processed = process_images(p)
-
+
+ shared.state.end()
+
b64images = list(map(encode_pil_to_base64, processed.images))
if (not img2imgreq.include_init_images):
img2imgreq.init_images = None
img2imgreq.mask = None
-
+
return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
def extras_single_image_api(self, req: ExtrasSingleImageRequest):
@@ -124,9 +141,40 @@ class Api:
result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", **reqDict)
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
-
- def pnginfoapi(self):
- raise NotImplementedError
+
+ def pnginfoapi(self, req: PNGInfoRequest):
+ if(not req.image.strip()):
+ return PNGInfoResponse(info="")
+
+ result = run_pnginfo(decode_base64_to_image(req.image.strip()))
+
+ return PNGInfoResponse(info=result[1])
+
+ def progressapi(self, req: ProgressRequest = Depends()):
+ # copy from check_progress_call of ui.py
+
+ if shared.state.job_count == 0:
+ return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict())
+
+ # avoid dividing zero
+ progress = 0.01
+
+ if shared.state.job_count > 0:
+ progress += shared.state.job_no / shared.state.job_count
+ if shared.state.sampling_steps > 0:
+ progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
+
+ time_since_start = time.time() - shared.state.time_start
+ eta = (time_since_start/progress)
+ eta_relative = eta-time_since_start
+
+ progress = min(progress, 1)
+
+ current_image = None
+ if shared.state.current_image and not req.skip_current_image:
+ current_image = encode_pil_to_base64(shared.state.current_image)
+
+ return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
def launch(self, server_name, port):
self.app.include_router(self.router)
diff --git a/modules/api/models.py b/modules/api/models.py
index dd122321..9ee42a17 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -1,4 +1,5 @@
import inspect
+from click import prompt
from pydantic import BaseModel, Field, create_model
from typing import Any, Optional
from typing_extensions import Literal
@@ -51,17 +52,17 @@ class PydanticModelGenerator:
# field_type = str if not overrides.get(k) else overrides[k]["type"]
# print(k, v.annotation, v.default)
field_type = v.annotation
-
+
return Optional[field_type]
-
+
def merge_class_params(class_):
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
parameters = {}
for classes in all_classes:
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
return parameters
-
-
+
+
self._model_name = model_name
self._class_data = merge_class_params(class_instance)
self._model_def = [
@@ -73,11 +74,11 @@ class PydanticModelGenerator:
)
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
]
-
+
for fields in additional_fields:
self._model_def.append(ModelDef(
- field=underscore(fields["key"]),
- field_alias=fields["key"],
+ field=underscore(fields["key"]),
+ field_alias=fields["key"],
field_type=fields["type"],
field_value=fields["default"],
field_exclude=fields["exclude"] if "exclude" in fields else False))
@@ -94,15 +95,15 @@ class PydanticModelGenerator:
DynamicModel.__config__.allow_population_by_field_name = True
DynamicModel.__config__.allow_mutation = True
return DynamicModel
-
+
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
- "StableDiffusionProcessingTxt2Img",
+ "StableDiffusionProcessingTxt2Img",
StableDiffusionProcessingTxt2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}]
).generate_model()
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
- "StableDiffusionProcessingImg2Img",
+ "StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
).generate_model()
@@ -148,4 +149,19 @@ class ExtrasBatchImagesRequest(ExtrasBaseRequest):
imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
class ExtrasBatchImagesResponse(ExtraBaseResponse):
- images: list[str] = Field(title="Images", description="The generated images in base64 format.") \ No newline at end of file
+ images: list[str] = Field(title="Images", description="The generated images in base64 format.")
+
+class PNGInfoRequest(BaseModel):
+ image: str = Field(title="Image", description="The base64 encoded PNG image")
+
+class PNGInfoResponse(BaseModel):
+ info: str = Field(title="Image info", description="A string with all the info the image had")
+
+class ProgressRequest(BaseModel):
+ skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
+
+class ProgressResponse(BaseModel):
+ progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
+ eta_relative: float = Field(title="ETA in secs")
+ state: dict = Field(title="State", description="The current state snapshot")
+ current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index bbaad42e..df70c728 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -66,6 +66,7 @@ def integrate_settings_paste_fields(component_dict):
settings_map = {
'sd_hypernetwork': 'Hypernet',
+ 'sd_hypernetwork_strength': 'Hypernet strength',
'CLIP_stop_at_last_layers': 'Clip skip',
'sd_model_checkpoint': 'Model hash',
}
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 07acadc9..a11e01d6 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -209,13 +209,16 @@ def list_hypernetworks(path):
res = {}
for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
name = os.path.splitext(os.path.basename(filename))[0]
- res[name] = filename
+ # Prevent a hypothetical "None.pt" from being listed.
+ if name != "None":
+ res[name] = filename
return res
def load_hypernetwork(filename):
path = shared.hypernetworks.get(filename, None)
- if path is not None:
+ # Prevent any file named "None.pt" from being loaded.
+ if path is not None and filename != "None":
print(f"Loading hypernetwork {filename}")
try:
shared.loaded_hypernetwork = Hypernetwork()
diff --git a/modules/processing.py b/modules/processing.py
index 548eec29..b1df4918 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -396,6 +396,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
+ "Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
@@ -478,7 +479,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
model_hijack.embedding_db.load_textual_inversion_embeddings()
if p.scripts is not None:
- p.scripts.run_alwayson_scripts(p)
+ p.scripts.process(p)
infotexts = []
output_images = []
@@ -501,7 +502,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
- if (len(prompts) == 0):
+ if len(prompts) == 0:
break
with devices.autocast():
@@ -590,7 +591,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
devices.torch_gc()
- return Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
+
+ res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
+
+ if p.scripts is not None:
+ p.scripts.postprocess(p, res)
+
+ return res
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
@@ -680,15 +687,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ image_conditioning = self.txt2img_image_conditioning(x)
+
# GC now before running the next img2img to prevent running out of memory
x = None
devices.torch_gc()
- image_conditioning = self.img2img_image_conditioning(
- decoded_samples,
- samples,
- decoded_samples.new_ones(decoded_samples.shape[0], 1, decoded_samples.shape[2], decoded_samples.shape[3])
- )
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning)
return samples
diff --git a/modules/scripts.py b/modules/scripts.py
index a7f36012..96e44bfd 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -64,7 +64,16 @@ class Script:
def process(self, p, *args):
"""
This function is called before processing begins for AlwaysVisible scripts.
- scripts. You can modify the processing object (p) here, inject hooks, etc.
+ You can modify the processing object (p) here, inject hooks, etc.
+ args contains all values returned by components from ui()
+ """
+
+ pass
+
+ def postprocess(self, p, processed, *args):
+ """
+ This function is called after processing ends for AlwaysVisible scripts.
+ args contains all values returned by components from ui()
"""
pass
@@ -289,13 +298,22 @@ class ScriptRunner:
return processed
- def run_alwayson_scripts(self, p):
+ def process(self, p):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.process(p, *script_args)
except Exception:
- print(f"Error running alwayson script: {script.filename}", file=sys.stderr)
+ print(f"Error running process: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ def postprocess(self, p, processed):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.postprocess(p, processed, *script_args)
+ except Exception:
+ print(f"Error running postprocess: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def reload_sources(self, cache):
diff --git a/modules/shared.py b/modules/shared.py
index fb84afd8..e4f163c1 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -144,9 +144,38 @@ class State:
self.sampling_step = 0
self.current_image_sampling_step = 0
- def get_job_timestamp(self):
- return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
+ def dict(self):
+ obj = {
+ "skipped": self.skipped,
+ "interrupted": self.skipped,
+ "job": self.job,
+ "job_count": self.job_count,
+ "job_no": self.job_no,
+ "sampling_step": self.sampling_step,
+ "sampling_steps": self.sampling_steps,
+ }
+
+ return obj
+
+ def begin(self):
+ self.sampling_step = 0
+ self.job_count = -1
+ self.job_no = 0
+ self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
+ self.current_latent = None
+ self.current_image = None
+ self.current_image_sampling_step = 0
+ self.skipped = False
+ self.interrupted = False
+ self.textinfo = None
+
+ devices.torch_gc()
+
+ def end(self):
+ self.job = ""
+ self.job_count = 0
+ devices.torch_gc()
state = State()
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
index 3a736065..dd0c0ad1 100644
--- a/modules/textual_inversion/learn_schedule.py
+++ b/modules/textual_inversion/learn_schedule.py
@@ -4,30 +4,37 @@ import tqdm
class LearnScheduleIterator:
def __init__(self, learn_rate, max_steps, cur_step=0):
"""
- specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000
+ specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000
"""
pairs = learn_rate.split(',')
self.rates = []
self.it = 0
self.maxit = 0
- for i, pair in enumerate(pairs):
- tmp = pair.split(':')
- if len(tmp) == 2:
- step = int(tmp[1])
- if step > cur_step:
- self.rates.append((float(tmp[0]), min(step, max_steps)))
- self.maxit += 1
- if step > max_steps:
+ try:
+ for i, pair in enumerate(pairs):
+ if not pair.strip():
+ continue
+ tmp = pair.split(':')
+ if len(tmp) == 2:
+ step = int(tmp[1])
+ if step > cur_step:
+ self.rates.append((float(tmp[0]), min(step, max_steps)))
+ self.maxit += 1
+ if step > max_steps:
+ return
+ elif step == -1:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
return
- elif step == -1:
+ else:
self.rates.append((float(tmp[0]), max_steps))
self.maxit += 1
return
- else:
- self.rates.append((float(tmp[0]), max_steps))
- self.maxit += 1
- return
+ assert self.rates
+ except (ValueError, AssertionError):
+ raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.')
+
def __iter__(self):
return self