Merge branch 'master' into feat/progress-api
This commit is contained in:
commit
7f5212fb5f
6 changed files with 59 additions and 10 deletions
12
CODEOWNERS
12
CODEOWNERS
|
@ -1 +1,13 @@
|
|||
* @AUTOMATIC1111
|
||||
/localizations/ar_AR.json @xmodar @blackneoo
|
||||
/localizations/de_DE.json @LunixWasTaken
|
||||
/localizations/es_ES.json @innovaciones
|
||||
/localizations/fr_FR.json @tumbly
|
||||
/localizations/it_IT.json @EugenioBuffo
|
||||
/localizations/ja_JP.json @yuuki76
|
||||
/localizations/ko_KR.json @36DB
|
||||
/localizations/pt_BR.json @M-art-ucci
|
||||
/localizations/ru_RU.json @kabachuha
|
||||
/localizations/tr_TR.json @camenduru
|
||||
/localizations/zh_CN.json @dtlnor @bgluminous
|
||||
/localizations/zh_TW.json @benlisquare
|
||||
|
|
|
@ -37,7 +37,7 @@ from modules import devices
|
|||
from modules.api.models import *
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.sd_samplers import all_samplers
|
||||
from modules.extras import run_extras
|
||||
from modules.extras import run_extras, run_pnginfo
|
||||
|
||||
# copy from wrap_gradio_gpu_call of webui.py
|
||||
# because queue lock will be acquired in api handlers
|
||||
|
@ -90,6 +90,7 @@ class Api:
|
|||
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
|
||||
self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
|
||||
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
|
||||
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
|
||||
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"])
|
||||
|
||||
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
||||
|
@ -188,6 +189,14 @@ class Api:
|
|||
|
||||
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
||||
|
||||
def pnginfoapi(self, req: PNGInfoRequest):
|
||||
if(not req.image.strip()):
|
||||
return PNGInfoResponse(info="")
|
||||
|
||||
result = run_pnginfo(decode_base64_to_image(req.image.strip()))
|
||||
|
||||
return PNGInfoResponse(info=result[1])
|
||||
|
||||
def progressapi(self):
|
||||
# copy from check_progress_call of ui.py
|
||||
|
||||
|
@ -210,9 +219,6 @@ class Api:
|
|||
|
||||
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.js())
|
||||
|
||||
def pnginfoapi(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def launch(self, server_name, port):
|
||||
self.app.include_router(self.router)
|
||||
uvicorn.run(self.app, host=server_name, port=port)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import inspect
|
||||
from click import prompt
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from typing import Any, Optional
|
||||
from typing_extensions import Literal
|
||||
|
@ -150,6 +151,12 @@ class ExtrasBatchImagesRequest(ExtrasBaseRequest):
|
|||
class ExtrasBatchImagesResponse(ExtraBaseResponse):
|
||||
images: list[str] = Field(title="Images", description="The generated images in base64 format.")
|
||||
|
||||
class PNGInfoRequest(BaseModel):
|
||||
image: str = Field(title="Image", description="The base64 encoded PNG image")
|
||||
|
||||
class PNGInfoResponse(BaseModel):
|
||||
info: str = Field(title="Image info", description="A string with all the info the image had")
|
||||
|
||||
class ProgressResponse(BaseModel):
|
||||
progress: float
|
||||
eta_relative: float
|
||||
|
|
|
@ -478,7 +478,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.run_alwayson_scripts(p)
|
||||
p.scripts.process(p)
|
||||
|
||||
infotexts = []
|
||||
output_images = []
|
||||
|
@ -501,7 +501,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
|
||||
if (len(prompts) == 0):
|
||||
if len(prompts) == 0:
|
||||
break
|
||||
|
||||
with devices.autocast():
|
||||
|
@ -590,7 +590,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
||||
|
||||
devices.torch_gc()
|
||||
return Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
|
||||
|
||||
res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.postprocess(p, res)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
|
|
|
@ -64,7 +64,16 @@ class Script:
|
|||
def process(self, p, *args):
|
||||
"""
|
||||
This function is called before processing begins for AlwaysVisible scripts.
|
||||
scripts. You can modify the processing object (p) here, inject hooks, etc.
|
||||
You can modify the processing object (p) here, inject hooks, etc.
|
||||
args contains all values returned by components from ui()
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def postprocess(self, p, processed, *args):
|
||||
"""
|
||||
This function is called after processing ends for AlwaysVisible scripts.
|
||||
args contains all values returned by components from ui()
|
||||
"""
|
||||
|
||||
pass
|
||||
|
@ -289,13 +298,22 @@ class ScriptRunner:
|
|||
|
||||
return processed
|
||||
|
||||
def run_alwayson_scripts(self, p):
|
||||
def process(self, p):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.process(p, *script_args)
|
||||
except Exception:
|
||||
print(f"Error running alwayson script: {script.filename}", file=sys.stderr)
|
||||
print(f"Error running process: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
def postprocess(self, p, processed):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.postprocess(p, processed, *script_args)
|
||||
except Exception:
|
||||
print(f"Error running postprocess: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
def reload_sources(self, cache):
|
||||
|
|
Loading…
Reference in a new issue