From 6e4de5b4422dfc0d45063b2c8c78b19f00321615 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 6 Nov 2022 11:20:23 +0300 Subject: add load_with_extra function for modules to load checkpoints with extended whitelist --- modules/safe.py | 40 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/safe.py b/modules/safe.py index 348a24fc..a9209e38 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -23,11 +23,18 @@ def encode(*args): class RestrictedUnpickler(pickle.Unpickler): + extra_handler = None + def persistent_load(self, saved_id): assert saved_id[0] == 'storage' return TypedStorage() def find_class(self, module, name): + if self.extra_handler is not None: + res = self.extra_handler(module, name) + if res is not None: + return res + if module == 'collections' and name == 'OrderedDict': return getattr(collections, name) if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']: @@ -52,7 +59,7 @@ class RestrictedUnpickler(pickle.Unpickler): return set # Forbid everything else. - raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden") + raise Exception(f"global '{module}/{name}' is forbidden") allowed_zip_names = ["archive/data.pkl", "archive/version"] @@ -69,7 +76,7 @@ def check_zip_filenames(filename, names): raise Exception(f"bad file inside {filename}: {name}") -def check_pt(filename): +def check_pt(filename, extra_handler): try: # new pytorch format is a zip file @@ -78,6 +85,7 @@ def check_pt(filename): with z.open('archive/data.pkl') as file: unpickler = RestrictedUnpickler(file) + unpickler.extra_handler = extra_handler unpickler.load() except zipfile.BadZipfile: @@ -85,16 +93,42 @@ def check_pt(filename): # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle with open(filename, "rb") as file: unpickler = RestrictedUnpickler(file) + unpickler.extra_handler = extra_handler for i in range(5): unpickler.load() def load(filename, *args, **kwargs): + return load_with_extra(filename, *args, **kwargs) + + +def load_with_extra(filename, extra_handler=None, *args, **kwargs): + """ + this functon is intended to be used by extensions that want to load models with + some extra classes in them that the usual unpickler would find suspicious. + + Use the extra_handler argument to specify a function that takes module and field name as text, + and returns that field's value: + + ```python + def extra(module, name): + if module == 'collections' and name == 'OrderedDict': + return collections.OrderedDict + + return None + + safe.load_with_extra('model.pt', extra_handler=extra) + ``` + + The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is + definitely unsafe. + """ + from modules import shared try: if not shared.cmd_opts.disable_safe_unpickle: - check_pt(filename) + check_pt(filename, extra_handler) except pickle.UnpicklingError: print(f"Error verifying pickled file from {filename}:", file=sys.stderr) -- cgit v1.2.1