example API working with gradio
This commit is contained in:
parent
d42125baf6
commit
f80e914ac4
3 changed files with 60 additions and 27 deletions
|
@ -23,8 +23,13 @@ class Api:
|
||||||
app.add_api_route("/v1/txt2img", self.text2imgapi, methods=["POST"])
|
app.add_api_route("/v1/txt2img", self.text2imgapi, methods=["POST"])
|
||||||
|
|
||||||
def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
|
def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
|
||||||
p = StableDiffusionProcessingTxt2Img(**vars(txt2imgreq))
|
populate = txt2imgreq.copy(update={ # Override __init__ params
|
||||||
p.sd_model = shared.sd_model
|
"sd_model": shared.sd_model,
|
||||||
|
"sampler_index": 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
p = StableDiffusionProcessingTxt2Img(**vars(populate))
|
||||||
|
# Override object param
|
||||||
processed = process_images(p)
|
processed = process_images(p)
|
||||||
|
|
||||||
b64images = []
|
b64images = []
|
||||||
|
|
|
@ -5,6 +5,24 @@ from modules.processing import StableDiffusionProcessing, Processed, StableDiffu
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
|
|
||||||
|
API_NOT_ALLOWED = [
|
||||||
|
"self",
|
||||||
|
"kwargs",
|
||||||
|
"sd_model",
|
||||||
|
"outpath_samples",
|
||||||
|
"outpath_grids",
|
||||||
|
"sampler_index",
|
||||||
|
"do_not_save_samples",
|
||||||
|
"do_not_save_grid",
|
||||||
|
"extra_generation_params",
|
||||||
|
"overlay_images",
|
||||||
|
"do_not_reload_embeddings",
|
||||||
|
"seed_enable_extras",
|
||||||
|
"prompt_for_display",
|
||||||
|
"sampler_noise_scheduler_override",
|
||||||
|
"ddim_discretize"
|
||||||
|
]
|
||||||
|
|
||||||
class ModelDef(BaseModel):
|
class ModelDef(BaseModel):
|
||||||
"""Assistance Class for Pydantic Dynamic Model Generation"""
|
"""Assistance Class for Pydantic Dynamic Model Generation"""
|
||||||
|
|
||||||
|
@ -14,7 +32,7 @@ class ModelDef(BaseModel):
|
||||||
field_value: Any
|
field_value: Any
|
||||||
|
|
||||||
|
|
||||||
class pydanticModelGenerator:
|
class PydanticModelGenerator:
|
||||||
"""
|
"""
|
||||||
Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
|
Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
|
||||||
source_data is a snapshot of the default values produced by the class
|
source_data is a snapshot of the default values produced by the class
|
||||||
|
@ -24,30 +42,33 @@ class pydanticModelGenerator:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name: str = None,
|
model_name: str = None,
|
||||||
source_data: {} = {},
|
class_instance = None
|
||||||
params: Dict = {},
|
|
||||||
overrides: Dict = {},
|
|
||||||
optionals: Dict = {},
|
|
||||||
):
|
):
|
||||||
def field_type_generator(k, v, overrides, optionals):
|
def field_type_generator(k, v):
|
||||||
field_type = str if not overrides.get(k) else overrides[k]["type"]
|
# field_type = str if not overrides.get(k) else overrides[k]["type"]
|
||||||
if v is None:
|
# print(k, v.annotation, v.default)
|
||||||
field_type = Any
|
field_type = v.annotation
|
||||||
else:
|
|
||||||
field_type = type(v)
|
|
||||||
|
|
||||||
return Optional[field_type]
|
return Optional[field_type]
|
||||||
|
|
||||||
|
def merge_class_params(class_):
|
||||||
|
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
|
||||||
|
parameters = {}
|
||||||
|
for classes in all_classes:
|
||||||
|
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
|
||||||
|
return parameters
|
||||||
|
|
||||||
|
|
||||||
self._model_name = model_name
|
self._model_name = model_name
|
||||||
self._json_data = source_data
|
self._class_data = merge_class_params(class_instance)
|
||||||
self._model_def = [
|
self._model_def = [
|
||||||
ModelDef(
|
ModelDef(
|
||||||
field=underscore(k),
|
field=underscore(k),
|
||||||
field_alias=k,
|
field_alias=k,
|
||||||
field_type=field_type_generator(k, v, overrides, optionals),
|
field_type=field_type_generator(k, v),
|
||||||
field_value=v
|
field_value=v.default
|
||||||
)
|
)
|
||||||
for (k,v) in source_data.items() if k in params
|
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
||||||
]
|
]
|
||||||
|
|
||||||
def generate_model(self):
|
def generate_model(self):
|
||||||
|
@ -60,8 +81,7 @@ class pydanticModelGenerator:
|
||||||
}
|
}
|
||||||
DynamicModel = create_model(self._model_name, **fields)
|
DynamicModel = create_model(self._model_name, **fields)
|
||||||
DynamicModel.__config__.allow_population_by_field_name = True
|
DynamicModel.__config__.allow_population_by_field_name = True
|
||||||
|
DynamicModel.__config__.allow_mutation = True
|
||||||
return DynamicModel
|
return DynamicModel
|
||||||
|
|
||||||
StableDiffusionProcessingAPI = pydanticModelGenerator("StableDiffusionProcessing",
|
StableDiffusionProcessingAPI = PydanticModelGenerator("StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img).generate_model()
|
||||||
StableDiffusionProcessing().__dict__,
|
|
||||||
inspect.signature(StableDiffusionProcessing.__init__).parameters).generate_model()
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ from PIL import Image, ImageFilter, ImageOps
|
||||||
import random
|
import random
|
||||||
import cv2
|
import cv2
|
||||||
from skimage import exposure
|
from skimage import exposure
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram
|
from modules import devices, prompt_parser, masking, sd_samplers, lowvram
|
||||||
|
@ -51,9 +52,15 @@ def get_correct_sampler(p):
|
||||||
return sd_samplers.samplers
|
return sd_samplers.samplers
|
||||||
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
|
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
|
||||||
return sd_samplers.samplers_for_img2img
|
return sd_samplers.samplers_for_img2img
|
||||||
|
elif isinstance(p, modules.api.processing.StableDiffusionProcessingAPI):
|
||||||
|
return sd_samplers.samplers
|
||||||
|
|
||||||
class StableDiffusionProcessing:
|
class StableDiffusionProcessing():
|
||||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None, do_not_reload_embeddings=False):
|
"""
|
||||||
|
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str="", styles: List[str]=None, seed: int=-1, subseed: int=-1, subseed_strength: float=0, seed_resize_from_h: int=-1, seed_resize_from_w: int=-1, seed_enable_extras: bool=True, sampler_index: int=0, batch_size: int=1, n_iter: int=1, steps:int =50, cfg_scale:float=7.0, width:int=512, height:int=512, restore_faces:bool=False, tiling:bool=False, do_not_save_samples:bool=False, do_not_save_grid:bool=False, extra_generation_params: Dict[Any,Any]=None, overlay_images: Any=None, negative_prompt: str=None, eta: float =None, do_not_reload_embeddings: bool=False, denoising_strength: float = 0, ddim_discretize: str = "uniform", s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0):
|
||||||
self.sd_model = sd_model
|
self.sd_model = sd_model
|
||||||
self.outpath_samples: str = outpath_samples
|
self.outpath_samples: str = outpath_samples
|
||||||
self.outpath_grids: str = outpath_grids
|
self.outpath_grids: str = outpath_grids
|
||||||
|
@ -86,10 +93,10 @@ class StableDiffusionProcessing:
|
||||||
self.denoising_strength: float = 0
|
self.denoising_strength: float = 0
|
||||||
self.sampler_noise_scheduler_override = None
|
self.sampler_noise_scheduler_override = None
|
||||||
self.ddim_discretize = opts.ddim_discretize
|
self.ddim_discretize = opts.ddim_discretize
|
||||||
self.s_churn = opts.s_churn
|
self.s_churn = s_churn or opts.s_churn
|
||||||
self.s_tmin = opts.s_tmin
|
self.s_tmin = s_tmin or opts.s_tmin
|
||||||
self.s_tmax = float('inf') # not representable as a standard ui option
|
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
|
||||||
self.s_noise = opts.s_noise
|
self.s_noise = s_noise or opts.s_noise
|
||||||
|
|
||||||
if not seed_enable_extras:
|
if not seed_enable_extras:
|
||||||
self.subseed = -1
|
self.subseed = -1
|
||||||
|
@ -97,6 +104,7 @@ class StableDiffusionProcessing:
|
||||||
self.seed_resize_from_h = 0
|
self.seed_resize_from_h = 0
|
||||||
self.seed_resize_from_w = 0
|
self.seed_resize_from_w = 0
|
||||||
|
|
||||||
|
|
||||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -497,7 +505,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
sampler = None
|
sampler = None
|
||||||
|
|
||||||
def __init__(self, enable_hr=False, denoising_strength=0.75, firstphase_width=0, firstphase_height=0, **kwargs):
|
def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.enable_hr = enable_hr
|
self.enable_hr = enable_hr
|
||||||
self.denoising_strength = denoising_strength
|
self.denoising_strength = denoising_strength
|
||||||
|
|
Loading…
Reference in a new issue