[Python-checkins] r84978 - in python/branches/py3k/Lib: tarfile.py test/test_tarfile.py

antoine.pitrou python-checkins at python.org
Thu Sep 23 20:36:47 CEST 2010


Author: antoine.pitrou
Date: Thu Sep 23 20:36:46 2010
New Revision: 84978

Log:
Try to fix test_tarfile issues on Windows buildbots by closing file
objects explicitly instead of letting them linger on.



Modified:
   python/branches/py3k/Lib/tarfile.py
   python/branches/py3k/Lib/test/test_tarfile.py

Modified: python/branches/py3k/Lib/tarfile.py
==============================================================================
--- python/branches/py3k/Lib/tarfile.py	(original)
+++ python/branches/py3k/Lib/tarfile.py	Thu Sep 23 20:36:46 2010
@@ -1764,14 +1764,19 @@
 
         if fileobj is None:
             fileobj = bltn_open(name, mode + "b")
+            extfileobj = False
+        else:
+            extfileobj = True
 
         try:
             t = cls.taropen(name, mode,
                 gzip.GzipFile(name, mode, compresslevel, fileobj),
                 **kwargs)
         except IOError:
+            if not extfileobj:
+                fileobj.close()
             raise ReadError("not a gzip file")
-        t._extfileobj = False
+        t._extfileobj = extfileobj
         return t
 
     @classmethod
@@ -1795,6 +1800,7 @@
         try:
             t = cls.taropen(name, mode, fileobj, **kwargs)
         except (IOError, EOFError):
+            fileobj.close()
             raise ReadError("not a bzip2 file")
         t._extfileobj = False
         return t

Modified: python/branches/py3k/Lib/test/test_tarfile.py
==============================================================================
--- python/branches/py3k/Lib/test/test_tarfile.py	(original)
+++ python/branches/py3k/Lib/test/test_tarfile.py	Thu Sep 23 20:36:46 2010
@@ -59,10 +59,10 @@
     def test_fileobj_readlines(self):
         self.tar.extract("ustar/regtype", TEMPDIR)
         tarinfo = self.tar.getmember("ustar/regtype")
-        fobj1 = open(os.path.join(TEMPDIR, "ustar/regtype"), "r")
+        with open(os.path.join(TEMPDIR, "ustar/regtype"), "r") as fobj1:
+            lines1 = fobj1.readlines()
         fobj2 = io.TextIOWrapper(self.tar.extractfile(tarinfo))
 
-        lines1 = fobj1.readlines()
         lines2 = fobj2.readlines()
         self.assertTrue(lines1 == lines2,
                 "fileobj.readlines() failed")
@@ -75,18 +75,17 @@
     def test_fileobj_iter(self):
         self.tar.extract("ustar/regtype", TEMPDIR)
         tarinfo = self.tar.getmember("ustar/regtype")
-        fobj1 = open(os.path.join(TEMPDIR, "ustar/regtype"), "rU")
+        with open(os.path.join(TEMPDIR, "ustar/regtype"), "rU") as fobj1:
+            lines1 = fobj1.readlines()
         fobj2 = self.tar.extractfile(tarinfo)
-        lines1 = fobj1.readlines()
         lines2 = list(io.TextIOWrapper(fobj2))
         self.assertTrue(lines1 == lines2,
                      "fileobj.__iter__() failed")
 
     def test_fileobj_seek(self):
         self.tar.extract("ustar/regtype", TEMPDIR)
-        fobj = open(os.path.join(TEMPDIR, "ustar/regtype"), "rb")
-        data = fobj.read()
-        fobj.close()
+        with open(os.path.join(TEMPDIR, "ustar/regtype"), "rb") as fobj:
+            data = fobj.read()
 
         tarinfo = self.tar.getmember("ustar/regtype")
         fobj = self.tar.extractfile(tarinfo)
@@ -161,19 +160,24 @@
         # This test checks if tarfile.open() is able to open an empty tar
         # archive successfully. Note that an empty tar archive is not the
         # same as an empty file!
-        tarfile.open(tmpname, self.mode.replace("r", "w")).close()
+        with tarfile.open(tmpname, self.mode.replace("r", "w")):
+            pass
         try:
             tar = tarfile.open(tmpname, self.mode)
             tar.getnames()
         except tarfile.ReadError:
             self.fail("tarfile.open() failed on empty archive")
-        self.assertListEqual(tar.getmembers(), [])
+        else:
+            self.assertListEqual(tar.getmembers(), [])
+        finally:
+            tar.close()
 
     def test_null_tarfile(self):
         # Test for issue6123: Allow opening empty archives.
         # This test guarantees that tarfile.open() does not treat an empty
         # file as an empty tar archive.
-        open(tmpname, "wb").close()
+        with open(tmpname, "wb"):
+            pass
         self.assertRaises(tarfile.ReadError, tarfile.open, tmpname, self.mode)
         self.assertRaises(tarfile.ReadError, tarfile.open, tmpname)
 
@@ -189,33 +193,36 @@
         for char in (b'\0', b'a'):
             # Test if EOFHeaderError ('\0') and InvalidHeaderError ('a')
             # are ignored correctly.
-            fobj = _open(tmpname, "wb")
-            fobj.write(char * 1024)
-            fobj.write(tarfile.TarInfo("foo").tobuf())
-            fobj.close()
+            with _open(tmpname, "wb") as fobj:
+                fobj.write(char * 1024)
+                fobj.write(tarfile.TarInfo("foo").tobuf())
 
             tar = tarfile.open(tmpname, mode="r", ignore_zeros=True)
-            self.assertListEqual(tar.getnames(), ["foo"],
+            try:
+                self.assertListEqual(tar.getnames(), ["foo"],
                     "ignore_zeros=True should have skipped the %r-blocks" % char)
-            tar.close()
+            finally:
+                tar.close()
 
 
 class MiscReadTest(CommonReadTest):
 
     def test_no_name_argument(self):
-        fobj = open(self.tarname, "rb")
-        tar = tarfile.open(fileobj=fobj, mode=self.mode)
-        self.assertEqual(tar.name, os.path.abspath(fobj.name))
+        with open(self.tarname, "rb") as fobj:
+            tar = tarfile.open(fileobj=fobj, mode=self.mode)
+            self.assertEqual(tar.name, os.path.abspath(fobj.name))
 
     def test_no_name_attribute(self):
-        data = open(self.tarname, "rb").read()
+        with open(self.tarname, "rb") as fobj:
+            data = fobj.read()
         fobj = io.BytesIO(data)
         self.assertRaises(AttributeError, getattr, fobj, "name")
         tar = tarfile.open(fileobj=fobj, mode=self.mode)
         self.assertEqual(tar.name, None)
 
     def test_empty_name_attribute(self):
-        data = open(self.tarname, "rb").read()
+        with open(self.tarname, "rb") as fobj:
+            data = fobj.read()
         fobj = io.BytesIO(data)
         fobj.name = ""
         tar = tarfile.open(fileobj=fobj, mode=self.mode)
@@ -225,12 +232,14 @@
         # Skip the first member and store values from the second member
         # of the testtar.
         tar = tarfile.open(self.tarname, mode=self.mode)
-        tar.next()
-        t = tar.next()
-        name = t.name
-        offset = t.offset
-        data = tar.extractfile(t).read()
-        tar.close()
+        try:
+            tar.next()
+            t = tar.next()
+            name = t.name
+            offset = t.offset
+            data = tar.extractfile(t).read()
+        finally:
+            tar.close()
 
         # Open the testtar and seek to the offset of the second member.
         if self.mode.endswith(":gz"):
@@ -240,26 +249,30 @@
         else:
             _open = open
         fobj = _open(self.tarname, "rb")
-        fobj.seek(offset)
+        try:
+            fobj.seek(offset)
 
-        # Test if the tarfile starts with the second member.
-        tar = tar.open(self.tarname, mode="r:", fileobj=fobj)
-        t = tar.next()
-        self.assertEqual(t.name, name)
-        # Read to the end of fileobj and test if seeking back to the
-        # beginning works.
-        tar.getmembers()
-        self.assertEqual(tar.extractfile(t).read(), data,
-                "seek back did not work")
-        tar.close()
+            # Test if the tarfile starts with the second member.
+            tar = tar.open(self.tarname, mode="r:", fileobj=fobj)
+            t = tar.next()
+            self.assertEqual(t.name, name)
+            # Read to the end of fileobj and test if seeking back to the
+            # beginning works.
+            tar.getmembers()
+            self.assertEqual(tar.extractfile(t).read(), data,
+                    "seek back did not work")
+            tar.close()
+        finally:
+            fobj.close()
 
     def test_fail_comp(self):
         # For Gzip and Bz2 Tests: fail with a ReadError on an uncompressed file.
         if self.mode == "r:":
             return
         self.assertRaises(tarfile.ReadError, tarfile.open, tarname, self.mode)
-        fobj = open(tarname, "rb")
-        self.assertRaises(tarfile.ReadError, tarfile.open, fileobj=fobj, mode=self.mode)
+        with open(tarname, "rb") as fobj:
+            self.assertRaises(tarfile.ReadError, tarfile.open,
+                              fileobj=fobj, mode=self.mode)
 
     def test_v7_dirtype(self):
         # Test old style dirtype member (bug #1336623):
@@ -298,45 +311,51 @@
         # Test hardlink extraction (e.g. bug #857297).
         tar = tarfile.open(tarname, errorlevel=1, encoding="iso8859-1")
 
-        tar.extract("ustar/regtype", TEMPDIR)
         try:
-            tar.extract("ustar/lnktype", TEMPDIR)
-        except EnvironmentError as e:
-            if e.errno == errno.ENOENT:
-                self.fail("hardlink not extracted properly")
+            tar.extract("ustar/regtype", TEMPDIR)
+            try:
+                tar.extract("ustar/lnktype", TEMPDIR)
+            except EnvironmentError as e:
+                if e.errno == errno.ENOENT:
+                    self.fail("hardlink not extracted properly")
 
-        data = open(os.path.join(TEMPDIR, "ustar/lnktype"), "rb").read()
-        self.assertEqual(md5sum(data), md5_regtype)
+            data = open(os.path.join(TEMPDIR, "ustar/lnktype"), "rb").read()
+            self.assertEqual(md5sum(data), md5_regtype)
 
-        try:
-            tar.extract("ustar/symtype", TEMPDIR)
-        except EnvironmentError as e:
-            if e.errno == errno.ENOENT:
-                self.fail("symlink not extracted properly")
+            try:
+                tar.extract("ustar/symtype", TEMPDIR)
+            except EnvironmentError as e:
+                if e.errno == errno.ENOENT:
+                    self.fail("symlink not extracted properly")
 
-        data = open(os.path.join(TEMPDIR, "ustar/symtype"), "rb").read()
-        self.assertEqual(md5sum(data), md5_regtype)
+            data = open(os.path.join(TEMPDIR, "ustar/symtype"), "rb").read()
+            self.assertEqual(md5sum(data), md5_regtype)
+        finally:
+            tar.close()
 
     def test_extractall(self):
         # Test if extractall() correctly restores directory permissions
         # and times (see issue1735).
         tar = tarfile.open(tarname, encoding="iso8859-1")
-        directories = [t for t in tar if t.isdir()]
-        tar.extractall(TEMPDIR, directories)
-        for tarinfo in directories:
-            path = os.path.join(TEMPDIR, tarinfo.name)
-            if sys.platform != "win32":
-                # Win32 has no support for fine grained permissions.
-                self.assertEqual(tarinfo.mode & 0o777, os.stat(path).st_mode & 0o777)
-            self.assertEqual(tarinfo.mtime, os.path.getmtime(path))
-        tar.close()
+        try:
+            directories = [t for t in tar if t.isdir()]
+            tar.extractall(TEMPDIR, directories)
+            for tarinfo in directories:
+                path = os.path.join(TEMPDIR, tarinfo.name)
+                if sys.platform != "win32":
+                    # Win32 has no support for fine grained permissions.
+                    self.assertEqual(tarinfo.mode & 0o777, os.stat(path).st_mode & 0o777)
+                self.assertEqual(tarinfo.mtime, os.path.getmtime(path))
+        finally:
+            tar.close()
 
     def test_init_close_fobj(self):
         # Issue #7341: Close the internal file object in the TarFile
         # constructor in case of an error. For the test we rely on
         # the fact that opening an empty file raises a ReadError.
         empty = os.path.join(TEMPDIR, "empty")
-        open(empty, "wb").write(b"")
+        with open(empty, "wb") as fobj:
+            fobj.write(b"")
 
         try:
             tar = object.__new__(tarfile.TarFile)
@@ -347,7 +366,7 @@
             else:
                 self.fail("ReadError not raised")
         finally:
-            os.remove(empty)
+            support.unlink(empty)
 
 
 class StreamReadTest(CommonReadTest):
@@ -368,42 +387,47 @@
 
     def test_compare_members(self):
         tar1 = tarfile.open(tarname, encoding="iso8859-1")
-        tar2 = self.tar
-
-        while True:
-            t1 = tar1.next()
-            t2 = tar2.next()
-            if t1 is None:
-                break
-            self.assertTrue(t2 is not None, "stream.next() failed.")
-
-            if t2.islnk() or t2.issym():
-                self.assertRaises(tarfile.StreamError, tar2.extractfile, t2)
-                continue
-
-            v1 = tar1.extractfile(t1)
-            v2 = tar2.extractfile(t2)
-            if v1 is None:
-                continue
-            self.assertTrue(v2 is not None, "stream.extractfile() failed")
-            self.assertEqual(v1.read(), v2.read(), "stream extraction failed")
+        try:
+            tar2 = self.tar
 
-        tar1.close()
+            while True:
+                t1 = tar1.next()
+                t2 = tar2.next()
+                if t1 is None:
+                    break
+                self.assertTrue(t2 is not None, "stream.next() failed.")
+
+                if t2.islnk() or t2.issym():
+                    self.assertRaises(tarfile.StreamError, tar2.extractfile, t2)
+                    continue
+
+                v1 = tar1.extractfile(t1)
+                v2 = tar2.extractfile(t2)
+                if v1 is None:
+                    continue
+                self.assertTrue(v2 is not None, "stream.extractfile() failed")
+                self.assertEqual(v1.read(), v2.read(), "stream extraction failed")
+        finally:
+            tar1.close()
 
 
 class DetectReadTest(unittest.TestCase):
 
     def _testfunc_file(self, name, mode):
         try:
-            tarfile.open(name, mode)
+            tar = tarfile.open(name, mode)
         except tarfile.ReadError as e:
             self.fail()
+        else:
+            tar.close()
 
     def _testfunc_fileobj(self, name, mode):
         try:
-            tarfile.open(name, mode, fileobj=open(name, "rb"))
+            tar = tarfile.open(name, mode, fileobj=open(name, "rb"))
         except tarfile.ReadError as e:
             self.fail()
+        else:
+            tar.close()
 
     def _test_modes(self, testfunc):
         testfunc(tarname, "r")
@@ -579,33 +603,38 @@
 
     def test_pax_global_headers(self):
         tar = tarfile.open(tarname, encoding="iso8859-1")
-
-        tarinfo = tar.getmember("pax/regtype1")
-        self.assertEqual(tarinfo.uname, "foo")
-        self.assertEqual(tarinfo.gname, "bar")
-        self.assertEqual(tarinfo.pax_headers.get("VENDOR.umlauts"), "\xc4\xd6\xdc\xe4\xf6\xfc\xdf")
-
-        tarinfo = tar.getmember("pax/regtype2")
-        self.assertEqual(tarinfo.uname, "")
-        self.assertEqual(tarinfo.gname, "bar")
-        self.assertEqual(tarinfo.pax_headers.get("VENDOR.umlauts"), "\xc4\xd6\xdc\xe4\xf6\xfc\xdf")
-
-        tarinfo = tar.getmember("pax/regtype3")
-        self.assertEqual(tarinfo.uname, "tarfile")
-        self.assertEqual(tarinfo.gname, "tarfile")
-        self.assertEqual(tarinfo.pax_headers.get("VENDOR.umlauts"), "\xc4\xd6\xdc\xe4\xf6\xfc\xdf")
+        try:
+            tarinfo = tar.getmember("pax/regtype1")
+            self.assertEqual(tarinfo.uname, "foo")
+            self.assertEqual(tarinfo.gname, "bar")
+            self.assertEqual(tarinfo.pax_headers.get("VENDOR.umlauts"), "\xc4\xd6\xdc\xe4\xf6\xfc\xdf")
+
+            tarinfo = tar.getmember("pax/regtype2")
+            self.assertEqual(tarinfo.uname, "")
+            self.assertEqual(tarinfo.gname, "bar")
+            self.assertEqual(tarinfo.pax_headers.get("VENDOR.umlauts"), "\xc4\xd6\xdc\xe4\xf6\xfc\xdf")
+
+            tarinfo = tar.getmember("pax/regtype3")
+            self.assertEqual(tarinfo.uname, "tarfile")
+            self.assertEqual(tarinfo.gname, "tarfile")
+            self.assertEqual(tarinfo.pax_headers.get("VENDOR.umlauts"), "\xc4\xd6\xdc\xe4\xf6\xfc\xdf")
+        finally:
+            tar.close()
 
     def test_pax_number_fields(self):
         # All following number fields are read from the pax header.
         tar = tarfile.open(tarname, encoding="iso8859-1")
-        tarinfo = tar.getmember("pax/regtype4")
-        self.assertEqual(tarinfo.size, 7011)
-        self.assertEqual(tarinfo.uid, 123)
-        self.assertEqual(tarinfo.gid, 123)
-        self.assertEqual(tarinfo.mtime, 1041808783.0)
-        self.assertEqual(type(tarinfo.mtime), float)
-        self.assertEqual(float(tarinfo.pax_headers["atime"]), 1041808783.0)
-        self.assertEqual(float(tarinfo.pax_headers["ctime"]), 1041808783.0)
+        try:
+            tarinfo = tar.getmember("pax/regtype4")
+            self.assertEqual(tarinfo.size, 7011)
+            self.assertEqual(tarinfo.uid, 123)
+            self.assertEqual(tarinfo.gid, 123)
+            self.assertEqual(tarinfo.mtime, 1041808783.0)
+            self.assertEqual(type(tarinfo.mtime), float)
+            self.assertEqual(float(tarinfo.pax_headers["atime"]), 1041808783.0)
+            self.assertEqual(float(tarinfo.pax_headers["ctime"]), 1041808783.0)
+        finally:
+            tar.close()
 
 
 class WriteTestBase(unittest.TestCase):
@@ -631,52 +660,59 @@
         # a trailing '\0'.
         name = "0123456789" * 10
         tar = tarfile.open(tmpname, self.mode)
-        t = tarfile.TarInfo(name)
-        tar.addfile(t)
-        tar.close()
+        try:
+            t = tarfile.TarInfo(name)
+            tar.addfile(t)
+        finally:
+            tar.close()
 
         tar = tarfile.open(tmpname)
-        self.assertTrue(tar.getnames()[0] == name,
-                "failed to store 100 char filename")
-        tar.close()
+        try:
+            self.assertTrue(tar.getnames()[0] == name,
+                    "failed to store 100 char filename")
+        finally:
+            tar.close()
 
     def test_tar_size(self):
         # Test for bug #1013882.
         tar = tarfile.open(tmpname, self.mode)
-        path = os.path.join(TEMPDIR, "file")
-        fobj = open(path, "wb")
-        fobj.write(b"aaa")
-        fobj.close()
-        tar.add(path)
-        tar.close()
+        try:
+            path = os.path.join(TEMPDIR, "file")
+            with open(path, "wb") as fobj:
+                fobj.write(b"aaa")
+            tar.add(path)
+        finally:
+            tar.close()
         self.assertTrue(os.path.getsize(tmpname) > 0,
                 "tarfile is empty")
 
     # The test_*_size tests test for bug #1167128.
     def test_file_size(self):
         tar = tarfile.open(tmpname, self.mode)
+        try:
+            path = os.path.join(TEMPDIR, "file")
+            with open(path, "wb"):
+                pass
+            tarinfo = tar.gettarinfo(path)
+            self.assertEqual(tarinfo.size, 0)
 
-        path = os.path.join(TEMPDIR, "file")
-        fobj = open(path, "wb")
-        fobj.close()
-        tarinfo = tar.gettarinfo(path)
-        self.assertEqual(tarinfo.size, 0)
-
-        fobj = open(path, "wb")
-        fobj.write(b"aaa")
-        fobj.close()
-        tarinfo = tar.gettarinfo(path)
-        self.assertEqual(tarinfo.size, 3)
-
-        tar.close()
+            with open(path, "wb") as fobj:
+                fobj.write(b"aaa")
+            tarinfo = tar.gettarinfo(path)
+            self.assertEqual(tarinfo.size, 3)
+        finally:
+            tar.close()
 
     def test_directory_size(self):
         path = os.path.join(TEMPDIR, "directory")
         os.mkdir(path)
         try:
             tar = tarfile.open(tmpname, self.mode)
-            tarinfo = tar.gettarinfo(path)
-            self.assertEqual(tarinfo.size, 0)
+            try:
+                tarinfo = tar.gettarinfo(path)
+                self.assertEqual(tarinfo.size, 0)
+            finally:
+                tar.close()
         finally:
             os.rmdir(path)
 
@@ -684,16 +720,18 @@
         if hasattr(os, "link"):
             link = os.path.join(TEMPDIR, "link")
             target = os.path.join(TEMPDIR, "link_target")
-            fobj = open(target, "wb")
-            fobj.write(b"aaa")
-            fobj.close()
+            with open(target, "wb") as fobj:
+                fobj.write(b"aaa")
             os.link(target, link)
             try:
                 tar = tarfile.open(tmpname, self.mode)
-                # Record the link target in the inodes list.
-                tar.gettarinfo(target)
-                tarinfo = tar.gettarinfo(link)
-                self.assertEqual(tarinfo.size, 0)
+                try:
+                    # Record the link target in the inodes list.
+                    tar.gettarinfo(target)
+                    tarinfo = tar.gettarinfo(link)
+                    self.assertEqual(tarinfo.size, 0)
+                finally:
+                    tar.close()
             finally:
                 os.remove(target)
                 os.remove(link)
@@ -704,26 +742,30 @@
         os.symlink("link_target", path)
         try:
             tar = tarfile.open(tmpname, self.mode)
-            tarinfo = tar.gettarinfo(path)
-            self.assertEqual(tarinfo.size, 0)
+            try:
+                tarinfo = tar.gettarinfo(path)
+                self.assertEqual(tarinfo.size, 0)
+            finally:
+                tar.close()
         finally:
             os.remove(path)
 
     def test_add_self(self):
         # Test for #1257255.
         dstname = os.path.abspath(tmpname)
-
         tar = tarfile.open(tmpname, self.mode)
-        self.assertTrue(tar.name == dstname, "archive name must be absolute")
-
-        tar.add(dstname)
-        self.assertTrue(tar.getnames() == [], "added the archive to itself")
-
-        cwd = os.getcwd()
-        os.chdir(TEMPDIR)
-        tar.add(dstname)
-        os.chdir(cwd)
-        self.assertTrue(tar.getnames() == [], "added the archive to itself")
+        try:
+            self.assertTrue(tar.name == dstname, "archive name must be absolute")
+            tar.add(dstname)
+            self.assertTrue(tar.getnames() == [], "added the archive to itself")
+
+            cwd = os.getcwd()
+            os.chdir(TEMPDIR)
+            tar.add(dstname)
+            os.chdir(cwd)
+            self.assertTrue(tar.getnames() == [], "added the archive to itself")
+        finally:
+            tar.close()
 
     def test_exclude(self):
         tempdir = os.path.join(TEMPDIR, "exclude")
@@ -736,14 +778,19 @@
             exclude = os.path.isfile
 
             tar = tarfile.open(tmpname, self.mode, encoding="iso8859-1")
-            with support.check_warnings(("use the filter argument",
-                                         DeprecationWarning)):
-                tar.add(tempdir, arcname="empty_dir", exclude=exclude)
-            tar.close()
+            try:
+                with support.check_warnings(("use the filter argument",
+                                             DeprecationWarning)):
+                    tar.add(tempdir, arcname="empty_dir", exclude=exclude)
+            finally:
+                tar.close()
 
             tar = tarfile.open(tmpname, "r")
-            self.assertEqual(len(tar.getmembers()), 1)
-            self.assertEqual(tar.getnames()[0], "empty_dir")
+            try:
+                self.assertEqual(len(tar.getmembers()), 1)
+                self.assertEqual(tar.getnames()[0], "empty_dir")
+            finally:
+                tar.close()
         finally:
             shutil.rmtree(tempdir)
 
@@ -763,15 +810,19 @@
                 return tarinfo
 
             tar = tarfile.open(tmpname, self.mode, encoding="iso8859-1")
-            tar.add(tempdir, arcname="empty_dir", filter=filter)
-            tar.close()
+            try:
+                tar.add(tempdir, arcname="empty_dir", filter=filter)
+            finally:
+                tar.close()
 
             tar = tarfile.open(tmpname, "r")
-            for tarinfo in tar:
-                self.assertEqual(tarinfo.uid, 123)
-                self.assertEqual(tarinfo.uname, "foo")
-            self.assertEqual(len(tar.getmembers()), 3)
-            tar.close()
+            try:
+                for tarinfo in tar:
+                    self.assertEqual(tarinfo.uid, 123)
+                    self.assertEqual(tarinfo.uname, "foo")
+                self.assertEqual(len(tar.getmembers()), 3)
+            finally:
+                tar.close()
         finally:
             shutil.rmtree(tempdir)
 
@@ -789,12 +840,16 @@
             os.mkdir(foo)
 
         tar = tarfile.open(tmpname, self.mode)
-        tar.add(foo, arcname=path)
-        tar.close()
+        try:
+            tar.add(foo, arcname=path)
+        finally:
+            tar.close()
 
         tar = tarfile.open(tmpname, "r")
-        t = tar.next()
-        tar.close()
+        try:
+            t = tar.next()
+        finally:
+            tar.close()
 
         if not dir:
             os.remove(foo)
@@ -832,16 +887,18 @@
         cwd = os.getcwd()
         os.chdir(TEMPDIR)
         try:
-            open("foo", "w").close()
-
             tar = tarfile.open(tmpname, self.mode)
-            tar.add(".")
-            tar.close()
+            try:
+                tar.add(".")
+            finally:
+                tar.close()
 
             tar = tarfile.open(tmpname, "r")
-            for t in tar:
-                self.assert_(t.name == "." or t.name.startswith("./"))
-            tar.close()
+            try:
+                for t in tar:
+                    self.assert_(t.name == "." or t.name.startswith("./"))
+            finally:
+                tar.close()
         finally:
             os.chdir(cwd)
 
@@ -856,19 +913,18 @@
         tar.close()
 
         if self.mode.endswith("gz"):
-            fobj = gzip.GzipFile(tmpname)
-            data = fobj.read()
-            fobj.close()
+            with gzip.GzipFile(tmpname) as fobj:
+                data = fobj.read()
         elif self.mode.endswith("bz2"):
             dec = bz2.BZ2Decompressor()
-            data = open(tmpname, "rb").read()
+            with open(tmpname, "rb") as fobj:
+                data = fobj.read()
             data = dec.decompress(data)
             self.assertTrue(len(dec.unused_data) == 0,
                     "found trailing data")
         else:
-            fobj = open(tmpname, "rb")
-            data = fobj.read()
-            fobj.close()
+            with open(tmpname, "rb") as fobj:
+                data = fobj.read()
 
         self.assertTrue(data.count(b"\0") == tarfile.RECORDSIZE,
                          "incorrect zero padding")
@@ -923,23 +979,27 @@
             tarinfo.type = tarfile.LNKTYPE
 
         tar = tarfile.open(tmpname, "w")
-        tar.format = tarfile.GNU_FORMAT
-        tar.addfile(tarinfo)
-
-        v1 = self._calc_size(name, link)
-        v2 = tar.offset
-        self.assertTrue(v1 == v2, "GNU longname/longlink creation failed")
+        try:
+            tar.format = tarfile.GNU_FORMAT
+            tar.addfile(tarinfo)
 
-        tar.close()
+            v1 = self._calc_size(name, link)
+            v2 = tar.offset
+            self.assertTrue(v1 == v2, "GNU longname/longlink creation failed")
+        finally:
+            tar.close()
 
         tar = tarfile.open(tmpname)
-        member = tar.next()
-        self.assertIsNotNone(member,
-                "unable to read longname member")
-        self.assertEqual(tarinfo.name, member.name,
-                "unable to read longname member")
-        self.assertEqual(tarinfo.linkname, member.linkname,
-                "unable to read longname member")
+        try:
+            member = tar.next()
+            self.assertIsNotNone(member,
+                    "unable to read longname member")
+            self.assertEqual(tarinfo.name, member.name,
+                    "unable to read longname member")
+            self.assertEqual(tarinfo.linkname, member.linkname,
+                    "unable to read longname member")
+        finally:
+            tar.close()
 
     def test_longname_1023(self):
         self._test(("longnam/" * 127) + "longnam")
@@ -979,9 +1039,8 @@
         self.foo = os.path.join(TEMPDIR, "foo")
         self.bar = os.path.join(TEMPDIR, "bar")
 
-        fobj = open(self.foo, "wb")
-        fobj.write(b"foo")
-        fobj.close()
+        with open(self.foo, "wb") as fobj:
+            fobj.write(b"foo")
 
         os.link(self.foo, self.bar)
 
@@ -990,8 +1049,8 @@
 
     def tearDown(self):
         self.tar.close()
-        os.remove(self.foo)
-        os.remove(self.bar)
+        support.unlink(self.foo)
+        support.unlink(self.bar)
 
     def test_add_twice(self):
         # The same name will be added as a REGTYPE every
@@ -1022,16 +1081,21 @@
             tarinfo.type = tarfile.LNKTYPE
 
         tar = tarfile.open(tmpname, "w", format=tarfile.PAX_FORMAT)
-        tar.addfile(tarinfo)
-        tar.close()
+        try:
+            tar.addfile(tarinfo)
+        finally:
+            tar.close()
 
         tar = tarfile.open(tmpname)
-        if link:
-            l = tar.getmembers()[0].linkname
-            self.assertTrue(link == l, "PAX longlink creation failed")
-        else:
-            n = tar.getmembers()[0].name
-            self.assertTrue(name == n, "PAX longname creation failed")
+        try:
+            if link:
+                l = tar.getmembers()[0].linkname
+                self.assertTrue(link == l, "PAX longlink creation failed")
+            else:
+                n = tar.getmembers()[0].name
+                self.assertTrue(name == n, "PAX longname creation failed")
+        finally:
+            tar.close()
 
     def test_pax_global_header(self):
         pax_headers = {
@@ -1043,23 +1107,27 @@
 
         tar = tarfile.open(tmpname, "w", format=tarfile.PAX_FORMAT,
                 pax_headers=pax_headers)
-        tar.addfile(tarfile.TarInfo("test"))
-        tar.close()
+        try:
+            tar.addfile(tarfile.TarInfo("test"))
+        finally:
+            tar.close()
 
         # Test if the global header was written correctly.
         tar = tarfile.open(tmpname, encoding="iso8859-1")
-        self.assertEqual(tar.pax_headers, pax_headers)
-        self.assertEqual(tar.getmembers()[0].pax_headers, pax_headers)
-
-        # Test if all the fields are strings.
-        for key, val in tar.pax_headers.items():
-            self.assertTrue(type(key) is not bytes)
-            self.assertTrue(type(val) is not bytes)
-            if key in tarfile.PAX_NUMBER_FIELDS:
-                try:
-                    tarfile.PAX_NUMBER_FIELDS[key](val)
-                except (TypeError, ValueError):
-                    self.fail("unable to convert pax header field")
+        try:
+            self.assertEqual(tar.pax_headers, pax_headers)
+            self.assertEqual(tar.getmembers()[0].pax_headers, pax_headers)
+            # Test if all the fields are strings.
+            for key, val in tar.pax_headers.items():
+                self.assertTrue(type(key) is not bytes)
+                self.assertTrue(type(val) is not bytes)
+                if key in tarfile.PAX_NUMBER_FIELDS:
+                    try:
+                        tarfile.PAX_NUMBER_FIELDS[key](val)
+                    except (TypeError, ValueError):
+                        self.fail("unable to convert pax header field")
+        finally:
+            tar.close()
 
     def test_pax_extended_header(self):
         # The fields from the pax header have priority over the
@@ -1067,18 +1135,23 @@
         pax_headers = {"path": "foo", "uid": "123"}
 
         tar = tarfile.open(tmpname, "w", format=tarfile.PAX_FORMAT, encoding="iso8859-1")
-        t = tarfile.TarInfo()
-        t.name = "\xe4\xf6\xfc" # non-ASCII
-        t.uid = 8**8 # too large
-        t.pax_headers = pax_headers
-        tar.addfile(t)
-        tar.close()
+        try:
+            t = tarfile.TarInfo()
+            t.name = "\xe4\xf6\xfc" # non-ASCII
+            t.uid = 8**8 # too large
+            t.pax_headers = pax_headers
+            tar.addfile(t)
+        finally:
+            tar.close()
 
         tar = tarfile.open(tmpname, encoding="iso8859-1")
-        t = tar.getmembers()[0]
-        self.assertEqual(t.pax_headers, pax_headers)
-        self.assertEqual(t.name, "foo")
-        self.assertEqual(t.uid, 123)
+        try:
+            t = tar.getmembers()[0]
+            self.assertEqual(t.pax_headers, pax_headers)
+            self.assertEqual(t.name, "foo")
+            self.assertEqual(t.uid, 123)
+        finally:
+            tar.close()
 
 
 class UstarUnicodeTest(unittest.TestCase):
@@ -1096,13 +1169,17 @@
 
     def _test_unicode_filename(self, encoding):
         tar = tarfile.open(tmpname, "w", format=self.format, encoding=encoding, errors="strict")
-        name = "\xe4\xf6\xfc"
-        tar.addfile(tarfile.TarInfo(name))
-        tar.close()
+        try:
+            name = "\xe4\xf6\xfc"
+            tar.addfile(tarfile.TarInfo(name))
+        finally:
+            tar.close()
 
         tar = tarfile.open(tmpname, encoding=encoding)
-        self.assertEqual(tar.getmembers()[0].name, name)
-        tar.close()
+        try:
+            self.assertEqual(tar.getmembers()[0].name, name)
+        finally:
+            tar.close()
 
     def test_unicode_filename_error(self):
         if self.format == tarfile.PAX_FORMAT:
@@ -1110,23 +1187,28 @@
             return
 
         tar = tarfile.open(tmpname, "w", format=self.format, encoding="ascii", errors="strict")
-        tarinfo = tarfile.TarInfo()
+        try:
+            tarinfo = tarfile.TarInfo()
 
-        tarinfo.name = "\xe4\xf6\xfc"
-        self.assertRaises(UnicodeError, tar.addfile, tarinfo)
+            tarinfo.name = "\xe4\xf6\xfc"
+            self.assertRaises(UnicodeError, tar.addfile, tarinfo)
 
-        tarinfo.name = "foo"
-        tarinfo.uname = "\xe4\xf6\xfc"
-        self.assertRaises(UnicodeError, tar.addfile, tarinfo)
+            tarinfo.name = "foo"
+            tarinfo.uname = "\xe4\xf6\xfc"
+            self.assertRaises(UnicodeError, tar.addfile, tarinfo)
+        finally:
+            tar.close()
 
     def test_unicode_argument(self):
         tar = tarfile.open(tarname, "r", encoding="iso8859-1", errors="strict")
-        for t in tar:
-            self.assertTrue(type(t.name) is str)
-            self.assertTrue(type(t.linkname) is str)
-            self.assertTrue(type(t.uname) is str)
-            self.assertTrue(type(t.gname) is str)
-        tar.close()
+        try:
+            for t in tar:
+                self.assertTrue(type(t.name) is str)
+                self.assertTrue(type(t.linkname) is str)
+                self.assertTrue(type(t.uname) is str)
+                self.assertTrue(type(t.gname) is str)
+        finally:
+            tar.close()
 
     def test_uname_unicode(self):
         t = tarfile.TarInfo("foo")
@@ -1134,19 +1216,24 @@
         t.gname = "\xe4\xf6\xfc"
 
         tar = tarfile.open(tmpname, mode="w", format=self.format, encoding="iso8859-1")
-        tar.addfile(t)
-        tar.close()
+        try:
+            tar.addfile(t)
+        finally:
+            tar.close()
 
         tar = tarfile.open(tmpname, encoding="iso8859-1")
-        t = tar.getmember("foo")
-        self.assertEqual(t.uname, "\xe4\xf6\xfc")
-        self.assertEqual(t.gname, "\xe4\xf6\xfc")
-
-        if self.format != tarfile.PAX_FORMAT:
-            tar = tarfile.open(tmpname, encoding="ascii")
+        try:
             t = tar.getmember("foo")
-            self.assertEqual(t.uname, "\udce4\udcf6\udcfc")
-            self.assertEqual(t.gname, "\udce4\udcf6\udcfc")
+            self.assertEqual(t.uname, "\xe4\xf6\xfc")
+            self.assertEqual(t.gname, "\xe4\xf6\xfc")
+
+            if self.format != tarfile.PAX_FORMAT:
+                tar = tarfile.open(tmpname, encoding="ascii")
+                t = tar.getmember("foo")
+                self.assertEqual(t.uname, "\udce4\udcf6\udcfc")
+                self.assertEqual(t.gname, "\udce4\udcf6\udcfc")
+        finally:
+            tar.close()
 
 
 class GNUUnicodeTest(UstarUnicodeTest):
@@ -1189,22 +1276,20 @@
             os.remove(self.tarname)
 
     def _add_testfile(self, fileobj=None):
-        tar = tarfile.open(self.tarname, "a", fileobj=fileobj)
-        tar.addfile(tarfile.TarInfo("bar"))
-        tar.close()
+        with tarfile.open(self.tarname, "a", fileobj=fileobj) as tar:
+            tar.addfile(tarfile.TarInfo("bar"))
 
     def _create_testtar(self, mode="w:"):
-        src = tarfile.open(tarname, encoding="iso8859-1")
-        t = src.getmember("ustar/regtype")
-        t.name = "foo"
-        f = src.extractfile(t)
-        tar = tarfile.open(self.tarname, mode)
-        tar.addfile(t, f)
-        tar.close()
+        with tarfile.open(tarname, encoding="iso8859-1") as src:
+            t = src.getmember("ustar/regtype")
+            t.name = "foo"
+            f = src.extractfile(t)
+            with tarfile.open(self.tarname, mode) as tar:
+                tar.addfile(t, f)
 
     def _test(self, names=["bar"], fileobj=None):
-        tar = tarfile.open(self.tarname, fileobj=fileobj)
-        self.assertEqual(tar.getnames(), names)
+        with tarfile.open(self.tarname, fileobj=fileobj) as tar:
+            self.assertEqual(tar.getnames(), names)
 
     def test_non_existing(self):
         self._add_testfile()
@@ -1223,7 +1308,8 @@
 
     def test_fileobj(self):
         self._create_testtar()
-        data = open(self.tarname, "rb").read()
+        with open(self.tarname, "rb") as fobj:
+            data = fobj.read()
         fobj = io.BytesIO(data)
         self._add_testfile(fobj)
         fobj.seek(0)
@@ -1249,7 +1335,8 @@
     # Append mode is supposed to fail if the tarfile to append to
     # does not end with a zero block.
     def _test_error(self, data):
-        open(self.tarname, "wb").write(data)
+        with open(self.tarname, "wb") as fobj:
+            fobj.write(data)
         self.assertRaises(tarfile.ReadError, self._add_testfile)
 
     def test_null(self):
@@ -1390,15 +1477,14 @@
     def test_fileobj(self):
         # Test that __exit__() did not close the external file
         # object.
-        fobj = open(tmpname, "wb")
-        try:
-            with tarfile.open(fileobj=fobj, mode="w") as tar:
-                raise Exception
-        except:
-            pass
-        self.assertFalse(fobj.closed, "external file object was closed")
-        self.assertTrue(tar.closed, "context manager failed")
-        fobj.close()
+        with open(tmpname, "wb") as fobj:
+            try:
+                with tarfile.open(fileobj=fobj, mode="w") as tar:
+                    raise Exception
+            except:
+                pass
+            self.assertFalse(fobj.closed, "external file object was closed")
+            self.assertTrue(tar.closed, "context manager failed")
 
 
 class LinkEmulationTest(ReadTest):
@@ -1495,6 +1581,7 @@
 
 
 def test_main():
+    support.unlink(TEMPDIR)
     os.makedirs(TEMPDIR)
 
     tests = [
@@ -1523,15 +1610,14 @@
     else:
         tests.append(LinkEmulationTest)
 
-    fobj = open(tarname, "rb")
-    data = fobj.read()
-    fobj.close()
+    with open(tarname, "rb") as fobj:
+        data = fobj.read()
 
     if gzip:
         # Create testtar.tar.gz and add gzip-specific tests.
-        tar = gzip.open(gzipname, "wb")
-        tar.write(data)
-        tar.close()
+        support.unlink(gzipname)
+        with gzip.open(gzipname, "wb") as tar:
+            tar.write(data)
 
         tests += [
             GzipMiscReadTest,
@@ -1543,9 +1629,12 @@
 
     if bz2:
         # Create testtar.tar.bz2 and add bz2-specific tests.
+        support.unlink(bz2name)
         tar = bz2.BZ2File(bz2name, "wb")
-        tar.write(data)
-        tar.close()
+        try:
+            tar.write(data)
+        finally:
+            tar.close()
 
         tests += [
             Bz2MiscReadTest,


More information about the Python-checkins mailing list