added first version of inpainting
fixed flag option
This commit is contained in:
parent
587db9c420
commit
54f74d4472
1 changed files with 72 additions and 10 deletions
82
webui.py
82
webui.py
|
@ -9,7 +9,7 @@ import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin, ImageFilter, ImageOps
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import random
|
import random
|
||||||
|
@ -158,6 +158,7 @@ class Options:
|
||||||
"samples_save": OptionInfo(True, "Save indiviual samples"),
|
"samples_save": OptionInfo(True, "Save indiviual samples"),
|
||||||
"samples_format": OptionInfo('png', 'File format for indiviual samples'),
|
"samples_format": OptionInfo('png', 'File format for indiviual samples'),
|
||||||
"grid_save": OptionInfo(True, "Save image grids"),
|
"grid_save": OptionInfo(True, "Save image grids"),
|
||||||
|
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
||||||
"grid_format": OptionInfo('png', 'File format for grids'),
|
"grid_format": OptionInfo('png', 'File format for grids'),
|
||||||
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
|
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
|
||||||
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
|
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
|
||||||
|
@ -957,6 +958,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
|
|
||||||
unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
|
unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
|
||||||
if (p.prompt_matrix or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
|
if (p.prompt_matrix or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
|
||||||
|
return_grid = opts.return_grid
|
||||||
|
|
||||||
if p.prompt_matrix:
|
if p.prompt_matrix:
|
||||||
grid = image_grid(output_images, p.batch_size, rows=1 << ((len(prompt_matrix_parts)-1)//2))
|
grid = image_grid(output_images, p.batch_size, rows=1 << ((len(prompt_matrix_parts)-1)//2))
|
||||||
|
|
||||||
|
@ -967,10 +970,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
print("Error creating prompt_matrix text:", file=sys.stderr)
|
print("Error creating prompt_matrix text:", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
output_images.insert(0, grid)
|
return_grid = True
|
||||||
else:
|
else:
|
||||||
grid = image_grid(output_images, p.batch_size)
|
grid = image_grid(output_images, p.batch_size)
|
||||||
|
|
||||||
|
if return_grid:
|
||||||
|
output_images.insert(0, grid)
|
||||||
|
|
||||||
save_image(grid, p.outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
|
save_image(grid, p.outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
|
||||||
grid_count += 1
|
grid_count += 1
|
||||||
|
|
||||||
|
@ -1042,7 +1048,7 @@ class Flagging(gr.FlaggingCallback):
|
||||||
os.makedirs("log/images", exist_ok=True)
|
os.makedirs("log/images", exist_ok=True)
|
||||||
|
|
||||||
# those must match the "txt2img" function
|
# those must match the "txt2img" function
|
||||||
prompt, ddim_steps, sampler_name, use_gfpgan, prompt_matrix, ddim_eta, n_iter, n_samples, cfg_scale, request_seed, height, width, code, images, seed, comment = flag_data
|
prompt, steps, sampler_index, use_gfpgan, prompt_matrix, n_iter, batch_size, cfg_scale, seed, height, width, code, images, seed, comment = flag_data
|
||||||
|
|
||||||
filenames = []
|
filenames = []
|
||||||
|
|
||||||
|
@ -1067,7 +1073,7 @@ class Flagging(gr.FlaggingCallback):
|
||||||
|
|
||||||
filenames.append(filename)
|
filenames.append(filename)
|
||||||
|
|
||||||
writer.writerow([prompt, seed, width, height, cfg_scale, ddim_steps, filenames[0]])
|
writer.writerow([prompt, seed, width, height, cfg_scale, steps, filenames[0]])
|
||||||
|
|
||||||
print("Logged:", filenames[0])
|
print("Logged:", filenames[0])
|
||||||
|
|
||||||
|
@ -1097,27 +1103,64 @@ txt2img_interface = gr.Interface(
|
||||||
flagging_callback=Flagging()
|
flagging_callback=Flagging()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def fill(image, mask):
|
||||||
|
image_mod = Image.new('RGBA', (image.width, image.height))
|
||||||
|
|
||||||
|
image_masked = Image.new('RGBa', (image.width, image.height))
|
||||||
|
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
|
||||||
|
|
||||||
|
image_masked = image_masked.convert('RGBa')
|
||||||
|
|
||||||
|
for radius, repeats in [(64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
|
||||||
|
blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
|
||||||
|
for _ in range(repeats):
|
||||||
|
image_mod.alpha_composite(blurred)
|
||||||
|
|
||||||
|
return image_mod.convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
sampler = None
|
sampler = None
|
||||||
|
|
||||||
def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, **kwargs):
|
def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
self.init_images = init_images
|
self.init_images = init_images
|
||||||
self.resize_mode: int = resize_mode
|
self.resize_mode: int = resize_mode
|
||||||
self.denoising_strength: float = denoising_strength
|
self.denoising_strength: float = denoising_strength
|
||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
|
self.original_mask = mask
|
||||||
|
self.mask_blur = mask_blur
|
||||||
|
self.mask = None
|
||||||
|
self.nmask = None
|
||||||
|
|
||||||
def init(self):
|
def init(self):
|
||||||
self.sampler = samplers_for_img2img[self.sampler_index].constructor()
|
self.sampler = samplers_for_img2img[self.sampler_index].constructor()
|
||||||
|
|
||||||
|
if self.original_mask is not None:
|
||||||
|
if self.mask_blur > 0:
|
||||||
|
self.original_mask = self.original_mask.filter(ImageFilter.GaussianBlur(self.mask_blur)).convert('L')
|
||||||
|
|
||||||
|
latmask = self.original_mask.convert('RGB').resize((64, 64))
|
||||||
|
latmask = np.moveaxis(np.array(latmask, dtype=np.float), 2, 0) / 255
|
||||||
|
latmask = latmask[0]
|
||||||
|
latmask = np.tile(latmask[None], (4, 1, 1))
|
||||||
|
|
||||||
|
self.mask = torch.asarray(1.0 - latmask).to(device).type(sd_model.dtype)
|
||||||
|
self.nmask = torch.asarray(latmask).to(device).type(sd_model.dtype)
|
||||||
|
|
||||||
|
|
||||||
imgs = []
|
imgs = []
|
||||||
for img in self.init_images:
|
for img in self.init_images:
|
||||||
image = img.convert("RGB")
|
image = img.convert("RGB")
|
||||||
image = resize_image(self.resize_mode, image, self.width, self.height)
|
image = resize_image(self.resize_mode, image, self.width, self.height)
|
||||||
|
|
||||||
|
if self.original_mask is not None
|
||||||
|
image = fill(image, self.original_mask)
|
||||||
|
|
||||||
image = np.array(image).astype(np.float32) / 255.0
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
image = np.moveaxis(image, 2, 0)
|
image = np.moveaxis(image, 2, 0)
|
||||||
|
|
||||||
imgs.append(image)
|
imgs.append(image)
|
||||||
|
|
||||||
if len(imgs) == 1:
|
if len(imgs) == 1:
|
||||||
|
@ -1139,16 +1182,33 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
|
|
||||||
sigmas = self.sampler.model_wrap.get_sigmas(self.steps)
|
sigmas = self.sampler.model_wrap.get_sigmas(self.steps)
|
||||||
noise = x * sigmas[self.steps - t_enc - 1]
|
noise = x * sigmas[self.steps - t_enc - 1]
|
||||||
|
|
||||||
xi = self.init_latent + noise
|
xi = self.init_latent + noise
|
||||||
sigma_sched = sigmas[self.steps - t_enc - 1:]
|
sigma_sched = sigmas[self.steps - t_enc - 1:]
|
||||||
samples_ddim = self.sampler.func(self.sampler.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': self.cfg_scale}, disable=False)
|
|
||||||
|
#if self.mask is not None:
|
||||||
|
# xi = xi * self.mask + noise * self.nmask
|
||||||
|
|
||||||
|
def mask_cb(v):
|
||||||
|
v["denoised"][:] = v["denoised"][:] * self.nmask + self.init_latent * self.mask
|
||||||
|
|
||||||
|
samples_ddim = self.sampler.func(self.sampler.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': self.cfg_scale}, disable=False, callback=mask_cb if self.mask is not None else None)
|
||||||
|
|
||||||
|
if self.mask is not None:
|
||||||
|
samples_ddim = samples_ddim * self.nmask + self.init_latent * self.mask
|
||||||
|
|
||||||
return samples_ddim
|
return samples_ddim
|
||||||
|
|
||||||
|
|
||||||
def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, sd_upscale: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
|
def img2img(prompt: str, init_img, init_img_with_mask, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, sd_upscale: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
|
||||||
outpath = opts.outdir or "outputs/img2img-samples"
|
outpath = opts.outdir or "outputs/img2img-samples"
|
||||||
|
|
||||||
|
if init_img_with_mask is not None:
|
||||||
|
image = init_img_with_mask['image']
|
||||||
|
mask = init_img_with_mask['mask']
|
||||||
|
else:
|
||||||
|
image = init_img
|
||||||
|
mask = None
|
||||||
|
|
||||||
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||||
|
|
||||||
p = StableDiffusionProcessingImg2Img(
|
p = StableDiffusionProcessingImg2Img(
|
||||||
|
@ -1164,7 +1224,8 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
|
||||||
height=height,
|
height=height,
|
||||||
prompt_matrix=prompt_matrix,
|
prompt_matrix=prompt_matrix,
|
||||||
use_GFPGAN=use_GFPGAN,
|
use_GFPGAN=use_GFPGAN,
|
||||||
init_images=[init_img],
|
init_images=[image],
|
||||||
|
mask=mask,
|
||||||
resize_mode=resize_mode,
|
resize_mode=resize_mode,
|
||||||
denoising_strength=denoising_strength,
|
denoising_strength=denoising_strength,
|
||||||
extra_generation_params={"Denoising Strength": denoising_strength}
|
extra_generation_params={"Denoising Strength": denoising_strength}
|
||||||
|
@ -1262,7 +1323,8 @@ img2img_interface = gr.Interface(
|
||||||
wrap_gradio_call(img2img),
|
wrap_gradio_call(img2img),
|
||||||
inputs=[
|
inputs=[
|
||||||
gr.Textbox(placeholder="A fantasy landscape, trending on artstation.", lines=1),
|
gr.Textbox(placeholder="A fantasy landscape, trending on artstation.", lines=1),
|
||||||
gr.Image(value=sample_img2img, source="upload", interactive=True, type="pil"),
|
gr.Image(label="Image for img2img", source="upload", interactive=True, type="pil"),
|
||||||
|
gr.Image(label="Image for inpainting with mask", source="upload", interactive=True, type="pil", tool="sketch"),
|
||||||
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20),
|
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20),
|
||||||
gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index"),
|
gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index"),
|
||||||
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=have_gfpgan),
|
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=have_gfpgan),
|
||||||
|
|
Loading…
Reference in a new issue