Add config and lists endpoints
This commit is contained in:
parent
d98eacea40
commit
7a2e36b583
2 changed files with 159 additions and 8 deletions
|
@ -2,14 +2,17 @@ import base64
|
||||||
import io
|
import io
|
||||||
import time
|
import time
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from gradio.processing_utils import decode_base64_to_file, decode_base64_to_image
|
from threading import Lock
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
|
||||||
|
from fastapi import APIRouter, Depends, FastAPI, HTTPException
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.api.models import *
|
from modules.api.models import *
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.sd_samplers import all_samplers, sample_to_image, samples_to_image_grid
|
from modules.sd_samplers import all_samplers
|
||||||
from modules.extras import run_extras, run_pnginfo
|
from modules.extras import run_extras, run_pnginfo
|
||||||
|
from modules.sd_models import checkpoints_list
|
||||||
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
|
from typing import List
|
||||||
|
|
||||||
def upscaler_to_index(name: str):
|
def upscaler_to_index(name: str):
|
||||||
try:
|
try:
|
||||||
|
@ -37,7 +40,7 @@ def encode_pil_to_base64(image):
|
||||||
|
|
||||||
|
|
||||||
class Api:
|
class Api:
|
||||||
def __init__(self, app, queue_lock):
|
def __init__(self, app: FastAPI, queue_lock: Lock):
|
||||||
self.router = APIRouter()
|
self.router = APIRouter()
|
||||||
self.app = app
|
self.app = app
|
||||||
self.queue_lock = queue_lock
|
self.queue_lock = queue_lock
|
||||||
|
@ -48,6 +51,19 @@ class Api:
|
||||||
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
|
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
|
||||||
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
|
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
|
||||||
self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
||||||
|
self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
|
||||||
|
self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
||||||
|
self.app.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
|
||||||
|
self.app.add_api_route("/sdapi/v1/info", self.get_info, methods=["GET"])
|
||||||
|
self.app.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
|
||||||
|
self.app.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
|
||||||
|
self.app.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
|
||||||
|
self.app.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
|
||||||
|
self.app.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
|
||||||
|
self.app.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
|
||||||
|
self.app.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem])
|
||||||
|
self.app.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
|
||||||
|
self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
|
||||||
|
|
||||||
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
||||||
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
|
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
|
||||||
|
@ -190,6 +206,77 @@ class Api:
|
||||||
shared.state.interrupt()
|
shared.state.interrupt()
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
options = {}
|
||||||
|
for key in shared.opts.data.keys():
|
||||||
|
metadata = shared.opts.data_labels.get(key)
|
||||||
|
if(metadata is not None):
|
||||||
|
options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
|
||||||
|
else:
|
||||||
|
options.update({key: shared.opts.data.get(key, None)})
|
||||||
|
|
||||||
|
return options
|
||||||
|
|
||||||
|
def set_config(self, req: OptionsModel):
|
||||||
|
reqDict = vars(req)
|
||||||
|
for o in reqDict:
|
||||||
|
setattr(shared.opts, o, reqDict[o])
|
||||||
|
|
||||||
|
shared.opts.save(shared.config_filename)
|
||||||
|
return
|
||||||
|
|
||||||
|
def get_cmd_flags(self):
|
||||||
|
return vars(shared.cmd_opts)
|
||||||
|
|
||||||
|
def get_info(self):
|
||||||
|
|
||||||
|
return {
|
||||||
|
"hypernetworks": [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks],
|
||||||
|
"face_restorers": [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers],
|
||||||
|
"realesrgan_models":[{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)],
|
||||||
|
"promp_styles":[shared.prompt_styles.styles[k] for k in shared.prompt_styles.styles],
|
||||||
|
"artists_categories": shared.artist_db.cats,
|
||||||
|
# "artists": [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_samplers(self):
|
||||||
|
return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers]
|
||||||
|
|
||||||
|
def get_upscalers(self):
|
||||||
|
upscalers = []
|
||||||
|
|
||||||
|
for upscaler in shared.sd_upscalers:
|
||||||
|
u = upscaler.scaler
|
||||||
|
upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url})
|
||||||
|
|
||||||
|
return upscalers
|
||||||
|
|
||||||
|
def get_sd_models(self):
|
||||||
|
return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()]
|
||||||
|
|
||||||
|
def get_hypernetworks(self):
|
||||||
|
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
||||||
|
|
||||||
|
def get_face_restorers(self):
|
||||||
|
return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]
|
||||||
|
|
||||||
|
def get_realesrgan_models(self):
|
||||||
|
return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
|
||||||
|
|
||||||
|
def get_promp_styles(self):
|
||||||
|
styleList = []
|
||||||
|
for k in shared.prompt_styles.styles:
|
||||||
|
style = shared.prompt_styles.styles[k]
|
||||||
|
styleList.append({"name":style[0], "prompt": style[1], "negative_prompr": style[2]})
|
||||||
|
|
||||||
|
return styleList
|
||||||
|
|
||||||
|
def get_artists_categories(self):
|
||||||
|
return shared.artist_db.cats
|
||||||
|
|
||||||
|
def get_artists(self):
|
||||||
|
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
|
||||||
|
|
||||||
def launch(self, server_name, port):
|
def launch(self, server_name, port):
|
||||||
self.app.include_router(self.router)
|
self.app.include_router(self.router)
|
||||||
|
|
|
@ -1,11 +1,10 @@
|
||||||
import inspect
|
import inspect
|
||||||
from click import prompt
|
|
||||||
from pydantic import BaseModel, Field, create_model
|
from pydantic import BaseModel, Field, create_model
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, Union
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
from inflection import underscore
|
from inflection import underscore
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
|
||||||
from modules.shared import sd_upscalers
|
from modules.shared import sd_upscalers, opts, parser
|
||||||
|
|
||||||
API_NOT_ALLOWED = [
|
API_NOT_ALLOWED = [
|
||||||
"self",
|
"self",
|
||||||
|
@ -165,3 +164,68 @@ class ProgressResponse(BaseModel):
|
||||||
eta_relative: float = Field(title="ETA in secs")
|
eta_relative: float = Field(title="ETA in secs")
|
||||||
state: dict = Field(title="State", description="The current state snapshot")
|
state: dict = Field(title="State", description="The current state snapshot")
|
||||||
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
|
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
|
||||||
|
|
||||||
|
fields = {}
|
||||||
|
for key, value in opts.data.items():
|
||||||
|
metadata = opts.data_labels.get(key)
|
||||||
|
optType = opts.typemap.get(type(value), type(value))
|
||||||
|
|
||||||
|
if (metadata is not None):
|
||||||
|
fields.update({key: (Optional[optType], Field(
|
||||||
|
default=metadata.default ,description=metadata.label))})
|
||||||
|
else:
|
||||||
|
fields.update({key: (Optional[optType], Field())})
|
||||||
|
|
||||||
|
OptionsModel = create_model("Options", **fields)
|
||||||
|
|
||||||
|
flags = {}
|
||||||
|
_options = vars(parser)['_option_string_actions']
|
||||||
|
for key in _options:
|
||||||
|
if(_options[key].dest != 'help'):
|
||||||
|
flag = _options[key]
|
||||||
|
_type = str
|
||||||
|
if(_options[key].default != None): _type = type(_options[key].default)
|
||||||
|
flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
|
||||||
|
|
||||||
|
FlagsModel = create_model("Flags", **flags)
|
||||||
|
|
||||||
|
class SamplerItem(BaseModel):
|
||||||
|
name: str = Field(title="Name")
|
||||||
|
aliases: list[str] = Field(title="Aliases")
|
||||||
|
options: dict[str, str] = Field(title="Options")
|
||||||
|
|
||||||
|
class UpscalerItem(BaseModel):
|
||||||
|
name: str = Field(title="Name")
|
||||||
|
model_name: str | None = Field(title="Model Name")
|
||||||
|
model_path: str | None = Field(title="Path")
|
||||||
|
model_url: str | None = Field(title="URL")
|
||||||
|
|
||||||
|
class SDModelItem(BaseModel):
|
||||||
|
title: str = Field(title="Title")
|
||||||
|
model_name: str = Field(title="Model Name")
|
||||||
|
hash: str = Field(title="Hash")
|
||||||
|
filename: str = Field(title="Filename")
|
||||||
|
config: str = Field(title="Config file")
|
||||||
|
|
||||||
|
class HypernetworkItem(BaseModel):
|
||||||
|
name: str = Field(title="Name")
|
||||||
|
path: str | None = Field(title="Path")
|
||||||
|
|
||||||
|
class FaceRestorerItem(BaseModel):
|
||||||
|
name: str = Field(title="Name")
|
||||||
|
cmd_dir: str | None = Field(title="Path")
|
||||||
|
|
||||||
|
class RealesrganItem(BaseModel):
|
||||||
|
name: str = Field(title="Name")
|
||||||
|
path: str | None = Field(title="Path")
|
||||||
|
scale: int | None = Field(title="Scale")
|
||||||
|
|
||||||
|
class PromptStyleItem(BaseModel):
|
||||||
|
name: str = Field(title="Name")
|
||||||
|
prompt: str | None = Field(title="Prompt")
|
||||||
|
negative_prompt: str | None = Field(title="Negative Prompt")
|
||||||
|
|
||||||
|
class ArtistItem(BaseModel):
|
||||||
|
name: str = Field(title="Name")
|
||||||
|
score: float = Field(title="Score")
|
||||||
|
category: str = Field(title="Category")
|
Loading…
Reference in a new issue