Merge pull request #135 from rewbs/img2img2-color-correction
Add color correction to img2img loopback to avoid a progressive skew to magenta. Based on codedealer's PR to hlky's repo here: https://github.com/sd-webui/stable-diffusion-webui/pull/698/files.
This commit is contained in:
commit
9ddaf8269e
3 changed files with 28 additions and 1 deletions
|
@ -1,4 +1,6 @@
|
||||||
import math
|
import math
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
from PIL import Image, ImageOps, ImageChops
|
from PIL import Image, ImageOps, ImageChops
|
||||||
|
|
||||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||||
|
@ -59,8 +61,19 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index
|
||||||
|
|
||||||
state.job_count = n_iter
|
state.job_count = n_iter
|
||||||
|
|
||||||
|
do_color_correction = False
|
||||||
|
try:
|
||||||
|
from skimage import exposure
|
||||||
|
do_color_correction = True
|
||||||
|
except:
|
||||||
|
print("Install scikit-image to perform color correction on loopback")
|
||||||
|
|
||||||
|
|
||||||
for i in range(n_iter):
|
for i in range(n_iter):
|
||||||
|
|
||||||
|
if do_color_correction and i == 0:
|
||||||
|
correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB)
|
||||||
|
|
||||||
p.n_iter = 1
|
p.n_iter = 1
|
||||||
p.batch_size = 1
|
p.batch_size = 1
|
||||||
p.do_not_save_grid = True
|
p.do_not_save_grid = True
|
||||||
|
@ -71,8 +84,20 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index
|
||||||
if initial_seed is None:
|
if initial_seed is None:
|
||||||
initial_seed = processed.seed
|
initial_seed = processed.seed
|
||||||
initial_info = processed.info
|
initial_info = processed.info
|
||||||
|
|
||||||
|
init_img = processed.images[0]
|
||||||
|
|
||||||
p.init_images = [processed.images[0]]
|
if do_color_correction and correction_target is not None:
|
||||||
|
init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
|
||||||
|
cv2.cvtColor(
|
||||||
|
np.asarray(init_img),
|
||||||
|
cv2.COLOR_RGB2LAB
|
||||||
|
),
|
||||||
|
correction_target,
|
||||||
|
channel_axis=2
|
||||||
|
), cv2.COLOR_LAB2RGB).astype("uint8"))
|
||||||
|
|
||||||
|
p.init_images = [init_img]
|
||||||
p.seed = processed.seed + 1
|
p.seed = processed.seed + 1
|
||||||
p.denoising_strength = max(p.denoising_strength * 0.95, 0.1)
|
p.denoising_strength = max(p.denoising_strength * 0.95, 0.1)
|
||||||
history.append(processed.images[0])
|
history.append(processed.images[0])
|
||||||
|
|
|
@ -10,5 +10,6 @@ omegaconf
|
||||||
pytorch_lightning
|
pytorch_lightning
|
||||||
diffusers
|
diffusers
|
||||||
invisible-watermark
|
invisible-watermark
|
||||||
|
scikit-image
|
||||||
git+https://github.com/crowsonkb/k-diffusion.git
|
git+https://github.com/crowsonkb/k-diffusion.git
|
||||||
git+https://github.com/TencentARC/GFPGAN.git
|
git+https://github.com/TencentARC/GFPGAN.git
|
||||||
|
|
|
@ -8,3 +8,4 @@ torch
|
||||||
transformers==4.19.2
|
transformers==4.19.2
|
||||||
omegaconf==2.1.1
|
omegaconf==2.1.1
|
||||||
pytorch_lightning==1.7.2
|
pytorch_lightning==1.7.2
|
||||||
|
scikit-image==0.19.2
|
||||||
|
|
Loading…
Reference in a new issue