example API working with gradio
This commit is contained in:
parent
d42125baf6
commit
f80e914ac4
3 changed files with 60 additions and 27 deletions
|
@ -23,8 +23,13 @@ class Api:
|
|||
app.add_api_route("/v1/txt2img", self.text2imgapi, methods=["POST"])
|
||||
|
||||
def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
|
||||
p = StableDiffusionProcessingTxt2Img(**vars(txt2imgreq))
|
||||
p.sd_model = shared.sd_model
|
||||
populate = txt2imgreq.copy(update={ # Override __init__ params
|
||||
"sd_model": shared.sd_model,
|
||||
"sampler_index": 0,
|
||||
}
|
||||
)
|
||||
p = StableDiffusionProcessingTxt2Img(**vars(populate))
|
||||
# Override object param
|
||||
processed = process_images(p)
|
||||
|
||||
b64images = []
|
||||
|
|
|
@ -5,6 +5,24 @@ from modules.processing import StableDiffusionProcessing, Processed, StableDiffu
|
|||
import inspect
|
||||
|
||||
|
||||
API_NOT_ALLOWED = [
|
||||
"self",
|
||||
"kwargs",
|
||||
"sd_model",
|
||||
"outpath_samples",
|
||||
"outpath_grids",
|
||||
"sampler_index",
|
||||
"do_not_save_samples",
|
||||
"do_not_save_grid",
|
||||
"extra_generation_params",
|
||||
"overlay_images",
|
||||
"do_not_reload_embeddings",
|
||||
"seed_enable_extras",
|
||||
"prompt_for_display",
|
||||
"sampler_noise_scheduler_override",
|
||||
"ddim_discretize"
|
||||
]
|
||||
|
||||
class ModelDef(BaseModel):
|
||||
"""Assistance Class for Pydantic Dynamic Model Generation"""
|
||||
|
||||
|
@ -14,7 +32,7 @@ class ModelDef(BaseModel):
|
|||
field_value: Any
|
||||
|
||||
|
||||
class pydanticModelGenerator:
|
||||
class PydanticModelGenerator:
|
||||
"""
|
||||
Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
|
||||
source_data is a snapshot of the default values produced by the class
|
||||
|
@ -24,30 +42,33 @@ class pydanticModelGenerator:
|
|||
def __init__(
|
||||
self,
|
||||
model_name: str = None,
|
||||
source_data: {} = {},
|
||||
params: Dict = {},
|
||||
overrides: Dict = {},
|
||||
optionals: Dict = {},
|
||||
class_instance = None
|
||||
):
|
||||
def field_type_generator(k, v, overrides, optionals):
|
||||
field_type = str if not overrides.get(k) else overrides[k]["type"]
|
||||
if v is None:
|
||||
field_type = Any
|
||||
else:
|
||||
field_type = type(v)
|
||||
def field_type_generator(k, v):
|
||||
# field_type = str if not overrides.get(k) else overrides[k]["type"]
|
||||
# print(k, v.annotation, v.default)
|
||||
field_type = v.annotation
|
||||
|
||||
return Optional[field_type]
|
||||
|
||||
def merge_class_params(class_):
|
||||
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
|
||||
parameters = {}
|
||||
for classes in all_classes:
|
||||
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
|
||||
return parameters
|
||||
|
||||
|
||||
self._model_name = model_name
|
||||
self._json_data = source_data
|
||||
self._class_data = merge_class_params(class_instance)
|
||||
self._model_def = [
|
||||
ModelDef(
|
||||
field=underscore(k),
|
||||
field_alias=k,
|
||||
field_type=field_type_generator(k, v, overrides, optionals),
|
||||
field_value=v
|
||||
field_type=field_type_generator(k, v),
|
||||
field_value=v.default
|
||||
)
|
||||
for (k,v) in source_data.items() if k in params
|
||||
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
||||
]
|
||||
|
||||
def generate_model(self):
|
||||
|
@ -60,8 +81,7 @@ class pydanticModelGenerator:
|
|||
}
|
||||
DynamicModel = create_model(self._model_name, **fields)
|
||||
DynamicModel.__config__.allow_population_by_field_name = True
|
||||
DynamicModel.__config__.allow_mutation = True
|
||||
return DynamicModel
|
||||
|
||||
StableDiffusionProcessingAPI = pydanticModelGenerator("StableDiffusionProcessing",
|
||||
StableDiffusionProcessing().__dict__,
|
||||
inspect.signature(StableDiffusionProcessing.__init__).parameters).generate_model()
|
||||
StableDiffusionProcessingAPI = PydanticModelGenerator("StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img).generate_model()
|
||||
|
|
|
@ -9,6 +9,7 @@ from PIL import Image, ImageFilter, ImageOps
|
|||
import random
|
||||
import cv2
|
||||
from skimage import exposure
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import modules.sd_hijack
|
||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram
|
||||
|
@ -51,9 +52,15 @@ def get_correct_sampler(p):
|
|||
return sd_samplers.samplers
|
||||
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
|
||||
return sd_samplers.samplers_for_img2img
|
||||
elif isinstance(p, modules.api.processing.StableDiffusionProcessingAPI):
|
||||
return sd_samplers.samplers
|
||||
|
||||
class StableDiffusionProcessing:
|
||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None, do_not_reload_embeddings=False):
|
||||
class StableDiffusionProcessing():
|
||||
"""
|
||||
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
||||
|
||||
"""
|
||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str="", styles: List[str]=None, seed: int=-1, subseed: int=-1, subseed_strength: float=0, seed_resize_from_h: int=-1, seed_resize_from_w: int=-1, seed_enable_extras: bool=True, sampler_index: int=0, batch_size: int=1, n_iter: int=1, steps:int =50, cfg_scale:float=7.0, width:int=512, height:int=512, restore_faces:bool=False, tiling:bool=False, do_not_save_samples:bool=False, do_not_save_grid:bool=False, extra_generation_params: Dict[Any,Any]=None, overlay_images: Any=None, negative_prompt: str=None, eta: float =None, do_not_reload_embeddings: bool=False, denoising_strength: float = 0, ddim_discretize: str = "uniform", s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0):
|
||||
self.sd_model = sd_model
|
||||
self.outpath_samples: str = outpath_samples
|
||||
self.outpath_grids: str = outpath_grids
|
||||
|
@ -86,10 +93,10 @@ class StableDiffusionProcessing:
|
|||
self.denoising_strength: float = 0
|
||||
self.sampler_noise_scheduler_override = None
|
||||
self.ddim_discretize = opts.ddim_discretize
|
||||
self.s_churn = opts.s_churn
|
||||
self.s_tmin = opts.s_tmin
|
||||
self.s_tmax = float('inf') # not representable as a standard ui option
|
||||
self.s_noise = opts.s_noise
|
||||
self.s_churn = s_churn or opts.s_churn
|
||||
self.s_tmin = s_tmin or opts.s_tmin
|
||||
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
|
||||
self.s_noise = s_noise or opts.s_noise
|
||||
|
||||
if not seed_enable_extras:
|
||||
self.subseed = -1
|
||||
|
@ -97,6 +104,7 @@ class StableDiffusionProcessing:
|
|||
self.seed_resize_from_h = 0
|
||||
self.seed_resize_from_w = 0
|
||||
|
||||
|
||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||
pass
|
||||
|
||||
|
@ -497,7 +505,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
sampler = None
|
||||
|
||||
def __init__(self, enable_hr=False, denoising_strength=0.75, firstphase_width=0, firstphase_height=0, **kwargs):
|
||||
def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.enable_hr = enable_hr
|
||||
self.denoising_strength = denoising_strength
|
||||
|
|
Loading…
Reference in a new issue