preview current image when opts.show_progress_every_n_steps is enabled

This commit is contained in:
evshiron 2022-10-30 05:19:17 +08:00
parent 88f46a5bec
commit 9f104b53c4
2 changed files with 7 additions and 2 deletions

View file

@ -1,7 +1,7 @@
import time import time
import uvicorn import uvicorn
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, Depends, HTTPException
import modules.shared as shared import modules.shared as shared
from modules import devices from modules import devices
from modules.api.models import * from modules.api.models import *
@ -187,7 +187,11 @@ class Api:
progress = min(progress, 1) progress = min(progress, 1)
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict()) current_image = None
if shared.state.current_image:
current_image = encode_pil_to_base64(shared.state.current_image)
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
def launch(self, server_name, port): def launch(self, server_name, port):
self.app.include_router(self.router) self.app.include_router(self.router)

View file

@ -161,3 +161,4 @@ class ProgressResponse(BaseModel):
progress: float = Field(title="Progress", description="The progress with a range of 0 to 1") progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
eta_relative: float = Field(title="ETA in secs") eta_relative: float = Field(title="ETA in secs")
state: dict = Field(title="State", description="The current state snapshot") state: dict = Field(title="State", description="The current state snapshot")
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")