[Python-checkins] cpython (merge 3.3 -> default): Issue #16408: Fix file descriptors not being closed in error conditions in the

antoine.pitrou python-checkins at python.org
Sat Nov 17 23:56:14 CET 2012


http://hg.python.org/cpython/rev/779e8f31dd30
changeset:   80484:779e8f31dd30
parent:      80481:5f2624db78bd
parent:      80483:27cb1a3d57c8
user:        Antoine Pitrou <solipsis at pitrou.net>
date:        Sat Nov 17 23:54:40 2012 +0100
summary:
  Issue #16408: Fix file descriptors not being closed in error conditions in the zipfile module.
Patch by Serhiy Storchaka.

files:
  Lib/zipfile.py |  425 +++++++++++++++++-------------------
  Misc/NEWS      |    3 +
  2 files changed, 207 insertions(+), 221 deletions(-)


diff --git a/Lib/zipfile.py b/Lib/zipfile.py
--- a/Lib/zipfile.py
+++ b/Lib/zipfile.py
@@ -906,30 +906,34 @@
             self.fp = file
             self.filename = getattr(file, 'name', None)
 
-        if key == 'r':
-            self._GetContents()
-        elif key == 'w':
-            # set the modified flag so central directory gets written
-            # even if no files are added to the archive
-            self._didModify = True
-        elif key == 'a':
-            try:
-                # See if file is a zip file
+        try:
+            if key == 'r':
                 self._RealGetContents()
-                # seek to start of directory and overwrite
-                self.fp.seek(self.start_dir, 0)
-            except BadZipFile:
-                # file is not a zip file, just append
-                self.fp.seek(0, 2)
-
+            elif key == 'w':
                 # set the modified flag so central directory gets written
                 # even if no files are added to the archive
                 self._didModify = True
-        else:
+            elif key == 'a':
+                try:
+                    # See if file is a zip file
+                    self._RealGetContents()
+                    # seek to start of directory and overwrite
+                    self.fp.seek(self.start_dir, 0)
+                except BadZipFile:
+                    # file is not a zip file, just append
+                    self.fp.seek(0, 2)
+
+                    # set the modified flag so central directory gets written
+                    # even if no files are added to the archive
+                    self._didModify = True
+            else:
+                raise RuntimeError('Mode must be "r", "w" or "a"')
+        except:
+            fp = self.fp
+            self.fp = None
             if not self._filePassed:
-                self.fp.close()
-                self.fp = None
-            raise RuntimeError('Mode must be "r", "w" or "a"')
+                fp.close()
+            raise
 
     def __enter__(self):
         return self
@@ -937,17 +941,6 @@
     def __exit__(self, type, value, traceback):
         self.close()
 
-    def _GetContents(self):
-        """Read the directory, making sure we close the file if the format
-        is bad."""
-        try:
-            self._RealGetContents()
-        except BadZipFile:
-            if not self._filePassed:
-                self.fp.close()
-                self.fp = None
-            raise
-
     def _RealGetContents(self):
         """Read in the table of contents for the ZIP file."""
         fp = self.fp
@@ -1049,9 +1042,9 @@
             try:
                 # Read by chunks, to avoid an OverflowError or a
                 # MemoryError with very large embedded files.
-                f = self.open(zinfo.filename, "r")
-                while f.read(chunk_size):     # Check CRC-32
-                    pass
+                with self.open(zinfo.filename, "r") as f:
+                    while f.read(chunk_size):     # Check CRC-32
+                        pass
             except BadZipFile:
                 return zinfo.filename
 
@@ -1113,84 +1106,78 @@
         else:
             zef_file = io.open(self.filename, 'rb')
 
-        # Make sure we have an info object
-        if isinstance(name, ZipInfo):
-            # 'name' is already an info object
-            zinfo = name
-        else:
-            # Get info object for name
-            try:
+        try:
+            # Make sure we have an info object
+            if isinstance(name, ZipInfo):
+                # 'name' is already an info object
+                zinfo = name
+            else:
+                # Get info object for name
                 zinfo = self.getinfo(name)
-            except KeyError:
-                if not self._filePassed:
-                    zef_file.close()
-                raise
-        zef_file.seek(zinfo.header_offset, 0)
+            zef_file.seek(zinfo.header_offset, 0)
 
-        # Skip the file header:
-        fheader = zef_file.read(sizeFileHeader)
-        if fheader[0:4] != stringFileHeader:
-            raise BadZipFile("Bad magic number for file header")
+            # Skip the file header:
+            fheader = zef_file.read(sizeFileHeader)
+            if fheader[0:4] != stringFileHeader:
+                raise BadZipFile("Bad magic number for file header")
 
-        fheader = struct.unpack(structFileHeader, fheader)
-        fname = zef_file.read(fheader[_FH_FILENAME_LENGTH])
-        if fheader[_FH_EXTRA_FIELD_LENGTH]:
-            zef_file.read(fheader[_FH_EXTRA_FIELD_LENGTH])
+            fheader = struct.unpack(structFileHeader, fheader)
+            fname = zef_file.read(fheader[_FH_FILENAME_LENGTH])
+            if fheader[_FH_EXTRA_FIELD_LENGTH]:
+                zef_file.read(fheader[_FH_EXTRA_FIELD_LENGTH])
 
-        if zinfo.flag_bits & 0x20:
-            # Zip 2.7: compressed patched data
-            raise NotImplementedError("compressed patched data (flag bit 5)")
+            if zinfo.flag_bits & 0x20:
+                # Zip 2.7: compressed patched data
+                raise NotImplementedError("compressed patched data (flag bit 5)")
 
-        if zinfo.flag_bits & 0x40:
-            # strong encryption
-            raise NotImplementedError("strong encryption (flag bit 6)")
+            if zinfo.flag_bits & 0x40:
+                # strong encryption
+                raise NotImplementedError("strong encryption (flag bit 6)")
 
-        if zinfo.flag_bits & 0x800:
-            # UTF-8 filename
-            fname_str = fname.decode("utf-8")
-        else:
-            fname_str = fname.decode("cp437")
+            if zinfo.flag_bits & 0x800:
+                # UTF-8 filename
+                fname_str = fname.decode("utf-8")
+            else:
+                fname_str = fname.decode("cp437")
 
-        if fname_str != zinfo.orig_filename:
+            if fname_str != zinfo.orig_filename:
+                raise BadZipFile(
+                    'File name in directory %r and header %r differ.'
+                    % (zinfo.orig_filename, fname))
+
+            # check for encrypted flag & handle password
+            is_encrypted = zinfo.flag_bits & 0x1
+            zd = None
+            if is_encrypted:
+                if not pwd:
+                    pwd = self.pwd
+                if not pwd:
+                    raise RuntimeError("File %s is encrypted, password "
+                                       "required for extraction" % name)
+
+                zd = _ZipDecrypter(pwd)
+                # The first 12 bytes in the cypher stream is an encryption header
+                #  used to strengthen the algorithm. The first 11 bytes are
+                #  completely random, while the 12th contains the MSB of the CRC,
+                #  or the MSB of the file time depending on the header type
+                #  and is used to check the correctness of the password.
+                header = zef_file.read(12)
+                h = list(map(zd, header[0:12]))
+                if zinfo.flag_bits & 0x8:
+                    # compare against the file type from extended local headers
+                    check_byte = (zinfo._raw_time >> 8) & 0xff
+                else:
+                    # compare against the CRC otherwise
+                    check_byte = (zinfo.CRC >> 24) & 0xff
+                if h[11] != check_byte:
+                    raise RuntimeError("Bad password for file", name)
+
+            return ZipExtFile(zef_file, mode, zinfo, zd,
+                              close_fileobj=not self._filePassed)
+        except:
             if not self._filePassed:
                 zef_file.close()
-            raise BadZipFile(
-                  'File name in directory %r and header %r differ.'
-                  % (zinfo.orig_filename, fname))
-
-        # check for encrypted flag & handle password
-        is_encrypted = zinfo.flag_bits & 0x1
-        zd = None
-        if is_encrypted:
-            if not pwd:
-                pwd = self.pwd
-            if not pwd:
-                if not self._filePassed:
-                    zef_file.close()
-                raise RuntimeError("File %s is encrypted, "
-                                   "password required for extraction" % name)
-
-            zd = _ZipDecrypter(pwd)
-            # The first 12 bytes in the cypher stream is an encryption header
-            #  used to strengthen the algorithm. The first 11 bytes are
-            #  completely random, while the 12th contains the MSB of the CRC,
-            #  or the MSB of the file time depending on the header type
-            #  and is used to check the correctness of the password.
-            header = zef_file.read(12)
-            h = list(map(zd, header[0:12]))
-            if zinfo.flag_bits & 0x8:
-                # compare against the file type from extended local headers
-                check_byte = (zinfo._raw_time >> 8) & 0xff
-            else:
-                # compare against the CRC otherwise
-                check_byte = (zinfo.CRC >> 24) & 0xff
-            if h[11] != check_byte:
-                if not self._filePassed:
-                    zef_file.close()
-                raise RuntimeError("Bad password for file", name)
-
-        return ZipExtFile(zef_file, mode, zinfo, zd,
-                          close_fileobj=not self._filePassed)
+            raise
 
     def extract(self, member, path=None, pwd=None):
         """Extract a member from the archive to the current working directory,
@@ -1247,11 +1234,9 @@
                 os.mkdir(targetpath)
             return targetpath
 
-        source = self.open(member, pwd=pwd)
-        target = open(targetpath, "wb")
-        shutil.copyfileobj(source, target)
-        source.close()
-        target.close()
+        with self.open(member, pwd=pwd) as source, \
+             open(targetpath, "wb") as target:
+            shutil.copyfileobj(source, target)
 
         return targetpath
 
@@ -1411,105 +1396,107 @@
         if self.fp is None:
             return
 
-        if self.mode in ("w", "a") and self._didModify: # write ending records
-            count = 0
-            pos1 = self.fp.tell()
-            for zinfo in self.filelist:         # write central directory
-                count = count + 1
-                dt = zinfo.date_time
-                dosdate = (dt[0] - 1980) << 9 | dt[1] << 5 | dt[2]
-                dostime = dt[3] << 11 | dt[4] << 5 | (dt[5] // 2)
-                extra = []
-                if zinfo.file_size > ZIP64_LIMIT \
-                        or zinfo.compress_size > ZIP64_LIMIT:
-                    extra.append(zinfo.file_size)
-                    extra.append(zinfo.compress_size)
-                    file_size = 0xffffffff
-                    compress_size = 0xffffffff
-                else:
-                    file_size = zinfo.file_size
-                    compress_size = zinfo.compress_size
+        try:
+            if self.mode in ("w", "a") and self._didModify: # write ending records
+                count = 0
+                pos1 = self.fp.tell()
+                for zinfo in self.filelist:         # write central directory
+                    count = count + 1
+                    dt = zinfo.date_time
+                    dosdate = (dt[0] - 1980) << 9 | dt[1] << 5 | dt[2]
+                    dostime = dt[3] << 11 | dt[4] << 5 | (dt[5] // 2)
+                    extra = []
+                    if zinfo.file_size > ZIP64_LIMIT \
+                            or zinfo.compress_size > ZIP64_LIMIT:
+                        extra.append(zinfo.file_size)
+                        extra.append(zinfo.compress_size)
+                        file_size = 0xffffffff
+                        compress_size = 0xffffffff
+                    else:
+                        file_size = zinfo.file_size
+                        compress_size = zinfo.compress_size
 
-                if zinfo.header_offset > ZIP64_LIMIT:
-                    extra.append(zinfo.header_offset)
-                    header_offset = 0xffffffff
-                else:
-                    header_offset = zinfo.header_offset
+                    if zinfo.header_offset > ZIP64_LIMIT:
+                        extra.append(zinfo.header_offset)
+                        header_offset = 0xffffffff
+                    else:
+                        header_offset = zinfo.header_offset
 
-                extra_data = zinfo.extra
-                min_version = 0
-                if extra:
-                    # Append a ZIP64 field to the extra's
-                    extra_data = struct.pack(
-                            '<HH' + 'Q'*len(extra),
-                            1, 8*len(extra), *extra) + extra_data
+                    extra_data = zinfo.extra
+                    min_version = 0
+                    if extra:
+                        # Append a ZIP64 field to the extra's
+                        extra_data = struct.pack(
+                                '<HH' + 'Q'*len(extra),
+                                1, 8*len(extra), *extra) + extra_data
 
-                    min_version = ZIP64_VERSION
+                        min_version = ZIP64_VERSION
 
-                if zinfo.compress_type == ZIP_BZIP2:
-                    min_version = max(BZIP2_VERSION, min_version)
-                elif zinfo.compress_type == ZIP_LZMA:
-                    min_version = max(LZMA_VERSION, min_version)
+                    if zinfo.compress_type == ZIP_BZIP2:
+                        min_version = max(BZIP2_VERSION, min_version)
+                    elif zinfo.compress_type == ZIP_LZMA:
+                        min_version = max(LZMA_VERSION, min_version)
 
-                extract_version = max(min_version, zinfo.extract_version)
-                create_version = max(min_version, zinfo.create_version)
-                try:
-                    filename, flag_bits = zinfo._encodeFilenameFlags()
-                    centdir = struct.pack(structCentralDir,
-                        stringCentralDir, create_version,
-                        zinfo.create_system, extract_version, zinfo.reserved,
-                        flag_bits, zinfo.compress_type, dostime, dosdate,
-                        zinfo.CRC, compress_size, file_size,
-                        len(filename), len(extra_data), len(zinfo.comment),
-                        0, zinfo.internal_attr, zinfo.external_attr,
-                        header_offset)
-                except DeprecationWarning:
-                    print((structCentralDir, stringCentralDir, create_version,
-                        zinfo.create_system, extract_version, zinfo.reserved,
-                        zinfo.flag_bits, zinfo.compress_type, dostime, dosdate,
-                        zinfo.CRC, compress_size, file_size,
-                        len(zinfo.filename), len(extra_data), len(zinfo.comment),
-                        0, zinfo.internal_attr, zinfo.external_attr,
-                        header_offset), file=sys.stderr)
-                    raise
-                self.fp.write(centdir)
-                self.fp.write(filename)
-                self.fp.write(extra_data)
-                self.fp.write(zinfo.comment)
+                    extract_version = max(min_version, zinfo.extract_version)
+                    create_version = max(min_version, zinfo.create_version)
+                    try:
+                        filename, flag_bits = zinfo._encodeFilenameFlags()
+                        centdir = struct.pack(structCentralDir,
+                            stringCentralDir, create_version,
+                            zinfo.create_system, extract_version, zinfo.reserved,
+                            flag_bits, zinfo.compress_type, dostime, dosdate,
+                            zinfo.CRC, compress_size, file_size,
+                            len(filename), len(extra_data), len(zinfo.comment),
+                            0, zinfo.internal_attr, zinfo.external_attr,
+                            header_offset)
+                    except DeprecationWarning:
+                        print((structCentralDir, stringCentralDir, create_version,
+                            zinfo.create_system, extract_version, zinfo.reserved,
+                            zinfo.flag_bits, zinfo.compress_type, dostime, dosdate,
+                            zinfo.CRC, compress_size, file_size,
+                            len(zinfo.filename), len(extra_data), len(zinfo.comment),
+                            0, zinfo.internal_attr, zinfo.external_attr,
+                            header_offset), file=sys.stderr)
+                        raise
+                    self.fp.write(centdir)
+                    self.fp.write(filename)
+                    self.fp.write(extra_data)
+                    self.fp.write(zinfo.comment)
 
-            pos2 = self.fp.tell()
-            # Write end-of-zip-archive record
-            centDirCount = count
-            centDirSize = pos2 - pos1
-            centDirOffset = pos1
-            if (centDirCount >= ZIP_FILECOUNT_LIMIT or
-                centDirOffset > ZIP64_LIMIT or
-                centDirSize > ZIP64_LIMIT):
-                # Need to write the ZIP64 end-of-archive records
-                zip64endrec = struct.pack(
-                        structEndArchive64, stringEndArchive64,
-                        44, 45, 45, 0, 0, centDirCount, centDirCount,
-                        centDirSize, centDirOffset)
-                self.fp.write(zip64endrec)
+                pos2 = self.fp.tell()
+                # Write end-of-zip-archive record
+                centDirCount = count
+                centDirSize = pos2 - pos1
+                centDirOffset = pos1
+                if (centDirCount >= ZIP_FILECOUNT_LIMIT or
+                    centDirOffset > ZIP64_LIMIT or
+                    centDirSize > ZIP64_LIMIT):
+                    # Need to write the ZIP64 end-of-archive records
+                    zip64endrec = struct.pack(
+                            structEndArchive64, stringEndArchive64,
+                            44, 45, 45, 0, 0, centDirCount, centDirCount,
+                            centDirSize, centDirOffset)
+                    self.fp.write(zip64endrec)
 
-                zip64locrec = struct.pack(
-                        structEndArchive64Locator,
-                        stringEndArchive64Locator, 0, pos2, 1)
-                self.fp.write(zip64locrec)
-                centDirCount = min(centDirCount, 0xFFFF)
-                centDirSize = min(centDirSize, 0xFFFFFFFF)
-                centDirOffset = min(centDirOffset, 0xFFFFFFFF)
+                    zip64locrec = struct.pack(
+                            structEndArchive64Locator,
+                            stringEndArchive64Locator, 0, pos2, 1)
+                    self.fp.write(zip64locrec)
+                    centDirCount = min(centDirCount, 0xFFFF)
+                    centDirSize = min(centDirSize, 0xFFFFFFFF)
+                    centDirOffset = min(centDirOffset, 0xFFFFFFFF)
 
-            endrec = struct.pack(structEndArchive, stringEndArchive,
-                                 0, 0, centDirCount, centDirCount,
-                                 centDirSize, centDirOffset, len(self._comment))
-            self.fp.write(endrec)
-            self.fp.write(self._comment)
-            self.fp.flush()
-
-        if not self._filePassed:
-            self.fp.close()
-        self.fp = None
+                endrec = struct.pack(structEndArchive, stringEndArchive,
+                                    0, 0, centDirCount, centDirCount,
+                                    centDirSize, centDirOffset, len(self._comment))
+                self.fp.write(endrec)
+                self.fp.write(self._comment)
+                self.fp.flush()
+        finally:
+            fp = self.fp
+            self.fp = None
+            if not self._filePassed:
+                fp.close()
 
 
 class PyZipFile(ZipFile):
@@ -1676,16 +1663,15 @@
         if len(args) != 2:
             print(USAGE)
             sys.exit(1)
-        zf = ZipFile(args[1], 'r')
-        zf.printdir()
-        zf.close()
+        with ZipFile(args[1], 'r') as zf:
+            zf.printdir()
 
     elif args[0] == '-t':
         if len(args) != 2:
             print(USAGE)
             sys.exit(1)
-        zf = ZipFile(args[1], 'r')
-        badfile = zf.testzip()
+        with ZipFile(args[1], 'r') as zf:
+            badfile = zf.testzip()
         if badfile:
             print("The following enclosed file is corrupted: {!r}".format(badfile))
         print("Done testing")
@@ -1695,20 +1681,19 @@
             print(USAGE)
             sys.exit(1)
 
-        zf = ZipFile(args[1], 'r')
-        out = args[2]
-        for path in zf.namelist():
-            if path.startswith('./'):
-                tgt = os.path.join(out, path[2:])
-            else:
-                tgt = os.path.join(out, path)
+        with ZipFile(args[1], 'r') as zf:
+            out = args[2]
+            for path in zf.namelist():
+                if path.startswith('./'):
+                    tgt = os.path.join(out, path[2:])
+                else:
+                    tgt = os.path.join(out, path)
 
-            tgtdir = os.path.dirname(tgt)
-            if not os.path.exists(tgtdir):
-                os.makedirs(tgtdir)
-            with open(tgt, 'wb') as fp:
-                fp.write(zf.read(path))
-        zf.close()
+                tgtdir = os.path.dirname(tgt)
+                if not os.path.exists(tgtdir):
+                    os.makedirs(tgtdir)
+                with open(tgt, 'wb') as fp:
+                    fp.write(zf.read(path))
 
     elif args[0] == '-c':
         if len(args) < 3:
@@ -1724,11 +1709,9 @@
                             os.path.join(path, nm), os.path.join(zippath, nm))
             # else: ignore
 
-        zf = ZipFile(args[1], 'w', allowZip64=True)
-        for src in args[2:]:
-            addToZip(zf, src, os.path.basename(src))
-
-        zf.close()
+        with ZipFile(args[1], 'w', allowZip64=True) as zf:
+            for src in args[2:]:
+                addToZip(zf, src, os.path.basename(src))
 
 if __name__ == "__main__":
     main()
diff --git a/Misc/NEWS b/Misc/NEWS
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -130,6 +130,9 @@
 Library
 -------
 
+- Issue #16408: Fix file descriptors not being closed in error conditions
+  in the zipfile module.  Patch by Serhiy Storchaka.
+
 - Issue #14631: Add a new :class:`weakref.WeakMethod` to simulate weak
   references to bound methods.
 

-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list