From 37638370032892c03734f511eb5935be370ba56f Mon Sep 17 00:00:00 2001 From: ArrowM Date: Thu, 15 Sep 2022 22:23:37 -0500 Subject: [PATCH] Add batch processing to Extras tab --- modules/extras.py | 99 ++++++++++++++++++++++++++++------------------- modules/ui.py | 13 +++++-- 2 files changed, 68 insertions(+), 44 deletions(-) diff --git a/modules/extras.py b/modules/extras.py index e3c7d3e5..ffae7d67 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -13,66 +13,85 @@ import piexif.helper cached_images = {} -def run_extras(image, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility): +def run_extras(image, image_folder, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility): devices.torch_gc() - existing_pnginfo = image.info or {} + imageArr = [] - image = image.convert("RGB") - info = "" + if image_folder != None: + if image != None: + print("Batch detected and single image detected, please only use one of the two. Aborting.") + return None + #convert file to pillow image + for img in image_folder: + image = Image.fromarray(np.array(Image.open(img))) + imageArr.append(image) + + elif image != None: + if image_folder != None: + print("Batch detected and single image detected, please only use one of the two. Aborting.") + return None + else: + imageArr.append(image) outpath = opts.outdir_samples or opts.outdir_extras_samples - if gfpgan_visibility > 0: - restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) - res = Image.fromarray(restored_img) + for image in imageArr: + existing_pnginfo = image.info or {} - if gfpgan_visibility < 1.0: - res = Image.blend(image, res, gfpgan_visibility) + image = image.convert("RGB") + info = "" - info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n" - image = res + if gfpgan_visibility > 0: + restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) + res = Image.fromarray(restored_img) - if codeformer_visibility > 0: - restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) - res = Image.fromarray(restored_img) + if gfpgan_visibility < 1.0: + res = Image.blend(image, res, gfpgan_visibility) - if codeformer_visibility < 1.0: - res = Image.blend(image, res, codeformer_visibility) + info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n" + image = res - info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility)}\n" - image = res + if codeformer_visibility > 0: + restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) + res = Image.fromarray(restored_img) - if upscaling_resize != 1.0: - def upscale(image, scaler_index, resize): - small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10)) - pixels = tuple(np.array(small).flatten().tolist()) - key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels + if codeformer_visibility < 1.0: + res = Image.blend(image, res, codeformer_visibility) - c = cached_images.get(key) - if c is None: - upscaler = shared.sd_upscalers[scaler_index] - c = upscaler.upscale(image, image.width * resize, image.height * resize) - cached_images[key] = c + info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility)}\n" + image = res - return c + if upscaling_resize != 1.0: + def upscale(image, scaler_index, resize): + small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10)) + pixels = tuple(np.array(small).flatten().tolist()) + key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels - info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n" - res = upscale(image, extras_upscaler_1, upscaling_resize) + c = cached_images.get(key) + if c is None: + upscaler = shared.sd_upscalers[scaler_index] + c = upscaler.upscale(image, image.width * resize, image.height * resize) + cached_images[key] = c - if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: - res2 = upscale(image, extras_upscaler_2, upscaling_resize) - info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n" - res = Image.blend(res, res2, extras_upscaler_2_visibility) + return c - image = res + info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n" + res = upscale(image, extras_upscaler_1, upscaling_resize) - while len(cached_images) > 2: - del cached_images[next(iter(cached_images.keys()))] + if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: + res2 = upscale(image, extras_upscaler_2, upscaling_resize) + info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n" + res = Image.blend(res, res2, extras_upscaler_2_visibility) - images.save_image(image, path=outpath, basename="", seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo) + image = res - return image, plaintext_to_html(info), '' + while len(cached_images) > 2: + del cached_images[next(iter(cached_images.keys()))] + + images.save_image(image, path=outpath, basename="", seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo) + + return imageArr, plaintext_to_html(info), '' def run_pnginfo(image): diff --git a/modules/ui.py b/modules/ui.py index efd57b2e..b6d5dcd8 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -644,8 +644,12 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): with gr.Blocks(analytics_enabled=False) as extras_interface: with gr.Row().style(equal_height=False): with gr.Column(variant='panel'): - with gr.Group(): - image = gr.Image(label="Source", source="upload", interactive=True, type="pil") + with gr.Tabs(): + with gr.TabItem('Single Image'): + image = gr.Image(label="Source", source="upload", interactive=True, type="pil") + + with gr.TabItem('Batch Process'): + image_batch = gr.File(label="Batch Process", file_count="multiple", source="upload", interactive=True, type="file") upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2) @@ -666,7 +670,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') with gr.Column(variant='panel'): - result_image = gr.Image(label="Result") + result_images = gr.Gallery(label="Result") html_info_x = gr.HTML() html_info = gr.HTML() @@ -674,6 +678,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): fn=run_extras, inputs=[ image, + image_batch, gfpgan_visibility, codeformer_visibility, codeformer_weight, @@ -683,7 +688,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): extras_upscaler_2_visibility, ], outputs=[ - result_image, + result_images, html_info_x, html_info, ]