extras: Rework image cache

Bit of a refactor to the image cache to make it easier to extend.
Also takes into account the entire image instead of just a cropped portion.
This commit is contained in:
Chris OBryan 2022-10-28 14:30:04 -05:00
parent 26d0819384
commit bde4731f1d

View file

@ -7,7 +7,7 @@ from PIL import Image
import torch import torch
import tqdm import tqdm
from typing import Callable, List, Tuple from typing import Callable, Dict, List, Tuple
from functools import partial from functools import partial
from dataclasses import dataclass from dataclasses import dataclass
@ -21,7 +21,18 @@ import piexif.helper
import gradio as gr import gradio as gr
cached_images = {} @dataclass(frozen=True)
class CacheKey:
image_hash: int
info_hash: int
args_hash: int
@dataclass
class CacheEntry:
image: Image.Image
info: str
cached_images: Dict[CacheKey, CacheEntry] = {}
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool ): def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool ):
@ -84,22 +95,13 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10)) upscaler = shared.sd_upscalers[scaler_index]
pixels = tuple(np.array(small).flatten().tolist()) res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight, if mode == 1 and crop:
resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop) + pixels cropped = Image.new("RGB", (resize_w, resize_h))
cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2))
c = cached_images.get(key) res = cropped
if c is None: return res
upscaler = shared.sd_upscalers[scaler_index]
c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
if mode == 1 and crop:
cropped = Image.new("RGB", (resize_w, resize_h))
cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2))
c = cropped
cached_images[key] = c
return c
def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]: def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
# Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text
@ -118,8 +120,18 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
def run_upscalers_blend( params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: def run_upscalers_blend( params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
blended_result: Image.Image = None blended_result: Image.Image = None
for upscaler in params: for upscaler in params:
res = upscale(image, upscaler.upscaler_idx, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop) upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" cache_key = CacheKey( image_hash = hash(np.array(image.getdata()).tobytes()),
info_hash = hash(info),
args_hash = hash(upscale_args) )
cached_entry = cached_images.get(cache_key)
if cached_entry is None:
res = upscale(image, *upscale_args)
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n"
cached_images[cache_key] = CacheEntry(image=res, info=info)
else:
res, info = cached_entry.image, cached_entry.info
if blended_result is None: if blended_result is None:
blended_result = res blended_result = res
else: else: