update progress response model

This commit is contained in:
evshiron 2022-10-30 05:04:29 +08:00
parent e9c6c2a51f
commit 88f46a5bec
3 changed files with 7 additions and 7 deletions

View file

@ -61,7 +61,7 @@ class Api:
self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"])
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@ -171,7 +171,7 @@ class Api:
# copy from check_progress_call of ui.py
if shared.state.job_count == 0:
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.js())
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict())
# avoid dividing zero
progress = 0.01
@ -187,7 +187,7 @@ class Api:
progress = min(progress, 1)
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.js())
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict())
def launch(self, server_name, port):
self.app.include_router(self.router)

View file

@ -1,6 +1,6 @@
import inspect
from click import prompt
from pydantic import BaseModel, Field, Json, create_model
from pydantic import BaseModel, Field, create_model
from typing import Any, Optional
from typing_extensions import Literal
from inflection import underscore
@ -160,4 +160,4 @@ class PNGInfoResponse(BaseModel):
class ProgressResponse(BaseModel):
progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
eta_relative: float = Field(title="ETA in secs")
state: Json = Field(title="State", description="The current state snapshot")
state: dict = Field(title="State", description="The current state snapshot")

View file

@ -147,7 +147,7 @@ class State:
def get_job_timestamp(self):
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
def js(self):
def dict(self):
obj = {
"skipped": self.skipped,
"interrupted": self.skipped,
@ -158,7 +158,7 @@ class State:
"sampling_steps": self.sampling_steps,
}
return json.dumps(obj)
return obj
state = State()