become even stricter with pickles

no pickle shall pass
thank you again, RyotaK
This commit is contained in:
AUTOMATIC 2022-10-11 17:03:00 +03:00
parent a05c824384
commit 66b7d7584f

View file

@ -10,6 +10,7 @@ import torch
import numpy import numpy
import _codecs import _codecs
import zipfile import zipfile
import re
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
@ -54,11 +55,27 @@ class RestrictedUnpickler(pickle.Unpickler):
raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden") raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
allowed_zip_names = ["archive/data.pkl", "archive/version"]
allowed_zip_names_re = re.compile(r"^archive/data/\d+$")
def check_zip_filenames(filename, names):
for name in names:
if name in allowed_zip_names:
continue
if allowed_zip_names_re.match(name):
continue
raise Exception(f"bad file inside {filename}: {name}")
def check_pt(filename): def check_pt(filename):
try: try:
# new pytorch format is a zip file # new pytorch format is a zip file
with zipfile.ZipFile(filename) as z: with zipfile.ZipFile(filename) as z:
check_zip_filenames(filename, z.namelist())
with z.open('archive/data.pkl') as file: with z.open('archive/data.pkl') as file:
unpickler = RestrictedUnpickler(file) unpickler = RestrictedUnpickler(file)
unpickler.load() unpickler.load()