add a helpful message when user puts RealESRGAN model into ESRGAN directory.
This commit is contained in:
parent
62ce77e245
commit
ad02b249f5
1 changed files with 8 additions and 5 deletions
|
@ -14,17 +14,20 @@ import modules.images
|
|||
|
||||
def load_model(filename):
|
||||
# this code is adapted from https://github.com/xinntao/ESRGAN
|
||||
if torch.has_mps:
|
||||
map_l = 'cpu'
|
||||
else:
|
||||
map_l = None
|
||||
pretrained_net = torch.load(filename, map_location=map_l)
|
||||
pretrained_net = torch.load(filename, map_location='cpu' if torch.has_mps else None)
|
||||
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
|
||||
|
||||
if 'conv_first.weight' in pretrained_net:
|
||||
crt_model.load_state_dict(pretrained_net)
|
||||
return crt_model
|
||||
|
||||
if 'model.0.weight' not in pretrained_net:
|
||||
is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
|
||||
if is_realesrgan:
|
||||
raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
|
||||
else:
|
||||
raise Exception("The file is not a ESRGAN model.")
|
||||
|
||||
crt_net = crt_model.state_dict()
|
||||
load_net_clean = {}
|
||||
for k, v in pretrained_net.items():
|
||||
|
|
Loading…
Reference in a new issue