make swinir actually useful

This commit is contained in:
C43H66N12O12S2 2022-09-20 16:36:20 +03:00 committed by AUTOMATIC1111
parent 7267b7d2d9
commit 948eff4b3c

View file

@ -12,7 +12,13 @@ import modules.images
from modules.shared import cmd_opts, opts, device
from modules.swinir_arch import SwinIR as net
precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
def load_model(task = "realsr", large_model = True, model_path=next(os.listdir(cmd_opts.esrgan_models_path))):
def load_model(task = "realsr", large_model = True, model_path="C:/sd/ESRGANn/4x-large.pth", scale=4):
try:
modules.shared.sd_upscalers.append(UpscalerSwin("McSwinnySwin"))
except Exception:
print(f"Error loading ESRGAN model", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
if not large_model:
# use 'nearest+conv' to avoid block artifacts
model = net(upscale=scale, in_chans=3, img_size=64, window_size=8,
@ -26,12 +32,16 @@ def load_model(task = "realsr", large_model = True, model_path=next(os.listdir(c
mlp_ratio=2, upsampler='nearest+conv', resi_connection='3conv')
pretrained_model = torch.load(model_path)
model.load_state_dict(pretrained_model, strict=True)
model.load_state_dict(pretrained_model["params_ema"], strict=True)
return model.half().to(device)
def upscale(img, tile=opts.ESRGAN_tile, tile_overlap=opts.ESRGAN_tile_overlap, window_size = 8, scale = 4):
img = cv2.imread(img, cv2.IMREAD_COLOR).astype(np.float16) / 255.
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(device)
model = load_model()
with torch.no_grad(), precision_scope("cuda"):
_, _, h_old, w_old = img.size()
@ -45,7 +55,7 @@ def upscale(img, tile=opts.ESRGAN_tile, tile_overlap=opts.ESRGAN_tile_overlap, w
if output.ndim == 3:
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
return output
return Image.fromarray(output, 'RGB')
def inference(img, model, tile, tile_overlap, window_size, scale):
@ -72,3 +82,11 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
output = E.div_(W)
return output
class UpscalerSwin(modules.images.Upscaler):
def __init__(self, title):
self.name = title
def do_upscale(self, img):
img = upscale(img)
return img