Merge pull request #3810 from royshil/roy.add_simple_interrogate_api

Add a barebones CLIP interrogate API endpoint
This commit is contained in:
AUTOMATIC1111 2022-11-06 11:28:00 +03:00 committed by GitHub
commit 5302e2cdd4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 0 deletions

View file

@ -63,6 +63,7 @@ class Api:
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse) self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse) self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
self.app.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel) self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
@ -214,6 +215,19 @@ class Api:
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image) return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
def interrogateapi(self, interrogatereq: InterrogateRequest):
image_b64 = interrogatereq.image
if image_b64 is None:
raise HTTPException(status_code=404, detail="Image not found")
img = self.__base64_to_image(image_b64)
# Override object param
with self.queue_lock:
processed = shared.interrogator.interrogate(img)
return InterrogateResponse(caption=processed)
def interruptapi(self): def interruptapi(self):
shared.state.interrupt() shared.state.interrupt()

View file

@ -65,6 +65,7 @@ class PydanticModelGenerator:
self._model_name = model_name self._model_name = model_name
self._class_data = merge_class_params(class_instance) self._class_data = merge_class_params(class_instance)
self._model_def = [ self._model_def = [
ModelDef( ModelDef(
field=underscore(k), field=underscore(k),
@ -167,6 +168,12 @@ class ProgressResponse(BaseModel):
state: dict = Field(title="State", description="The current state snapshot") state: dict = Field(title="State", description="The current state snapshot")
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.") current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
class InterrogateRequest(BaseModel):
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
class InterrogateResponse(BaseModel):
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
fields = {} fields = {}
for key, value in opts.data.items(): for key, value in opts.data.items():
metadata = opts.data_labels.get(key) metadata = opts.data_labels.get(key)
@ -231,3 +238,4 @@ class ArtistItem(BaseModel):
name: str = Field(title="Name") name: str = Field(title="Name")
score: float = Field(title="Score") score: float = Field(title="Score")
category: str = Field(title="Category") category: str = Field(title="Category")