From e46147786914484b422899ee7154ae1685d96ae5 Mon Sep 17 00:00:00 2001 From: SmirkingFace <116507648+smirkingface@users.noreply.github.com> Date: Fri, 2 Dec 2022 11:12:13 +0100 Subject: Fixed safe.py for pytorch 1.13 ckpt files --- modules/safe.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/safe.py b/modules/safe.py index a9209e38..10460ad0 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -62,14 +62,12 @@ class RestrictedUnpickler(pickle.Unpickler): raise Exception(f"global '{module}/{name}' is forbidden") -allowed_zip_names = ["archive/data.pkl", "archive/version"] -allowed_zip_names_re = re.compile(r"^archive/data/\d+$") - +# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/' +allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$") +data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$") def check_zip_filenames(filename, names): for name in names: - if name in allowed_zip_names: - continue if allowed_zip_names_re.match(name): continue @@ -82,8 +80,14 @@ def check_pt(filename, extra_handler): # new pytorch format is a zip file with zipfile.ZipFile(filename) as z: check_zip_filenames(filename, z.namelist()) - - with z.open('archive/data.pkl') as file: + + # find filename of data.pkl in zip file: '/data.pkl' + data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)] + if len(data_pkl_filenames) == 0: + raise Exception(f"data.pkl not found in {filename}") + if len(data_pkl_filenames) > 1: + raise Exception(f"Multiple data.pkl found in {filename}") + with z.open(data_pkl_filenames[0]) as file: unpickler = RestrictedUnpickler(file) unpickler.extra_handler = extra_handler unpickler.load() -- cgit v1.2.1