diff --git a/webui.py b/webui.py index 6f935339..8901706d 100644 --- a/webui.py +++ b/webui.py @@ -40,6 +40,9 @@ import signal import tqdm import re import threading +import time +import base64 +import io import k_diffusion.sampling from ldm.util import instantiate_from_config @@ -1285,6 +1288,22 @@ def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, u return processed.images, processed.js(), plaintext_to_html(processed.info) +def image_from_url_text(filedata): + if filedata.startswith("data:image/png;base64,"): + filedata = filedata[len("data:image/png;base64,"):] + + filedata = base64.decodebytes(filedata.encode('utf-8')) + image = Image.open(io.BytesIO(filedata)) + return image + + +def send_gradio_gallery_to_image(x): + if len(x) == 0: + return None + + return image_from_url_text(x[0]) + + def save_files(js_data, images): import csv @@ -1295,9 +1314,6 @@ def save_files(js_data, images): data = json.loads(js_data) with open("log/log.csv", "a", encoding="utf8", newline='') as file: - import time - import base64 - at_start = file.tell() == 0 writer = csv.writer(file) if at_start: @@ -1352,12 +1368,15 @@ with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Column(variant='panel'): with gr.Group(): - gallery = gr.Gallery(label='Output') + txt2img_gallery = gr.Gallery(label='Output') with gr.Group(): with gr.Row(): - interrupt = gr.Button('Interrupt') save = gr.Button('Save') + send_to_img2img = gr.Button('Send to img2img') + send_to_inpaint = gr.Button('Send to inpaint') + send_to_extras = gr.Button('Send to extras') + interrupt = gr.Button('Interrupt') with gr.Group(): html_info = gr.HTML() @@ -1381,7 +1400,7 @@ with gr.Blocks(analytics_enabled=False) as txt2img_interface: code ], outputs=[ - gallery, + txt2img_gallery, generation_info, html_info ] @@ -1400,7 +1419,7 @@ with gr.Blocks(analytics_enabled=False) as txt2img_interface: fn=wrap_gradio_call(save_files), inputs=[ generation_info, - gallery, + txt2img_gallery, ], outputs=[ html_info, @@ -1742,12 +1761,13 @@ with gr.Blocks(analytics_enabled=False) as img2img_interface: with gr.Column(variant='panel'): with gr.Group(): - gallery = gr.Gallery(label='Output') + img2img_gallery = gr.Gallery(label='Output') with gr.Group(): with gr.Row(): interrupt = gr.Button('Interrupt') save = gr.Button('Save') + img2img_send_to_extras = gr.Button('Send to extras') with gr.Group(): html_info = gr.HTML() @@ -1815,7 +1835,7 @@ with gr.Blocks(analytics_enabled=False) as img2img_interface: inpaint_full_res, ], outputs=[ - gallery, + img2img_gallery, generation_info, html_info ] @@ -1834,7 +1854,7 @@ with gr.Blocks(analytics_enabled=False) as img2img_interface: fn=wrap_gradio_call(save_files), inputs=[ generation_info, - gallery, + img2img_gallery, ], outputs=[ html_info, @@ -1843,6 +1863,19 @@ with gr.Blocks(analytics_enabled=False) as img2img_interface: ] ) + send_to_img2img.click( + fn=send_gradio_gallery_to_image, + inputs=[txt2img_gallery], + outputs=[init_img], + ) + + send_to_inpaint.click( + fn=send_gradio_gallery_to_image, + inputs=[txt2img_gallery], + outputs=[init_img_with_mask], + ) + + def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index): info = realesrgan_models[RealESRGAN_model_index] @@ -1887,22 +1920,52 @@ def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_in return image, '', '' -extras_interface = gr.Interface( - wrap_gradio_gpu_call(run_extras), - inputs=[ - gr.Image(label="Source", source="upload", interactive=True, type="pil"), - gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN strength", value=1, interactive=have_gfpgan), - gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Real-ESRGAN upscaling", value=2, interactive=have_realesrgan), - gr.Radio(label='Real-ESRGAN model', choices=[x.name for x in realesrgan_models], value=realesrgan_models[0].name, type="index", interactive=have_realesrgan), - ], - outputs=[ - gr.Image(label="Result"), - gr.HTML(), - gr.HTML(), - ], - allow_flagging="never", - analytics_enabled=False, -) +with gr.Blocks(analytics_enabled=False) as extras_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + with gr.Group(): + image = gr.Image(label="Source", source="upload", interactive=True, type="pil") + gfpgan_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN strength", value=1, interactive=have_gfpgan) + realesrgan_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Real-ESRGAN upscaling", value=2, interactive=have_realesrgan) + realesrgan_model = gr.Radio(label='Real-ESRGAN model', choices=[x.name for x in realesrgan_models], value=realesrgan_models[0].name, type="index", interactive=have_realesrgan) + + submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') + + with gr.Column(variant='panel'): + result_image = gr.Image(label="Result") + html_info_x = gr.HTML() + html_info = gr.HTML() + + extras_args = dict( + fn=wrap_gradio_gpu_call(run_extras), + inputs=[ + image, + gfpgan_strength, + realesrgan_resize, + realesrgan_model, + ], + outputs=[ + result_image, + html_info_x, + html_info, + ] + ) + + submit.click(**extras_args) + + send_to_extras.click( + fn=send_gradio_gallery_to_image, + inputs=[txt2img_gallery], + outputs=[image], + ) + + img2img_send_to_extras.click( + fn=send_gradio_gallery_to_image, + inputs=[img2img_gallery], + outputs=[image], + ) + + def run_pnginfo(image):