From ad4de819c43997f2666b5bad95301f5c37f9018e Mon Sep 17 00:00:00 2001 From: victorca25 Date: Sun, 9 Oct 2022 13:02:12 +0200 Subject: [PATCH] update ESRGAN architecture and model to support all ESRGAN models in the DB, BSRGAN and real-ESRGAN models --- modules/bsrgan_model.py | 76 ------ modules/bsrgan_model_arch.py | 102 -------- modules/esrgam_model_arch.py | 80 ------ modules/esrgan_model.py | 178 +++++++++----- modules/esrgan_model_arch.py | 463 +++++++++++++++++++++++++++++++++++ 5 files changed, 585 insertions(+), 314 deletions(-) delete mode 100644 modules/bsrgan_model.py delete mode 100644 modules/bsrgan_model_arch.py delete mode 100644 modules/esrgam_model_arch.py create mode 100644 modules/esrgan_model_arch.py diff --git a/modules/bsrgan_model.py b/modules/bsrgan_model.py deleted file mode 100644 index 737e1a76..00000000 --- a/modules/bsrgan_model.py +++ /dev/null @@ -1,76 +0,0 @@ -import os.path -import sys -import traceback - -import PIL.Image -import numpy as np -import torch -from basicsr.utils.download_util import load_file_from_url - -import modules.upscaler -from modules import devices, modelloader -from modules.bsrgan_model_arch import RRDBNet - - -class UpscalerBSRGAN(modules.upscaler.Upscaler): - def __init__(self, dirname): - self.name = "BSRGAN" - self.model_name = "BSRGAN 4x" - self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth" - self.user_path = dirname - super().__init__() - model_paths = self.find_models(ext_filter=[".pt", ".pth"]) - scalers = [] - if len(model_paths) == 0: - scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4) - scalers.append(scaler_data) - for file in model_paths: - if "http" in file: - name = self.model_name - else: - name = modelloader.friendly_name(file) - try: - scaler_data = modules.upscaler.UpscalerData(name, file, self, 4) - scalers.append(scaler_data) - except Exception: - print(f"Error loading BSRGAN model: {file}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - self.scalers = scalers - - def do_upscale(self, img: PIL.Image, selected_file): - torch.cuda.empty_cache() - model = self.load_model(selected_file) - if model is None: - return img - model.to(devices.device_bsrgan) - torch.cuda.empty_cache() - img = np.array(img) - img = img[:, :, ::-1] - img = np.moveaxis(img, 2, 0) / 255 - img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(devices.device_bsrgan) - with torch.no_grad(): - output = model(img) - output = output.squeeze().float().cpu().clamp_(0, 1).numpy() - output = 255. * np.moveaxis(output, 0, 2) - output = output.astype(np.uint8) - output = output[:, :, ::-1] - torch.cuda.empty_cache() - return PIL.Image.fromarray(output, 'RGB') - - def load_model(self, path: str): - if "http" in path: - filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name, - progress=True) - else: - filename = path - if not os.path.exists(filename) or filename is None: - print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr) - return None - model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4) # define network - model.load_state_dict(torch.load(filename), strict=True) - model.eval() - for k, v in model.named_parameters(): - v.requires_grad = False - return model - diff --git a/modules/bsrgan_model_arch.py b/modules/bsrgan_model_arch.py deleted file mode 100644 index cb4d1c13..00000000 --- a/modules/bsrgan_model_arch.py +++ /dev/null @@ -1,102 +0,0 @@ -import functools -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.nn.init as init - - -def initialize_weights(net_l, scale=1): - if not isinstance(net_l, list): - net_l = [net_l] - for net in net_l: - for m in net.modules(): - if isinstance(m, nn.Conv2d): - init.kaiming_normal_(m.weight, a=0, mode='fan_in') - m.weight.data *= scale # for residual block - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - init.kaiming_normal_(m.weight, a=0, mode='fan_in') - m.weight.data *= scale - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.BatchNorm2d): - init.constant_(m.weight, 1) - init.constant_(m.bias.data, 0.0) - - -def make_layer(block, n_layers): - layers = [] - for _ in range(n_layers): - layers.append(block()) - return nn.Sequential(*layers) - - -class ResidualDenseBlock_5C(nn.Module): - def __init__(self, nf=64, gc=32, bias=True): - super(ResidualDenseBlock_5C, self).__init__() - # gc: growth channel, i.e. intermediate channels - self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) - self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) - self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) - self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) - self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - - # initialization - initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) - - def forward(self, x): - x1 = self.lrelu(self.conv1(x)) - x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) - x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) - x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) - x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) - return x5 * 0.2 + x - - -class RRDB(nn.Module): - '''Residual in Residual Dense Block''' - - def __init__(self, nf, gc=32): - super(RRDB, self).__init__() - self.RDB1 = ResidualDenseBlock_5C(nf, gc) - self.RDB2 = ResidualDenseBlock_5C(nf, gc) - self.RDB3 = ResidualDenseBlock_5C(nf, gc) - - def forward(self, x): - out = self.RDB1(x) - out = self.RDB2(out) - out = self.RDB3(out) - return out * 0.2 + x - - -class RRDBNet(nn.Module): - def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4): - super(RRDBNet, self).__init__() - RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) - self.sf = sf - - self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) - self.RRDB_trunk = make_layer(RRDB_block_f, nb) - self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - #### upsampling - self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - if self.sf==4: - self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) - - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - - def forward(self, x): - fea = self.conv_first(x) - trunk = self.trunk_conv(self.RRDB_trunk(fea)) - fea = fea + trunk - - fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) - if self.sf==4: - fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) - out = self.conv_last(self.lrelu(self.HRconv(fea))) - - return out \ No newline at end of file diff --git a/modules/esrgam_model_arch.py b/modules/esrgam_model_arch.py deleted file mode 100644 index e413d36e..00000000 --- a/modules/esrgam_model_arch.py +++ /dev/null @@ -1,80 +0,0 @@ -# this file is taken from https://github.com/xinntao/ESRGAN - -import functools -import torch -import torch.nn as nn -import torch.nn.functional as F - - -def make_layer(block, n_layers): - layers = [] - for _ in range(n_layers): - layers.append(block()) - return nn.Sequential(*layers) - - -class ResidualDenseBlock_5C(nn.Module): - def __init__(self, nf=64, gc=32, bias=True): - super(ResidualDenseBlock_5C, self).__init__() - # gc: growth channel, i.e. intermediate channels - self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) - self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) - self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) - self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) - self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - - # initialization - # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) - - def forward(self, x): - x1 = self.lrelu(self.conv1(x)) - x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) - x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) - x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) - x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) - return x5 * 0.2 + x - - -class RRDB(nn.Module): - '''Residual in Residual Dense Block''' - - def __init__(self, nf, gc=32): - super(RRDB, self).__init__() - self.RDB1 = ResidualDenseBlock_5C(nf, gc) - self.RDB2 = ResidualDenseBlock_5C(nf, gc) - self.RDB3 = ResidualDenseBlock_5C(nf, gc) - - def forward(self, x): - out = self.RDB1(x) - out = self.RDB2(out) - out = self.RDB3(out) - return out * 0.2 + x - - -class RRDBNet(nn.Module): - def __init__(self, in_nc, out_nc, nf, nb, gc=32): - super(RRDBNet, self).__init__() - RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) - - self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) - self.RRDB_trunk = make_layer(RRDB_block_f, nb) - self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - #### upsampling - self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) - self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) - - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - - def forward(self, x): - fea = self.conv_first(x) - trunk = self.trunk_conv(self.RRDB_trunk(fea)) - fea = fea + trunk - - fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) - fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) - out = self.conv_last(self.lrelu(self.HRconv(fea))) - - return out diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 3970e6e4..a49e2258 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -5,68 +5,115 @@ import torch from PIL import Image from basicsr.utils.download_util import load_file_from_url -import modules.esrgam_model_arch as arch +import modules.esrgan_model_arch as arch from modules import shared, modelloader, images, devices from modules.upscaler import Upscaler, UpscalerData from modules.shared import opts -def fix_model_layers(crt_model, pretrained_net): - # this code is adapted from https://github.com/xinntao/ESRGAN - if 'conv_first.weight' in pretrained_net: - return pretrained_net - if 'model.0.weight' not in pretrained_net: - is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"] - if is_realesrgan: - raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.") - else: - raise Exception("The file is not a ESRGAN model.") +def mod2normal(state_dict): + # this code is copied from https://github.com/victorca25/iNNfer + if 'conv_first.weight' in state_dict: + crt_net = {} + items = [] + for k, v in state_dict.items(): + items.append(k) - crt_net = crt_model.state_dict() - load_net_clean = {} - for k, v in pretrained_net.items(): - if k.startswith('module.'): - load_net_clean[k[7:]] = v - else: - load_net_clean[k] = v - pretrained_net = load_net_clean + crt_net['model.0.weight'] = state_dict['conv_first.weight'] + crt_net['model.0.bias'] = state_dict['conv_first.bias'] - tbd = [] - for k, v in crt_net.items(): - tbd.append(k) + for k in items.copy(): + if 'RDB' in k: + ori_k = k.replace('RRDB_trunk.', 'model.1.sub.') + if '.weight' in k: + ori_k = ori_k.replace('.weight', '.0.weight') + elif '.bias' in k: + ori_k = ori_k.replace('.bias', '.0.bias') + crt_net[ori_k] = state_dict[k] + items.remove(k) - # directly copy - for k, v in crt_net.items(): - if k in pretrained_net and pretrained_net[k].size() == v.size(): - crt_net[k] = pretrained_net[k] - tbd.remove(k) + crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight'] + crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias'] + crt_net['model.3.weight'] = state_dict['upconv1.weight'] + crt_net['model.3.bias'] = state_dict['upconv1.bias'] + crt_net['model.6.weight'] = state_dict['upconv2.weight'] + crt_net['model.6.bias'] = state_dict['upconv2.bias'] + crt_net['model.8.weight'] = state_dict['HRconv.weight'] + crt_net['model.8.bias'] = state_dict['HRconv.bias'] + crt_net['model.10.weight'] = state_dict['conv_last.weight'] + crt_net['model.10.bias'] = state_dict['conv_last.bias'] + state_dict = crt_net + return state_dict - crt_net['conv_first.weight'] = pretrained_net['model.0.weight'] - crt_net['conv_first.bias'] = pretrained_net['model.0.bias'] - for k in tbd.copy(): - if 'RDB' in k: - ori_k = k.replace('RRDB_trunk.', 'model.1.sub.') - if '.weight' in k: - ori_k = ori_k.replace('.weight', '.0.weight') - elif '.bias' in k: - ori_k = ori_k.replace('.bias', '.0.bias') - crt_net[k] = pretrained_net[ori_k] - tbd.remove(k) +def resrgan2normal(state_dict, nb=23): + # this code is copied from https://github.com/victorca25/iNNfer + if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict: + crt_net = {} + items = [] + for k, v in state_dict.items(): + items.append(k) - crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight'] - crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias'] - crt_net['upconv1.weight'] = pretrained_net['model.3.weight'] - crt_net['upconv1.bias'] = pretrained_net['model.3.bias'] - crt_net['upconv2.weight'] = pretrained_net['model.6.weight'] - crt_net['upconv2.bias'] = pretrained_net['model.6.bias'] - crt_net['HRconv.weight'] = pretrained_net['model.8.weight'] - crt_net['HRconv.bias'] = pretrained_net['model.8.bias'] - crt_net['conv_last.weight'] = pretrained_net['model.10.weight'] - crt_net['conv_last.bias'] = pretrained_net['model.10.bias'] + crt_net['model.0.weight'] = state_dict['conv_first.weight'] + crt_net['model.0.bias'] = state_dict['conv_first.bias'] + + for k in items.copy(): + if "rdb" in k: + ori_k = k.replace('body.', 'model.1.sub.') + ori_k = ori_k.replace('.rdb', '.RDB') + if '.weight' in k: + ori_k = ori_k.replace('.weight', '.0.weight') + elif '.bias' in k: + ori_k = ori_k.replace('.bias', '.0.bias') + crt_net[ori_k] = state_dict[k] + items.remove(k) + + crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight'] + crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias'] + crt_net['model.3.weight'] = state_dict['conv_up1.weight'] + crt_net['model.3.bias'] = state_dict['conv_up1.bias'] + crt_net['model.6.weight'] = state_dict['conv_up2.weight'] + crt_net['model.6.bias'] = state_dict['conv_up2.bias'] + crt_net['model.8.weight'] = state_dict['conv_hr.weight'] + crt_net['model.8.bias'] = state_dict['conv_hr.bias'] + crt_net['model.10.weight'] = state_dict['conv_last.weight'] + crt_net['model.10.bias'] = state_dict['conv_last.bias'] + state_dict = crt_net + return state_dict + + +def infer_params(state_dict): + # this code is copied from https://github.com/victorca25/iNNfer + scale2x = 0 + scalemin = 6 + n_uplayer = 0 + plus = False + + for block in list(state_dict): + parts = block.split(".") + n_parts = len(parts) + if n_parts == 5 and parts[2] == "sub": + nb = int(parts[3]) + elif n_parts == 3: + part_num = int(parts[1]) + if (part_num > scalemin + and parts[0] == "model" + and parts[2] == "weight"): + scale2x += 1 + if part_num > n_uplayer: + n_uplayer = part_num + out_nc = state_dict[block].shape[0] + if not plus and "conv1x1" in block: + plus = True + + nf = state_dict["model.0.weight"].shape[0] + in_nc = state_dict["model.0.weight"].shape[1] + out_nc = out_nc + scale = 2 ** scale2x + + return in_nc, out_nc, nf, nb, plus, scale - return crt_net class UpscalerESRGAN(Upscaler): def __init__(self, dirname): @@ -109,20 +156,39 @@ class UpscalerESRGAN(Upscaler): print("Unable to load %s from %s" % (self.model_path, filename)) return None - pretrained_net = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) - crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32) + state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) - pretrained_net = fix_model_layers(crt_model, pretrained_net) - crt_model.load_state_dict(pretrained_net) - crt_model.eval() + if "params_ema" in state_dict: + state_dict = state_dict["params_ema"] + elif "params" in state_dict: + state_dict = state_dict["params"] + num_conv = 16 if "realesr-animevideov3" in filename else 32 + model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu') + model.load_state_dict(state_dict) + model.eval() + return model - return crt_model + if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict: + nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23 + state_dict = resrgan2normal(state_dict, nb) + elif "conv_first.weight" in state_dict: + state_dict = mod2normal(state_dict) + elif "model.0.weight" not in state_dict: + raise Exception("The file is not a recognized ESRGAN model.") + + in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict) + + model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus) + model.load_state_dict(state_dict) + model.eval() + + return model def upscale_without_tiling(model, img): img = np.array(img) img = img[:, :, ::-1] - img = np.moveaxis(img, 2, 0) / 255 + img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() img = img.unsqueeze(0).to(devices.device_esrgan) with torch.no_grad(): diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py new file mode 100644 index 00000000..bc9ceb2a --- /dev/null +++ b/modules/esrgan_model_arch.py @@ -0,0 +1,463 @@ +# this file is adapted from https://github.com/victorca25/iNNfer + +import math +import functools +import torch +import torch.nn as nn +import torch.nn.functional as F + + +#################### +# RRDBNet Generator +#################### + +class RRDBNet(nn.Module): + def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None, + act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D', + finalact=None, gaussian_noise=False, plus=False): + super(RRDBNet, self).__init__() + n_upscale = int(math.log(upscale, 2)) + if upscale == 3: + n_upscale = 1 + + self.resrgan_scale = 0 + if in_nc % 16 == 0: + self.resrgan_scale = 1 + elif in_nc != 4 and in_nc % 4 == 0: + self.resrgan_scale = 2 + + fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype) + rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', + norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype, + gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)] + LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype) + + if upsample_mode == 'upconv': + upsample_block = upconv_block + elif upsample_mode == 'pixelshuffle': + upsample_block = pixelshuffle_block + else: + raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) + if upscale == 3: + upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype) + else: + upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)] + HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype) + HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype) + + outact = act(finalact) if finalact else None + + self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)), + *upsampler, HR_conv0, HR_conv1, outact) + + def forward(self, x, outm=None): + if self.resrgan_scale == 1: + feat = pixel_unshuffle(x, scale=4) + elif self.resrgan_scale == 2: + feat = pixel_unshuffle(x, scale=2) + else: + feat = x + + return self.model(feat) + + +class RRDB(nn.Module): + """ + Residual in Residual Dense Block + (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks) + """ + + def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', + norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D', + spectral_norm=False, gaussian_noise=False, plus=False): + super(RRDB, self).__init__() + # This is for backwards compatibility with existing models + if nr == 3: + self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, + norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, + gaussian_noise=gaussian_noise, plus=plus) + self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, + norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, + gaussian_noise=gaussian_noise, plus=plus) + self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, + norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, + gaussian_noise=gaussian_noise, plus=plus) + else: + RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, + norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, + gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)] + self.RDBs = nn.Sequential(*RDB_list) + + def forward(self, x): + if hasattr(self, 'RDB1'): + out = self.RDB1(x) + out = self.RDB2(out) + out = self.RDB3(out) + else: + out = self.RDBs(x) + return out * 0.2 + x + + +class ResidualDenseBlock_5C(nn.Module): + """ + Residual Dense Block + The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18) + Modified options that can be used: + - "Partial Convolution based Padding" arXiv:1811.11718 + - "Spectral normalization" arXiv:1802.05957 + - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. + {Rakotonirina} and A. {Rasoanaivo} + """ + + def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', + norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D', + spectral_norm=False, gaussian_noise=False, plus=False): + super(ResidualDenseBlock_5C, self).__init__() + + self.noise = GaussianNoise() if gaussian_noise else None + self.conv1x1 = conv1x1(nf, gc) if plus else None + + self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type, + norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, + spectral_norm=spectral_norm) + self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, + norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, + spectral_norm=spectral_norm) + self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, + norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, + spectral_norm=spectral_norm) + self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, + norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, + spectral_norm=spectral_norm) + if mode == 'CNA': + last_act = None + else: + last_act = act_type + self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type, + norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype, + spectral_norm=spectral_norm) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(torch.cat((x, x1), 1)) + if self.conv1x1: + x2 = x2 + self.conv1x1(x) + x3 = self.conv3(torch.cat((x, x1, x2), 1)) + x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) + if self.conv1x1: + x4 = x4 + x2 + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + if self.noise: + return self.noise(x5.mul(0.2) + x) + else: + return x5 * 0.2 + x + + +#################### +# ESRGANplus +#################### + +class GaussianNoise(nn.Module): + def __init__(self, sigma=0.1, is_relative_detach=False): + super().__init__() + self.sigma = sigma + self.is_relative_detach = is_relative_detach + self.noise = torch.tensor(0, dtype=torch.float) + + def forward(self, x): + if self.training and self.sigma != 0: + self.noise = self.noise.to(x.device) + scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x + sampled_noise = self.noise.repeat(*x.size()).normal_() * scale + x = x + sampled_noise + return x + +def conv1x1(in_planes, out_planes, stride=1): + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +#################### +# SRVGGNetCompact +#################### + +class SRVGGNetCompact(nn.Module): + """A compact VGG-style network structure for super-resolution. + This class is copied from https://github.com/xinntao/Real-ESRGAN + """ + + def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'): + super(SRVGGNetCompact, self).__init__() + self.num_in_ch = num_in_ch + self.num_out_ch = num_out_ch + self.num_feat = num_feat + self.num_conv = num_conv + self.upscale = upscale + self.act_type = act_type + + self.body = nn.ModuleList() + # the first conv + self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)) + # the first activation + if act_type == 'relu': + activation = nn.ReLU(inplace=True) + elif act_type == 'prelu': + activation = nn.PReLU(num_parameters=num_feat) + elif act_type == 'leakyrelu': + activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.body.append(activation) + + # the body structure + for _ in range(num_conv): + self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1)) + # activation + if act_type == 'relu': + activation = nn.ReLU(inplace=True) + elif act_type == 'prelu': + activation = nn.PReLU(num_parameters=num_feat) + elif act_type == 'leakyrelu': + activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.body.append(activation) + + # the last conv + self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1)) + # upsample + self.upsampler = nn.PixelShuffle(upscale) + + def forward(self, x): + out = x + for i in range(0, len(self.body)): + out = self.body[i](out) + + out = self.upsampler(out) + # add the nearest upsampled image, so that the network learns the residual + base = F.interpolate(x, scale_factor=self.upscale, mode='nearest') + out += base + return out + + +#################### +# Upsampler +#################### + +class Upsample(nn.Module): + r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data. + The input data is assumed to be of the form + `minibatch x channels x [optional depth] x [optional height] x width`. + """ + + def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None): + super(Upsample, self).__init__() + if isinstance(scale_factor, tuple): + self.scale_factor = tuple(float(factor) for factor in scale_factor) + else: + self.scale_factor = float(scale_factor) if scale_factor else None + self.mode = mode + self.size = size + self.align_corners = align_corners + + def forward(self, x): + return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) + + def extra_repr(self): + if self.scale_factor is not None: + info = 'scale_factor=' + str(self.scale_factor) + else: + info = 'size=' + str(self.size) + info += ', mode=' + self.mode + return info + + +def pixel_unshuffle(x, scale): + """ Pixel unshuffle. + Args: + x (Tensor): Input feature with shape (b, c, hh, hw). + scale (int): Downsample ratio. + Returns: + Tensor: the pixel unshuffled feature. + """ + b, c, hh, hw = x.size() + out_channel = c * (scale**2) + assert hh % scale == 0 and hw % scale == 0 + h = hh // scale + w = hw // scale + x_view = x.view(b, c, h, scale, w, scale) + return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) + + +def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, + pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'): + """ + Pixel shuffle layer + (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional + Neural Network, CVPR17) + """ + conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias, + pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype) + pixel_shuffle = nn.PixelShuffle(upscale_factor) + + n = norm(norm_type, out_nc) if norm_type else None + a = act(act_type) if act_type else None + return sequential(conv, pixel_shuffle, n, a) + + +def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, + pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'): + """ Upconv layer """ + upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor + upsample = Upsample(scale_factor=upscale_factor, mode=mode) + conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias, + pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype) + return sequential(upsample, conv) + + + + + + + + +#################### +# Basic blocks +#################### + + +def make_layer(basic_block, num_basic_block, **kwarg): + """Make layers by stacking the same blocks. + Args: + basic_block (nn.module): nn.module class for basic block. (block) + num_basic_block (int): number of blocks. (n_layers) + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_basic_block): + layers.append(basic_block(**kwarg)) + return nn.Sequential(*layers) + + +def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0): + """ activation helper """ + act_type = act_type.lower() + if act_type == 'relu': + layer = nn.ReLU(inplace) + elif act_type in ('leakyrelu', 'lrelu'): + layer = nn.LeakyReLU(neg_slope, inplace) + elif act_type == 'prelu': + layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) + elif act_type == 'tanh': # [-1, 1] range output + layer = nn.Tanh() + elif act_type == 'sigmoid': # [0, 1] range output + layer = nn.Sigmoid() + else: + raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type)) + return layer + + +class Identity(nn.Module): + def __init__(self, *kwargs): + super(Identity, self).__init__() + + def forward(self, x, *kwargs): + return x + + +def norm(norm_type, nc): + """ Return a normalization layer """ + norm_type = norm_type.lower() + if norm_type == 'batch': + layer = nn.BatchNorm2d(nc, affine=True) + elif norm_type == 'instance': + layer = nn.InstanceNorm2d(nc, affine=False) + elif norm_type == 'none': + def norm_layer(x): return Identity() + else: + raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type)) + return layer + + +def pad(pad_type, padding): + """ padding layer helper """ + pad_type = pad_type.lower() + if padding == 0: + return None + if pad_type == 'reflect': + layer = nn.ReflectionPad2d(padding) + elif pad_type == 'replicate': + layer = nn.ReplicationPad2d(padding) + elif pad_type == 'zero': + layer = nn.ZeroPad2d(padding) + else: + raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type)) + return layer + + +def get_valid_padding(kernel_size, dilation): + kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) + padding = (kernel_size - 1) // 2 + return padding + + +class ShortcutBlock(nn.Module): + """ Elementwise sum the output of a submodule to its input """ + def __init__(self, submodule): + super(ShortcutBlock, self).__init__() + self.sub = submodule + + def forward(self, x): + output = x + self.sub(x) + return output + + def __repr__(self): + return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|') + + +def sequential(*args): + """ Flatten Sequential. It unwraps nn.Sequential. """ + if len(args) == 1: + if isinstance(args[0], OrderedDict): + raise NotImplementedError('sequential does not support OrderedDict input.') + return args[0] # No sequential is needed. + modules = [] + for module in args: + if isinstance(module, nn.Sequential): + for submodule in module.children(): + modules.append(submodule) + elif isinstance(module, nn.Module): + modules.append(module) + return nn.Sequential(*modules) + + +def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True, + pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D', + spectral_norm=False): + """ Conv layer with padding, normalization, activation """ + assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode) + padding = get_valid_padding(kernel_size, dilation) + p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None + padding = padding if pad_type == 'zero' else 0 + + if convtype=='PartialConv2D': + c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, bias=bias, groups=groups) + elif convtype=='DeformConv2D': + c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, bias=bias, groups=groups) + elif convtype=='Conv3D': + c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, bias=bias, groups=groups) + else: + c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, bias=bias, groups=groups) + + if spectral_norm: + c = nn.utils.spectral_norm(c) + + a = act(act_type) if act_type else None + if 'CNA' in mode: + n = norm(norm_type, out_nc) if norm_type else None + return sequential(p, c, n, a) + elif mode == 'NAC': + if norm_type is None and act_type is not None: + a = act(act_type, inplace=False) + n = norm(norm_type, in_nc) if norm_type else None + return sequential(n, a, p, c)