Cleanup existing directories, fixes
This commit is contained in:
parent
740070ea9c
commit
7d5c29b674
6 changed files with 48 additions and 14 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -4,6 +4,7 @@ __pycache__
|
||||||
/venv
|
/venv
|
||||||
/tmp
|
/tmp
|
||||||
/model.ckpt
|
/model.ckpt
|
||||||
|
/models/**/*
|
||||||
/models/*.ckpt
|
/models/*.ckpt
|
||||||
/GFPGANv1.3.pth
|
/GFPGANv1.3.pth
|
||||||
/gfpgan/weights/*.pth
|
/gfpgan/weights/*.pth
|
||||||
|
|
|
@ -5,14 +5,13 @@ import traceback
|
||||||
import cv2
|
import cv2
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import modules.face_restoration
|
||||||
|
import modules.shared
|
||||||
from modules import shared, devices, modelloader
|
from modules import shared, devices, modelloader
|
||||||
from modules.paths import script_path, models_path
|
from modules.paths import script_path, models_path
|
||||||
import modules.shared
|
|
||||||
import modules.face_restoration
|
|
||||||
from importlib import reload
|
|
||||||
|
|
||||||
# codeformer people made a choice to include modified basicsr library to their project, which makes
|
# codeformer people made a choice to include modified basicsr library to their project which makes
|
||||||
# it utterly impossible to use it alongside other libraries that also use basicsr, like GFPGAN.
|
# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
|
||||||
# I am making a choice to include some files from codeformer to work around this issue.
|
# I am making a choice to include some files from codeformer to work around this issue.
|
||||||
model_dir = "Codeformer"
|
model_dir = "Codeformer"
|
||||||
model_path = os.path.join(models_path, model_dir)
|
model_path = os.path.join(models_path, model_dir)
|
||||||
|
@ -31,11 +30,6 @@ def setup_model(dirname):
|
||||||
if path is None:
|
if path is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
# both GFPGAN and CodeFormer use bascisr, one has it installed from pip the other uses its own
|
|
||||||
#stored_sys_path = sys.path
|
|
||||||
#sys.path = [path] + sys.path
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torchvision.transforms.functional import normalize
|
from torchvision.transforms.functional import normalize
|
||||||
from modules.codeformer.codeformer_arch import CodeFormer
|
from modules.codeformer.codeformer_arch import CodeFormer
|
||||||
|
@ -67,7 +61,6 @@ def setup_model(dirname):
|
||||||
print("Unable to load codeformer model.")
|
print("Unable to load codeformer model.")
|
||||||
return None, None
|
return None, None
|
||||||
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
|
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
|
||||||
ckpt_path = load_file_from_url(url=pretrain_model_url, model_dir=os.path.join(path, 'weights/CodeFormer'), progress=True)
|
|
||||||
checkpoint = torch.load(ckpt_path)['params_ema']
|
checkpoint = torch.load(ckpt_path)['params_ema']
|
||||||
net.load_state_dict(checkpoint)
|
net.load_state_dict(checkpoint)
|
||||||
net.eval()
|
net.eval()
|
||||||
|
|
|
@ -18,7 +18,7 @@ from modules.shared import opts
|
||||||
model_dir = "ESRGAN"
|
model_dir = "ESRGAN"
|
||||||
model_path = os.path.join(models_path, model_dir)
|
model_path = os.path.join(models_path, model_dir)
|
||||||
model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
|
model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
|
||||||
model_name = "ESRGAN_x4.pth"
|
model_name = "ESRGAN_x4"
|
||||||
|
|
||||||
|
|
||||||
def load_model(path: str, name: str):
|
def load_model(path: str, name: str):
|
||||||
|
@ -27,7 +27,7 @@ def load_model(path: str, name: str):
|
||||||
global model_dir
|
global model_dir
|
||||||
global model_name
|
global model_name
|
||||||
if "http" in path:
|
if "http" in path:
|
||||||
filename = load_file_from_url(url=model_url, model_dir=model_path, file_name=model_name, progress=True)
|
filename = load_file_from_url(url=model_url, model_dir=model_path, file_name="%s.pth" % model_name, progress=True)
|
||||||
else:
|
else:
|
||||||
filename = path
|
filename = path
|
||||||
if not os.path.exists(filename) or filename is None:
|
if not os.path.exists(filename) or filename is None:
|
||||||
|
|
|
@ -19,7 +19,7 @@ have_ldsr = False
|
||||||
LDSR_obj = None
|
LDSR_obj = None
|
||||||
|
|
||||||
|
|
||||||
class UpscalerLDSR(modules.images.Upscaler):
|
class UpscalerLDSR(images.Upscaler):
|
||||||
def __init__(self, steps):
|
def __init__(self, steps):
|
||||||
self.steps = steps
|
self.steps = steps
|
||||||
self.name = "LDSR"
|
self.name = "LDSR"
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
|
||||||
|
from modules.paths import script_path, models_path
|
||||||
|
|
||||||
|
|
||||||
def load_models(model_path: str, model_url: str = None, command_path: str = None, dl_name: str = None, existing=None,
|
def load_models(model_path: str, model_url: str = None, command_path: str = None, dl_name: str = None, existing=None,
|
||||||
ext_filter=None) -> list:
|
ext_filter=None) -> list:
|
||||||
|
@ -63,3 +66,38 @@ def friendly_name(file: str):
|
||||||
model_name, extension = os.path.splitext(file)
|
model_name, extension = os.path.splitext(file)
|
||||||
model_name = model_name.replace("_", " ").title()
|
model_name = model_name.replace("_", " ").title()
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_models():
|
||||||
|
root_path = script_path
|
||||||
|
src_path = os.path.join(root_path, "ESRGAN")
|
||||||
|
dest_path = os.path.join(models_path, "ESRGAN")
|
||||||
|
move_files(src_path, dest_path)
|
||||||
|
src_path = os.path.join(root_path, "gfpgan")
|
||||||
|
dest_path = os.path.join(models_path, "GFPGAN")
|
||||||
|
move_files(src_path, dest_path)
|
||||||
|
src_path = os.path.join(root_path, "SwinIR")
|
||||||
|
dest_path = os.path.join(models_path, "SwinIR")
|
||||||
|
move_files(src_path, dest_path)
|
||||||
|
src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
|
||||||
|
dest_path = os.path.join(models_path, "LDSR")
|
||||||
|
move_files(src_path, dest_path)
|
||||||
|
|
||||||
|
|
||||||
|
def move_files(src_path: str, dest_path: str):
|
||||||
|
try:
|
||||||
|
if not os.path.exists(dest_path):
|
||||||
|
os.makedirs(dest_path)
|
||||||
|
if os.path.exists(src_path):
|
||||||
|
for file in os.listdir(src_path):
|
||||||
|
if os.path.isfile(file):
|
||||||
|
fullpath = os.path.join(src_path, file)
|
||||||
|
print("Moving file: %s to %s" % (fullpath, dest_path))
|
||||||
|
try:
|
||||||
|
shutil.move(fullpath, dest_path)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
print("Removing folder: %s" % src_path)
|
||||||
|
shutil.rmtree(src_path, True)
|
||||||
|
except:
|
||||||
|
pass
|
2
webui.py
2
webui.py
|
@ -18,9 +18,11 @@ import modules.shared as shared
|
||||||
import modules.swinir_model as swinir
|
import modules.swinir_model as swinir
|
||||||
import modules.txt2img
|
import modules.txt2img
|
||||||
import modules.ui
|
import modules.ui
|
||||||
|
from modules import modelloader
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
|
|
||||||
|
modelloader.cleanup_models()
|
||||||
codeformer.setup_model(cmd_opts.codeformer_models_path)
|
codeformer.setup_model(cmd_opts.codeformer_models_path)
|
||||||
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
|
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
|
||||||
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
|
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
|
||||||
|
|
Loading…
Reference in a new issue