aboutsummaryrefslogtreecommitdiff
path: root/modules/safe.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/safe.py')
-rw-r--r--modules/safe.py58
1 files changed, 48 insertions, 10 deletions
diff --git a/modules/safe.py b/modules/safe.py
index 348a24fc..10460ad0 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -23,11 +23,18 @@ def encode(*args):
class RestrictedUnpickler(pickle.Unpickler):
+ extra_handler = None
+
def persistent_load(self, saved_id):
assert saved_id[0] == 'storage'
return TypedStorage()
def find_class(self, module, name):
+ if self.extra_handler is not None:
+ res = self.extra_handler(module, name)
+ if res is not None:
+ return res
+
if module == 'collections' and name == 'OrderedDict':
return getattr(collections, name)
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
@@ -52,32 +59,37 @@ class RestrictedUnpickler(pickle.Unpickler):
return set
# Forbid everything else.
- raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
+ raise Exception(f"global '{module}/{name}' is forbidden")
-allowed_zip_names = ["archive/data.pkl", "archive/version"]
-allowed_zip_names_re = re.compile(r"^archive/data/\d+$")
-
+# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
+allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
+data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
def check_zip_filenames(filename, names):
for name in names:
- if name in allowed_zip_names:
- continue
if allowed_zip_names_re.match(name):
continue
raise Exception(f"bad file inside {filename}: {name}")
-def check_pt(filename):
+def check_pt(filename, extra_handler):
try:
# new pytorch format is a zip file
with zipfile.ZipFile(filename) as z:
check_zip_filenames(filename, z.namelist())
-
- with z.open('archive/data.pkl') as file:
+
+ # find filename of data.pkl in zip file: '<directory name>/data.pkl'
+ data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
+ if len(data_pkl_filenames) == 0:
+ raise Exception(f"data.pkl not found in {filename}")
+ if len(data_pkl_filenames) > 1:
+ raise Exception(f"Multiple data.pkl found in {filename}")
+ with z.open(data_pkl_filenames[0]) as file:
unpickler = RestrictedUnpickler(file)
+ unpickler.extra_handler = extra_handler
unpickler.load()
except zipfile.BadZipfile:
@@ -85,16 +97,42 @@ def check_pt(filename):
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
with open(filename, "rb") as file:
unpickler = RestrictedUnpickler(file)
+ unpickler.extra_handler = extra_handler
for i in range(5):
unpickler.load()
def load(filename, *args, **kwargs):
+ return load_with_extra(filename, *args, **kwargs)
+
+
+def load_with_extra(filename, extra_handler=None, *args, **kwargs):
+ """
+ this functon is intended to be used by extensions that want to load models with
+ some extra classes in them that the usual unpickler would find suspicious.
+
+ Use the extra_handler argument to specify a function that takes module and field name as text,
+ and returns that field's value:
+
+ ```python
+ def extra(module, name):
+ if module == 'collections' and name == 'OrderedDict':
+ return collections.OrderedDict
+
+ return None
+
+ safe.load_with_extra('model.pt', extra_handler=extra)
+ ```
+
+ The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
+ definitely unsafe.
+ """
+
from modules import shared
try:
if not shared.cmd_opts.disable_safe_unpickle:
- check_pt(filename)
+ check_pt(filename, extra_handler)
except pickle.UnpicklingError:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)