provide sampler by name
This commit is contained in:
parent
8d5d863a9d
commit
e7f4808505
2 changed files with 24 additions and 4 deletions
|
@ -1,14 +1,17 @@
|
||||||
from modules.api.processing import StableDiffusionProcessingAPI
|
from modules.api.processing import StableDiffusionProcessingAPI
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, process_images
|
||||||
|
from modules.sd_samplers import samplers_k_diffusion
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import Body, APIRouter
|
from fastapi import Body, APIRouter, HTTPException
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from pydantic import BaseModel, Field, Json
|
from pydantic import BaseModel, Field, Json
|
||||||
import json
|
import json
|
||||||
import io
|
import io
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
|
sampler_to_index = lambda name: next(filter(lambda row: name in row[1][2], enumerate(samplers_k_diffusion)), None)
|
||||||
|
|
||||||
class TextToImageResponse(BaseModel):
|
class TextToImageResponse(BaseModel):
|
||||||
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||||
parameters: Json
|
parameters: Json
|
||||||
|
@ -23,9 +26,14 @@ class Api:
|
||||||
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
|
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
|
||||||
|
|
||||||
def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
|
def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
|
||||||
|
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
|
||||||
|
|
||||||
|
if sampler_index is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Sampler not found")
|
||||||
|
|
||||||
populate = txt2imgreq.copy(update={ # Override __init__ params
|
populate = txt2imgreq.copy(update={ # Override __init__ params
|
||||||
"sd_model": shared.sd_model,
|
"sd_model": shared.sd_model,
|
||||||
"sampler_index": 0,
|
"sampler_index": sampler_index[0],
|
||||||
"do_not_save_samples": True,
|
"do_not_save_samples": True,
|
||||||
"do_not_save_grid": True
|
"do_not_save_grid": True
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,7 +42,8 @@ class PydanticModelGenerator:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name: str = None,
|
model_name: str = None,
|
||||||
class_instance = None
|
class_instance = None,
|
||||||
|
additional_fields = None,
|
||||||
):
|
):
|
||||||
def field_type_generator(k, v):
|
def field_type_generator(k, v):
|
||||||
# field_type = str if not overrides.get(k) else overrides[k]["type"]
|
# field_type = str if not overrides.get(k) else overrides[k]["type"]
|
||||||
|
@ -71,6 +72,13 @@ class PydanticModelGenerator:
|
||||||
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
||||||
]
|
]
|
||||||
|
|
||||||
|
for fields in additional_fields:
|
||||||
|
self._model_def.append(ModelDef(
|
||||||
|
field=underscore(fields["key"]),
|
||||||
|
field_alias=fields["key"],
|
||||||
|
field_type=fields["type"],
|
||||||
|
field_value=fields["default"]))
|
||||||
|
|
||||||
def generate_model(self):
|
def generate_model(self):
|
||||||
"""
|
"""
|
||||||
Creates a pydantic BaseModel
|
Creates a pydantic BaseModel
|
||||||
|
@ -84,4 +92,8 @@ class PydanticModelGenerator:
|
||||||
DynamicModel.__config__.allow_mutation = True
|
DynamicModel.__config__.allow_mutation = True
|
||||||
return DynamicModel
|
return DynamicModel
|
||||||
|
|
||||||
StableDiffusionProcessingAPI = PydanticModelGenerator("StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img).generate_model()
|
StableDiffusionProcessingAPI = PydanticModelGenerator(
|
||||||
|
"StableDiffusionProcessingTxt2Img",
|
||||||
|
StableDiffusionProcessingTxt2Img,
|
||||||
|
[{"key": "sampler_index", "type": str, "default": "k_euler_a"}]
|
||||||
|
).generate_model()
|
||||||
|
|
Loading…
Reference in a new issue