From d8ed699839f4a9d3c232a3ca90c81545814dc45c Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Tue, 20 Sep 2022 20:09:13 +0300 Subject: [PATCH] Update swinir.py --- modules/swinir.py | 117 +++++++++++++++++++++++++++++----------------- 1 file changed, 74 insertions(+), 43 deletions(-) diff --git a/modules/swinir.py b/modules/swinir.py index 6c7f0a2d..7e8fd5e3 100644 --- a/modules/swinir.py +++ b/modules/swinir.py @@ -1,63 +1,87 @@ import sys import traceback import cv2 -from collections import OrderedDict import os -import requests -from collections import namedtuple +import contextlib import numpy as np from PIL import Image import torch import modules.images from modules.shared import cmd_opts, opts, device from modules.swinir_arch import SwinIR as net -precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext -def load_model(task = "realsr", large_model = True, model_path="C:/sd/ESRGANn/4x-large.pth", scale=4): - try: - modules.shared.sd_upscalers.append(UpscalerSwin("McSwinnySwin")) - except Exception: - print(f"Error loading ESRGAN model", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - if not large_model: - # use 'nearest+conv' to avoid block artifacts - model = net(upscale=scale, in_chans=3, img_size=64, window_size=8, - img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6], - mlp_ratio=2, upsampler='nearest+conv', resi_connection='1conv') - else: - # larger model size; use '3conv' to save parameters and memory; use ema for GAN training - model = net(upscale=scale, in_chans=3, img_size=64, window_size=8, - img_range=1., depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], embed_dim=240, - num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], - mlp_ratio=2, upsampler='nearest+conv', resi_connection='3conv') - - pretrained_model = torch.load(model_path) +precision_scope = ( + torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext +) + + +def load_model(filename, scale=4): + model = net( + upscale=scale, + in_chans=3, + img_size=64, + window_size=8, + img_range=1.0, + depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], + embed_dim=240, + num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], + mlp_ratio=2, + upsampler="nearest+conv", + resi_connection="3conv", + ) + + pretrained_model = torch.load(filename) model.load_state_dict(pretrained_model["params_ema"], strict=True) + if not cmd_opts.no_half: + model = model.half() + return model - return model.half().to(device) - -def upscale(img, tile=opts.ESRGAN_tile, tile_overlap=opts.ESRGAN_tile_overlap, window_size = 8, scale = 4): + +def load_models(dirname): + for file in os.listdir(dirname): + path = os.path.join(dirname, file) + model_name, extension = os.path.splitext(file) + + if extension != ".pt" and extension != ".pth": + continue + + try: + modules.shared.sd_upscalers.append(UpscalerSwin(path, model_name)) + except Exception: + print(f"Error loading SwinIR model: {path}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + +def upscale( + img, + model, + tile=opts.GAN_tile, + tile_overlap=opts.GAN_tile_overlap, + window_size=8, + scale=4, +): img = np.array(img) img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() img = img.unsqueeze(0).to(device) - model = load_model() with torch.no_grad(), precision_scope("cuda"): _, _, h_old, w_old = img.size() h_pad = (h_old // window_size + 1) * window_size - h_old w_pad = (w_old // window_size + 1) * window_size - w_old - img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, :h_old + h_pad, :] - img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, :w_old + w_pad] + img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] + img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] output = inference(img, model, tile, tile_overlap, window_size, scale) - output = output[..., :h_old * scale, :w_old * scale] + output = output[..., : h_old * scale, : w_old * scale] output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() if output.ndim == 3: - output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR + output = np.transpose( + output[[2, 1, 0], :, :], (1, 2, 0) + ) # CHW-RGB to HCW-BGR output = (output * 255.0).round().astype(np.uint8) # float32 to uint8 - return Image.fromarray(output, 'RGB') - - + return Image.fromarray(output, "RGB") + + def inference(img, model, tile, tile_overlap, window_size, scale): # test the image tile by tile b, c, h, w = img.size() @@ -66,27 +90,34 @@ def inference(img, model, tile, tile_overlap, window_size, scale): sf = scale stride = tile - tile_overlap - h_idx_list = list(range(0, h-tile, stride)) + [h-tile] - w_idx_list = list(range(0, w-tile, stride)) + [w-tile] - E = torch.zeros(b, c, h*sf, w*sf, dtype=torch.half, device=device).type_as(img) + h_idx_list = list(range(0, h - tile, stride)) + [h - tile] + w_idx_list = list(range(0, w - tile, stride)) + [w - tile] + E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img) W = torch.zeros_like(E, dtype=torch.half, device=device) for h_idx in h_idx_list: for w_idx in w_idx_list: - in_patch = img[..., h_idx:h_idx+tile, w_idx:w_idx+tile] + in_patch = img[..., h_idx : h_idx + tile, w_idx : w_idx + tile] out_patch = model(in_patch) out_patch_mask = torch.ones_like(out_patch) - E[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch) - W[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch_mask) + E[ + ..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf + ].add_(out_patch) + W[ + ..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf + ].add_(out_patch_mask) output = E.div_(W) return output - + + class UpscalerSwin(modules.images.Upscaler): - def __init__(self, title): + def __init__(self, filename, title): self.name = title + self.model = load_model(filename) def do_upscale(self, img): - img = upscale(img) + model = self.model.to(device) + img = upscale(img, model) return img