codeformer support

This commit is contained in:
AUTOMATIC 2022-09-07 12:32:28 +03:00
parent 9cb3cc3a2f
commit 6a9b33c848
13 changed files with 919 additions and 32 deletions

View file

@ -0,0 +1,276 @@
import math
import numpy as np
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from typing import Optional, List
from modules.codeformer.vqgan_arch import *
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY
def calc_mean_std(feat, eps=1e-5):
"""Calculate mean and std for adaptive_instance_normalization.
Args:
feat (Tensor): 4D tensor.
eps (float): A small value added to the variance to avoid
divide-by-zero. Default: 1e-5.
"""
size = feat.size()
assert len(size) == 4, 'The input feature should be 4D tensor.'
b, c = size[:2]
feat_var = feat.view(b, c, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(b, c, 1, 1)
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
return feat_mean, feat_std
def adaptive_instance_normalization(content_feat, style_feat):
"""Adaptive instance normalization.
Adjust the reference features to have the similar color and illuminations
as those in the degradate features.
Args:
content_feat (Tensor): The reference feature.
style_feat (Tensor): The degradate features.
"""
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, x, mask=None):
if mask is None:
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
class TransformerSALayer(nn.Module):
def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
super().__init__()
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
# Implementation of Feedforward model - MLP
self.linear1 = nn.Linear(embed_dim, dim_mlp)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_mlp, embed_dim)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward(self, tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
# self attention
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
# ffn
tgt2 = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout2(tgt2)
return tgt
class Fuse_sft_block(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.encode_enc = ResBlock(2*in_ch, out_ch)
self.scale = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
self.shift = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
def forward(self, enc_feat, dec_feat, w=1):
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
scale = self.scale(enc_feat)
shift = self.shift(enc_feat)
residual = w * (dec_feat * scale + shift)
out = dec_feat + residual
return out
@ARCH_REGISTRY.register()
class CodeFormer(VQAutoEncoder):
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
codebook_size=1024, latent_size=256,
connect_list=['32', '64', '128', '256'],
fix_modules=['quantize','generator']):
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
if fix_modules is not None:
for module in fix_modules:
for param in getattr(self, module).parameters():
param.requires_grad = False
self.connect_list = connect_list
self.n_layers = n_layers
self.dim_embd = dim_embd
self.dim_mlp = dim_embd*2
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
self.feat_emb = nn.Linear(256, self.dim_embd)
# transformer
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
for _ in range(self.n_layers)])
# logits_predict head
self.idx_pred_layer = nn.Sequential(
nn.LayerNorm(dim_embd),
nn.Linear(dim_embd, codebook_size, bias=False))
self.channels = {
'16': 512,
'32': 256,
'64': 256,
'128': 128,
'256': 128,
'512': 64,
}
# after second residual block for > 16, before attn layer for ==16
self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
# after first residual block for > 16, before attn layer for ==16
self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
# fuse_convs_dict
self.fuse_convs_dict = nn.ModuleDict()
for f_size in self.connect_list:
in_ch = self.channels[f_size]
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
# ################### Encoder #####################
enc_feat_dict = {}
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.encoder.blocks):
x = block(x)
if i in out_list:
enc_feat_dict[str(x.shape[-1])] = x.clone()
lq_feat = x
# ################# Transformer ###################
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
# BCHW -> BC(HW) -> (HW)BC
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
query_emb = feat_emb
# Transformer encoder
for layer in self.ft_layers:
query_emb = layer(query_emb, query_pos=pos_emb)
# output logits
logits = self.idx_pred_layer(query_emb) # (hw)bn
logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
if code_only: # for training stage II
# logits doesn't need softmax before cross_entropy loss
return logits, lq_feat
# ################# Quantization ###################
# if self.training:
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
# # b(hw)c -> bc(hw) -> bchw
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
# ------------
soft_one_hot = F.softmax(logits, dim=2)
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
# preserve gradients
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
if detach_16:
quant_feat = quant_feat.detach() # for training stage III
if adain:
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
# ################## Generator ####################
x = quant_feat
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.generator.blocks):
x = block(x)
if i in fuse_list: # fuse after i-th block
f_size = str(x.shape[-1])
if w>0:
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
out = x
# logits doesn't need softmax before cross_entropy loss
return out, logits, lq_feat

View file

@ -0,0 +1,435 @@
'''
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
'''
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY
def normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
@torch.jit.script
def swish(x):
return x*torch.sigmoid(x)
# Define VQVAE classes
class VectorQuantizer(nn.Module):
def __init__(self, codebook_size, emb_dim, beta):
super(VectorQuantizer, self).__init__()
self.codebook_size = codebook_size # number of embeddings
self.emb_dim = emb_dim # dimension of embedding
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
def forward(self, z):
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.emb_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
2 * torch.matmul(z_flattened, self.embedding.weight.t())
mean_distance = torch.mean(d)
# find closest encodings
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
# [0-1], higher score, higher confidence
min_encoding_scores = torch.exp(-min_encoding_scores/10)
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
min_encodings.scatter_(1, min_encoding_indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
# compute loss for embedding
loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
# preserve gradients
z_q = z + (z_q - z).detach()
# perplexity
e_mean = torch.mean(min_encodings, dim=0)
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q, loss, {
"perplexity": perplexity,
"min_encodings": min_encodings,
"min_encoding_indices": min_encoding_indices,
"min_encoding_scores": min_encoding_scores,
"mean_distance": mean_distance
}
def get_codebook_feat(self, indices, shape):
# input indices: batch*token_num -> (batch*token_num)*1
# shape: batch, height, width, channel
indices = indices.view(-1,1)
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
min_encodings.scatter_(1, indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
if shape is not None: # reshape back to match original input shape
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
return z_q
class GumbelQuantizer(nn.Module):
def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
super().__init__()
self.codebook_size = codebook_size # number of embeddings
self.emb_dim = emb_dim # dimension of embedding
self.straight_through = straight_through
self.temperature = temp_init
self.kl_weight = kl_weight
self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
self.embed = nn.Embedding(codebook_size, emb_dim)
def forward(self, z):
hard = self.straight_through if self.training else True
logits = self.proj(z)
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
# + kl divergence to the prior loss
qy = F.softmax(logits, dim=1)
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
min_encoding_indices = soft_one_hot.argmax(dim=1)
return z_q, diff, {
"min_encoding_indices": min_encoding_indices
}
class Downsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels=None):
super(ResBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.norm1 = normalize(in_channels)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = normalize(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x_in):
x = x_in
x = self.norm1(x)
x = swish(x)
x = self.conv1(x)
x = self.norm2(x)
x = swish(x)
x = self.conv2(x)
if self.in_channels != self.out_channels:
x_in = self.conv_out(x_in)
return x + x_in
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0
)
self.k = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0
)
self.v = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h*w)
q = q.permute(0, 2, 1)
k = k.reshape(b, c, h*w)
w_ = torch.bmm(q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = F.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h*w)
w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x+h_
class Encoder(nn.Module):
def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
super().__init__()
self.nf = nf
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.attn_resolutions = attn_resolutions
curr_res = self.resolution
in_ch_mult = (1,)+tuple(ch_mult)
blocks = []
# initial convultion
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
# residual and downsampling blocks, with attention on smaller res (16x16)
for i in range(self.num_resolutions):
block_in_ch = nf * in_ch_mult[i]
block_out_ch = nf * ch_mult[i]
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch))
block_in_ch = block_out_ch
if curr_res in attn_resolutions:
blocks.append(AttnBlock(block_in_ch))
if i != self.num_resolutions - 1:
blocks.append(Downsample(block_in_ch))
curr_res = curr_res // 2
# non-local attention block
blocks.append(ResBlock(block_in_ch, block_in_ch))
blocks.append(AttnBlock(block_in_ch))
blocks.append(ResBlock(block_in_ch, block_in_ch))
# normalise and convert to latent size
blocks.append(normalize(block_in_ch))
blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
class Generator(nn.Module):
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
super().__init__()
self.nf = nf
self.ch_mult = ch_mult
self.num_resolutions = len(self.ch_mult)
self.num_res_blocks = res_blocks
self.resolution = img_size
self.attn_resolutions = attn_resolutions
self.in_channels = emb_dim
self.out_channels = 3
block_in_ch = self.nf * self.ch_mult[-1]
curr_res = self.resolution // 2 ** (self.num_resolutions-1)
blocks = []
# initial conv
blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
# non-local attention block
blocks.append(ResBlock(block_in_ch, block_in_ch))
blocks.append(AttnBlock(block_in_ch))
blocks.append(ResBlock(block_in_ch, block_in_ch))
for i in reversed(range(self.num_resolutions)):
block_out_ch = self.nf * self.ch_mult[i]
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch))
block_in_ch = block_out_ch
if curr_res in self.attn_resolutions:
blocks.append(AttnBlock(block_in_ch))
if i != 0:
blocks.append(Upsample(block_in_ch))
curr_res = curr_res * 2
blocks.append(normalize(block_in_ch))
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
@ARCH_REGISTRY.register()
class VQAutoEncoder(nn.Module):
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
super().__init__()
logger = get_root_logger()
self.in_channels = 3
self.nf = nf
self.n_blocks = res_blocks
self.codebook_size = codebook_size
self.embed_dim = emb_dim
self.ch_mult = ch_mult
self.resolution = img_size
self.attn_resolutions = attn_resolutions
self.quantizer_type = quantizer
self.encoder = Encoder(
self.in_channels,
self.nf,
self.embed_dim,
self.ch_mult,
self.n_blocks,
self.resolution,
self.attn_resolutions
)
if self.quantizer_type == "nearest":
self.beta = beta #0.25
self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
elif self.quantizer_type == "gumbel":
self.gumbel_num_hiddens = emb_dim
self.straight_through = gumbel_straight_through
self.kl_weight = gumbel_kl_weight
self.quantize = GumbelQuantizer(
self.codebook_size,
self.embed_dim,
self.gumbel_num_hiddens,
self.straight_through,
self.kl_weight
)
self.generator = Generator(
self.nf,
self.embed_dim,
self.ch_mult,
self.n_blocks,
self.resolution,
self.attn_resolutions
)
if model_path is not None:
chkpt = torch.load(model_path, map_location='cpu')
if 'params_ema' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
elif 'params' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
logger.info(f'vqgan is loaded from: {model_path} [params]')
else:
raise ValueError(f'Wrong params!')
def forward(self, x):
x = self.encoder(x)
quant, codebook_loss, quant_stats = self.quantize(x)
x = self.generator(quant)
return x, codebook_loss, quant_stats
# patch based discriminator
@ARCH_REGISTRY.register()
class VQGANDiscriminator(nn.Module):
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
super().__init__()
layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
ndf_mult = 1
ndf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
ndf_mult_prev = ndf_mult
ndf_mult = min(2 ** n, 8)
layers += [
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ndf * ndf_mult),
nn.LeakyReLU(0.2, True)
]
ndf_mult_prev = ndf_mult
ndf_mult = min(2 ** n_layers, 8)
layers += [
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
nn.BatchNorm2d(ndf * ndf_mult),
nn.LeakyReLU(0.2, True)
]
layers += [
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
self.main = nn.Sequential(*layers)
if model_path is not None:
chkpt = torch.load(model_path, map_location='cpu')
if 'params_d' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
elif 'params' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
else:
raise ValueError(f'Wrong params!')
def forward(self, x):
return self.main(x)

108
modules/codeformer_model.py Normal file
View file

@ -0,0 +1,108 @@
import os
import sys
import traceback
import torch
from modules import shared
from modules.paths import script_path
import modules.shared
import modules.face_restoration
from importlib import reload
# codeformer people made a choice to include modified basicsr librry to their projectwhich makes
# it utterly impossiblr to use it alongside with other libraries that also use basicsr, like GFPGAN.
# I am making a choice to include some files from codeformer to work around this issue.
pretrain_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
have_codeformer = False
def setup_codeformer():
path = modules.paths.paths.get("CodeFormer", None)
if path is None:
return
# both GFPGAN and CodeFormer use bascisr, one has it installed from pip the other uses its own
#stored_sys_path = sys.path
#sys.path = [path] + sys.path
try:
from torchvision.transforms.functional import normalize
from modules.codeformer.codeformer_arch import CodeFormer
from basicsr.utils.download_util import load_file_from_url
from basicsr.utils import imwrite, img2tensor, tensor2img
from facelib.utils.face_restoration_helper import FaceRestoreHelper
from modules.shared import cmd_opts
net_class = CodeFormer
class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
def name(self):
return "CodeFormer"
def __init__(self):
self.net = None
self.face_helper = None
def create_models(self):
if self.net is not None and self.face_helper is not None:
return self.net, self.face_helper
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(shared.device)
ckpt_path = load_file_from_url(url=pretrain_model_url, model_dir=os.path.join(path, 'weights/CodeFormer'), progress=True)
checkpoint = torch.load(ckpt_path)['params_ema']
net.load_state_dict(checkpoint)
net.eval()
face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=shared.device)
if not cmd_opts.unload_gfpgan:
self.net = net
self.face_helper = face_helper
return net, face_helper
def restore(self, np_image):
np_image = np_image[:, :, ::-1]
net, face_helper = self.create_models()
face_helper.clean_all()
face_helper.read_image(np_image)
face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
face_helper.align_warp_face()
for idx, cropped_face in enumerate(face_helper.cropped_faces):
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(shared.device)
try:
with torch.no_grad():
output = net(cropped_face_t, w=shared.opts.code_former_weight, adain=True)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output
torch.cuda.empty_cache()
except Exception as error:
print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
restored_face = restored_face.astype('uint8')
face_helper.add_restored_face(restored_face)
face_helper.get_inverse_affine(None)
restored_img = face_helper.paste_faces_to_input_image()
restored_img = restored_img[:, :, ::-1]
return restored_img
global have_codeformer
have_codeformer = True
shared.face_restorers.append(FaceRestorerCodeFormer())
except Exception:
print("Error setting up CodeFormer:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
# sys.path = stored_sys_path

View file

@ -0,0 +1,19 @@
from modules import shared
class FaceRestoration:
def name(self):
return "None"
def restore(self, np_image):
return np_image
def restore_faces(np_image):
face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None]
if len(face_restorers) == 0:
return np_image
face_restorer = face_restorers[0]
return face_restorer.restore(np_image)

View file

@ -2,12 +2,15 @@ import os
import sys import sys
import traceback import traceback
from modules.paths import script_path from modules import shared
from modules.shared import cmd_opts from modules.shared import cmd_opts
import modules.shared from modules.paths import script_path
import modules.face_restoration
def gfpgan_model_path(): def gfpgan_model_path():
from modules.shared import cmd_opts
places = [script_path, '.', os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models')] places = [script_path, '.', os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models')]
files = [cmd_opts.gfpgan_model] + [os.path.join(dirname, cmd_opts.gfpgan_model) for dirname in places] files = [cmd_opts.gfpgan_model] + [os.path.join(dirname, cmd_opts.gfpgan_model) for dirname in places]
found = [x for x in files if os.path.exists(x)] found = [x for x in files if os.path.exists(x)]
@ -62,6 +65,19 @@ def setup_gfpgan():
global gfpgan_constructor global gfpgan_constructor
gfpgan_constructor = GFPGANer gfpgan_constructor = GFPGANer
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
def name(self):
return "GFPGAN"
def restore(self, np_image):
np_image_bgr = np_image[:, :, ::-1]
cropped_faces, restored_faces, gfpgan_output_bgr = gfpgan().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1]
return np_image
shared.face_restorers.append(FaceRestorerGFPGAN())
except Exception: except Exception:
print("Error setting up GFPGAN:", file=sys.stderr) print("Error setting up GFPGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)

View file

@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html
import modules.images as images import modules.images as images
import modules.scripts import modules.scripts
def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, use_GFPGAN: bool, tiling: bool, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, upscaler_index: str, upscale_overlap: int, inpaint_full_res: bool, inpainting_mask_invert: int, *args): def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, upscaler_index: str, upscale_overlap: int, inpaint_full_res: bool, inpainting_mask_invert: int, *args):
is_inpaint = mode == 1 is_inpaint = mode == 1
is_loopback = mode == 2 is_loopback = mode == 2
is_upscale = mode == 3 is_upscale = mode == 3
@ -36,7 +36,7 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index
cfg_scale=cfg_scale, cfg_scale=cfg_scale,
width=width, width=width,
height=height, height=height,
use_GFPGAN=use_GFPGAN, restore_faces=restore_faces,
tiling=tiling, tiling=tiling,
init_images=[image], init_images=[image],
mask=mask, mask=mask,

View file

@ -5,7 +5,7 @@ import sys
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.insert(0, script_path) sys.path.insert(0, script_path)
# use current directory as SD dir if it has related files, otherwise parent dir of script as stated in guide # search for directory of stable diffsuion in following palces
sd_path = None sd_path = None
possible_sd_paths = ['.', os.path.dirname(script_path), os.path.join(script_path, 'repositories/stable-diffusion')] possible_sd_paths = ['.', os.path.dirname(script_path), os.path.join(script_path, 'repositories/stable-diffusion')]
for possible_sd_path in possible_sd_paths: for possible_sd_path in possible_sd_paths:
@ -14,14 +14,19 @@ for possible_sd_path in possible_sd_paths:
assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + possible_sd_paths assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + possible_sd_paths
# add parent directory to path; this is where Stable diffusion repo should be
path_dirs = [ path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion'), (sd_path, 'ldm', 'Stable Diffusion'),
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers') (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers'),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer'),
] ]
paths = {}
for d, must_exist, what in path_dirs: for d, must_exist, what in path_dirs:
must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist)) must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
if not os.path.exists(must_exist_path): if not os.path.exists(must_exist_path):
print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr) print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr)
else: else:
sys.path.append(os.path.join(script_path, d)) d = os.path.abspath(d)
sys.path.append(d)
paths[what] = d

View file

@ -14,7 +14,7 @@ from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img from modules.sd_samplers import samplers, samplers_for_img2img
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
import modules.gfpgan_model as gfpgan import modules.face_restoration
import modules.images as images import modules.images as images
# some of those options should not be changed at all because they would break the model, so I removed them from options. # some of those options should not be changed at all because they would break the model, so I removed them from options.
@ -29,7 +29,7 @@ def torch_gc():
class StableDiffusionProcessing: class StableDiffusionProcessing:
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, use_GFPGAN=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None): def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
self.sd_model = sd_model self.sd_model = sd_model
self.outpath_samples: str = outpath_samples self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids self.outpath_grids: str = outpath_grids
@ -44,7 +44,7 @@ class StableDiffusionProcessing:
self.cfg_scale: float = cfg_scale self.cfg_scale: float = cfg_scale
self.width: int = width self.width: int = width
self.height: int = height self.height: int = height
self.use_GFPGAN: bool = use_GFPGAN self.restore_faces: bool = restore_faces
self.tiling: bool = tiling self.tiling: bool = tiling
self.do_not_save_samples: bool = do_not_save_samples self.do_not_save_samples: bool = do_not_save_samples
self.do_not_save_grid: bool = do_not_save_grid self.do_not_save_grid: bool = do_not_save_grid
@ -136,7 +136,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
"Sampler": samplers[p.sampler_index].name, "Sampler": samplers[p.sampler_index].name,
"CFG scale": p.cfg_scale, "CFG scale": p.cfg_scale,
"Seed": all_seeds[position_in_batch + iteration * p.batch_size], "Seed": all_seeds[position_in_batch + iteration * p.batch_size],
"GFPGAN": ("GFPGAN" if p.use_GFPGAN else None), "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
"Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch), "Batch pos": (None if p.batch_size < 2 else position_in_batch),
} }
@ -193,10 +193,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
if p.use_GFPGAN: if p.restore_faces:
torch_gc() torch_gc()
x_sample = gfpgan.gfpgan_fix_faces(x_sample) x_sample = modules.face_restoration.restore_faces(x_sample)
image = Image.fromarray(x_sample) image = Image.fromarray(x_sample)

View file

@ -1,11 +1,13 @@
import argparse import argparse
import json import json
import os import os
import gradio as gr import gradio as gr
import torch import torch
import modules.artists import modules.artists
from modules.paths import script_path, sd_path from modules.paths import script_path, sd_path
import modules.codeformer_model
config_filename = "config.json" config_filename = "config.json"
@ -40,6 +42,7 @@ device = gpu if torch.cuda.is_available() else cpu
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
class State: class State:
interrupted = False interrupted = False
job = "" job = ""
@ -65,6 +68,7 @@ state = State()
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv')) artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
face_restorers = []
def find_any_font(): def find_any_font():
fonts = ['/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf'] fonts = ['/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf']
@ -116,6 +120,8 @@ class Options:
"upscale_at_full_resolution_padding": OptionInfo(16, "Inpainting at full resolution: padding, in pixels, for the masked region.", gr.Slider, {"minimum": 0, "maximum": 128, "step": 4}), "upscale_at_full_resolution_padding": OptionInfo(16, "Inpainting at full resolution: padding, in pixels, for the masked region.", gr.Slider, {"minimum": 0, "maximum": 128, "step": 4}),
"show_progressbar": OptionInfo(True, "Show progressbar"), "show_progressbar": OptionInfo(True, "Show progressbar"),
"show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}), "show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
"face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
} }
def __init__(self): def __init__(self):

View file

@ -6,7 +6,7 @@ import modules.processing as processing
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, use_GFPGAN: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int, *args): def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int, *args):
p = StableDiffusionProcessingTxt2Img( p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model, sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@ -21,7 +21,7 @@ def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, u
cfg_scale=cfg_scale, cfg_scale=cfg_scale,
width=width, width=width,
height=height, height=height,
use_GFPGAN=use_GFPGAN, restore_faces=restore_faces,
tiling=tiling, tiling=tiling,
) )

View file

@ -206,7 +206,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index") sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index")
with gr.Row(): with gr.Row():
use_gfpgan = gr.Checkbox(label='GFPGAN', value=False, visible=gfpgan.have_gfpgan) restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
tiling = gr.Checkbox(label='Tiling', value=False) tiling = gr.Checkbox(label='Tiling', value=False)
with gr.Row(): with gr.Row():
@ -253,7 +253,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
negative_prompt, negative_prompt,
steps, steps,
sampler_index, sampler_index,
use_gfpgan, restore_faces,
tiling, tiling,
batch_count, batch_count,
batch_size, batch_size,
@ -335,7 +335,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
inpainting_mask_invert = gr.Radio(label='Masking mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", visible=False) inpainting_mask_invert = gr.Radio(label='Masking mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", visible=False)
with gr.Row(): with gr.Row():
use_gfpgan = gr.Checkbox(label='GFPGAN', value=False, visible=gfpgan.have_gfpgan) restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
tiling = gr.Checkbox(label='Tiling', value=False) tiling = gr.Checkbox(label='Tiling', value=False)
sd_upscale_overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, visible=False) sd_upscale_overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, visible=False)
@ -425,7 +425,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
sampler_index, sampler_index,
mask_blur, mask_blur,
inpainting_fill, inpainting_fill,
use_gfpgan, restore_faces,
tiling, tiling,
switch_mode, switch_mode,
batch_count, batch_count,
@ -521,7 +521,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1) extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1)
with gr.Group(): with gr.Group():
gfpgan_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN strength", value=0, interactive=gfpgan.have_gfpgan) face_restoration_blending = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Faces restoration visibility", value=0, interactive=len(shared.face_restorers) > 1)
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
@ -534,7 +534,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
fn=run_extras, fn=run_extras,
inputs=[ inputs=[
image, image,
gfpgan_strength, face_restoration_blending,
upscaling_resize, upscaling_resize,
extras_upscaler_1, extras_upscaler_1,
extras_upscaler_2, extras_upscaler_2,
@ -585,7 +585,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
t = type(info.default) t = type(info.default)
if info.component is not None: if info.component is not None:
item = info.component(label=info.label, value=fun, **(info.component_args or {})) args = info.component_args() if callable(info.component_args) else info.component_args
item = info.component(label=info.label, value=fun, **(args or {}))
elif t == str: elif t == str:
item = gr.Textbox(label=info.label, value=fun, lines=1) item = gr.Textbox(label=info.label, value=fun, lines=1)
elif t == int: elif t == int:

View file

@ -92,6 +92,7 @@ echo Installing requirements...
%PYTHON% -m pip install -r %REQS_FILE% --prefer-binary >tmp/stdout.txt 2>tmp/stderr.txt %PYTHON% -m pip install -r %REQS_FILE% --prefer-binary >tmp/stdout.txt 2>tmp/stderr.txt
if %ERRORLEVEL% == 0 goto :update_numpy if %ERRORLEVEL% == 0 goto :update_numpy
goto :show_stdout_stderr goto :show_stdout_stderr
:update_numpy :update_numpy
%PYTHON% -m pip install -U numpy --prefer-binary >tmp/stdout.txt 2>tmp/stderr.txt %PYTHON% -m pip install -U numpy --prefer-binary >tmp/stdout.txt 2>tmp/stderr.txt
@ -105,12 +106,28 @@ if %ERRORLEVEL% == 0 goto :clone_transformers
goto :show_stdout_stderr goto :show_stdout_stderr
:clone_transformers :clone_transformers
if exist repositories\taming-transformers goto :check_model if exist repositories\taming-transformers goto :clone_codeformer
echo Cloning Taming Transforming repository... echo Cloning Taming Transforming repository...
%GIT% clone https://github.com/CompVis/taming-transformers.git repositories\taming-transformers >tmp/stdout.txt 2>tmp/stderr.txt %GIT% clone https://github.com/CompVis/taming-transformers.git repositories\taming-transformers >tmp/stdout.txt 2>tmp/stderr.txt
if %ERRORLEVEL% == 0 goto :clone_codeformer
goto :show_stdout_stderr
:clone_codeformer
if exist repositories\CodeFormer goto :install_codeformer_reqs
echo Cloning CodeFormer repository...
%GIT% clone https://github.com/sczhou/CodeFormer.git repositories\CodeFormer >tmp/stdout.txt 2>tmp/stderr.txt
if %ERRORLEVEL% == 0 goto :install_codeformer_reqs
goto :show_stdout_stderr
:install_codeformer_reqs
%PYTHON% -c "import lpips" >tmp/stdout.txt 2>tmp/stderr.txt
if %ERRORLEVEL% == 0 goto :check_model
echo Installing requirements for CodeFormer...
%PYTHON% -m pip install -r repositories\CodeFormer\requirements.txt --prefer-binary >tmp/stdout.txt 2>tmp/stderr.txt
if %ERRORLEVEL% == 0 goto :check_model if %ERRORLEVEL% == 0 goto :check_model
goto :show_stdout_stderr goto :show_stdout_stderr
:check_model :check_model
dir model.ckpt >tmp/stdout.txt 2>tmp/stderr.txt dir model.ckpt >tmp/stdout.txt 2>tmp/stderr.txt
if %ERRORLEVEL% == 0 goto :check_gfpgan if %ERRORLEVEL% == 0 goto :check_gfpgan

View file

@ -19,7 +19,9 @@ from modules.ui import plaintext_to_html
import modules.scripts import modules.scripts
import modules.processing as processing import modules.processing as processing
import modules.sd_hijack import modules.sd_hijack
import modules.gfpgan_model as gfpgan import modules.codeformer_model
import modules.gfpgan_model
import modules.face_restoration
import modules.realesrgan_model as realesrgan import modules.realesrgan_model as realesrgan
import modules.esrgan_model as esrgan import modules.esrgan_model as esrgan
import modules.images as images import modules.images as images
@ -28,10 +30,12 @@ import modules.txt2img
import modules.img2img import modules.img2img
modules.codeformer_model.setup_codeformer()
modules.gfpgan_model.setup_gfpgan()
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
esrgan.load_models(cmd_opts.esrgan_models_path) esrgan.load_models(cmd_opts.esrgan_models_path)
realesrgan.setup_realesrgan() realesrgan.setup_realesrgan()
gfpgan.setup_gfpgan()
def load_model_from_config(config, ckpt, verbose=False): def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}") print(f"Loading model from {ckpt}")
@ -54,19 +58,19 @@ def load_model_from_config(config, ckpt, verbose=False):
cached_images = {} cached_images = {}
def run_extras(image, gfpgan_strength, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility): def run_extras(image, face_restoration_blending, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
processing.torch_gc() processing.torch_gc()
image = image.convert("RGB") image = image.convert("RGB")
outpath = opts.outdir_samples or opts.outdir_extras_samples outpath = opts.outdir_samples or opts.outdir_extras_samples
if gfpgan.have_gfpgan is not None and gfpgan_strength > 0: if face_restoration_blending > 0:
restored_img = gfpgan.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) restored_img = modules.face_restoration.restore_faces(np.array(image, dtype=np.uint8))
res = Image.fromarray(restored_img) res = Image.fromarray(restored_img)
if gfpgan_strength < 1.0: if face_restoration_blending < 1.0:
res = Image.blend(image, res, gfpgan_strength) res = Image.blend(image, res, face_restoration_blending)
image = res image = res