[Python-checkins] bpo-32314: Fix asyncio.run() to cancel runinng tasks on shutdown (#5262)

Yury Selivanov webhook-mailer at python.org
Sun Jan 21 14:57:02 EST 2018


https://github.com/python/cpython/commit/a4afcdfa55ddffa4b9ae3b0cf101628c7bff4102
commit: a4afcdfa55ddffa4b9ae3b0cf101628c7bff4102
branch: master
author: Yury Selivanov <yury at magic.io>
committer: GitHub <noreply at github.com>
date: 2018-01-21T14:56:59-05:00
summary:

bpo-32314: Fix asyncio.run() to cancel runinng tasks on shutdown (#5262)

files:
M Lib/asyncio/base_events.py
M Lib/asyncio/runners.py
M Lib/test/test_asyncio/test_runners.py
M Lib/test/test_asyncio/utils.py

diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py
index a10e706d504..ca9eee765e3 100644
--- a/Lib/asyncio/base_events.py
+++ b/Lib/asyncio/base_events.py
@@ -228,14 +228,9 @@ def __init__(self):
         self._coroutine_origin_tracking_enabled = False
         self._coroutine_origin_tracking_saved_depth = None
 
-        if hasattr(sys, 'get_asyncgen_hooks'):
-            # Python >= 3.6
-            # A weak set of all asynchronous generators that are
-            # being iterated by the loop.
-            self._asyncgens = weakref.WeakSet()
-        else:
-            self._asyncgens = None
-
+        # A weak set of all asynchronous generators that are
+        # being iterated by the loop.
+        self._asyncgens = weakref.WeakSet()
         # Set to True when `loop.shutdown_asyncgens` is called.
         self._asyncgens_shutdown_called = False
 
@@ -354,7 +349,7 @@ def _asyncgen_firstiter_hook(self, agen):
         """Shutdown all active asynchronous generators."""
         self._asyncgens_shutdown_called = True
 
-        if self._asyncgens is None or not len(self._asyncgens):
+        if not len(self._asyncgens):
             # If Python version is <3.6 or we don't have any asynchronous
             # generators alive.
             return
@@ -386,10 +381,10 @@ def run_forever(self):
                 'Cannot run the event loop while another loop is running')
         self._set_coroutine_origin_tracking(self._debug)
         self._thread_id = threading.get_ident()
-        if self._asyncgens is not None:
-            old_agen_hooks = sys.get_asyncgen_hooks()
-            sys.set_asyncgen_hooks(firstiter=self._asyncgen_firstiter_hook,
-                                   finalizer=self._asyncgen_finalizer_hook)
+
+        old_agen_hooks = sys.get_asyncgen_hooks()
+        sys.set_asyncgen_hooks(firstiter=self._asyncgen_firstiter_hook,
+                               finalizer=self._asyncgen_finalizer_hook)
         try:
             events._set_running_loop(self)
             while True:
@@ -401,8 +396,7 @@ def run_forever(self):
             self._thread_id = None
             events._set_running_loop(None)
             self._set_coroutine_origin_tracking(False)
-            if self._asyncgens is not None:
-                sys.set_asyncgen_hooks(*old_agen_hooks)
+            sys.set_asyncgen_hooks(*old_agen_hooks)
 
     def run_until_complete(self, future):
         """Run until the Future is done.
@@ -1374,6 +1368,7 @@ def call_exception_handler(self, context):
         - 'message': Error message;
         - 'exception' (optional): Exception object;
         - 'future' (optional): Future instance;
+        - 'task' (optional): Task instance;
         - 'handle' (optional): Handle instance;
         - 'protocol' (optional): Protocol instance;
         - 'transport' (optional): Transport instance;
diff --git a/Lib/asyncio/runners.py b/Lib/asyncio/runners.py
index 94d94097ab9..bb54b725278 100644
--- a/Lib/asyncio/runners.py
+++ b/Lib/asyncio/runners.py
@@ -2,6 +2,7 @@
 
 from . import coroutines
 from . import events
+from . import tasks
 
 
 def run(main, *, debug=False):
@@ -42,7 +43,31 @@ def run(main, *, debug=False):
         return loop.run_until_complete(main)
     finally:
         try:
+            _cancel_all_tasks(loop)
             loop.run_until_complete(loop.shutdown_asyncgens())
         finally:
             events.set_event_loop(None)
             loop.close()
+
+
+def _cancel_all_tasks(loop):
+    to_cancel = [task for task in tasks.all_tasks(loop)
+                 if not task.done()]
+    if not to_cancel:
+        return
+
+    for task in to_cancel:
+        task.cancel()
+
+    loop.run_until_complete(
+        tasks.gather(*to_cancel, loop=loop, return_exceptions=True))
+
+    for task in to_cancel:
+        if task.cancelled():
+            continue
+        if task.exception() is not None:
+            loop.call_exception_handler({
+                'message': 'unhandled exception during asyncio.run() shutdown',
+                'exception': task.exception(),
+                'task': task,
+            })
diff --git a/Lib/test/test_asyncio/test_runners.py b/Lib/test/test_asyncio/test_runners.py
index c52bd9443ea..3b58ddee443 100644
--- a/Lib/test/test_asyncio/test_runners.py
+++ b/Lib/test/test_asyncio/test_runners.py
@@ -2,6 +2,7 @@
 import unittest
 
 from unittest import mock
+from . import utils as test_utils
 
 
 class TestPolicy(asyncio.AbstractEventLoopPolicy):
@@ -98,3 +99,81 @@ def test_asyncio_run_from_running_loop(self):
         with self.assertRaisesRegex(RuntimeError,
                                     'cannot be called from a running'):
             asyncio.run(main())
+
+    def test_asyncio_run_cancels_hanging_tasks(self):
+        lo_task = None
+
+        async def leftover():
+            await asyncio.sleep(0.1)
+
+        async def main():
+            nonlocal lo_task
+            lo_task = asyncio.create_task(leftover())
+            return 123
+
+        self.assertEqual(asyncio.run(main()), 123)
+        self.assertTrue(lo_task.done())
+
+    def test_asyncio_run_reports_hanging_tasks_errors(self):
+        lo_task = None
+        call_exc_handler_mock = mock.Mock()
+
+        async def leftover():
+            try:
+                await asyncio.sleep(0.1)
+            except asyncio.CancelledError:
+                1 / 0
+
+        async def main():
+            loop = asyncio.get_running_loop()
+            loop.call_exception_handler = call_exc_handler_mock
+
+            nonlocal lo_task
+            lo_task = asyncio.create_task(leftover())
+            return 123
+
+        self.assertEqual(asyncio.run(main()), 123)
+        self.assertTrue(lo_task.done())
+
+        call_exc_handler_mock.assert_called_with({
+            'message': test_utils.MockPattern(r'asyncio.run.*shutdown'),
+            'task': lo_task,
+            'exception': test_utils.MockInstanceOf(ZeroDivisionError)
+        })
+
+    def test_asyncio_run_closes_gens_after_hanging_tasks_errors(self):
+        spinner = None
+        lazyboy = None
+
+        class FancyExit(Exception):
+            pass
+
+        async def fidget():
+            while True:
+                yield 1
+                await asyncio.sleep(1)
+
+        async def spin():
+            nonlocal spinner
+            spinner = fidget()
+            try:
+                async for the_meaning_of_life in spinner:  # NoQA
+                    pass
+            except asyncio.CancelledError:
+                1 / 0
+
+        async def main():
+            loop = asyncio.get_running_loop()
+            loop.call_exception_handler = mock.Mock()
+
+            nonlocal lazyboy
+            lazyboy = asyncio.create_task(spin())
+            raise FancyExit
+
+        with self.assertRaises(FancyExit):
+            asyncio.run(main())
+
+        self.assertTrue(lazyboy.done())
+
+        self.assertIsNone(spinner.ag_frame)
+        self.assertFalse(spinner.ag_running)
diff --git a/Lib/test/test_asyncio/utils.py b/Lib/test/test_asyncio/utils.py
index 6c809770b4a..f756ec9016f 100644
--- a/Lib/test/test_asyncio/utils.py
+++ b/Lib/test/test_asyncio/utils.py
@@ -485,6 +485,14 @@ def __eq__(self, other):
         return bool(re.search(str(self), other, re.S))
 
 
+class MockInstanceOf:
+    def __init__(self, type):
+        self._type = type
+
+    def __eq__(self, other):
+        return isinstance(other, self._type)
+
+
 def get_function_source(func):
     source = format_helpers._get_function_source(func)
     if source is None:



More information about the Python-checkins mailing list