From 700c47a67492b1502265e5077c5be9ed70f8eb2a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 7 Sep 2022 17:00:51 +0300 Subject: [PATCH] big improvements to inpainting and outpainting --- modules/processing.py | 19 +++++++++++-------- modules/sd_samplers.py | 5 ++++- scripts/poor_mans_outpainting.py | 25 ++++++++++++++++++++----- 3 files changed, 35 insertions(+), 14 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 49474b73..73b060f4 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -52,7 +52,7 @@ class StableDiffusionProcessing: self.overlay_images = overlay_images self.paste_to = None - def init(self): + def init(self, seed): pass def sample(self, x, conditioning, unconditional_conditioning): @@ -155,7 +155,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope) with torch.no_grad(), precision_scope("cuda"), ema_scope(): - p.init() + p.init(seed=all_seeds[0]) if state.job_count == -1: state.job_count = p.n_iter @@ -240,7 +240,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): sampler = None - def init(self): + def init(self, seed): self.sampler = samplers[self.sampler_index].constructor(self.sd_model) def sample(self, x, conditioning, unconditional_conditioning): @@ -320,7 +320,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.mask = None self.nmask = None - def init(self): + def init(self, seed): self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model) crop_region = None @@ -347,11 +347,13 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): else: self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height) np_mask = np.array(self.image_mask) - np_mask = 255 - np.clip((255 - np_mask.astype(np.float)) * 2, 0, 255).astype(np.uint8) + np_mask = np.clip((np_mask.astype(np.float)) * 2, 0, 255).astype(np.uint8) self.mask_for_overlay = Image.fromarray(np_mask) self.overlay_images = [] + latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask + imgs = [] for img in self.init_images: image = img.convert("RGB") @@ -361,7 +363,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.image_mask is not None: if self.inpainting_fill != 1: - image = fill(image, self.mask_for_overlay) + image = fill(image, latent_mask) image_masked = Image.new('RGBa', (image.width, image.height)) image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) @@ -394,17 +396,18 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image)) if self.image_mask is not None: - init_mask = self.latent_mask if self.latent_mask is not None else self.image_mask + init_mask = latent_mask latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) latmask = np.moveaxis(np.array(latmask, dtype=np.float64), 2, 0) / 255 latmask = latmask[0] + latmask = np.around(latmask) latmask = np.tile(latmask[None], (4, 1, 1)) self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype) self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype) if self.inpainting_fill == 2: - self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], [self.seed + x + 1 for x in range(self.init_latent.shape[0])]) * self.nmask + self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], [seed + x + 1 for x in range(self.init_latent.shape[0])]) * self.nmask elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index e8bc5be2..140b5dea 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -58,7 +58,10 @@ def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs): img_orig = sampler_wrapper.sampler.model.q_sample(sampler_wrapper.init_latent, ts) x_dec = img_orig * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec - store_latent(x_dec) + store_latent(sampler_wrapper.init_latent * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec) + + else: + store_latent(x_dec) return sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs) diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index 08171877..a549fde3 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -21,7 +21,7 @@ class Script(scripts.Script): if not is_img2img: return None - pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=128, step=8) + pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, visible=False) inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", visible=False) direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down']) @@ -32,7 +32,7 @@ class Script(scripts.Script): initial_seed = None initial_info = None - p.mask_blur = mask_blur + p.mask_blur = mask_blur * 2 p.inpainting_fill = inpainting_fill p.inpaint_full_res = False @@ -67,13 +67,18 @@ class Script(scripts.Script): latent_mask = Image.new("L", (img.width, img.height), "white") latent_draw = ImageDraw.Draw(latent_mask) - latent_draw.rectangle((left + left//2, up + up//2, mask.width - right - right//2, mask.height - down - down//2), fill="black") + latent_draw.rectangle(( + left + (mask_blur//2 if left > 0 else 0), + up + (mask_blur//2 if up > 0 else 0), + mask.width - right - (mask_blur//2 if right > 0 else 0), + mask.height - down - (mask_blur//2 if down > 0 else 0) + ), fill="black") processing.torch_gc() grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=pixels) grid_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels) - grid_latent_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels) + grid_latent_mask = images.split_grid(latent_mask, tile_w=p.width, tile_h=p.height, overlap=pixels) p.n_iter = 1 p.batch_size = 1 @@ -85,8 +90,13 @@ class Script(scripts.Script): work_latent_mask = [] work_results = [] - for (_, _, row), (_, _, row_mask), (_, _, row_latent_mask) in zip(grid.tiles, grid_mask.tiles, grid_latent_mask.tiles): + for (y, h, row), (_, _, row_mask), (_, _, row_latent_mask) in zip(grid.tiles, grid_mask.tiles, grid_latent_mask.tiles): for tiledata, tiledata_mask, tiledata_latent_mask in zip(row, row_mask, row_latent_mask): + x, w = tiledata[0:2] + + if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down: + continue + work.append(tiledata[2]) work_mask.append(tiledata_mask[2]) work_latent_mask.append(tiledata_latent_mask[2]) @@ -115,6 +125,11 @@ class Script(scripts.Script): image_index = 0 for y, h, row in grid.tiles: for tiledata in row: + x, w = tiledata[0:2] + + if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down: + continue + tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height)) image_index += 1