Merge pull request #2067 from victorca25/esrgan_mod
update ESRGAN architecture and model to support all ESRGAN models
This commit is contained in:
commit
6bd6154a92
4 changed files with 556 additions and 285 deletions
|
@ -1,76 +0,0 @@
|
||||||
import os.path
|
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
import PIL.Image
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
|
||||||
|
|
||||||
import modules.upscaler
|
|
||||||
from modules import devices, modelloader
|
|
||||||
from modules.bsrgan_model_arch import RRDBNet
|
|
||||||
|
|
||||||
|
|
||||||
class UpscalerBSRGAN(modules.upscaler.Upscaler):
|
|
||||||
def __init__(self, dirname):
|
|
||||||
self.name = "BSRGAN"
|
|
||||||
self.model_name = "BSRGAN 4x"
|
|
||||||
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
|
|
||||||
self.user_path = dirname
|
|
||||||
super().__init__()
|
|
||||||
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
|
||||||
scalers = []
|
|
||||||
if len(model_paths) == 0:
|
|
||||||
scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4)
|
|
||||||
scalers.append(scaler_data)
|
|
||||||
for file in model_paths:
|
|
||||||
if "http" in file:
|
|
||||||
name = self.model_name
|
|
||||||
else:
|
|
||||||
name = modelloader.friendly_name(file)
|
|
||||||
try:
|
|
||||||
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
|
|
||||||
scalers.append(scaler_data)
|
|
||||||
except Exception:
|
|
||||||
print(f"Error loading BSRGAN model: {file}", file=sys.stderr)
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
self.scalers = scalers
|
|
||||||
|
|
||||||
def do_upscale(self, img: PIL.Image, selected_file):
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
model = self.load_model(selected_file)
|
|
||||||
if model is None:
|
|
||||||
return img
|
|
||||||
model.to(devices.device_bsrgan)
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
img = np.array(img)
|
|
||||||
img = img[:, :, ::-1]
|
|
||||||
img = np.moveaxis(img, 2, 0) / 255
|
|
||||||
img = torch.from_numpy(img).float()
|
|
||||||
img = img.unsqueeze(0).to(devices.device_bsrgan)
|
|
||||||
with torch.no_grad():
|
|
||||||
output = model(img)
|
|
||||||
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
|
||||||
output = 255. * np.moveaxis(output, 0, 2)
|
|
||||||
output = output.astype(np.uint8)
|
|
||||||
output = output[:, :, ::-1]
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
return PIL.Image.fromarray(output, 'RGB')
|
|
||||||
|
|
||||||
def load_model(self, path: str):
|
|
||||||
if "http" in path:
|
|
||||||
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
|
|
||||||
progress=True)
|
|
||||||
else:
|
|
||||||
filename = path
|
|
||||||
if not os.path.exists(filename) or filename is None:
|
|
||||||
print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr)
|
|
||||||
return None
|
|
||||||
model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4) # define network
|
|
||||||
model.load_state_dict(torch.load(filename), strict=True)
|
|
||||||
model.eval()
|
|
||||||
for k, v in model.named_parameters():
|
|
||||||
v.requires_grad = False
|
|
||||||
return model
|
|
||||||
|
|
|
@ -1,102 +0,0 @@
|
||||||
import functools
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.nn.init as init
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_weights(net_l, scale=1):
|
|
||||||
if not isinstance(net_l, list):
|
|
||||||
net_l = [net_l]
|
|
||||||
for net in net_l:
|
|
||||||
for m in net.modules():
|
|
||||||
if isinstance(m, nn.Conv2d):
|
|
||||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
|
||||||
m.weight.data *= scale # for residual block
|
|
||||||
if m.bias is not None:
|
|
||||||
m.bias.data.zero_()
|
|
||||||
elif isinstance(m, nn.Linear):
|
|
||||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
|
||||||
m.weight.data *= scale
|
|
||||||
if m.bias is not None:
|
|
||||||
m.bias.data.zero_()
|
|
||||||
elif isinstance(m, nn.BatchNorm2d):
|
|
||||||
init.constant_(m.weight, 1)
|
|
||||||
init.constant_(m.bias.data, 0.0)
|
|
||||||
|
|
||||||
|
|
||||||
def make_layer(block, n_layers):
|
|
||||||
layers = []
|
|
||||||
for _ in range(n_layers):
|
|
||||||
layers.append(block())
|
|
||||||
return nn.Sequential(*layers)
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualDenseBlock_5C(nn.Module):
|
|
||||||
def __init__(self, nf=64, gc=32, bias=True):
|
|
||||||
super(ResidualDenseBlock_5C, self).__init__()
|
|
||||||
# gc: growth channel, i.e. intermediate channels
|
|
||||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
|
||||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
|
||||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
|
||||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
|
||||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
|
||||||
|
|
||||||
# initialization
|
|
||||||
initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x1 = self.lrelu(self.conv1(x))
|
|
||||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
|
||||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
|
||||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
|
||||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
|
||||||
return x5 * 0.2 + x
|
|
||||||
|
|
||||||
|
|
||||||
class RRDB(nn.Module):
|
|
||||||
'''Residual in Residual Dense Block'''
|
|
||||||
|
|
||||||
def __init__(self, nf, gc=32):
|
|
||||||
super(RRDB, self).__init__()
|
|
||||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
|
||||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
|
||||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = self.RDB1(x)
|
|
||||||
out = self.RDB2(out)
|
|
||||||
out = self.RDB3(out)
|
|
||||||
return out * 0.2 + x
|
|
||||||
|
|
||||||
|
|
||||||
class RRDBNet(nn.Module):
|
|
||||||
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
|
|
||||||
super(RRDBNet, self).__init__()
|
|
||||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
|
||||||
self.sf = sf
|
|
||||||
|
|
||||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
|
||||||
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
|
||||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
#### upsampling
|
|
||||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
if self.sf==4:
|
|
||||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
|
||||||
|
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
fea = self.conv_first(x)
|
|
||||||
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
|
||||||
fea = fea + trunk
|
|
||||||
|
|
||||||
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
|
||||||
if self.sf==4:
|
|
||||||
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
|
||||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
|
||||||
|
|
||||||
return out
|
|
|
@ -11,62 +11,109 @@ from modules.upscaler import Upscaler, UpscalerData
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
|
|
||||||
def fix_model_layers(crt_model, pretrained_net):
|
|
||||||
# this code is adapted from https://github.com/xinntao/ESRGAN
|
|
||||||
if 'conv_first.weight' in pretrained_net:
|
|
||||||
return pretrained_net
|
|
||||||
|
|
||||||
if 'model.0.weight' not in pretrained_net:
|
def mod2normal(state_dict):
|
||||||
is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
|
# this code is copied from https://github.com/victorca25/iNNfer
|
||||||
if is_realesrgan:
|
if 'conv_first.weight' in state_dict:
|
||||||
raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
|
crt_net = {}
|
||||||
else:
|
items = []
|
||||||
raise Exception("The file is not a ESRGAN model.")
|
for k, v in state_dict.items():
|
||||||
|
items.append(k)
|
||||||
|
|
||||||
crt_net = crt_model.state_dict()
|
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
||||||
load_net_clean = {}
|
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
||||||
for k, v in pretrained_net.items():
|
|
||||||
if k.startswith('module.'):
|
|
||||||
load_net_clean[k[7:]] = v
|
|
||||||
else:
|
|
||||||
load_net_clean[k] = v
|
|
||||||
pretrained_net = load_net_clean
|
|
||||||
|
|
||||||
tbd = []
|
for k in items.copy():
|
||||||
for k, v in crt_net.items():
|
|
||||||
tbd.append(k)
|
|
||||||
|
|
||||||
# directly copy
|
|
||||||
for k, v in crt_net.items():
|
|
||||||
if k in pretrained_net and pretrained_net[k].size() == v.size():
|
|
||||||
crt_net[k] = pretrained_net[k]
|
|
||||||
tbd.remove(k)
|
|
||||||
|
|
||||||
crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
|
|
||||||
crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
|
|
||||||
|
|
||||||
for k in tbd.copy():
|
|
||||||
if 'RDB' in k:
|
if 'RDB' in k:
|
||||||
ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
|
ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
|
||||||
if '.weight' in k:
|
if '.weight' in k:
|
||||||
ori_k = ori_k.replace('.weight', '.0.weight')
|
ori_k = ori_k.replace('.weight', '.0.weight')
|
||||||
elif '.bias' in k:
|
elif '.bias' in k:
|
||||||
ori_k = ori_k.replace('.bias', '.0.bias')
|
ori_k = ori_k.replace('.bias', '.0.bias')
|
||||||
crt_net[k] = pretrained_net[ori_k]
|
crt_net[ori_k] = state_dict[k]
|
||||||
tbd.remove(k)
|
items.remove(k)
|
||||||
|
|
||||||
crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
|
crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
|
||||||
crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
|
crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
|
||||||
crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
|
crt_net['model.3.weight'] = state_dict['upconv1.weight']
|
||||||
crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
|
crt_net['model.3.bias'] = state_dict['upconv1.bias']
|
||||||
crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
|
crt_net['model.6.weight'] = state_dict['upconv2.weight']
|
||||||
crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
|
crt_net['model.6.bias'] = state_dict['upconv2.bias']
|
||||||
crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
|
crt_net['model.8.weight'] = state_dict['HRconv.weight']
|
||||||
crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
|
crt_net['model.8.bias'] = state_dict['HRconv.bias']
|
||||||
crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
|
crt_net['model.10.weight'] = state_dict['conv_last.weight']
|
||||||
crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
|
crt_net['model.10.bias'] = state_dict['conv_last.bias']
|
||||||
|
state_dict = crt_net
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def resrgan2normal(state_dict, nb=23):
|
||||||
|
# this code is copied from https://github.com/victorca25/iNNfer
|
||||||
|
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
|
||||||
|
crt_net = {}
|
||||||
|
items = []
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
items.append(k)
|
||||||
|
|
||||||
|
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
||||||
|
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
||||||
|
|
||||||
|
for k in items.copy():
|
||||||
|
if "rdb" in k:
|
||||||
|
ori_k = k.replace('body.', 'model.1.sub.')
|
||||||
|
ori_k = ori_k.replace('.rdb', '.RDB')
|
||||||
|
if '.weight' in k:
|
||||||
|
ori_k = ori_k.replace('.weight', '.0.weight')
|
||||||
|
elif '.bias' in k:
|
||||||
|
ori_k = ori_k.replace('.bias', '.0.bias')
|
||||||
|
crt_net[ori_k] = state_dict[k]
|
||||||
|
items.remove(k)
|
||||||
|
|
||||||
|
crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
|
||||||
|
crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
|
||||||
|
crt_net['model.3.weight'] = state_dict['conv_up1.weight']
|
||||||
|
crt_net['model.3.bias'] = state_dict['conv_up1.bias']
|
||||||
|
crt_net['model.6.weight'] = state_dict['conv_up2.weight']
|
||||||
|
crt_net['model.6.bias'] = state_dict['conv_up2.bias']
|
||||||
|
crt_net['model.8.weight'] = state_dict['conv_hr.weight']
|
||||||
|
crt_net['model.8.bias'] = state_dict['conv_hr.bias']
|
||||||
|
crt_net['model.10.weight'] = state_dict['conv_last.weight']
|
||||||
|
crt_net['model.10.bias'] = state_dict['conv_last.bias']
|
||||||
|
state_dict = crt_net
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def infer_params(state_dict):
|
||||||
|
# this code is copied from https://github.com/victorca25/iNNfer
|
||||||
|
scale2x = 0
|
||||||
|
scalemin = 6
|
||||||
|
n_uplayer = 0
|
||||||
|
plus = False
|
||||||
|
|
||||||
|
for block in list(state_dict):
|
||||||
|
parts = block.split(".")
|
||||||
|
n_parts = len(parts)
|
||||||
|
if n_parts == 5 and parts[2] == "sub":
|
||||||
|
nb = int(parts[3])
|
||||||
|
elif n_parts == 3:
|
||||||
|
part_num = int(parts[1])
|
||||||
|
if (part_num > scalemin
|
||||||
|
and parts[0] == "model"
|
||||||
|
and parts[2] == "weight"):
|
||||||
|
scale2x += 1
|
||||||
|
if part_num > n_uplayer:
|
||||||
|
n_uplayer = part_num
|
||||||
|
out_nc = state_dict[block].shape[0]
|
||||||
|
if not plus and "conv1x1" in block:
|
||||||
|
plus = True
|
||||||
|
|
||||||
|
nf = state_dict["model.0.weight"].shape[0]
|
||||||
|
in_nc = state_dict["model.0.weight"].shape[1]
|
||||||
|
out_nc = out_nc
|
||||||
|
scale = 2 ** scale2x
|
||||||
|
|
||||||
|
return in_nc, out_nc, nf, nb, plus, scale
|
||||||
|
|
||||||
return crt_net
|
|
||||||
|
|
||||||
class UpscalerESRGAN(Upscaler):
|
class UpscalerESRGAN(Upscaler):
|
||||||
def __init__(self, dirname):
|
def __init__(self, dirname):
|
||||||
|
@ -109,20 +156,39 @@ class UpscalerESRGAN(Upscaler):
|
||||||
print("Unable to load %s from %s" % (self.model_path, filename))
|
print("Unable to load %s from %s" % (self.model_path, filename))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
pretrained_net = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
||||||
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
|
|
||||||
|
|
||||||
pretrained_net = fix_model_layers(crt_model, pretrained_net)
|
if "params_ema" in state_dict:
|
||||||
crt_model.load_state_dict(pretrained_net)
|
state_dict = state_dict["params_ema"]
|
||||||
crt_model.eval()
|
elif "params" in state_dict:
|
||||||
|
state_dict = state_dict["params"]
|
||||||
|
num_conv = 16 if "realesr-animevideov3" in filename else 32
|
||||||
|
model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
return crt_model
|
if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
|
||||||
|
nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
|
||||||
|
state_dict = resrgan2normal(state_dict, nb)
|
||||||
|
elif "conv_first.weight" in state_dict:
|
||||||
|
state_dict = mod2normal(state_dict)
|
||||||
|
elif "model.0.weight" not in state_dict:
|
||||||
|
raise Exception("The file is not a recognized ESRGAN model.")
|
||||||
|
|
||||||
|
in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
|
||||||
|
|
||||||
|
model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def upscale_without_tiling(model, img):
|
def upscale_without_tiling(model, img):
|
||||||
img = np.array(img)
|
img = np.array(img)
|
||||||
img = img[:, :, ::-1]
|
img = img[:, :, ::-1]
|
||||||
img = np.moveaxis(img, 2, 0) / 255
|
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
|
||||||
img = torch.from_numpy(img).float()
|
img = torch.from_numpy(img).float()
|
||||||
img = img.unsqueeze(0).to(devices.device_esrgan)
|
img = img.unsqueeze(0).to(devices.device_esrgan)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
|
@ -1,80 +1,463 @@
|
||||||
# this file is taken from https://github.com/xinntao/ESRGAN
|
# this file is adapted from https://github.com/victorca25/iNNfer
|
||||||
|
|
||||||
|
import math
|
||||||
import functools
|
import functools
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
def make_layer(block, n_layers):
|
####################
|
||||||
layers = []
|
# RRDBNet Generator
|
||||||
for _ in range(n_layers):
|
####################
|
||||||
layers.append(block())
|
|
||||||
return nn.Sequential(*layers)
|
|
||||||
|
|
||||||
|
class RRDBNet(nn.Module):
|
||||||
|
def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
|
||||||
|
act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
|
||||||
|
finalact=None, gaussian_noise=False, plus=False):
|
||||||
|
super(RRDBNet, self).__init__()
|
||||||
|
n_upscale = int(math.log(upscale, 2))
|
||||||
|
if upscale == 3:
|
||||||
|
n_upscale = 1
|
||||||
|
|
||||||
class ResidualDenseBlock_5C(nn.Module):
|
self.resrgan_scale = 0
|
||||||
def __init__(self, nf=64, gc=32, bias=True):
|
if in_nc % 16 == 0:
|
||||||
super(ResidualDenseBlock_5C, self).__init__()
|
self.resrgan_scale = 1
|
||||||
# gc: growth channel, i.e. intermediate channels
|
elif in_nc != 4 and in_nc % 4 == 0:
|
||||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
self.resrgan_scale = 2
|
||||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
|
||||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
|
||||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
|
||||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
|
||||||
|
|
||||||
# initialization
|
fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
||||||
# mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||||
|
norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
|
||||||
|
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
|
||||||
|
LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
|
||||||
|
|
||||||
def forward(self, x):
|
if upsample_mode == 'upconv':
|
||||||
x1 = self.lrelu(self.conv1(x))
|
upsample_block = upconv_block
|
||||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
elif upsample_mode == 'pixelshuffle':
|
||||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
upsample_block = pixelshuffle_block
|
||||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
else:
|
||||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
|
||||||
return x5 * 0.2 + x
|
if upscale == 3:
|
||||||
|
upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
|
||||||
|
else:
|
||||||
|
upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
|
||||||
|
HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
|
||||||
|
HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
||||||
|
|
||||||
|
outact = act(finalact) if finalact else None
|
||||||
|
|
||||||
|
self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
|
||||||
|
*upsampler, HR_conv0, HR_conv1, outact)
|
||||||
|
|
||||||
|
def forward(self, x, outm=None):
|
||||||
|
if self.resrgan_scale == 1:
|
||||||
|
feat = pixel_unshuffle(x, scale=4)
|
||||||
|
elif self.resrgan_scale == 2:
|
||||||
|
feat = pixel_unshuffle(x, scale=2)
|
||||||
|
else:
|
||||||
|
feat = x
|
||||||
|
|
||||||
|
return self.model(feat)
|
||||||
|
|
||||||
|
|
||||||
class RRDB(nn.Module):
|
class RRDB(nn.Module):
|
||||||
'''Residual in Residual Dense Block'''
|
"""
|
||||||
|
Residual in Residual Dense Block
|
||||||
|
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, nf, gc=32):
|
def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||||
|
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
||||||
|
spectral_norm=False, gaussian_noise=False, plus=False):
|
||||||
super(RRDB, self).__init__()
|
super(RRDB, self).__init__()
|
||||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
# This is for backwards compatibility with existing models
|
||||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
if nr == 3:
|
||||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||||
|
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||||
|
gaussian_noise=gaussian_noise, plus=plus)
|
||||||
|
self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||||
|
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||||
|
gaussian_noise=gaussian_noise, plus=plus)
|
||||||
|
self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||||
|
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||||
|
gaussian_noise=gaussian_noise, plus=plus)
|
||||||
|
else:
|
||||||
|
RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||||
|
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||||
|
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
|
||||||
|
self.RDBs = nn.Sequential(*RDB_list)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
if hasattr(self, 'RDB1'):
|
||||||
out = self.RDB1(x)
|
out = self.RDB1(x)
|
||||||
out = self.RDB2(out)
|
out = self.RDB2(out)
|
||||||
out = self.RDB3(out)
|
out = self.RDB3(out)
|
||||||
|
else:
|
||||||
|
out = self.RDBs(x)
|
||||||
return out * 0.2 + x
|
return out * 0.2 + x
|
||||||
|
|
||||||
|
|
||||||
class RRDBNet(nn.Module):
|
class ResidualDenseBlock_5C(nn.Module):
|
||||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
|
"""
|
||||||
super(RRDBNet, self).__init__()
|
Residual Dense Block
|
||||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
|
||||||
|
Modified options that can be used:
|
||||||
|
- "Partial Convolution based Padding" arXiv:1811.11718
|
||||||
|
- "Spectral normalization" arXiv:1802.05957
|
||||||
|
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
||||||
|
{Rakotonirina} and A. {Rasoanaivo}
|
||||||
|
"""
|
||||||
|
|
||||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||||
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
||||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
spectral_norm=False, gaussian_noise=False, plus=False):
|
||||||
#### upsampling
|
super(ResidualDenseBlock_5C, self).__init__()
|
||||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
|
||||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
|
||||||
|
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
self.noise = GaussianNoise() if gaussian_noise else None
|
||||||
|
self.conv1x1 = conv1x1(nf, gc) if plus else None
|
||||||
|
|
||||||
|
self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||||
|
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||||
|
spectral_norm=spectral_norm)
|
||||||
|
self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||||
|
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||||
|
spectral_norm=spectral_norm)
|
||||||
|
self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||||
|
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||||
|
spectral_norm=spectral_norm)
|
||||||
|
self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||||
|
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||||
|
spectral_norm=spectral_norm)
|
||||||
|
if mode == 'CNA':
|
||||||
|
last_act = None
|
||||||
|
else:
|
||||||
|
last_act = act_type
|
||||||
|
self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
|
||||||
|
norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
|
||||||
|
spectral_norm=spectral_norm)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
fea = self.conv_first(x)
|
x1 = self.conv1(x)
|
||||||
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
x2 = self.conv2(torch.cat((x, x1), 1))
|
||||||
fea = fea + trunk
|
if self.conv1x1:
|
||||||
|
x2 = x2 + self.conv1x1(x)
|
||||||
|
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
||||||
|
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
||||||
|
if self.conv1x1:
|
||||||
|
x4 = x4 + x2
|
||||||
|
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||||
|
if self.noise:
|
||||||
|
return self.noise(x5.mul(0.2) + x)
|
||||||
|
else:
|
||||||
|
return x5 * 0.2 + x
|
||||||
|
|
||||||
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
|
||||||
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
|
||||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
|
||||||
|
|
||||||
|
####################
|
||||||
|
# ESRGANplus
|
||||||
|
####################
|
||||||
|
|
||||||
|
class GaussianNoise(nn.Module):
|
||||||
|
def __init__(self, sigma=0.1, is_relative_detach=False):
|
||||||
|
super().__init__()
|
||||||
|
self.sigma = sigma
|
||||||
|
self.is_relative_detach = is_relative_detach
|
||||||
|
self.noise = torch.tensor(0, dtype=torch.float)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.training and self.sigma != 0:
|
||||||
|
self.noise = self.noise.to(x.device)
|
||||||
|
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
|
||||||
|
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
|
||||||
|
x = x + sampled_noise
|
||||||
|
return x
|
||||||
|
|
||||||
|
def conv1x1(in_planes, out_planes, stride=1):
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
####################
|
||||||
|
# SRVGGNetCompact
|
||||||
|
####################
|
||||||
|
|
||||||
|
class SRVGGNetCompact(nn.Module):
|
||||||
|
"""A compact VGG-style network structure for super-resolution.
|
||||||
|
This class is copied from https://github.com/xinntao/Real-ESRGAN
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
|
||||||
|
super(SRVGGNetCompact, self).__init__()
|
||||||
|
self.num_in_ch = num_in_ch
|
||||||
|
self.num_out_ch = num_out_ch
|
||||||
|
self.num_feat = num_feat
|
||||||
|
self.num_conv = num_conv
|
||||||
|
self.upscale = upscale
|
||||||
|
self.act_type = act_type
|
||||||
|
|
||||||
|
self.body = nn.ModuleList()
|
||||||
|
# the first conv
|
||||||
|
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
|
||||||
|
# the first activation
|
||||||
|
if act_type == 'relu':
|
||||||
|
activation = nn.ReLU(inplace=True)
|
||||||
|
elif act_type == 'prelu':
|
||||||
|
activation = nn.PReLU(num_parameters=num_feat)
|
||||||
|
elif act_type == 'leakyrelu':
|
||||||
|
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||||
|
self.body.append(activation)
|
||||||
|
|
||||||
|
# the body structure
|
||||||
|
for _ in range(num_conv):
|
||||||
|
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
|
||||||
|
# activation
|
||||||
|
if act_type == 'relu':
|
||||||
|
activation = nn.ReLU(inplace=True)
|
||||||
|
elif act_type == 'prelu':
|
||||||
|
activation = nn.PReLU(num_parameters=num_feat)
|
||||||
|
elif act_type == 'leakyrelu':
|
||||||
|
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||||
|
self.body.append(activation)
|
||||||
|
|
||||||
|
# the last conv
|
||||||
|
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
|
||||||
|
# upsample
|
||||||
|
self.upsampler = nn.PixelShuffle(upscale)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = x
|
||||||
|
for i in range(0, len(self.body)):
|
||||||
|
out = self.body[i](out)
|
||||||
|
|
||||||
|
out = self.upsampler(out)
|
||||||
|
# add the nearest upsampled image, so that the network learns the residual
|
||||||
|
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
|
||||||
|
out += base
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
####################
|
||||||
|
# Upsampler
|
||||||
|
####################
|
||||||
|
|
||||||
|
class Upsample(nn.Module):
|
||||||
|
r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
|
||||||
|
The input data is assumed to be of the form
|
||||||
|
`minibatch x channels x [optional depth] x [optional height] x width`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
||||||
|
super(Upsample, self).__init__()
|
||||||
|
if isinstance(scale_factor, tuple):
|
||||||
|
self.scale_factor = tuple(float(factor) for factor in scale_factor)
|
||||||
|
else:
|
||||||
|
self.scale_factor = float(scale_factor) if scale_factor else None
|
||||||
|
self.mode = mode
|
||||||
|
self.size = size
|
||||||
|
self.align_corners = align_corners
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
if self.scale_factor is not None:
|
||||||
|
info = 'scale_factor=' + str(self.scale_factor)
|
||||||
|
else:
|
||||||
|
info = 'size=' + str(self.size)
|
||||||
|
info += ', mode=' + self.mode
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
def pixel_unshuffle(x, scale):
|
||||||
|
""" Pixel unshuffle.
|
||||||
|
Args:
|
||||||
|
x (Tensor): Input feature with shape (b, c, hh, hw).
|
||||||
|
scale (int): Downsample ratio.
|
||||||
|
Returns:
|
||||||
|
Tensor: the pixel unshuffled feature.
|
||||||
|
"""
|
||||||
|
b, c, hh, hw = x.size()
|
||||||
|
out_channel = c * (scale**2)
|
||||||
|
assert hh % scale == 0 and hw % scale == 0
|
||||||
|
h = hh // scale
|
||||||
|
w = hw // scale
|
||||||
|
x_view = x.view(b, c, h, scale, w, scale)
|
||||||
|
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
||||||
|
|
||||||
|
|
||||||
|
def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
||||||
|
pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
|
||||||
|
"""
|
||||||
|
Pixel shuffle layer
|
||||||
|
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
|
||||||
|
Neural Network, CVPR17)
|
||||||
|
"""
|
||||||
|
conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
|
||||||
|
pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
|
||||||
|
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
||||||
|
|
||||||
|
n = norm(norm_type, out_nc) if norm_type else None
|
||||||
|
a = act(act_type) if act_type else None
|
||||||
|
return sequential(conv, pixel_shuffle, n, a)
|
||||||
|
|
||||||
|
|
||||||
|
def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
||||||
|
pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
|
||||||
|
""" Upconv layer """
|
||||||
|
upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
|
||||||
|
upsample = Upsample(scale_factor=upscale_factor, mode=mode)
|
||||||
|
conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
|
||||||
|
pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
|
||||||
|
return sequential(upsample, conv)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
####################
|
||||||
|
# Basic blocks
|
||||||
|
####################
|
||||||
|
|
||||||
|
|
||||||
|
def make_layer(basic_block, num_basic_block, **kwarg):
|
||||||
|
"""Make layers by stacking the same blocks.
|
||||||
|
Args:
|
||||||
|
basic_block (nn.module): nn.module class for basic block. (block)
|
||||||
|
num_basic_block (int): number of blocks. (n_layers)
|
||||||
|
Returns:
|
||||||
|
nn.Sequential: Stacked blocks in nn.Sequential.
|
||||||
|
"""
|
||||||
|
layers = []
|
||||||
|
for _ in range(num_basic_block):
|
||||||
|
layers.append(basic_block(**kwarg))
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
|
||||||
|
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
|
||||||
|
""" activation helper """
|
||||||
|
act_type = act_type.lower()
|
||||||
|
if act_type == 'relu':
|
||||||
|
layer = nn.ReLU(inplace)
|
||||||
|
elif act_type in ('leakyrelu', 'lrelu'):
|
||||||
|
layer = nn.LeakyReLU(neg_slope, inplace)
|
||||||
|
elif act_type == 'prelu':
|
||||||
|
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
||||||
|
elif act_type == 'tanh': # [-1, 1] range output
|
||||||
|
layer = nn.Tanh()
|
||||||
|
elif act_type == 'sigmoid': # [0, 1] range output
|
||||||
|
layer = nn.Sigmoid()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
|
||||||
|
return layer
|
||||||
|
|
||||||
|
|
||||||
|
class Identity(nn.Module):
|
||||||
|
def __init__(self, *kwargs):
|
||||||
|
super(Identity, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x, *kwargs):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def norm(norm_type, nc):
|
||||||
|
""" Return a normalization layer """
|
||||||
|
norm_type = norm_type.lower()
|
||||||
|
if norm_type == 'batch':
|
||||||
|
layer = nn.BatchNorm2d(nc, affine=True)
|
||||||
|
elif norm_type == 'instance':
|
||||||
|
layer = nn.InstanceNorm2d(nc, affine=False)
|
||||||
|
elif norm_type == 'none':
|
||||||
|
def norm_layer(x): return Identity()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
|
||||||
|
return layer
|
||||||
|
|
||||||
|
|
||||||
|
def pad(pad_type, padding):
|
||||||
|
""" padding layer helper """
|
||||||
|
pad_type = pad_type.lower()
|
||||||
|
if padding == 0:
|
||||||
|
return None
|
||||||
|
if pad_type == 'reflect':
|
||||||
|
layer = nn.ReflectionPad2d(padding)
|
||||||
|
elif pad_type == 'replicate':
|
||||||
|
layer = nn.ReplicationPad2d(padding)
|
||||||
|
elif pad_type == 'zero':
|
||||||
|
layer = nn.ZeroPad2d(padding)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
|
||||||
|
return layer
|
||||||
|
|
||||||
|
|
||||||
|
def get_valid_padding(kernel_size, dilation):
|
||||||
|
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
|
||||||
|
padding = (kernel_size - 1) // 2
|
||||||
|
return padding
|
||||||
|
|
||||||
|
|
||||||
|
class ShortcutBlock(nn.Module):
|
||||||
|
""" Elementwise sum the output of a submodule to its input """
|
||||||
|
def __init__(self, submodule):
|
||||||
|
super(ShortcutBlock, self).__init__()
|
||||||
|
self.sub = submodule
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output = x + self.sub(x)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
|
||||||
|
|
||||||
|
|
||||||
|
def sequential(*args):
|
||||||
|
""" Flatten Sequential. It unwraps nn.Sequential. """
|
||||||
|
if len(args) == 1:
|
||||||
|
if isinstance(args[0], OrderedDict):
|
||||||
|
raise NotImplementedError('sequential does not support OrderedDict input.')
|
||||||
|
return args[0] # No sequential is needed.
|
||||||
|
modules = []
|
||||||
|
for module in args:
|
||||||
|
if isinstance(module, nn.Sequential):
|
||||||
|
for submodule in module.children():
|
||||||
|
modules.append(submodule)
|
||||||
|
elif isinstance(module, nn.Module):
|
||||||
|
modules.append(module)
|
||||||
|
return nn.Sequential(*modules)
|
||||||
|
|
||||||
|
|
||||||
|
def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
|
||||||
|
pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
|
||||||
|
spectral_norm=False):
|
||||||
|
""" Conv layer with padding, normalization, activation """
|
||||||
|
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
|
||||||
|
padding = get_valid_padding(kernel_size, dilation)
|
||||||
|
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
|
||||||
|
padding = padding if pad_type == 'zero' else 0
|
||||||
|
|
||||||
|
if convtype=='PartialConv2D':
|
||||||
|
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
|
dilation=dilation, bias=bias, groups=groups)
|
||||||
|
elif convtype=='DeformConv2D':
|
||||||
|
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
|
dilation=dilation, bias=bias, groups=groups)
|
||||||
|
elif convtype=='Conv3D':
|
||||||
|
c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
|
dilation=dilation, bias=bias, groups=groups)
|
||||||
|
else:
|
||||||
|
c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
|
dilation=dilation, bias=bias, groups=groups)
|
||||||
|
|
||||||
|
if spectral_norm:
|
||||||
|
c = nn.utils.spectral_norm(c)
|
||||||
|
|
||||||
|
a = act(act_type) if act_type else None
|
||||||
|
if 'CNA' in mode:
|
||||||
|
n = norm(norm_type, out_nc) if norm_type else None
|
||||||
|
return sequential(p, c, n, a)
|
||||||
|
elif mode == 'NAC':
|
||||||
|
if norm_type is None and act_type is not None:
|
||||||
|
a = act(act_type, inplace=False)
|
||||||
|
n = norm(norm_type, in_nc) if norm_type else None
|
||||||
|
return sequential(n, a, p, c)
|
||||||
|
|
Loading…
Reference in a new issue