Implementation for SD upscale.

This commit is contained in:
AUTOMATIC 2022-08-27 16:13:33 +03:00
parent 9597b265ec
commit 4e0fdca2f4
3 changed files with 165 additions and 24 deletions

View file

@ -194,3 +194,14 @@ Using `()` in prompt decreases model's attention to enclosed words, and `[]` inc
multiple modifiers:
![](images/attention-3.jpg)
### SD upscale
Upscale image using RealESRGAN and then go through tiles of the result, improving them with img2img.
Original idea by: https://github.com/jquesnelle/txt2imghd. This is an independent implementation.
To use this feature, tick a checkbox in the img2img interface. Original
image will be upscaled to twice the original width and height, while width and height sliders
will specify the size of individual tiles. At the moment this method does not support batch size.
![](images/sd-upscale.jpg)

BIN
images/sd-upscale.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 712 KiB

178
webui.py
View file

@ -85,11 +85,6 @@ try:
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
realesrgan_models = [
RealesrganModelInfo(
name="Real-ESRGAN 2x plus",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
netscale=2, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
),
RealesrganModelInfo(
name="Real-ESRGAN 4x plus",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
@ -100,6 +95,11 @@ try:
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
netscale=4, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
),
RealesrganModelInfo(
name="Real-ESRGAN 2x plus",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
netscale=2, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
),
]
have_realesrgan = True
except:
@ -124,6 +124,7 @@ class Options:
"verify_input": (True, "Check input, and produce warning if it's too long"),
"enable_pnginfo": (True, "Save text information about generation parameters as chunks to png files"),
"prompt_matrix_add_to_start": (True, "In prompt matrix, add the variable combination of text to the start of the prompt, rather than the end"),
"sd_upscale_overlap": (64, "Overlap for tiles for SD upscale. The smaller it is, the less smooth transition from one tile to another", 0, 256, 16),
}
def __init__(self):
@ -289,6 +290,73 @@ def image_grid(imgs, batch_size, force_n_rows=None):
return grid
Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
def split_grid(image, tile_w=512, tile_h=512, overlap=64):
w = image.width
h = image.height
now = tile_w - overlap # non-overlap width
noh = tile_h - overlap
cols = math.ceil((w - overlap) / now)
rows = math.ceil((h - overlap) / noh)
grid = Grid([], tile_w, tile_h, w, h, overlap)
for row in range(rows):
row_images = []
y = row * noh
if y + tile_h >= h:
y = h - tile_h
for col in range(cols):
x = col * now
if x+tile_w >= w:
x = w - tile_w
tile = image.crop((x, y, x + tile_w, y + tile_h))
row_images.append([x, tile_w, tile])
grid.tiles.append([y, tile_h, row_images])
return grid
def combine_grid(grid):
def make_mask_image(r):
r = r * 255 / grid.overlap
r = r.astype(np.uint8)
return Image.fromarray(r, 'L')
mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
for y, h, row in grid.tiles:
combined_row = Image.new("RGB", (grid.image_w, h))
for x, w, tile in row:
if x == 0:
combined_row.paste(tile, (0, 0))
continue
combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
if y == 0:
combined_image.paste(combined_row, (0, 0))
continue
combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))
return combined_image
def draw_prompt_matrix(im, width, height, all_prompts):
def wrap(text, d, font, line_length):
lines = ['']
@ -491,6 +559,7 @@ class StableDiffuionModelHijack:
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, embeddings):
super().__init__()
@ -740,8 +809,6 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index,
save_image(grid, outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
grid_count += 1
torch_gc()
return output_images, seed, infotext()
@ -847,7 +914,7 @@ txt2img_interface = gr.Interface(
)
def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, sd_upscale: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
outpath = opts.outdir or "outputs/img2img-samples"
sampler = samplers_for_img2img[sampler_index].constructor(model)
@ -894,7 +961,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_index=0,
sampler_index=sampler_index,
batch_size=1,
n_iter=1,
steps=ddim_steps,
@ -923,6 +990,59 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
output_images = history
seed = initial_seed
elif sd_upscale:
initial_seed = None
initial_info = None
img = upscale_with_realesrgan(init_img, RealESRGAN_upscaling=2, RealESRGAN_model_index=0)
torch_gc()
grid = split_grid(img, tile_w=width, tile_h=height, overlap=opts.sd_upscale_overlap)
print(f"SD upscaling will process a total of {len(grid.tiles[0][2])}x{len(grid.tiles)} images.")
for y, h, row in grid.tiles:
for tiledata in row:
init_img = tiledata[2]
output_images, seed, info = process_images(
outpath=outpath,
func_init=init,
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_index=sampler_index,
batch_size=1, # since process_images can't work with multiple different images we have to do this for now
n_iter=1,
steps=ddim_steps,
cfg_scale=cfg_scale,
width=width,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN,
do_not_save_grid=True,
extra_generation_params={"Denoising Strength": denoising_strength},
)
if initial_seed is None:
initial_seed = seed
initial_info = info
seed += 1
tiledata[2] = output_images[0]
combined_image = combine_grid(grid)
grid_count = len(os.listdir(outpath)) - 1
save_image(combined_image, outpath, f"grid-{grid_count:04}", initial_seed, prompt, opts.grid_format, info=initial_info, short_filename=not opts.grid_extended_filename)
output_images = [combined_image]
seed = initial_seed
info = initial_info
else:
output_images, seed, info = process_images(
outpath=outpath,
@ -930,7 +1050,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_index=0,
sampler_index=sampler_index,
batch_size=batch_size,
n_iter=n_iter,
steps=ddim_steps,
@ -960,6 +1080,7 @@ img2img_interface = gr.Interface(
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
gr.Checkbox(label='Loopback (use images from previous batch when creating next batch)', value=False),
gr.Checkbox(label='Stable Diffusion upscale', value=False),
gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1),
gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
@ -978,7 +1099,26 @@ img2img_interface = gr.Interface(
)
def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index):
info = realesrgan_models[RealESRGAN_model_index]
model = info.model()
upsampler = RealESRGANer(
scale=info.netscale,
model_path=info.location,
model=model,
half=True
)
upsampled = upsampler.enhance(np.array(image), outscale=RealESRGAN_upscaling)[0]
image = Image.fromarray(upsampled)
return image
def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_index):
torch_gc()
image = image.convert("RGB")
outpath = opts.outdir or "outputs/extras-samples"
@ -993,19 +1133,7 @@ def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_in
image = res
if have_realesrgan and RealESRGAN_upscaling != 1.0:
info = realesrgan_models[RealESRGAN_model_index]
model = info.model()
upsampler = RealESRGANer(
scale=info.netscale,
model_path=info.location,
model=model,
half=True
)
upsampled = upsampler.enhance(np.array(image), outscale=RealESRGAN_upscaling)[0]
image = Image.fromarray(upsampled)
image = upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index)
os.makedirs(outpath, exist_ok=True)
base_count = len(os.listdir(outpath))
@ -1058,7 +1186,9 @@ def create_setting_component(key):
if t == str:
item = gr.Textbox(label=label, value=fun, lines=1)
elif t == int:
if len(labelinfo) == 4:
if len(labelinfo) == 5:
item = gr.Slider(minimum=labelinfo[2], maximum=labelinfo[3], step=labelinfo[4], label=label, value=fun)
elif len(labelinfo) == 4:
item = gr.Slider(minimum=labelinfo[2], maximum=labelinfo[3], step=1, label=label, value=fun)
else:
item = gr.Number(label=label, value=fun)