From 7350c712598b748c3cedc2a224887bd839a27d76 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 4 Sep 2022 01:29:43 +0300 Subject: [PATCH] added poor man's inpainting script --- modules/images.py | 15 +++-- modules/img2img.py | 2 +- modules/processing.py | 9 ++- modules/scripts.py | 100 +++++++++++++++++----------- modules/txt2img.py | 2 +- modules/ui.py | 6 +- scripts/poor_mans_outpainting.py | 110 +++++++++++++++++++++++++++++++ 7 files changed, 193 insertions(+), 51 deletions(-) create mode 100644 scripts/poor_mans_outpainting.py diff --git a/modules/images.py b/modules/images.py index b05276c3..4b9667d2 100644 --- a/modules/images.py +++ b/modules/images.py @@ -39,23 +39,26 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64): w = image.width h = image.height - now = tile_w - overlap # non-overlap width - noh = tile_h - overlap + non_overlap_width = tile_w - overlap + non_overlap_height = tile_h - overlap - cols = math.ceil((w - overlap) / now) - rows = math.ceil((h - overlap) / noh) + cols = math.ceil((w - overlap) / non_overlap_width) + rows = math.ceil((h - overlap) / non_overlap_height) + + dx = (w - tile_w) // (cols-1) if cols > 1 else 0 + dy = (h - tile_h) // (rows-1) if rows > 1 else 0 grid = Grid([], tile_w, tile_h, w, h, overlap) for row in range(rows): row_images = [] - y = row * noh + y = row * dy if y + tile_h >= h: y = h - tile_h for col in range(cols): - x = col * now + x = col * dx if x+tile_w >= w: x = w - tile_w diff --git a/modules/img2img.py b/modules/img2img.py index 06de2db3..d5787dd3 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -130,7 +130,7 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index else: - processed = modules.scripts.run(p, *args) + processed = modules.scripts.scripts_img2img.run(p, *args) if processed is None: processed = process_images(p) diff --git a/modules/processing.py b/modules/processing.py index 2830209e..adc5d851 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -271,7 +271,7 @@ def fill(image, mask): image_masked = image_masked.convert('RGBa') - for radius, repeats in [(64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]: + for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]: blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA') for _ in range(repeats): image_mod.alpha_composite(blurred) @@ -290,6 +290,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.denoising_strength: float = denoising_strength self.init_latent = None self.image_mask = mask + #self.image_unblurred_mask = None + self.latent_mask = None self.mask_for_overlay = None self.mask_blur = mask_blur self.inpainting_fill = inpainting_fill @@ -308,6 +310,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.inpainting_mask_invert: self.image_mask = ImageOps.invert(self.image_mask) + #self.image_unblurred_mask = self.image_mask + if self.mask_blur > 0: self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur)) @@ -368,7 +372,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image)) if self.image_mask is not None: - latmask = self.image_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) + init_mask = self.latent_mask if self.latent_mask is not None else self.image_mask + latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) latmask = np.moveaxis(np.array(latmask, dtype=np.float64), 2, 0) / 255 latmask = latmask[0] latmask = np.tile(latmask[None], (4, 1, 1)) diff --git a/modules/scripts.py b/modules/scripts.py index 99502857..89a0618d 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -18,6 +18,9 @@ class Script: def ui(self, is_img2img): pass + def show(self, is_img2img): + return True + def run(self, *args): raise NotImplementedError() @@ -25,7 +28,7 @@ class Script: return "" -scripts = [] +scripts_data = [] def load_scripts(basedir): @@ -49,10 +52,8 @@ def load_scripts(basedir): for key, script_class in module.__dict__.items(): if type(script_class) == type and issubclass(script_class, Script): - obj = script_class() - obj.filename = path + scripts_data.append((script_class, path)) - scripts.append(obj) except Exception: print(f"Error loading script: {filename}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) @@ -69,52 +70,75 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs): return default -def setup_ui(is_img2img): - titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in scripts] +class ScriptRunner: + def __init__(self): + self.scripts = [] - dropdown = gr.Dropdown(label="Script", choices=["None"] + titles, value="None", type="index") + def setup_ui(self, is_img2img): + for script_class, path in scripts_data: + script = script_class() + script.filename = path - inputs = [dropdown] + if not script.show(is_img2img): + continue - for script in scripts: - script.args_from = len(inputs) - controls = script.ui(is_img2img) + self.scripts.append(script) - for control in controls: - control.visible = False + titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts] - inputs += controls - script.args_to = len(inputs) + dropdown = gr.Dropdown(label="Script", choices=["None"] + titles, value="None", type="index") + inputs = [dropdown] - def select_script(index): - if index > 0: - script = scripts[index-1] - args_from = script.args_from - args_to = script.args_to - else: - args_from = 0 - args_to = 0 + for script in self.scripts: + script.args_from = len(inputs) - return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))] + controls = wrap_call(script.ui, script.filename, "ui", is_img2img) - dropdown.change( - fn=select_script, - inputs=[dropdown], - outputs=inputs - ) + if controls is None: + continue - return inputs + for control in controls: + control.visible = False + + inputs += controls + script.args_to = len(inputs) + + def select_script(script_index): + if 0 < script_index <= len(self.scripts): + script = self.scripts[script_index-1] + args_from = script.args_from + args_to = script.args_to + else: + args_from = 0 + args_to = 0 + + return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))] + + dropdown.change( + fn=select_script, + inputs=[dropdown], + outputs=inputs + ) + + return inputs -def run(p: StableDiffusionProcessing, *args): - script_index = args[0] - 1 + def run(self, p: StableDiffusionProcessing, *args): + script_index = args[0] - if script_index < 0 or script_index >= len(scripts): - return None + if script_index == 0: + return None - script = scripts[script_index] + script = self.scripts[script_index-1] - script_args = args[script.args_from:script.args_to] - processed = script.run(p, *script_args) + if script is None: + return None - return processed + script_args = args[script.args_from:script.args_to] + processed = script.run(p, *script_args) + + return processed + + +scripts_txt2img = ScriptRunner() +scripts_img2img = ScriptRunner() diff --git a/modules/txt2img.py b/modules/txt2img.py index f5ac0540..fb65a7f6 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -24,7 +24,7 @@ def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, u use_GFPGAN=use_GFPGAN ) - processed = modules.scripts.run(p, *args) + processed = modules.scripts.scripts_txt2img.run(p, *args) if processed is not None: pass diff --git a/modules/ui.py b/modules/ui.py index ccca871a..65d53bcd 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -162,7 +162,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): seed = gr.Number(label='Seed', value=-1) with gr.Group(): - custom_inputs = modules.scripts.setup_ui(is_img2img=False) + custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False) with gr.Column(variant='panel'): with gr.Group(): @@ -244,7 +244,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", visible=False) with gr.Row(): - inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=True, visible=False) + inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False, visible=False) inpainting_mask_invert = gr.Radio(label='Masking mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", visible=False) with gr.Row(): @@ -269,7 +269,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): seed = gr.Number(label='Seed', value=-1) with gr.Group(): - custom_inputs = modules.scripts.setup_ui(is_img2img=True) + custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True) with gr.Column(variant='panel'): diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py new file mode 100644 index 00000000..98e1def0 --- /dev/null +++ b/scripts/poor_mans_outpainting.py @@ -0,0 +1,110 @@ +import math + +import modules.scripts as scripts +import gradio as gr +from PIL import Image, ImageDraw + +from modules import images, processing +from modules.processing import Processed, process_images +from modules.shared import opts, cmd_opts, state + + + +class Script(scripts.Script): + def title(self): + return "Poor man's outpainting" + + def show(self, is_img2img): + return is_img2img + + def ui(self, is_img2img): + if not is_img2img: + return None + + pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=128, step=8) + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, visible=False) + inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", visible=False) + + return [pixels, mask_blur, inpainting_fill] + + def run(self, p, pixels, mask_blur, inpainting_fill): + initial_seed = None + initial_info = None + + p.mask_blur = mask_blur + p.inpainting_fill = inpainting_fill + p.inpaint_full_res = False + + init_img = p.init_images[0] + target_w = math.ceil((init_img.width + pixels * 2) / 64) * 64 + target_h = math.ceil((init_img.height + pixels * 2) / 64) * 64 + + border_x = (target_w - init_img.width)//2 + border_y = (target_h - init_img.height)//2 + + img = Image.new("RGB", (target_w, target_h)) + img.paste(init_img, (border_x, border_y)) + + mask = Image.new("L", (img.width, img.height), "white") + draw = ImageDraw.Draw(mask) + draw.rectangle((border_x + mask_blur * 2, border_y + mask_blur * 2, mask.width - border_x - mask_blur * 2, mask.height - border_y - mask_blur * 2), fill="black") + + latent_mask = Image.new("L", (img.width, img.height), "white") + latent_draw = ImageDraw.Draw(latent_mask) + latent_draw.rectangle((border_x + 1, border_y + 1, mask.width - border_x - 1, mask.height - border_y - 1), fill="black") + + processing.torch_gc() + + grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=pixels) + grid_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels) + grid_latent_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels) + + p.n_iter = 1 + p.batch_size = 1 + p.do_not_save_grid = True + p.do_not_save_samples = True + + work = [] + work_mask = [] + work_latent_mask = [] + work_results = [] + + for (_, _, row), (_, _, row_mask), (_, _, row_latent_mask) in zip(grid.tiles, grid_mask.tiles, grid_latent_mask.tiles): + for tiledata, tiledata_mask, tiledata_latent_mask in zip(row, row_mask, row_latent_mask): + work.append(tiledata[2]) + work_mask.append(tiledata_mask[2]) + work_latent_mask.append(tiledata_latent_mask[2]) + + batch_count = len(work) + print(f"Poor man's outpainting will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)}.") + + for i in range(batch_count): + p.init_images = [work[i]] + p.image_mask = work_mask[i] + p.latent_mask = work_latent_mask[i] + + state.job = f"Batch {i + 1} out of {batch_count}" + processed = process_images(p) + + if initial_seed is None: + initial_seed = processed.seed + initial_info = processed.info + + p.seed = processed.seed + 1 + work_results += processed.images + + image_index = 0 + for y, h, row in grid.tiles: + for tiledata in row: + tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height)) + image_index += 1 + + combined_image = images.combine_grid(grid) + + if opts.samples_save: + images.save_image(combined_image, p.outpath_samples, "", initial_seed, p.prompt, opts.grid_format, info=initial_info) + + processed = Processed(p, [combined_image], initial_seed, initial_info) + + return processed +