working mask

This commit is contained in:
Stephen 2022-10-22 17:10:28 -04:00 committed by AUTOMATIC1111
parent 9e1a8b7734
commit 5dc0739ecd

View file

@ -33,6 +33,14 @@ class Api:
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"]) self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"])
def __base64_to_image(self, base64_string):
# if has a comma, deal with prefix
if "," in base64_string:
base64_string = base64_string.split(",")[1]
imgdata = base64.b64decode(base64_string)
# convert base64 to PIL image
return Image.open(io.BytesIO(imgdata))
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index) sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@ -74,26 +82,22 @@ class Api:
mask = img2imgreq.mask mask = img2imgreq.mask
if mask: if mask:
raise HTTPException(status_code=400, detail="Mask not supported yet") mask = self.__base64_to_image(mask)
populate = img2imgreq.copy(update={ # Override __init__ params populate = img2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model, "sd_model": shared.sd_model,
"sampler_index": sampler_index[0], "sampler_index": sampler_index[0],
"do_not_save_samples": True, "do_not_save_samples": True,
"do_not_save_grid": True "do_not_save_grid": True,
"mask": mask
} }
) )
p = StableDiffusionProcessingImg2Img(**vars(populate)) p = StableDiffusionProcessingImg2Img(**vars(populate))
imgs = [] imgs = []
for img in init_images: for img in init_images:
# if has a comma, deal with prefix img = self.__base64_to_image(img)
if "," in img:
img = img.split(",")[1]
# convert base64 to PIL image
img = base64.b64decode(img)
img = Image.open(io.BytesIO(img))
imgs = [img] * p.batch_size imgs = [img] * p.batch_size
p.init_images = imgs p.init_images = imgs