[Python-checkins] bpo-45046: Support context managers in unittest (GH-28045)

serhiy-storchaka webhook-mailer at python.org
Sun May 8 10:49:21 EDT 2022


https://github.com/python/cpython/commit/086c6b1b0fe8d47ebd15512d7bdcb64c60a360f0
commit: 086c6b1b0fe8d47ebd15512d7bdcb64c60a360f0
branch: main
author: Serhiy Storchaka <storchaka at gmail.com>
committer: serhiy-storchaka <storchaka at gmail.com>
date: 2022-05-08T17:49:09+03:00
summary:

bpo-45046: Support context managers in unittest (GH-28045)

Add methods enterContext() and enterClassContext() in TestCase.
Add method enterAsyncContext() in IsolatedAsyncioTestCase.
Add function enterModuleContext().

files:
A Misc/NEWS.d/next/Library/2021-08-29-19-59-16.bpo-45046.eGq0NC.rst
M Doc/library/unittest.rst
M Doc/whatsnew/3.11.rst
M Lib/distutils/tests/test_build_ext.py
M Lib/test/test__osx_support.py
M Lib/test/test_argparse.py
M Lib/test/test_getopt.py
M Lib/test/test_gettext.py
M Lib/test/test_global.py
M Lib/test/test_importlib/source/test_finder.py
M Lib/test/test_importlib/test_namespace_pkgs.py
M Lib/test/test_logging.py
M Lib/test/test_nntplib.py
M Lib/test/test_peg_generator/test_c_parser.py
M Lib/test/test_poll.py
M Lib/test/test_posix.py
M Lib/test/test_set.py
M Lib/test/test_socket.py
M Lib/test/test_ssl.py
M Lib/test/test_tempfile.py
M Lib/test/test_urllib.py
M Lib/unittest/__init__.py
M Lib/unittest/async_case.py
M Lib/unittest/case.py
M Lib/unittest/test/test_async_case.py
M Lib/unittest/test/test_runner.py

diff --git a/Doc/library/unittest.rst b/Doc/library/unittest.rst
index 9b8b75acce514..f6bcba06d90ee 100644
--- a/Doc/library/unittest.rst
+++ b/Doc/library/unittest.rst
@@ -1495,6 +1495,16 @@ Test cases
       .. versionadded:: 3.1
 
 
+   .. method:: enterContext(cm)
+
+      Enter the supplied :term:`context manager`.  If successful, also
+      add its :meth:`~object.__exit__` method as a cleanup function by
+      :meth:`addCleanup` and return the result of the
+      :meth:`~object.__enter__` method.
+
+      .. versionadded:: 3.11
+
+
    .. method:: doCleanups()
 
       This method is called unconditionally after :meth:`tearDown`, or
@@ -1510,6 +1520,7 @@ Test cases
 
       .. versionadded:: 3.1
 
+
    .. classmethod:: addClassCleanup(function, /, *args, **kwargs)
 
       Add a function to be called after :meth:`tearDownClass` to cleanup
@@ -1524,6 +1535,16 @@ Test cases
       .. versionadded:: 3.8
 
 
+   .. classmethod:: enterClassContext(cm)
+
+      Enter the supplied :term:`context manager`.  If successful, also
+      add its :meth:`~object.__exit__` method as a cleanup function by
+      :meth:`addClassCleanup` and return the result of the
+      :meth:`~object.__enter__` method.
+
+      .. versionadded:: 3.11
+
+
    .. classmethod:: doClassCleanups()
 
       This method is called unconditionally after :meth:`tearDownClass`, or
@@ -1571,6 +1592,16 @@ Test cases
 
       This method accepts a coroutine that can be used as a cleanup function.
 
+   .. coroutinemethod:: enterAsyncContext(cm)
+
+      Enter the supplied :term:`asynchronous context manager`.  If successful,
+      also add its :meth:`~object.__aexit__` method as a cleanup function by
+      :meth:`addAsyncCleanup` and return the result of the
+      :meth:`~object.__aenter__` method.
+
+      .. versionadded:: 3.11
+
+
    .. method:: run(result=None)
 
       Sets up a new event loop to run the test, collecting the result into
@@ -2465,6 +2496,16 @@ To add cleanup code that must be run even in the case of an exception, use
    .. versionadded:: 3.8
 
 
+.. classmethod:: enterModuleContext(cm)
+
+   Enter the supplied :term:`context manager`.  If successful, also
+   add its :meth:`~object.__exit__` method as a cleanup function by
+   :func:`addModuleCleanup` and return the result of the
+   :meth:`~object.__enter__` method.
+
+   .. versionadded:: 3.11
+
+
 .. function:: doModuleCleanups()
 
    This function is called unconditionally after :func:`tearDownModule`, or
@@ -2480,6 +2521,7 @@ To add cleanup code that must be run even in the case of an exception, use
 
    .. versionadded:: 3.8
 
+
 Signal Handling
 ---------------
 
diff --git a/Doc/whatsnew/3.11.rst b/Doc/whatsnew/3.11.rst
index c4e8e6f9a1051..defaeebc7739a 100644
--- a/Doc/whatsnew/3.11.rst
+++ b/Doc/whatsnew/3.11.rst
@@ -758,6 +758,18 @@ unicodedata
 * The Unicode database has been updated to version 14.0.0. (:issue:`45190`).
 
 
+unittest
+--------
+
+* Added methods :meth:`~unittest.TestCase.enterContext` and
+  :meth:`~unittest.TestCase.enterClassContext` of class
+  :class:`~unittest.TestCase`, method
+  :meth:`~unittest.IsolatedAsyncioTestCase.enterAsyncContext` of
+  class :class:`~unittest.IsolatedAsyncioTestCase` and function
+  :func:`unittest.enterModuleContext`.
+  (Contributed by Serhiy Storchaka in :issue:`45046`.)
+
+
 venv
 ----
 
diff --git a/Lib/distutils/tests/test_build_ext.py b/Lib/distutils/tests/test_build_ext.py
index 031897bd2734f..4ebeafecef03c 100644
--- a/Lib/distutils/tests/test_build_ext.py
+++ b/Lib/distutils/tests/test_build_ext.py
@@ -41,9 +41,7 @@ def setUp(self):
         # bpo-30132: On Windows, a .pdb file may be created in the current
         # working directory. Create a temporary working directory to cleanup
         # everything at the end of the test.
-        change_cwd = os_helper.change_cwd(self.tmp_dir)
-        change_cwd.__enter__()
-        self.addCleanup(change_cwd.__exit__, None, None, None)
+        self.enterContext(os_helper.change_cwd(self.tmp_dir))
 
     def tearDown(self):
         import site
diff --git a/Lib/test/test__osx_support.py b/Lib/test/test__osx_support.py
index 907ae27d529b5..4a14cb352138e 100644
--- a/Lib/test/test__osx_support.py
+++ b/Lib/test/test__osx_support.py
@@ -19,8 +19,7 @@ def setUp(self):
         self.maxDiff = None
         self.prog_name = 'bogus_program_xxxx'
         self.temp_path_dir = os.path.abspath(os.getcwd())
-        self.env = os_helper.EnvironmentVarGuard()
-        self.addCleanup(self.env.__exit__)
+        self.env = self.enterContext(os_helper.EnvironmentVarGuard())
         for cv in ('CFLAGS', 'LDFLAGS', 'CPPFLAGS',
                             'BASECFLAGS', 'BLDSHARED', 'LDSHARED', 'CC',
                             'CXX', 'PY_CFLAGS', 'PY_LDFLAGS', 'PY_CPPFLAGS',
diff --git a/Lib/test/test_argparse.py b/Lib/test/test_argparse.py
index 8509deb93f1e2..273db45c00f7a 100644
--- a/Lib/test/test_argparse.py
+++ b/Lib/test/test_argparse.py
@@ -41,9 +41,8 @@ def setUp(self):
         # The tests assume that line wrapping occurs at 80 columns, but this
         # behaviour can be overridden by setting the COLUMNS environment
         # variable.  To ensure that this width is used, set COLUMNS to 80.
-        env = os_helper.EnvironmentVarGuard()
+        env = self.enterContext(os_helper.EnvironmentVarGuard())
         env['COLUMNS'] = '80'
-        self.addCleanup(env.__exit__)
 
 
 class TempDirMixin(object):
@@ -3428,9 +3427,8 @@ class TestShortColumns(HelpTestCase):
     but we don't want any exceptions thrown in such cases. Only ugly representation.
     '''
     def setUp(self):
-        env = os_helper.EnvironmentVarGuard()
+        env = self.enterContext(os_helper.EnvironmentVarGuard())
         env.set("COLUMNS", '15')
-        self.addCleanup(env.__exit__)
 
     parser_signature            = TestHelpBiggerOptionals.parser_signature
     argument_signatures         = TestHelpBiggerOptionals.argument_signatures
diff --git a/Lib/test/test_getopt.py b/Lib/test/test_getopt.py
index 9261276ebb972..64b9ce01e05ea 100644
--- a/Lib/test/test_getopt.py
+++ b/Lib/test/test_getopt.py
@@ -11,14 +11,10 @@
 
 class GetoptTests(unittest.TestCase):
     def setUp(self):
-        self.env = EnvironmentVarGuard()
+        self.env = self.enterContext(EnvironmentVarGuard())
         if "POSIXLY_CORRECT" in self.env:
             del self.env["POSIXLY_CORRECT"]
 
-    def tearDown(self):
-        self.env.__exit__()
-        del self.env
-
     def assertError(self, *args, **kwargs):
         self.assertRaises(getopt.GetoptError, *args, **kwargs)
 
diff --git a/Lib/test/test_gettext.py b/Lib/test/test_gettext.py
index 467652a41f0cd..1608d1b18e98f 100644
--- a/Lib/test/test_gettext.py
+++ b/Lib/test/test_gettext.py
@@ -117,6 +117,7 @@
 
 class GettextBaseTest(unittest.TestCase):
     def setUp(self):
+        self.addCleanup(os_helper.rmtree, os.path.split(LOCALEDIR)[0])
         if not os.path.isdir(LOCALEDIR):
             os.makedirs(LOCALEDIR)
         with open(MOFILE, 'wb') as fp:
@@ -129,14 +130,10 @@ def setUp(self):
             fp.write(base64.decodebytes(UMO_DATA))
         with open(MMOFILE, 'wb') as fp:
             fp.write(base64.decodebytes(MMO_DATA))
-        self.env = os_helper.EnvironmentVarGuard()
+        self.env = self.enterContext(os_helper.EnvironmentVarGuard())
         self.env['LANGUAGE'] = 'xx'
         gettext._translations.clear()
 
-    def tearDown(self):
-        self.env.__exit__()
-        del self.env
-        os_helper.rmtree(os.path.split(LOCALEDIR)[0])
 
 GNU_MO_DATA_ISSUE_17898 = b'''\
 3hIElQAAAAABAAAAHAAAACQAAAAAAAAAAAAAAAAAAAAsAAAAggAAAC0AAAAAUGx1cmFsLUZvcm1z
diff --git a/Lib/test/test_global.py b/Lib/test/test_global.py
index d0bde3fd040e6..f5b38c25ea072 100644
--- a/Lib/test/test_global.py
+++ b/Lib/test/test_global.py
@@ -9,14 +9,9 @@
 class GlobalTests(unittest.TestCase):
 
     def setUp(self):
-        self._warnings_manager = check_warnings()
-        self._warnings_manager.__enter__()
+        self.enterContext(check_warnings())
         warnings.filterwarnings("error", module="<test string>")
 
-    def tearDown(self):
-        self._warnings_manager.__exit__(None, None, None)
-
-
     def test1(self):
         prog_text_1 = """\
 def wrong1():
@@ -54,9 +49,7 @@ def test4(self):
 
 
 def setUpModule():
-    cm = warnings.catch_warnings()
-    cm.__enter__()
-    unittest.addModuleCleanup(cm.__exit__, None, None, None)
+    unittest.enterModuleContext(warnings.catch_warnings())
     warnings.filterwarnings("error", module="<test string>")
 
 
diff --git a/Lib/test/test_importlib/source/test_finder.py b/Lib/test/test_importlib/source/test_finder.py
index 6a23e9d50f6ff..bed9d56dca84e 100644
--- a/Lib/test/test_importlib/source/test_finder.py
+++ b/Lib/test/test_importlib/source/test_finder.py
@@ -157,21 +157,12 @@ def test_dir_removal_handling(self):
     def test_no_read_directory(self):
         # Issue #16730
         tempdir = tempfile.TemporaryDirectory()
+        self.enterContext(tempdir)
+        # Since we muck with the permissions, we want to set them back to
+        # their original values to make sure the directory can be properly
+        # cleaned up.
         original_mode = os.stat(tempdir.name).st_mode
-        def cleanup(tempdir):
-            """Cleanup function for the temporary directory.
-
-            Since we muck with the permissions, we want to set them back to
-            their original values to make sure the directory can be properly
-            cleaned up.
-
-            """
-            os.chmod(tempdir.name, original_mode)
-            # If this is not explicitly called then the __del__ method is used,
-            # but since already mucking around might as well explicitly clean
-            # up.
-            tempdir.__exit__(None, None, None)
-        self.addCleanup(cleanup, tempdir)
+        self.addCleanup(os.chmod, tempdir.name, original_mode)
         os.chmod(tempdir.name, stat.S_IWUSR | stat.S_IXUSR)
         finder = self.get_finder(tempdir.name)
         found = self._find(finder, 'doesnotexist')
diff --git a/Lib/test/test_importlib/test_namespace_pkgs.py b/Lib/test/test_importlib/test_namespace_pkgs.py
index 2ea41b7a4c5c3..cd08498545e80 100644
--- a/Lib/test/test_importlib/test_namespace_pkgs.py
+++ b/Lib/test/test_importlib/test_namespace_pkgs.py
@@ -65,12 +65,7 @@ def setUp(self):
         self.resolved_paths = [
             os.path.join(self.root, path) for path in self.paths
         ]
-        self.ctx = namespace_tree_context(path=self.resolved_paths)
-        self.ctx.__enter__()
-
-    def tearDown(self):
-        # TODO: will we ever want to pass exc_info to __exit__?
-        self.ctx.__exit__(None, None, None)
+        self.enterContext(namespace_tree_context(path=self.resolved_paths))
 
 
 class SingleNamespacePackage(NamespacePackageTest):
diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py
index 5d4ddedd059fc..e69afae484aa7 100644
--- a/Lib/test/test_logging.py
+++ b/Lib/test/test_logging.py
@@ -5650,9 +5650,7 @@ def test__all__(self):
 # why the test does this, but in any case we save the current locale
 # first and restore it at the end.
 def setUpModule():
-    cm = support.run_with_locale('LC_ALL', '')
-    cm.__enter__()
-    unittest.addModuleCleanup(cm.__exit__, None, None, None)
+    unittest.enterModuleContext(support.run_with_locale('LC_ALL', ''))
 
 
 if __name__ == "__main__":
diff --git a/Lib/test/test_nntplib.py b/Lib/test/test_nntplib.py
index 9812c05519351..31a02f86abb00 100644
--- a/Lib/test/test_nntplib.py
+++ b/Lib/test/test_nntplib.py
@@ -1593,8 +1593,7 @@ def setUp(self):
         self.background.start()
         self.addCleanup(self.background.join)
 
-        self.nntp = NNTP(socket_helper.HOST, port, usenetrc=False).__enter__()
-        self.addCleanup(self.nntp.__exit__, None, None, None)
+        self.nntp = self.enterContext(NNTP(socket_helper.HOST, port, usenetrc=False))
 
     def run_server(self, sock):
         # Could be generalized to handle more commands in separate methods
diff --git a/Lib/test/test_peg_generator/test_c_parser.py b/Lib/test/test_peg_generator/test_c_parser.py
index 13b83a9db9eb3..d25bc112cfdc4 100644
--- a/Lib/test/test_peg_generator/test_c_parser.py
+++ b/Lib/test/test_peg_generator/test_c_parser.py
@@ -96,9 +96,7 @@ def setUp(self):
             self.skipTest("The %r command is not found" % cmd)
         self.old_cwd = os.getcwd()
         self.tmp_path = tempfile.mkdtemp(dir=self.tmp_base)
-        change_cwd = os_helper.change_cwd(self.tmp_path)
-        change_cwd.__enter__()
-        self.addCleanup(change_cwd.__exit__, None, None, None)
+        self.enterContext(os_helper.change_cwd(self.tmp_path))
 
     def tearDown(self):
         os.chdir(self.old_cwd)
diff --git a/Lib/test/test_poll.py b/Lib/test/test_poll.py
index 7d542b5cfd783..02165a0244ddf 100644
--- a/Lib/test/test_poll.py
+++ b/Lib/test/test_poll.py
@@ -128,8 +128,7 @@ def test_poll2(self):
         cmd = 'for i in 0 1 2 3 4 5 6 7 8 9; do echo testing...; sleep 1; done'
         proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE,
                                 bufsize=0)
-        proc.__enter__()
-        self.addCleanup(proc.__exit__, None, None, None)
+        self.enterContext(proc)
         p = proc.stdout
         pollster = select.poll()
         pollster.register( p, select.POLLIN )
diff --git a/Lib/test/test_posix.py b/Lib/test/test_posix.py
index f44b8d0403ff2..28e5e90297e24 100644
--- a/Lib/test/test_posix.py
+++ b/Lib/test/test_posix.py
@@ -53,19 +53,13 @@ class PosixTester(unittest.TestCase):
 
     def setUp(self):
         # create empty file
+        self.addCleanup(os_helper.unlink, os_helper.TESTFN)
         with open(os_helper.TESTFN, "wb"):
             pass
-        self.teardown_files = [ os_helper.TESTFN ]
-        self._warnings_manager = warnings_helper.check_warnings()
-        self._warnings_manager.__enter__()
+        self.enterContext(warnings_helper.check_warnings())
         warnings.filterwarnings('ignore', '.* potential security risk .*',
                                 RuntimeWarning)
 
-    def tearDown(self):
-        for teardown_file in self.teardown_files:
-            os_helper.unlink(teardown_file)
-        self._warnings_manager.__exit__(None, None, None)
-
     def testNoArgFunctions(self):
         # test posix functions which take no arguments and have
         # no side-effects which we need to cleanup (e.g., fork, wait, abort)
@@ -973,8 +967,8 @@ def test_lchflags_symlink(self):
 
         self.assertTrue(hasattr(testfn_st, 'st_flags'))
 
+        self.addCleanup(os_helper.unlink, _DUMMY_SYMLINK)
         os.symlink(os_helper.TESTFN, _DUMMY_SYMLINK)
-        self.teardown_files.append(_DUMMY_SYMLINK)
         dummy_symlink_st = os.lstat(_DUMMY_SYMLINK)
 
         def chflags_nofollow(path, flags):
diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py
index 3b57517a86101..43f23dbbf9bf7 100644
--- a/Lib/test/test_set.py
+++ b/Lib/test/test_set.py
@@ -1022,8 +1022,7 @@ def test_repr(self):
 
 class TestBasicOpsMixedStringBytes(TestBasicOps, unittest.TestCase):
     def setUp(self):
-        self._warning_filters = warnings_helper.check_warnings()
-        self._warning_filters.__enter__()
+        self.enterContext(warnings_helper.check_warnings())
         warnings.simplefilter('ignore', BytesWarning)
         self.case   = "string and bytes set"
         self.values = ["a", "b", b"a", b"b"]
@@ -1031,9 +1030,6 @@ def setUp(self):
         self.dup    = set(self.values)
         self.length = 4
 
-    def tearDown(self):
-        self._warning_filters.__exit__(None, None, None)
-
     def test_repr(self):
         self.check_repr_against_values()
 
diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py
index 613363722cf02..1aaa9e44f90c6 100755
--- a/Lib/test/test_socket.py
+++ b/Lib/test/test_socket.py
@@ -338,9 +338,7 @@ def serverExplicitReady(self):
         self.server_ready.set()
 
     def _setUp(self):
-        self.wait_threads = threading_helper.wait_threads_exit()
-        self.wait_threads.__enter__()
-        self.addCleanup(self.wait_threads.__exit__, None, None, None)
+        self.enterContext(threading_helper.wait_threads_exit())
 
         self.server_ready = threading.Event()
         self.client_ready = threading.Event()
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index 0eb8d18b3561e..fed76378726c9 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -1999,9 +1999,8 @@ def setUp(self):
         self.server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
         self.server_context.load_cert_chain(SIGNED_CERTFILE)
         server = ThreadedEchoServer(context=self.server_context)
+        self.enterContext(server)
         self.server_addr = (HOST, server.port)
-        server.__enter__()
-        self.addCleanup(server.__exit__, None, None, None)
 
     def test_connect(self):
         with test_wrap_socket(socket.socket(socket.AF_INET),
@@ -3713,8 +3712,7 @@ def _recvfrom_into():
 
     def test_recv_zero(self):
         server = ThreadedEchoServer(CERTFILE)
-        server.__enter__()
-        self.addCleanup(server.__exit__, None, None)
+        self.enterContext(server)
         s = socket.create_connection((HOST, server.port))
         self.addCleanup(s.close)
         s = test_wrap_socket(s, suppress_ragged_eofs=False)
diff --git a/Lib/test/test_tempfile.py b/Lib/test/test_tempfile.py
index a05f3c84ccfc9..f056e5ccb17f9 100644
--- a/Lib/test/test_tempfile.py
+++ b/Lib/test/test_tempfile.py
@@ -90,14 +90,10 @@ class BaseTestCase(unittest.TestCase):
     b_check = re.compile(br"^[a-z0-9_-]{8}$")
 
     def setUp(self):
-        self._warnings_manager = warnings_helper.check_warnings()
-        self._warnings_manager.__enter__()
+        self.enterContext(warnings_helper.check_warnings())
         warnings.filterwarnings("ignore", category=RuntimeWarning,
                                 message="mktemp", module=__name__)
 
-    def tearDown(self):
-        self._warnings_manager.__exit__(None, None, None)
-
     def nameCheck(self, name, dir, pre, suf):
         (ndir, nbase) = os.path.split(name)
         npre  = nbase[:len(pre)]
diff --git a/Lib/test/test_urllib.py b/Lib/test/test_urllib.py
index 82f1d9dc2e7bb..bc6e74c291ac1 100644
--- a/Lib/test/test_urllib.py
+++ b/Lib/test/test_urllib.py
@@ -232,17 +232,12 @@ class ProxyTests(unittest.TestCase):
 
     def setUp(self):
         # Records changes to env vars
-        self.env = os_helper.EnvironmentVarGuard()
+        self.env = self.enterContext(os_helper.EnvironmentVarGuard())
         # Delete all proxy related env vars
         for k in list(os.environ):
             if 'proxy' in k.lower():
                 self.env.unset(k)
 
-    def tearDown(self):
-        # Restore all proxy related env vars
-        self.env.__exit__()
-        del self.env
-
     def test_getproxies_environment_keep_no_proxies(self):
         self.env.set('NO_PROXY', 'localhost')
         proxies = urllib.request.getproxies_environment()
diff --git a/Lib/unittest/__init__.py b/Lib/unittest/__init__.py
index eda951ce73e66..005d23f6d00ec 100644
--- a/Lib/unittest/__init__.py
+++ b/Lib/unittest/__init__.py
@@ -49,7 +49,7 @@ def testMultiply(self):
            'defaultTestLoader', 'SkipTest', 'skip', 'skipIf', 'skipUnless',
            'expectedFailure', 'TextTestResult', 'installHandler',
            'registerResult', 'removeResult', 'removeHandler',
-           'addModuleCleanup', 'doModuleCleanups']
+           'addModuleCleanup', 'doModuleCleanups', 'enterModuleContext']
 
 # Expose obsolete functions for backwards compatibility
 # bpo-5846: Deprecated in Python 3.11, scheduled for removal in Python 3.13.
@@ -59,7 +59,8 @@ def testMultiply(self):
 
 from .result import TestResult
 from .case import (addModuleCleanup, TestCase, FunctionTestCase, SkipTest, skip,
-                   skipIf, skipUnless, expectedFailure, doModuleCleanups)
+                   skipIf, skipUnless, expectedFailure, doModuleCleanups,
+                   enterModuleContext)
 from .suite import BaseTestSuite, TestSuite
 from .loader import TestLoader, defaultTestLoader
 from .main import TestProgram, main
diff --git a/Lib/unittest/async_case.py b/Lib/unittest/async_case.py
index 85b938fb293af..a90eed98f8714 100644
--- a/Lib/unittest/async_case.py
+++ b/Lib/unittest/async_case.py
@@ -58,6 +58,26 @@ def addAsyncCleanup(self, func, /, *args, **kwargs):
         # 3. Regular "def func()" that returns awaitable object
         self.addCleanup(*(func, *args), **kwargs)
 
+    async def enterAsyncContext(self, cm):
+        """Enters the supplied asynchronous context manager.
+
+        If successful, also adds its __aexit__ method as a cleanup
+        function and returns the result of the __aenter__ method.
+        """
+        # We look up the special methods on the type to match the with
+        # statement.
+        cls = type(cm)
+        try:
+            enter = cls.__aenter__
+            exit = cls.__aexit__
+        except AttributeError:
+            raise TypeError(f"'{cls.__module__}.{cls.__qualname__}' object does "
+                            f"not support the asynchronous context manager protocol"
+                           ) from None
+        result = await enter(cm)
+        self.addAsyncCleanup(exit, cm, None, None, None)
+        return result
+
     def _callSetUp(self):
         self._asyncioTestContext.run(self.setUp)
         self._callAsync(self.asyncSetUp)
diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py
index 55770c06d7c5d..ffc8f19ddd38d 100644
--- a/Lib/unittest/case.py
+++ b/Lib/unittest/case.py
@@ -102,12 +102,31 @@ def _id(obj):
     return obj
 
 
+def _enter_context(cm, addcleanup):
+    # We look up the special methods on the type to match the with
+    # statement.
+    cls = type(cm)
+    try:
+        enter = cls.__enter__
+        exit = cls.__exit__
+    except AttributeError:
+        raise TypeError(f"'{cls.__module__}.{cls.__qualname__}' object does "
+                        f"not support the context manager protocol") from None
+    result = enter(cm)
+    addcleanup(exit, cm, None, None, None)
+    return result
+
+
 _module_cleanups = []
 def addModuleCleanup(function, /, *args, **kwargs):
     """Same as addCleanup, except the cleanup items are called even if
     setUpModule fails (unlike tearDownModule)."""
     _module_cleanups.append((function, args, kwargs))
 
+def enterModuleContext(cm):
+    """Same as enterContext, but module-wide."""
+    return _enter_context(cm, addModuleCleanup)
+
 
 def doModuleCleanups():
     """Execute all module cleanup functions. Normally called for you after
@@ -426,12 +445,25 @@ def addCleanup(self, function, /, *args, **kwargs):
         Cleanup items are called even if setUp fails (unlike tearDown)."""
         self._cleanups.append((function, args, kwargs))
 
+    def enterContext(self, cm):
+        """Enters the supplied context manager.
+
+        If successful, also adds its __exit__ method as a cleanup
+        function and returns the result of the __enter__ method.
+        """
+        return _enter_context(cm, self.addCleanup)
+
     @classmethod
     def addClassCleanup(cls, function, /, *args, **kwargs):
         """Same as addCleanup, except the cleanup items are called even if
         setUpClass fails (unlike tearDownClass)."""
         cls._class_cleanups.append((function, args, kwargs))
 
+    @classmethod
+    def enterClassContext(cls, cm):
+        """Same as enterContext, but class-wide."""
+        return _enter_context(cm, cls.addClassCleanup)
+
     def setUp(self):
         "Hook method for setting up the test fixture before exercising it."
         pass
diff --git a/Lib/unittest/test/test_async_case.py b/Lib/unittest/test/test_async_case.py
index 1b910a44eea0d..beadcac070b43 100644
--- a/Lib/unittest/test/test_async_case.py
+++ b/Lib/unittest/test/test_async_case.py
@@ -14,6 +14,29 @@ def tearDownModule():
     asyncio.set_event_loop_policy(None)
 
 
+class TestCM:
+    def __init__(self, ordering, enter_result=None):
+        self.ordering = ordering
+        self.enter_result = enter_result
+
+    async def __aenter__(self):
+        self.ordering.append('enter')
+        return self.enter_result
+
+    async def __aexit__(self, *exc_info):
+        self.ordering.append('exit')
+
+
+class LacksEnterAndExit:
+    pass
+class LacksEnter:
+    async def __aexit__(self, *exc_info):
+        pass
+class LacksExit:
+    async def __aenter__(self):
+        pass
+
+
 VAR = contextvars.ContextVar('VAR', default=())
 
 
@@ -337,6 +360,36 @@ async def coro():
         output = test.run()
         self.assertTrue(cancelled)
 
+    def test_enterAsyncContext(self):
+        events = []
+
+        class Test(unittest.IsolatedAsyncioTestCase):
+            async def test_func(slf):
+                slf.addAsyncCleanup(events.append, 'cleanup1')
+                cm = TestCM(events, 42)
+                self.assertEqual(await slf.enterAsyncContext(cm), 42)
+                slf.addAsyncCleanup(events.append, 'cleanup2')
+                events.append('test')
+
+        test = Test('test_func')
+        output = test.run()
+        self.assertTrue(output.wasSuccessful(), output)
+        self.assertEqual(events, ['enter', 'test', 'cleanup2', 'exit', 'cleanup1'])
+
+    def test_enterAsyncContext_arg_errors(self):
+        class Test(unittest.IsolatedAsyncioTestCase):
+            async def test_func(slf):
+                with self.assertRaisesRegex(TypeError, 'asynchronous context manager'):
+                    await slf.enterAsyncContext(LacksEnterAndExit())
+                with self.assertRaisesRegex(TypeError, 'asynchronous context manager'):
+                    await slf.enterAsyncContext(LacksEnter())
+                with self.assertRaisesRegex(TypeError, 'asynchronous context manager'):
+                    await slf.enterAsyncContext(LacksExit())
+
+        test = Test('test_func')
+        output = test.run()
+        self.assertTrue(output.wasSuccessful())
+
     def test_debug_cleanup_same_loop(self):
         class Test(unittest.IsolatedAsyncioTestCase):
             async def asyncSetUp(self):
diff --git a/Lib/unittest/test/test_runner.py b/Lib/unittest/test/test_runner.py
index 18062ae5a5871..d3488b40e82bd 100644
--- a/Lib/unittest/test/test_runner.py
+++ b/Lib/unittest/test/test_runner.py
@@ -46,6 +46,29 @@ def cleanup(ordering, blowUp=False):
         raise Exception('CleanUpExc')
 
 
+class TestCM:
+    def __init__(self, ordering, enter_result=None):
+        self.ordering = ordering
+        self.enter_result = enter_result
+
+    def __enter__(self):
+        self.ordering.append('enter')
+        return self.enter_result
+
+    def __exit__(self, *exc_info):
+        self.ordering.append('exit')
+
+
+class LacksEnterAndExit:
+    pass
+class LacksEnter:
+    def __exit__(self, *exc_info):
+        pass
+class LacksExit:
+    def __enter__(self):
+        pass
+
+
 class TestCleanUp(unittest.TestCase):
     def testCleanUp(self):
         class TestableTest(unittest.TestCase):
@@ -173,6 +196,39 @@ def cleanup2():
         self.assertEqual(ordering, ['setUp', 'test', 'tearDown', 'cleanup1', 'cleanup2'])
 
 
+    def test_enterContext(self):
+        class TestableTest(unittest.TestCase):
+            def testNothing(self):
+                pass
+
+        test = TestableTest('testNothing')
+        cleanups = []
+
+        test.addCleanup(cleanups.append, 'cleanup1')
+        cm = TestCM(cleanups, 42)
+        self.assertEqual(test.enterContext(cm), 42)
+        test.addCleanup(cleanups.append, 'cleanup2')
+
+        self.assertTrue(test.doCleanups())
+        self.assertEqual(cleanups, ['enter', 'cleanup2', 'exit', 'cleanup1'])
+
+    def test_enterContext_arg_errors(self):
+        class TestableTest(unittest.TestCase):
+            def testNothing(self):
+                pass
+
+        test = TestableTest('testNothing')
+
+        with self.assertRaisesRegex(TypeError, 'the context manager'):
+            test.enterContext(LacksEnterAndExit())
+        with self.assertRaisesRegex(TypeError, 'the context manager'):
+            test.enterContext(LacksEnter())
+        with self.assertRaisesRegex(TypeError, 'the context manager'):
+            test.enterContext(LacksExit())
+
+        self.assertEqual(test._cleanups, [])
+
+
 class TestClassCleanup(unittest.TestCase):
     def test_addClassCleanUp(self):
         class TestableTest(unittest.TestCase):
@@ -451,6 +507,35 @@ def tearDownClass(cls):
         self.assertEqual(ordering,
                          ['setUpClass', 'test', 'tearDownClass', 'cleanup_good'])
 
+    def test_enterClassContext(self):
+        class TestableTest(unittest.TestCase):
+            def testNothing(self):
+                pass
+
+        cleanups = []
+
+        TestableTest.addClassCleanup(cleanups.append, 'cleanup1')
+        cm = TestCM(cleanups, 42)
+        self.assertEqual(TestableTest.enterClassContext(cm), 42)
+        TestableTest.addClassCleanup(cleanups.append, 'cleanup2')
+
+        TestableTest.doClassCleanups()
+        self.assertEqual(cleanups, ['enter', 'cleanup2', 'exit', 'cleanup1'])
+
+    def test_enterClassContext_arg_errors(self):
+        class TestableTest(unittest.TestCase):
+            def testNothing(self):
+                pass
+
+        with self.assertRaisesRegex(TypeError, 'the context manager'):
+            TestableTest.enterClassContext(LacksEnterAndExit())
+        with self.assertRaisesRegex(TypeError, 'the context manager'):
+            TestableTest.enterClassContext(LacksEnter())
+        with self.assertRaisesRegex(TypeError, 'the context manager'):
+            TestableTest.enterClassContext(LacksExit())
+
+        self.assertEqual(TestableTest._class_cleanups, [])
+
 
 class TestModuleCleanUp(unittest.TestCase):
     def test_add_and_do_ModuleCleanup(self):
@@ -1000,6 +1085,31 @@ def tearDown(self):
                           'cleanup2',  'setUp2', 'test2', 'tearDown2',
                           'cleanup3', 'tearDownModule', 'cleanup1'])
 
+    def test_enterModuleContext(self):
+        cleanups = []
+
+        unittest.addModuleCleanup(cleanups.append, 'cleanup1')
+        cm = TestCM(cleanups, 42)
+        self.assertEqual(unittest.enterModuleContext(cm), 42)
+        unittest.addModuleCleanup(cleanups.append, 'cleanup2')
+
+        unittest.case.doModuleCleanups()
+        self.assertEqual(cleanups, ['enter', 'cleanup2', 'exit', 'cleanup1'])
+
+    def test_enterModuleContext_arg_errors(self):
+        class TestableTest(unittest.TestCase):
+            def testNothing(self):
+                pass
+
+        with self.assertRaisesRegex(TypeError, 'the context manager'):
+            unittest.enterModuleContext(LacksEnterAndExit())
+        with self.assertRaisesRegex(TypeError, 'the context manager'):
+            unittest.enterModuleContext(LacksEnter())
+        with self.assertRaisesRegex(TypeError, 'the context manager'):
+            unittest.enterModuleContext(LacksExit())
+
+        self.assertEqual(unittest.case._module_cleanups, [])
+
 
 class Test_TextTestRunner(unittest.TestCase):
     """Tests for TextTestRunner."""
diff --git a/Misc/NEWS.d/next/Library/2021-08-29-19-59-16.bpo-45046.eGq0NC.rst b/Misc/NEWS.d/next/Library/2021-08-29-19-59-16.bpo-45046.eGq0NC.rst
new file mode 100644
index 0000000000000..8072afaf445c5
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2021-08-29-19-59-16.bpo-45046.eGq0NC.rst
@@ -0,0 +1,7 @@
+Add support of context managers in :mod:`unittest`: methods
+:meth:`~unittest.TestCase.enterContext` and
+:meth:`~unittest.TestCase.enterClassContext` of class
+:class:`~unittest.TestCase`, method
+:meth:`~unittest.IsolatedAsyncioTestCase.enterAsyncContext` of class
+:class:`~unittest.IsolatedAsyncioTestCase` and function
+:func:`unittest.enterModuleContext`.



More information about the Python-checkins mailing list