Merge pull request #135 from rewbs/img2img2-color-correction

Add color correction to img2img loopback to avoid a progressive skew to magenta. Based on codedealer's PR to hlky's repo here: https://github.com/sd-webui/stable-diffusion-webui/pull/698/files.
This commit is contained in:
AUTOMATIC1111 2022-09-08 09:45:55 +03:00 committed by GitHub
commit 9ddaf8269e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 1 deletions

View file

@ -1,4 +1,6 @@
import math import math
import cv2
import numpy as np
from PIL import Image, ImageOps, ImageChops from PIL import Image, ImageOps, ImageChops
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
@ -59,8 +61,19 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index
state.job_count = n_iter state.job_count = n_iter
do_color_correction = False
try:
from skimage import exposure
do_color_correction = True
except:
print("Install scikit-image to perform color correction on loopback")
for i in range(n_iter): for i in range(n_iter):
if do_color_correction and i == 0:
correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB)
p.n_iter = 1 p.n_iter = 1
p.batch_size = 1 p.batch_size = 1
p.do_not_save_grid = True p.do_not_save_grid = True
@ -71,8 +84,20 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index
if initial_seed is None: if initial_seed is None:
initial_seed = processed.seed initial_seed = processed.seed
initial_info = processed.info initial_info = processed.info
init_img = processed.images[0]
p.init_images = [processed.images[0]] if do_color_correction and correction_target is not None:
init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
cv2.cvtColor(
np.asarray(init_img),
cv2.COLOR_RGB2LAB
),
correction_target,
channel_axis=2
), cv2.COLOR_LAB2RGB).astype("uint8"))
p.init_images = [init_img]
p.seed = processed.seed + 1 p.seed = processed.seed + 1
p.denoising_strength = max(p.denoising_strength * 0.95, 0.1) p.denoising_strength = max(p.denoising_strength * 0.95, 0.1)
history.append(processed.images[0]) history.append(processed.images[0])

View file

@ -10,5 +10,6 @@ omegaconf
pytorch_lightning pytorch_lightning
diffusers diffusers
invisible-watermark invisible-watermark
scikit-image
git+https://github.com/crowsonkb/k-diffusion.git git+https://github.com/crowsonkb/k-diffusion.git
git+https://github.com/TencentARC/GFPGAN.git git+https://github.com/TencentARC/GFPGAN.git

View file

@ -8,3 +8,4 @@ torch
transformers==4.19.2 transformers==4.19.2
omegaconf==2.1.1 omegaconf==2.1.1
pytorch_lightning==1.7.2 pytorch_lightning==1.7.2
scikit-image==0.19.2