become even stricter with pickles
no pickle shall pass thank you again, RyotaK
This commit is contained in:
parent
a05c824384
commit
66b7d7584f
1 changed files with 17 additions and 0 deletions
|
@ -10,6 +10,7 @@ import torch
|
||||||
import numpy
|
import numpy
|
||||||
import _codecs
|
import _codecs
|
||||||
import zipfile
|
import zipfile
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
||||||
|
@ -54,11 +55,27 @@ class RestrictedUnpickler(pickle.Unpickler):
|
||||||
raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
|
raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
|
||||||
|
|
||||||
|
|
||||||
|
allowed_zip_names = ["archive/data.pkl", "archive/version"]
|
||||||
|
allowed_zip_names_re = re.compile(r"^archive/data/\d+$")
|
||||||
|
|
||||||
|
|
||||||
|
def check_zip_filenames(filename, names):
|
||||||
|
for name in names:
|
||||||
|
if name in allowed_zip_names:
|
||||||
|
continue
|
||||||
|
if allowed_zip_names_re.match(name):
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise Exception(f"bad file inside {filename}: {name}")
|
||||||
|
|
||||||
|
|
||||||
def check_pt(filename):
|
def check_pt(filename):
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# new pytorch format is a zip file
|
# new pytorch format is a zip file
|
||||||
with zipfile.ZipFile(filename) as z:
|
with zipfile.ZipFile(filename) as z:
|
||||||
|
check_zip_filenames(filename, z.namelist())
|
||||||
|
|
||||||
with z.open('archive/data.pkl') as file:
|
with z.open('archive/data.pkl') as file:
|
||||||
unpickler = RestrictedUnpickler(file)
|
unpickler = RestrictedUnpickler(file)
|
||||||
unpickler.load()
|
unpickler.load()
|
||||||
|
|
Loading…
Reference in a new issue