Re-implement universal model loading
This commit is contained in:
parent
bfb7f15d46
commit
740070ea9c
12 changed files with 449 additions and 134 deletions
|
@ -5,22 +5,28 @@ import traceback
|
|||
import cv2
|
||||
import torch
|
||||
|
||||
from modules import shared, devices
|
||||
from modules.paths import script_path
|
||||
from modules import shared, devices, modelloader
|
||||
from modules.paths import script_path, models_path
|
||||
import modules.shared
|
||||
import modules.face_restoration
|
||||
from importlib import reload
|
||||
|
||||
# codeformer people made a choice to include modified basicsr librry to their projectwhich makes
|
||||
# it utterly impossiblr to use it alongside with other libraries that also use basicsr, like GFPGAN.
|
||||
# codeformer people made a choice to include modified basicsr library to their project, which makes
|
||||
# it utterly impossible to use it alongside other libraries that also use basicsr, like GFPGAN.
|
||||
# I am making a choice to include some files from codeformer to work around this issue.
|
||||
|
||||
pretrain_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||
model_dir = "Codeformer"
|
||||
model_path = os.path.join(models_path, model_dir)
|
||||
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||
|
||||
have_codeformer = False
|
||||
codeformer = None
|
||||
|
||||
def setup_codeformer():
|
||||
|
||||
def setup_model(dirname):
|
||||
global model_path
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(model_path)
|
||||
|
||||
path = modules.paths.paths.get("CodeFormer", None)
|
||||
if path is None:
|
||||
return
|
||||
|
@ -44,16 +50,22 @@ def setup_codeformer():
|
|||
def name(self):
|
||||
return "CodeFormer"
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, dirname):
|
||||
self.net = None
|
||||
self.face_helper = None
|
||||
self.cmd_dir = dirname
|
||||
|
||||
def create_models(self):
|
||||
|
||||
if self.net is not None and self.face_helper is not None:
|
||||
self.net.to(devices.device_codeformer)
|
||||
return self.net, self.face_helper
|
||||
|
||||
model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir)
|
||||
if len(model_paths) != 0:
|
||||
ckpt_path = model_paths[0]
|
||||
else:
|
||||
print("Unable to load codeformer model.")
|
||||
return None, None
|
||||
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
|
||||
ckpt_path = load_file_from_url(url=pretrain_model_url, model_dir=os.path.join(path, 'weights/CodeFormer'), progress=True)
|
||||
checkpoint = torch.load(ckpt_path)['params_ema']
|
||||
|
@ -74,6 +86,9 @@ def setup_codeformer():
|
|||
original_resolution = np_image.shape[0:2]
|
||||
|
||||
self.create_models()
|
||||
if self.net is None or self.face_helper is None:
|
||||
return np_image
|
||||
|
||||
self.face_helper.clean_all()
|
||||
self.face_helper.read_image(np_image)
|
||||
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
||||
|
@ -114,7 +129,7 @@ def setup_codeformer():
|
|||
have_codeformer = True
|
||||
|
||||
global codeformer
|
||||
codeformer = FaceRestorerCodeFormer()
|
||||
codeformer = FaceRestorerCodeFormer(dirname)
|
||||
shared.face_restorers.append(codeformer)
|
||||
|
||||
except Exception:
|
||||
|
|
|
@ -5,15 +5,35 @@ import traceback
|
|||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
import modules.esrgam_model_arch as arch
|
||||
from modules import shared
|
||||
from modules.shared import opts
|
||||
from modules.devices import has_mps
|
||||
import modules.images
|
||||
from modules import shared
|
||||
from modules import shared, modelloader
|
||||
from modules.devices import has_mps
|
||||
from modules.paths import models_path
|
||||
from modules.shared import opts
|
||||
|
||||
model_dir = "ESRGAN"
|
||||
model_path = os.path.join(models_path, model_dir)
|
||||
model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
|
||||
model_name = "ESRGAN_x4.pth"
|
||||
|
||||
|
||||
def load_model(filename):
|
||||
def load_model(path: str, name: str):
|
||||
global model_path
|
||||
global model_url
|
||||
global model_dir
|
||||
global model_name
|
||||
if "http" in path:
|
||||
filename = load_file_from_url(url=model_url, model_dir=model_path, file_name=model_name, progress=True)
|
||||
else:
|
||||
filename = path
|
||||
if not os.path.exists(filename) or filename is None:
|
||||
print("Unable to load %s from %s" % (model_dir, filename))
|
||||
return None
|
||||
print("Loading %s from %s" % (model_dir, filename))
|
||||
# this code is adapted from https://github.com/xinntao/ESRGAN
|
||||
pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
|
||||
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
|
||||
|
@ -118,24 +138,30 @@ def esrgan_upscale(model, img):
|
|||
class UpscalerESRGAN(modules.images.Upscaler):
|
||||
def __init__(self, filename, title):
|
||||
self.name = title
|
||||
self.model = load_model(filename)
|
||||
self.filename = filename
|
||||
|
||||
def do_upscale(self, img):
|
||||
model = self.model.to(shared.device)
|
||||
model = load_model(self.filename, self.name)
|
||||
if model is None:
|
||||
return img
|
||||
model.to(shared.device)
|
||||
img = esrgan_upscale(model, img)
|
||||
return img
|
||||
|
||||
|
||||
def load_models(dirname):
|
||||
for file in os.listdir(dirname):
|
||||
path = os.path.join(dirname, file)
|
||||
model_name, extension = os.path.splitext(file)
|
||||
|
||||
if extension != '.pt' and extension != '.pth':
|
||||
continue
|
||||
def setup_model(dirname):
|
||||
global model_path
|
||||
global model_name
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(model_path)
|
||||
|
||||
model_paths = modelloader.load_models(model_path, command_path=dirname, ext_filter=[".pt", ".pth"])
|
||||
if len(model_paths) == 0:
|
||||
modules.shared.sd_upscalers.append(UpscalerESRGAN(model_url, model_name))
|
||||
for file in model_paths:
|
||||
name = modelloader.friendly_name(file)
|
||||
try:
|
||||
modules.shared.sd_upscalers.append(UpscalerESRGAN(path, model_name))
|
||||
modules.shared.sd_upscalers.append(UpscalerESRGAN(file, name))
|
||||
except Exception:
|
||||
print(f"Error loading ESRGAN model: {path}", file=sys.stderr)
|
||||
print(f"Error loading ESRGAN model: {file}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
|
|
@ -36,6 +36,8 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
|
|||
|
||||
outputs = []
|
||||
for image, image_name in zip(imageArr, imageNameArr):
|
||||
if image is None:
|
||||
return outputs, "Please select an input image.", ''
|
||||
existing_pnginfo = image.info or {}
|
||||
|
||||
image = image.convert("RGB")
|
||||
|
|
|
@ -7,33 +7,20 @@ from modules import shared, devices
|
|||
from modules.shared import cmd_opts
|
||||
from modules.paths import script_path
|
||||
import modules.face_restoration
|
||||
from modules import shared, devices, modelloader
|
||||
from modules.paths import models_path
|
||||
|
||||
|
||||
def gfpgan_model_path():
|
||||
from modules.shared import cmd_opts
|
||||
|
||||
filemask = 'GFPGAN*.pth'
|
||||
|
||||
if cmd_opts.gfpgan_model is not None:
|
||||
return cmd_opts.gfpgan_model
|
||||
|
||||
places = [script_path, '.', os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models')]
|
||||
|
||||
filename = None
|
||||
for place in places:
|
||||
filename = next(iter(glob(os.path.join(place, filemask))), None)
|
||||
if filename is not None:
|
||||
break
|
||||
|
||||
return filename
|
||||
|
||||
model_dir = "GFPGAN"
|
||||
cmd_dir = None
|
||||
model_path = os.path.join(models_path, model_dir)
|
||||
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
||||
|
||||
loaded_gfpgan_model = None
|
||||
|
||||
|
||||
def gfpgan():
|
||||
global loaded_gfpgan_model
|
||||
|
||||
global model_path
|
||||
if loaded_gfpgan_model is not None:
|
||||
loaded_gfpgan_model.gfpgan.to(shared.device)
|
||||
return loaded_gfpgan_model
|
||||
|
@ -41,7 +28,15 @@ def gfpgan():
|
|||
if gfpgan_constructor is None:
|
||||
return None
|
||||
|
||||
model = gfpgan_constructor(model_path=gfpgan_model_path() or 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth', upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
|
||||
models = modelloader.load_models(model_path, model_url, cmd_dir)
|
||||
if len(models) != 0:
|
||||
latest_file = max(models, key=os.path.getctime)
|
||||
model_file = latest_file
|
||||
else:
|
||||
print("Unable to load gfpgan model!")
|
||||
return None
|
||||
model = gfpgan_constructor(model_path=model_file, model_dir=model_path, upscale=1, arch='clean', channel_multiplier=2,
|
||||
bg_upsampler=None)
|
||||
model.gfpgan.to(shared.device)
|
||||
loaded_gfpgan_model = model
|
||||
|
||||
|
@ -50,7 +45,8 @@ def gfpgan():
|
|||
|
||||
def gfpgan_fix_faces(np_image):
|
||||
model = gfpgan()
|
||||
|
||||
if model is None:
|
||||
return np_image
|
||||
np_image_bgr = np_image[:, :, ::-1]
|
||||
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
|
||||
np_image = gfpgan_output_bgr[:, :, ::-1]
|
||||
|
@ -64,19 +60,21 @@ def gfpgan_fix_faces(np_image):
|
|||
have_gfpgan = False
|
||||
gfpgan_constructor = None
|
||||
|
||||
def setup_gfpgan():
|
||||
|
||||
def setup_model(dirname):
|
||||
global model_path
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(model_path)
|
||||
|
||||
try:
|
||||
gfpgan_model_path()
|
||||
|
||||
if os.path.exists(cmd_opts.gfpgan_dir):
|
||||
sys.path.append(os.path.abspath(cmd_opts.gfpgan_dir))
|
||||
from gfpgan import GFPGANer
|
||||
|
||||
from modules.gfpgan_model_arch import GFPGANerr
|
||||
global cmd_dir
|
||||
global have_gfpgan
|
||||
have_gfpgan = True
|
||||
|
||||
global gfpgan_constructor
|
||||
gfpgan_constructor = GFPGANer
|
||||
|
||||
cmd_dir = dirname
|
||||
have_gfpgan = True
|
||||
gfpgan_constructor = GFPGANerr
|
||||
|
||||
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
|
||||
def name(self):
|
||||
|
|
150
modules/gfpgan_model_arch.py
Normal file
150
modules/gfpgan_model_arch.py
Normal file
|
@ -0,0 +1,150 @@
|
|||
# GFPGAN likes to download stuff "wherever", and we're trying to fix that, so this is a copy of the original...
|
||||
|
||||
import cv2
|
||||
import os
|
||||
import torch
|
||||
from basicsr.utils import img2tensor, tensor2img
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
from torchvision.transforms.functional import normalize
|
||||
|
||||
from gfpgan.archs.gfpgan_bilinear_arch import GFPGANBilinear
|
||||
from gfpgan.archs.gfpganv1_arch import GFPGANv1
|
||||
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
|
||||
|
||||
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
class GFPGANerr():
|
||||
"""Helper for restoration with GFPGAN.
|
||||
|
||||
It will detect and crop faces, and then resize the faces to 512x512.
|
||||
GFPGAN is used to restored the resized faces.
|
||||
The background is upsampled with the bg_upsampler.
|
||||
Finally, the faces will be pasted back to the upsample background image.
|
||||
|
||||
Args:
|
||||
model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
|
||||
upscale (float): The upscale of the final output. Default: 2.
|
||||
arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
|
||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||
bg_upsampler (nn.Module): The upsampler for the background. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, model_path, model_dir, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None, device=None):
|
||||
self.upscale = upscale
|
||||
self.bg_upsampler = bg_upsampler
|
||||
|
||||
# initialize model
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
|
||||
# initialize the GFP-GAN
|
||||
if arch == 'clean':
|
||||
self.gfpgan = GFPGANv1Clean(
|
||||
out_size=512,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=channel_multiplier,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=False,
|
||||
num_mlp=8,
|
||||
input_is_latent=True,
|
||||
different_w=True,
|
||||
narrow=1,
|
||||
sft_half=True)
|
||||
elif arch == 'bilinear':
|
||||
self.gfpgan = GFPGANBilinear(
|
||||
out_size=512,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=channel_multiplier,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=False,
|
||||
num_mlp=8,
|
||||
input_is_latent=True,
|
||||
different_w=True,
|
||||
narrow=1,
|
||||
sft_half=True)
|
||||
elif arch == 'original':
|
||||
self.gfpgan = GFPGANv1(
|
||||
out_size=512,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=channel_multiplier,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=True,
|
||||
num_mlp=8,
|
||||
input_is_latent=True,
|
||||
different_w=True,
|
||||
narrow=1,
|
||||
sft_half=True)
|
||||
elif arch == 'RestoreFormer':
|
||||
from gfpgan.archs.restoreformer_arch import RestoreFormer
|
||||
self.gfpgan = RestoreFormer()
|
||||
# initialize face helper
|
||||
self.face_helper = FaceRestoreHelper(
|
||||
upscale,
|
||||
face_size=512,
|
||||
crop_ratio=(1, 1),
|
||||
det_model='retinaface_resnet50',
|
||||
save_ext='png',
|
||||
use_parse=True,
|
||||
device=self.device,
|
||||
model_rootpath=model_dir)
|
||||
|
||||
if model_path.startswith('https://'):
|
||||
model_path = load_file_from_url(
|
||||
url=model_path, model_dir=model_dir, progress=True, file_name=None)
|
||||
loadnet = torch.load(model_path)
|
||||
if 'params_ema' in loadnet:
|
||||
keyname = 'params_ema'
|
||||
else:
|
||||
keyname = 'params'
|
||||
self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
|
||||
self.gfpgan.eval()
|
||||
self.gfpgan = self.gfpgan.to(self.device)
|
||||
|
||||
@torch.no_grad()
|
||||
def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True, weight=0.5):
|
||||
self.face_helper.clean_all()
|
||||
|
||||
if has_aligned: # the inputs are already aligned
|
||||
img = cv2.resize(img, (512, 512))
|
||||
self.face_helper.cropped_faces = [img]
|
||||
else:
|
||||
self.face_helper.read_image(img)
|
||||
# get face landmarks for each face
|
||||
self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
|
||||
# eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
|
||||
# TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
|
||||
# align and warp each face
|
||||
self.face_helper.align_warp_face()
|
||||
|
||||
# face restoration
|
||||
for cropped_face in self.face_helper.cropped_faces:
|
||||
# prepare data
|
||||
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
|
||||
|
||||
try:
|
||||
output = self.gfpgan(cropped_face_t, return_rgb=False, weight=weight)[0]
|
||||
# convert to image
|
||||
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
|
||||
except RuntimeError as error:
|
||||
print(f'\tFailed inference for GFPGAN: {error}.')
|
||||
restored_face = cropped_face
|
||||
|
||||
restored_face = restored_face.astype('uint8')
|
||||
self.face_helper.add_restored_face(restored_face)
|
||||
|
||||
if not has_aligned and paste_back:
|
||||
# upsample the background
|
||||
if self.bg_upsampler is not None:
|
||||
# Now only support RealESRGAN for upsampling background
|
||||
bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
|
||||
else:
|
||||
bg_img = None
|
||||
|
||||
self.face_helper.get_inverse_affine(None)
|
||||
# paste each restored face to the input image
|
||||
restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)
|
||||
return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
|
||||
else:
|
||||
return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
|
|
@ -3,11 +3,14 @@ import sys
|
|||
import traceback
|
||||
from collections import namedtuple
|
||||
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from modules import shared, images, modelloader, paths
|
||||
from modules.paths import models_path
|
||||
|
||||
import modules.images
|
||||
from modules import shared
|
||||
from modules.paths import script_path
|
||||
model_dir = "LDSR"
|
||||
model_path = os.path.join(models_path, model_dir)
|
||||
cmd_path = None
|
||||
model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
|
||||
yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
|
||||
|
||||
LDSRModelInfo = namedtuple("LDSRModelInfo", ["name", "location", "model", "netscale"])
|
||||
|
||||
|
@ -25,28 +28,32 @@ class UpscalerLDSR(modules.images.Upscaler):
|
|||
return upscale_with_ldsr(img)
|
||||
|
||||
|
||||
def add_lsdr():
|
||||
modules.shared.sd_upscalers.append(UpscalerLDSR(100))
|
||||
def setup_model(dirname):
|
||||
global cmd_path
|
||||
global model_path
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(model_path)
|
||||
cmd_path = dirname
|
||||
shared.sd_upscalers.append(UpscalerLDSR(100))
|
||||
|
||||
|
||||
def setup_ldsr():
|
||||
path = modules.paths.paths.get("LDSR", None)
|
||||
def prepare_ldsr():
|
||||
path = paths.paths.get("LDSR", None)
|
||||
if path is None:
|
||||
return
|
||||
global have_ldsr
|
||||
global LDSR_obj
|
||||
try:
|
||||
from LDSR import LDSR
|
||||
model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
|
||||
yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
|
||||
repo_path = 'latent-diffusion/experiments/pretrained_models/'
|
||||
model_path = load_file_from_url(url=model_url, model_dir=os.path.join("repositories", repo_path),
|
||||
progress=True, file_name="model.chkpt")
|
||||
yaml_path = load_file_from_url(url=yaml_url, model_dir=os.path.join("repositories", repo_path),
|
||||
progress=True, file_name="project.yaml")
|
||||
have_ldsr = True
|
||||
LDSR_obj = LDSR(model_path, yaml_path)
|
||||
|
||||
model_files = modelloader.load_models(model_path, model_url, cmd_path, dl_name="model.ckpt", ext_filter=[".ckpt"])
|
||||
yaml_files = modelloader.load_models(model_path, yaml_url, cmd_path, dl_name="project.yaml", ext_filter=[".yaml"])
|
||||
if len(model_files) != 0 and len(yaml_files) != 0:
|
||||
model_file = model_files[0]
|
||||
yaml_file = yaml_files[0]
|
||||
have_ldsr = True
|
||||
LDSR_obj = LDSR(model_file, yaml_file)
|
||||
else:
|
||||
return
|
||||
|
||||
except Exception:
|
||||
print("Error importing LDSR:", file=sys.stderr)
|
||||
|
@ -55,7 +62,7 @@ def setup_ldsr():
|
|||
|
||||
|
||||
def upscale_with_ldsr(image):
|
||||
setup_ldsr()
|
||||
prepare_ldsr()
|
||||
if not have_ldsr or LDSR_obj is None:
|
||||
return image
|
||||
|
||||
|
|
65
modules/modelloader.py
Normal file
65
modules/modelloader.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
import os
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
|
||||
def load_models(model_path: str, model_url: str = None, command_path: str = None, dl_name: str = None, existing=None,
|
||||
ext_filter=None) -> list:
|
||||
"""
|
||||
A one-and done loader to try finding the desired models in specified directories.
|
||||
|
||||
@param dl_name: The file name to use for downloading a model. If not specified, it will be used from the URL.
|
||||
@param model_url: If specified, attempt to download model from the given URL.
|
||||
@param model_path: The location to store/find models in.
|
||||
@param command_path: A command-line argument to search for models in first.
|
||||
@param existing: An array of existing model paths.
|
||||
@param ext_filter: An optional list of filename extensions to filter by
|
||||
@return: A list of paths containing the desired model(s)
|
||||
"""
|
||||
if ext_filter is None:
|
||||
ext_filter = []
|
||||
if existing is None:
|
||||
existing = []
|
||||
try:
|
||||
places = []
|
||||
if command_path is not None and command_path != model_path:
|
||||
pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
|
||||
if os.path.exists(pretrained_path):
|
||||
places.append(pretrained_path)
|
||||
elif os.path.exists(command_path):
|
||||
places.append(command_path)
|
||||
places.append(model_path)
|
||||
for place in places:
|
||||
if os.path.exists(place):
|
||||
for file in os.listdir(place):
|
||||
if os.path.isdir(file):
|
||||
continue
|
||||
if len(ext_filter) != 0:
|
||||
model_name, extension = os.path.splitext(file)
|
||||
if extension not in ext_filter:
|
||||
continue
|
||||
if file not in existing:
|
||||
path = os.path.join(place, file)
|
||||
existing.append(path)
|
||||
if model_url is not None:
|
||||
if dl_name is not None:
|
||||
model_file = load_file_from_url(url=model_url, model_dir=model_path, file_name=dl_name, progress=True)
|
||||
else:
|
||||
model_file = load_file_from_url(url=model_url, model_dir=model_path, progress=True)
|
||||
|
||||
if os.path.exists(model_file) and os.path.isfile(model_file) and model_file not in existing:
|
||||
existing.append(model_file)
|
||||
except:
|
||||
pass
|
||||
return existing
|
||||
|
||||
|
||||
def friendly_name(file: str):
|
||||
if "http" in file:
|
||||
file = urlparse(file).path
|
||||
|
||||
file = os.path.basename(file)
|
||||
model_name, extension = os.path.splitext(file)
|
||||
model_name = model_name.replace("_", " ").title()
|
||||
return model_name
|
|
@ -3,9 +3,10 @@ import os
|
|||
import sys
|
||||
|
||||
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
models_path = os.path.join(script_path, "models")
|
||||
sys.path.insert(0, script_path)
|
||||
|
||||
# search for directory of stable diffsuion in following palces
|
||||
# search for directory of stable diffusion in following places
|
||||
sd_path = None
|
||||
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)]
|
||||
for possible_sd_path in possible_sd_paths:
|
||||
|
|
|
@ -1,14 +1,20 @@
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from collections import namedtuple
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
import modules.images
|
||||
from modules.paths import models_path
|
||||
from modules.shared import cmd_opts, opts
|
||||
|
||||
model_dir = "RealESRGAN"
|
||||
model_path = os.path.join(models_path, model_dir)
|
||||
cmd_dir = None
|
||||
RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"])
|
||||
realesrgan_models = []
|
||||
have_realesrgan = False
|
||||
|
@ -17,7 +23,6 @@ have_realesrgan = False
|
|||
def get_realesrgan_models():
|
||||
try:
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
models = [
|
||||
RealesrganModelInfo(
|
||||
|
@ -59,7 +64,7 @@ def get_realesrgan_models():
|
|||
]
|
||||
return models
|
||||
except Exception as e:
|
||||
print("Error makeing Real-ESRGAN midels list:", file=sys.stderr)
|
||||
print("Error making Real-ESRGAN models list:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
|
||||
|
@ -73,10 +78,15 @@ class UpscalerRealESRGAN(modules.images.Upscaler):
|
|||
return upscale_with_realesrgan(img, self.upscaling, self.model_index)
|
||||
|
||||
|
||||
def setup_realesrgan():
|
||||
def setup_model(dirname):
|
||||
global model_path
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(model_path)
|
||||
|
||||
global realesrgan_models
|
||||
global have_realesrgan
|
||||
|
||||
if model_path != dirname:
|
||||
model_path = dirname
|
||||
try:
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
|
@ -104,6 +114,11 @@ def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index)
|
|||
info = realesrgan_models[RealESRGAN_model_index]
|
||||
|
||||
model = info.model()
|
||||
model_file = load_file_from_url(url=info.location, model_dir=model_path, progress=True)
|
||||
if not os.path.exists(model_file):
|
||||
print("Unable to load RealESRGAN model: %s" % info.name)
|
||||
return image
|
||||
|
||||
upsampler = RealESRGANer(
|
||||
scale=info.netscale,
|
||||
model_path=info.location,
|
||||
|
|
|
@ -16,11 +16,11 @@ import modules.sd_models
|
|||
|
||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||
default_sd_model_file = sd_model_file
|
||||
|
||||
model_path = os.path.join(script_path, 'models')
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
|
||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",)
|
||||
parser.add_argument("--ckpt-dir", type=str, default=os.path.join(script_path, 'models'), help="path to directory with stable diffusion checkpoints",)
|
||||
parser.add_argument("--ckpt-dir", type=str, default=model_path, help="path to directory with stable diffusion checkpoints",)
|
||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
|
||||
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
||||
|
@ -34,8 +34,12 @@ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="dis
|
|||
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
||||
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
|
||||
parser.add_argument("--esrgan-models-path", type=str, help="path to directory with ESRGAN models", default=os.path.join(script_path, 'ESRGAN'))
|
||||
parser.add_argument("--swinir-models-path", type=str, help="path to directory with SwinIR models", default=os.path.join(script_path, 'SwinIR'))
|
||||
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model(s)", default=os.path.join(model_path, 'Codeformer'))
|
||||
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model(s)", default=os.path.join(model_path, 'GFPGAN'))
|
||||
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN models", default=os.path.join(model_path, 'ESRGAN'))
|
||||
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN models", default=os.path.join(model_path, 'RealESRGAN'))
|
||||
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR models", default=os.path.join(model_path, 'SwinIR'))
|
||||
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR models", default=os.path.join(model_path, 'LDSR'))
|
||||
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
|
||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
||||
|
|
|
@ -1,21 +1,39 @@
|
|||
import contextlib
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import cv2
|
||||
import os
|
||||
import contextlib
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
import modules.images
|
||||
from modules.shared import cmd_opts, opts, device
|
||||
from modules.swinir_arch import SwinIR as net
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
import modules.images
|
||||
from modules import modelloader
|
||||
from modules.paths import models_path
|
||||
from modules.shared import cmd_opts, opts, device
|
||||
from modules.swinir_model_arch import SwinIR as net
|
||||
|
||||
model_dir = "SwinIR"
|
||||
model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
|
||||
model_name = "SwinIR x4"
|
||||
model_path = os.path.join(models_path, model_dir)
|
||||
cmd_path = ""
|
||||
precision_scope = (
|
||||
torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
|
||||
)
|
||||
|
||||
|
||||
def load_model(filename, scale=4):
|
||||
def load_model(path, scale=4):
|
||||
global model_path
|
||||
global model_name
|
||||
if "http" in path:
|
||||
dl_name = "%s%s" % (model_name.replace(" ", "_"), ".pth")
|
||||
filename = load_file_from_url(url=path, model_dir=model_path, file_name=dl_name, progress=True)
|
||||
else:
|
||||
filename = path
|
||||
if filename is None or not os.path.exists(filename):
|
||||
return None
|
||||
model = net(
|
||||
upscale=scale,
|
||||
in_chans=3,
|
||||
|
@ -37,19 +55,29 @@ def load_model(filename, scale=4):
|
|||
return model
|
||||
|
||||
|
||||
def load_models(dirname):
|
||||
for file in os.listdir(dirname):
|
||||
path = os.path.join(dirname, file)
|
||||
model_name, extension = os.path.splitext(file)
|
||||
def setup_model(dirname):
|
||||
global model_path
|
||||
global model_name
|
||||
global cmd_path
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(model_path)
|
||||
cmd_path = dirname
|
||||
model_file = ""
|
||||
try:
|
||||
models = modelloader.load_models(model_path, ext_filter=[".pt", ".pth"], command_path=cmd_path)
|
||||
|
||||
if extension != ".pt" and extension != ".pth":
|
||||
continue
|
||||
if len(models) != 0:
|
||||
model_file = models[0]
|
||||
name = modelloader.friendly_name(model_file)
|
||||
else:
|
||||
# Add the "default" model if none are found.
|
||||
model_file = model_url
|
||||
name = model_name
|
||||
|
||||
try:
|
||||
modules.shared.sd_upscalers.append(UpscalerSwin(path, model_name))
|
||||
except Exception:
|
||||
print(f"Error loading SwinIR model: {path}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
modules.shared.sd_upscalers.append(UpscalerSwin(model_file, name))
|
||||
except Exception:
|
||||
print(f"Error loading SwinIR model: {model_file}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
|
||||
def upscale(
|
||||
|
@ -115,9 +143,16 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
|
|||
class UpscalerSwin(modules.images.Upscaler):
|
||||
def __init__(self, filename, title):
|
||||
self.name = title
|
||||
self.model = load_model(filename)
|
||||
self.filename = filename
|
||||
|
||||
def do_upscale(self, img):
|
||||
model = self.model.to(device)
|
||||
model = load_model(self.filename)
|
||||
if model is None:
|
||||
return img
|
||||
model = model.to(device)
|
||||
img = upscale(img, model)
|
||||
return img
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
except:
|
||||
pass
|
||||
return img
|
47
webui.py
47
webui.py
|
@ -1,37 +1,34 @@
|
|||
import os
|
||||
import signal
|
||||
import threading
|
||||
|
||||
from modules.paths import script_path
|
||||
|
||||
import signal
|
||||
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
import modules.shared as shared
|
||||
import modules.ui
|
||||
import modules.codeformer_model as codeformer
|
||||
import modules.esrgan_model as esrgan
|
||||
import modules.extras
|
||||
import modules.face_restoration
|
||||
import modules.gfpgan_model as gfpgan
|
||||
import modules.img2img
|
||||
import modules.ldsr_model as ldsr
|
||||
import modules.lowvram
|
||||
import modules.realesrgan_model as realesrgan
|
||||
import modules.scripts
|
||||
import modules.sd_hijack
|
||||
import modules.codeformer_model
|
||||
import modules.gfpgan_model
|
||||
import modules.face_restoration
|
||||
import modules.realesrgan_model as realesrgan
|
||||
import modules.esrgan_model as esrgan
|
||||
import modules.ldsr_model as ldsr
|
||||
import modules.extras
|
||||
import modules.lowvram
|
||||
import modules.txt2img
|
||||
import modules.img2img
|
||||
import modules.swinir as swinir
|
||||
import modules.sd_models
|
||||
import modules.shared as shared
|
||||
import modules.swinir_model as swinir
|
||||
import modules.txt2img
|
||||
import modules.ui
|
||||
from modules.paths import script_path
|
||||
from modules.shared import cmd_opts
|
||||
|
||||
|
||||
modules.codeformer_model.setup_codeformer()
|
||||
modules.gfpgan_model.setup_gfpgan()
|
||||
codeformer.setup_model(cmd_opts.codeformer_models_path)
|
||||
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
|
||||
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
|
||||
|
||||
esrgan.load_models(cmd_opts.esrgan_models_path)
|
||||
swinir.load_models(cmd_opts.swinir_models_path)
|
||||
realesrgan.setup_realesrgan()
|
||||
ldsr.add_lsdr()
|
||||
esrgan.setup_model(cmd_opts.esrgan_models_path)
|
||||
swinir.setup_model(cmd_opts.swinir_models_path)
|
||||
realesrgan.setup_model(cmd_opts.realesrgan_models_path)
|
||||
ldsr.setup_model(cmd_opts.ldsr_models_path)
|
||||
queue_lock = threading.Lock()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue