add load_with_extra function for modules to load checkpoints with extended whitelist
This commit is contained in:
parent
9cd1a66648
commit
6e4de5b442
1 changed files with 37 additions and 3 deletions
|
@ -23,11 +23,18 @@ def encode(*args):
|
||||||
|
|
||||||
|
|
||||||
class RestrictedUnpickler(pickle.Unpickler):
|
class RestrictedUnpickler(pickle.Unpickler):
|
||||||
|
extra_handler = None
|
||||||
|
|
||||||
def persistent_load(self, saved_id):
|
def persistent_load(self, saved_id):
|
||||||
assert saved_id[0] == 'storage'
|
assert saved_id[0] == 'storage'
|
||||||
return TypedStorage()
|
return TypedStorage()
|
||||||
|
|
||||||
def find_class(self, module, name):
|
def find_class(self, module, name):
|
||||||
|
if self.extra_handler is not None:
|
||||||
|
res = self.extra_handler(module, name)
|
||||||
|
if res is not None:
|
||||||
|
return res
|
||||||
|
|
||||||
if module == 'collections' and name == 'OrderedDict':
|
if module == 'collections' and name == 'OrderedDict':
|
||||||
return getattr(collections, name)
|
return getattr(collections, name)
|
||||||
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
|
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
|
||||||
|
@ -52,7 +59,7 @@ class RestrictedUnpickler(pickle.Unpickler):
|
||||||
return set
|
return set
|
||||||
|
|
||||||
# Forbid everything else.
|
# Forbid everything else.
|
||||||
raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
|
raise Exception(f"global '{module}/{name}' is forbidden")
|
||||||
|
|
||||||
|
|
||||||
allowed_zip_names = ["archive/data.pkl", "archive/version"]
|
allowed_zip_names = ["archive/data.pkl", "archive/version"]
|
||||||
|
@ -69,7 +76,7 @@ def check_zip_filenames(filename, names):
|
||||||
raise Exception(f"bad file inside {filename}: {name}")
|
raise Exception(f"bad file inside {filename}: {name}")
|
||||||
|
|
||||||
|
|
||||||
def check_pt(filename):
|
def check_pt(filename, extra_handler):
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# new pytorch format is a zip file
|
# new pytorch format is a zip file
|
||||||
|
@ -78,6 +85,7 @@ def check_pt(filename):
|
||||||
|
|
||||||
with z.open('archive/data.pkl') as file:
|
with z.open('archive/data.pkl') as file:
|
||||||
unpickler = RestrictedUnpickler(file)
|
unpickler = RestrictedUnpickler(file)
|
||||||
|
unpickler.extra_handler = extra_handler
|
||||||
unpickler.load()
|
unpickler.load()
|
||||||
|
|
||||||
except zipfile.BadZipfile:
|
except zipfile.BadZipfile:
|
||||||
|
@ -85,16 +93,42 @@ def check_pt(filename):
|
||||||
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
|
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
|
||||||
with open(filename, "rb") as file:
|
with open(filename, "rb") as file:
|
||||||
unpickler = RestrictedUnpickler(file)
|
unpickler = RestrictedUnpickler(file)
|
||||||
|
unpickler.extra_handler = extra_handler
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
unpickler.load()
|
unpickler.load()
|
||||||
|
|
||||||
|
|
||||||
def load(filename, *args, **kwargs):
|
def load(filename, *args, **kwargs):
|
||||||
|
return load_with_extra(filename, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
this functon is intended to be used by extensions that want to load models with
|
||||||
|
some extra classes in them that the usual unpickler would find suspicious.
|
||||||
|
|
||||||
|
Use the extra_handler argument to specify a function that takes module and field name as text,
|
||||||
|
and returns that field's value:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def extra(module, name):
|
||||||
|
if module == 'collections' and name == 'OrderedDict':
|
||||||
|
return collections.OrderedDict
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
safe.load_with_extra('model.pt', extra_handler=extra)
|
||||||
|
```
|
||||||
|
|
||||||
|
The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
|
||||||
|
definitely unsafe.
|
||||||
|
"""
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not shared.cmd_opts.disable_safe_unpickle:
|
if not shared.cmd_opts.disable_safe_unpickle:
|
||||||
check_pt(filename)
|
check_pt(filename, extra_handler)
|
||||||
|
|
||||||
except pickle.UnpicklingError:
|
except pickle.UnpicklingError:
|
||||||
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
||||||
|
|
Loading…
Reference in a new issue