From f1858744189ad54bb464c97b9c735275795a6f53 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 11 Sep 2022 11:31:16 +0300 Subject: [PATCH] [Feature Request] Save defaults for extras & keep image parameters after using extras #251 --- modules/extras.py | 87 +++++++++++++++++++++++++++++++++++++++++++++++ modules/images.py | 4 +-- modules/ui.py | 1 + webui.py | 85 +++------------------------------------------ 4 files changed, 94 insertions(+), 83 deletions(-) create mode 100644 modules/extras.py diff --git a/modules/extras.py b/modules/extras.py new file mode 100644 index 00000000..6aeae6cb --- /dev/null +++ b/modules/extras.py @@ -0,0 +1,87 @@ +import numpy as np +from PIL import Image + +from modules import processing, shared, images +from modules.shared import opts +import modules.gfpgan_model +from modules.ui import plaintext_to_html +import modules.codeformer_model + +cached_images = {} + + +def run_extras(image, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility): + processing.torch_gc() + + image = image.convert("RGB") + info = "" + + outpath = opts.outdir_samples or opts.outdir_extras_samples + + if gfpgan_visibility > 0: + restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) + res = Image.fromarray(restored_img) + + if gfpgan_visibility < 1.0: + res = Image.blend(image, res, gfpgan_visibility) + + info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n" + image = res + + if codeformer_visibility > 0: + restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) + res = Image.fromarray(restored_img) + + if codeformer_visibility < 1.0: + res = Image.blend(image, res, codeformer_visibility) + + info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility)}\n" + image = res + + if upscaling_resize != 1.0: + def upscale(image, scaler_index, resize): + small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10)) + pixels = tuple(np.array(small).flatten().tolist()) + key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels + + c = cached_images.get(key) + if c is None: + upscaler = shared.sd_upscalers[scaler_index] + c = upscaler.upscale(image, image.width * resize, image.height * resize) + cached_images[key] = c + + return c + + info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n" + res = upscale(image, extras_upscaler_1, upscaling_resize) + + if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: + res2 = upscale(image, extras_upscaler_2, upscaling_resize) + info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n" + res = Image.blend(res, res2, extras_upscaler_2_visibility) + + image = res + + while len(cached_images) > 2: + del cached_images[next(iter(cached_images.keys()))] + + images.save_image(image, outpath, "", None, info=info, extension=opts.samples_format, short_filename=True, no_prompt=True, pnginfo_section_name="extras") + + return image, plaintext_to_html(info), '' + + +def run_pnginfo(image): + info = '' + for key, text in image.info.items(): + info += f""" +
+

{plaintext_to_html(str(key))}

+

{plaintext_to_html(str(text))}

+
+""".strip()+"\n" + + if len(info) == 0: + message = "Nothing found in the image." + info = f"

{message}

" + + return '', '', info diff --git a/modules/images.py b/modules/images.py index c1a58013..26c399b6 100644 --- a/modules/images.py +++ b/modules/images.py @@ -243,7 +243,7 @@ def sanitize_filename_part(text): return text.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128] -def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False): +def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, pnginfo_section_name='parameters'): # would be better to add this as an argument in future, but will do for now is_a_grid = basename != "" @@ -256,7 +256,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i if extension == 'png' and opts.enable_pnginfo and info is not None: pnginfo = PngImagePlugin.PngInfo() - pnginfo.add_text("parameters", info) + pnginfo.add_text(pnginfo_section_name, info) else: pnginfo = None diff --git a/modules/ui.py b/modules/ui.py index d8c7e465..032c20ff 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -797,6 +797,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): visit(txt2img_interface, loadsave, "txt2img") visit(img2img_interface, loadsave, "img2img") + visit(extras_interface, loadsave, "extras") if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)): with open(ui_config_file, "w", encoding="utf8") as file: diff --git a/webui.py b/webui.py index 6976848a..70c68338 100644 --- a/webui.py +++ b/webui.py @@ -4,9 +4,7 @@ import threading from modules.paths import script_path import torch -import numpy as np from omegaconf import OmegaConf -from PIL import Image import signal @@ -15,16 +13,14 @@ from ldm.util import instantiate_from_config from modules.shared import opts, cmd_opts, state import modules.shared as shared import modules.ui -from modules.ui import plaintext_to_html import modules.scripts -import modules.processing as processing import modules.sd_hijack import modules.codeformer_model import modules.gfpgan_model import modules.face_restoration import modules.realesrgan_model as realesrgan import modules.esrgan_model as esrgan -import modules.images as images +import modules.extras import modules.lowvram import modules.txt2img import modules.img2img @@ -56,80 +52,6 @@ def load_model_from_config(config, ckpt, verbose=False): model.eval() return model -cached_images = {} - - -def run_extras(image, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility): - processing.torch_gc() - - image = image.convert("RGB") - - outpath = opts.outdir_samples or opts.outdir_extras_samples - - if gfpgan_visibility > 0: - restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) - res = Image.fromarray(restored_img) - - if gfpgan_visibility < 1.0: - res = Image.blend(image, res, gfpgan_visibility) - - image = res - - if codeformer_visibility > 0: - restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) - res = Image.fromarray(restored_img) - - if codeformer_visibility < 1.0: - res = Image.blend(image, res, codeformer_visibility) - - image = res - - if upscaling_resize != 1.0: - def upscale(image, scaler_index, resize): - small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10)) - pixels = tuple(np.array(small).flatten().tolist()) - key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels - - c = cached_images.get(key) - if c is None: - upscaler = shared.sd_upscalers[scaler_index] - c = upscaler.upscale(image, image.width * resize, image.height * resize) - cached_images[key] = c - - return c - - res = upscale(image, extras_upscaler_1, upscaling_resize) - - if extras_upscaler_2 != 0 and extras_upscaler_2_visibility>0: - res2 = upscale(image, extras_upscaler_2, upscaling_resize) - res = Image.blend(res, res2, extras_upscaler_2_visibility) - - image = res - - while len(cached_images) > 2: - del cached_images[next(iter(cached_images.keys()))] - - images.save_image(image, outpath, "", None, '', opts.samples_format, short_filename=True, no_prompt=True) - - return image, '', '' - - -def run_pnginfo(image): - info = '' - for key, text in image.info.items(): - info += f""" -
-

{plaintext_to_html(str(key))}

-

{plaintext_to_html(str(text))}

-
-""".strip()+"\n" - - if len(info) == 0: - message = "Nothing found in the image." - info = f"

{message}

" - - return '', '', info - queue_lock = threading.Lock() @@ -153,6 +75,7 @@ def wrap_gradio_gpu_call(func): return modules.ui.wrap_gradio_call(f) + modules.scripts.load_scripts(os.path.join(script_path, "scripts")) try: @@ -187,8 +110,8 @@ def webui(): demo = modules.ui.create_ui( txt2img=wrap_gradio_gpu_call(modules.txt2img.txt2img), img2img=wrap_gradio_gpu_call(modules.img2img.img2img), - run_extras=wrap_gradio_gpu_call(run_extras), - run_pnginfo=run_pnginfo + run_extras=wrap_gradio_gpu_call(modules.extras.run_extras), + run_pnginfo=modules.extras.run_pnginfo ) demo.launch(share=cmd_opts.share, server_name="0.0.0.0" if cmd_opts.listen else None, server_port=cmd_opts.port)