big improvements to inpainting and outpainting
This commit is contained in:
parent
2cbda50cdd
commit
700c47a674
3 changed files with 35 additions and 14 deletions
|
@ -52,7 +52,7 @@ class StableDiffusionProcessing:
|
||||||
self.overlay_images = overlay_images
|
self.overlay_images = overlay_images
|
||||||
self.paste_to = None
|
self.paste_to = None
|
||||||
|
|
||||||
def init(self):
|
def init(self, seed):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def sample(self, x, conditioning, unconditional_conditioning):
|
def sample(self, x, conditioning, unconditional_conditioning):
|
||||||
|
@ -155,7 +155,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
|
precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
|
||||||
ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
|
ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
|
||||||
with torch.no_grad(), precision_scope("cuda"), ema_scope():
|
with torch.no_grad(), precision_scope("cuda"), ema_scope():
|
||||||
p.init()
|
p.init(seed=all_seeds[0])
|
||||||
|
|
||||||
if state.job_count == -1:
|
if state.job_count == -1:
|
||||||
state.job_count = p.n_iter
|
state.job_count = p.n_iter
|
||||||
|
@ -240,7 +240,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
sampler = None
|
sampler = None
|
||||||
|
|
||||||
def init(self):
|
def init(self, seed):
|
||||||
self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
|
self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
|
||||||
|
|
||||||
def sample(self, x, conditioning, unconditional_conditioning):
|
def sample(self, x, conditioning, unconditional_conditioning):
|
||||||
|
@ -320,7 +320,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
self.mask = None
|
self.mask = None
|
||||||
self.nmask = None
|
self.nmask = None
|
||||||
|
|
||||||
def init(self):
|
def init(self, seed):
|
||||||
self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model)
|
self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model)
|
||||||
crop_region = None
|
crop_region = None
|
||||||
|
|
||||||
|
@ -347,11 +347,13 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
else:
|
else:
|
||||||
self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
|
self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
|
||||||
np_mask = np.array(self.image_mask)
|
np_mask = np.array(self.image_mask)
|
||||||
np_mask = 255 - np.clip((255 - np_mask.astype(np.float)) * 2, 0, 255).astype(np.uint8)
|
np_mask = np.clip((np_mask.astype(np.float)) * 2, 0, 255).astype(np.uint8)
|
||||||
self.mask_for_overlay = Image.fromarray(np_mask)
|
self.mask_for_overlay = Image.fromarray(np_mask)
|
||||||
|
|
||||||
self.overlay_images = []
|
self.overlay_images = []
|
||||||
|
|
||||||
|
latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask
|
||||||
|
|
||||||
imgs = []
|
imgs = []
|
||||||
for img in self.init_images:
|
for img in self.init_images:
|
||||||
image = img.convert("RGB")
|
image = img.convert("RGB")
|
||||||
|
@ -361,7 +363,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
|
|
||||||
if self.image_mask is not None:
|
if self.image_mask is not None:
|
||||||
if self.inpainting_fill != 1:
|
if self.inpainting_fill != 1:
|
||||||
image = fill(image, self.mask_for_overlay)
|
image = fill(image, latent_mask)
|
||||||
|
|
||||||
image_masked = Image.new('RGBa', (image.width, image.height))
|
image_masked = Image.new('RGBa', (image.width, image.height))
|
||||||
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
|
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
|
||||||
|
@ -394,17 +396,18 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
|
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
|
||||||
|
|
||||||
if self.image_mask is not None:
|
if self.image_mask is not None:
|
||||||
init_mask = self.latent_mask if self.latent_mask is not None else self.image_mask
|
init_mask = latent_mask
|
||||||
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
|
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
|
||||||
latmask = np.moveaxis(np.array(latmask, dtype=np.float64), 2, 0) / 255
|
latmask = np.moveaxis(np.array(latmask, dtype=np.float64), 2, 0) / 255
|
||||||
latmask = latmask[0]
|
latmask = latmask[0]
|
||||||
|
latmask = np.around(latmask)
|
||||||
latmask = np.tile(latmask[None], (4, 1, 1))
|
latmask = np.tile(latmask[None], (4, 1, 1))
|
||||||
|
|
||||||
self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
|
self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
|
||||||
self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)
|
self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)
|
||||||
|
|
||||||
if self.inpainting_fill == 2:
|
if self.inpainting_fill == 2:
|
||||||
self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], [self.seed + x + 1 for x in range(self.init_latent.shape[0])]) * self.nmask
|
self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], [seed + x + 1 for x in range(self.init_latent.shape[0])]) * self.nmask
|
||||||
elif self.inpainting_fill == 3:
|
elif self.inpainting_fill == 3:
|
||||||
self.init_latent = self.init_latent * self.mask
|
self.init_latent = self.init_latent * self.mask
|
||||||
|
|
||||||
|
|
|
@ -58,6 +58,9 @@ def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs):
|
||||||
img_orig = sampler_wrapper.sampler.model.q_sample(sampler_wrapper.init_latent, ts)
|
img_orig = sampler_wrapper.sampler.model.q_sample(sampler_wrapper.init_latent, ts)
|
||||||
x_dec = img_orig * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec
|
x_dec = img_orig * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec
|
||||||
|
|
||||||
|
store_latent(sampler_wrapper.init_latent * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec)
|
||||||
|
|
||||||
|
else:
|
||||||
store_latent(x_dec)
|
store_latent(x_dec)
|
||||||
|
|
||||||
return sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs)
|
return sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs)
|
||||||
|
|
|
@ -21,7 +21,7 @@ class Script(scripts.Script):
|
||||||
if not is_img2img:
|
if not is_img2img:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=128, step=8)
|
pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128)
|
||||||
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, visible=False)
|
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, visible=False)
|
||||||
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", visible=False)
|
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", visible=False)
|
||||||
direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'])
|
direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'])
|
||||||
|
@ -32,7 +32,7 @@ class Script(scripts.Script):
|
||||||
initial_seed = None
|
initial_seed = None
|
||||||
initial_info = None
|
initial_info = None
|
||||||
|
|
||||||
p.mask_blur = mask_blur
|
p.mask_blur = mask_blur * 2
|
||||||
p.inpainting_fill = inpainting_fill
|
p.inpainting_fill = inpainting_fill
|
||||||
p.inpaint_full_res = False
|
p.inpaint_full_res = False
|
||||||
|
|
||||||
|
@ -67,13 +67,18 @@ class Script(scripts.Script):
|
||||||
|
|
||||||
latent_mask = Image.new("L", (img.width, img.height), "white")
|
latent_mask = Image.new("L", (img.width, img.height), "white")
|
||||||
latent_draw = ImageDraw.Draw(latent_mask)
|
latent_draw = ImageDraw.Draw(latent_mask)
|
||||||
latent_draw.rectangle((left + left//2, up + up//2, mask.width - right - right//2, mask.height - down - down//2), fill="black")
|
latent_draw.rectangle((
|
||||||
|
left + (mask_blur//2 if left > 0 else 0),
|
||||||
|
up + (mask_blur//2 if up > 0 else 0),
|
||||||
|
mask.width - right - (mask_blur//2 if right > 0 else 0),
|
||||||
|
mask.height - down - (mask_blur//2 if down > 0 else 0)
|
||||||
|
), fill="black")
|
||||||
|
|
||||||
processing.torch_gc()
|
processing.torch_gc()
|
||||||
|
|
||||||
grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=pixels)
|
grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=pixels)
|
||||||
grid_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels)
|
grid_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels)
|
||||||
grid_latent_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels)
|
grid_latent_mask = images.split_grid(latent_mask, tile_w=p.width, tile_h=p.height, overlap=pixels)
|
||||||
|
|
||||||
p.n_iter = 1
|
p.n_iter = 1
|
||||||
p.batch_size = 1
|
p.batch_size = 1
|
||||||
|
@ -85,8 +90,13 @@ class Script(scripts.Script):
|
||||||
work_latent_mask = []
|
work_latent_mask = []
|
||||||
work_results = []
|
work_results = []
|
||||||
|
|
||||||
for (_, _, row), (_, _, row_mask), (_, _, row_latent_mask) in zip(grid.tiles, grid_mask.tiles, grid_latent_mask.tiles):
|
for (y, h, row), (_, _, row_mask), (_, _, row_latent_mask) in zip(grid.tiles, grid_mask.tiles, grid_latent_mask.tiles):
|
||||||
for tiledata, tiledata_mask, tiledata_latent_mask in zip(row, row_mask, row_latent_mask):
|
for tiledata, tiledata_mask, tiledata_latent_mask in zip(row, row_mask, row_latent_mask):
|
||||||
|
x, w = tiledata[0:2]
|
||||||
|
|
||||||
|
if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down:
|
||||||
|
continue
|
||||||
|
|
||||||
work.append(tiledata[2])
|
work.append(tiledata[2])
|
||||||
work_mask.append(tiledata_mask[2])
|
work_mask.append(tiledata_mask[2])
|
||||||
work_latent_mask.append(tiledata_latent_mask[2])
|
work_latent_mask.append(tiledata_latent_mask[2])
|
||||||
|
@ -115,6 +125,11 @@ class Script(scripts.Script):
|
||||||
image_index = 0
|
image_index = 0
|
||||||
for y, h, row in grid.tiles:
|
for y, h, row in grid.tiles:
|
||||||
for tiledata in row:
|
for tiledata in row:
|
||||||
|
x, w = tiledata[0:2]
|
||||||
|
|
||||||
|
if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down:
|
||||||
|
continue
|
||||||
|
|
||||||
tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height))
|
tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height))
|
||||||
image_index += 1
|
image_index += 1
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue