aboutsummaryrefslogtreecommitdiff
path: root/modules/api/processing.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/api/processing.py')
-rw-r--r--modules/api/processing.py56
1 files changed, 38 insertions, 18 deletions
diff --git a/modules/api/processing.py b/modules/api/processing.py
index e4df93c5..b6798241 100644
--- a/modules/api/processing.py
+++ b/modules/api/processing.py
@@ -5,6 +5,24 @@ from modules.processing import StableDiffusionProcessing, Processed, StableDiffu
import inspect
+API_NOT_ALLOWED = [
+ "self",
+ "kwargs",
+ "sd_model",
+ "outpath_samples",
+ "outpath_grids",
+ "sampler_index",
+ "do_not_save_samples",
+ "do_not_save_grid",
+ "extra_generation_params",
+ "overlay_images",
+ "do_not_reload_embeddings",
+ "seed_enable_extras",
+ "prompt_for_display",
+ "sampler_noise_scheduler_override",
+ "ddim_discretize"
+]
+
class ModelDef(BaseModel):
"""Assistance Class for Pydantic Dynamic Model Generation"""
@@ -14,7 +32,7 @@ class ModelDef(BaseModel):
field_value: Any
-class pydanticModelGenerator:
+class PydanticModelGenerator:
"""
Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
source_data is a snapshot of the default values produced by the class
@@ -24,30 +42,33 @@ class pydanticModelGenerator:
def __init__(
self,
model_name: str = None,
- source_data: {} = {},
- params: Dict = {},
- overrides: Dict = {},
- optionals: Dict = {},
+ class_instance = None
):
- def field_type_generator(k, v, overrides, optionals):
- field_type = str if not overrides.get(k) else overrides[k]["type"]
- if v is None:
- field_type = Any
- else:
- field_type = type(v)
+ def field_type_generator(k, v):
+ # field_type = str if not overrides.get(k) else overrides[k]["type"]
+ # print(k, v.annotation, v.default)
+ field_type = v.annotation
return Optional[field_type]
+ def merge_class_params(class_):
+ all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
+ parameters = {}
+ for classes in all_classes:
+ parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
+ return parameters
+
+
self._model_name = model_name
- self._json_data = source_data
+ self._class_data = merge_class_params(class_instance)
self._model_def = [
ModelDef(
field=underscore(k),
field_alias=k,
- field_type=field_type_generator(k, v, overrides, optionals),
- field_value=v
+ field_type=field_type_generator(k, v),
+ field_value=v.default
)
- for (k,v) in source_data.items() if k in params
+ for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
]
def generate_model(self):
@@ -60,8 +81,7 @@ class pydanticModelGenerator:
}
DynamicModel = create_model(self._model_name, **fields)
DynamicModel.__config__.allow_population_by_field_name = True
+ DynamicModel.__config__.allow_mutation = True
return DynamicModel
-StableDiffusionProcessingAPI = pydanticModelGenerator("StableDiffusionProcessing",
- StableDiffusionProcessing().__dict__,
- inspect.signature(StableDiffusionProcessing.__init__).parameters).generate_model()
+StableDiffusionProcessingAPI = PydanticModelGenerator("StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img).generate_model()