update progress response model
This commit is contained in:
parent
e9c6c2a51f
commit
88f46a5bec
3 changed files with 7 additions and 7 deletions
|
@ -61,7 +61,7 @@ class Api:
|
|||
self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
|
||||
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
|
||||
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
|
||||
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"])
|
||||
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
|
||||
|
||||
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
||||
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
|
||||
|
@ -171,7 +171,7 @@ class Api:
|
|||
# copy from check_progress_call of ui.py
|
||||
|
||||
if shared.state.job_count == 0:
|
||||
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.js())
|
||||
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict())
|
||||
|
||||
# avoid dividing zero
|
||||
progress = 0.01
|
||||
|
@ -187,7 +187,7 @@ class Api:
|
|||
|
||||
progress = min(progress, 1)
|
||||
|
||||
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.js())
|
||||
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict())
|
||||
|
||||
def launch(self, server_name, port):
|
||||
self.app.include_router(self.router)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import inspect
|
||||
from click import prompt
|
||||
from pydantic import BaseModel, Field, Json, create_model
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from typing import Any, Optional
|
||||
from typing_extensions import Literal
|
||||
from inflection import underscore
|
||||
|
@ -160,4 +160,4 @@ class PNGInfoResponse(BaseModel):
|
|||
class ProgressResponse(BaseModel):
|
||||
progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
|
||||
eta_relative: float = Field(title="ETA in secs")
|
||||
state: Json = Field(title="State", description="The current state snapshot")
|
||||
state: dict = Field(title="State", description="The current state snapshot")
|
||||
|
|
|
@ -147,7 +147,7 @@ class State:
|
|||
def get_job_timestamp(self):
|
||||
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
|
||||
|
||||
def js(self):
|
||||
def dict(self):
|
||||
obj = {
|
||||
"skipped": self.skipped,
|
||||
"interrupted": self.skipped,
|
||||
|
@ -158,7 +158,7 @@ class State:
|
|||
"sampling_steps": self.sampling_steps,
|
||||
}
|
||||
|
||||
return json.dumps(obj)
|
||||
return obj
|
||||
|
||||
|
||||
state = State()
|
||||
|
|
Loading…
Reference in a new issue