[Python-checkins] bpo-38857: AsyncMock fix for awaitable values and StopIteration fix [3.8] (GH-17269)

Lisa Roach webhook-mailer at python.org
Wed Nov 20 19:27:57 EST 2019


https://github.com/python/cpython/commit/046442d02bcc6e848e71e93e47f6cde9e279e993
commit: 046442d02bcc6e848e71e93e47f6cde9e279e993
branch: master
author: Jason Fried <fried at fb.com>
committer: Lisa Roach <lisaroach14 at gmail.com>
date: 2019-11-20T16:27:51-08:00
summary:

bpo-38857: AsyncMock fix for awaitable values and StopIteration fix [3.8] (GH-17269)

files:
A Misc/NEWS.d/next/Library/2019-11-19-16-28-25.bpo-38857.YPUkU9.rst
A Misc/NEWS.d/next/Library/2019-11-19-16-30-46.bpo-38859.AZUzL8.rst
M Doc/library/unittest.mock.rst
M Lib/unittest/mock.py
M Lib/unittest/test/testmock/testasync.py

diff --git a/Doc/library/unittest.mock.rst b/Doc/library/unittest.mock.rst
index 7faecff84f63f..e92f5545d3eb0 100644
--- a/Doc/library/unittest.mock.rst
+++ b/Doc/library/unittest.mock.rst
@@ -873,7 +873,7 @@ object::
     exception,
   - if ``side_effect`` is an iterable, the async function will return the
     next value of the iterable, however, if the sequence of result is
-    exhausted, ``StopIteration`` is raised immediately,
+    exhausted, ``StopAsyncIteration`` is raised immediately,
   - if ``side_effect`` is not defined, the async function will return the
     value defined by ``return_value``, hence, by default, the async function
     returns a new :class:`AsyncMock` object.
diff --git a/Lib/unittest/mock.py b/Lib/unittest/mock.py
index a48132c5b1cb5..b06e29cf01c95 100644
--- a/Lib/unittest/mock.py
+++ b/Lib/unittest/mock.py
@@ -1139,8 +1139,8 @@ def _increment_mock_call(self, /, *args, **kwargs):
             _new_parent = _new_parent._mock_new_parent
 
     def _execute_mock_call(self, /, *args, **kwargs):
-        # seperate from _increment_mock_call so that awaited functions are
-        # executed seperately from their call
+        # separate from _increment_mock_call so that awaited functions are
+        # executed separately from their call, also AsyncMock overrides this method
 
         effect = self.side_effect
         if effect is not None:
@@ -2136,29 +2136,45 @@ def __init__(self, /, *args, **kwargs):
         code_mock.co_flags = inspect.CO_COROUTINE
         self.__dict__['__code__'] = code_mock
 
-    async def _mock_call(self, /, *args, **kwargs):
-        try:
-            result = super()._mock_call(*args, **kwargs)
-        except (BaseException, StopIteration) as e:
-            side_effect = self.side_effect
-            if side_effect is not None and not callable(side_effect):
-                raise
-            return await _raise(e)
+    async def _execute_mock_call(self, /, *args, **kwargs):
+        # This is nearly just like super(), except for sepcial handling
+        # of coroutines
 
         _call = self.call_args
+        self.await_count += 1
+        self.await_args = _call
+        self.await_args_list.append(_call)
 
-        async def proxy():
-            try:
-                if inspect.isawaitable(result):
-                    return await result
-                else:
-                    return result
-            finally:
-                self.await_count += 1
-                self.await_args = _call
-                self.await_args_list.append(_call)
+        effect = self.side_effect
+        if effect is not None:
+            if _is_exception(effect):
+                raise effect
+            elif not _callable(effect):
+                try:
+                    result = next(effect)
+                except StopIteration:
+                    # It is impossible to propogate a StopIteration
+                    # through coroutines because of PEP 479
+                    raise StopAsyncIteration
+                if _is_exception(result):
+                    raise result
+            elif asyncio.iscoroutinefunction(effect):
+                result = await effect(*args, **kwargs)
+            else:
+                result = effect(*args, **kwargs)
 
-        return await proxy()
+            if result is not DEFAULT:
+                return result
+
+        if self._mock_return_value is not DEFAULT:
+            return self.return_value
+
+        if self._mock_wraps is not None:
+            if asyncio.iscoroutinefunction(self._mock_wraps):
+                return await self._mock_wraps(*args, **kwargs)
+            return self._mock_wraps(*args, **kwargs)
+
+        return self.return_value
 
     def assert_awaited(self):
         """
@@ -2864,10 +2880,6 @@ def seal(mock):
             seal(m)
 
 
-async def _raise(exception):
-    raise exception
-
-
 class _AsyncIterator:
     """
     Wraps an iterator in an asynchronous iterator.
diff --git a/Lib/unittest/test/testmock/testasync.py b/Lib/unittest/test/testmock/testasync.py
index 0d2cdb0069ff7..149fd4deff102 100644
--- a/Lib/unittest/test/testmock/testasync.py
+++ b/Lib/unittest/test/testmock/testasync.py
@@ -358,42 +358,84 @@ def test_magicmock_lambda_spec(self):
             self.assertIsInstance(cm, MagicMock)
 
 
-class AsyncArguments(unittest.TestCase):
-    def test_add_return_value(self):
+class AsyncArguments(unittest.IsolatedAsyncioTestCase):
+    async def test_add_return_value(self):
         async def addition(self, var):
             return var + 1
 
         mock = AsyncMock(addition, return_value=10)
-        output = asyncio.run(mock(5))
+        output = await mock(5)
 
         self.assertEqual(output, 10)
 
-    def test_add_side_effect_exception(self):
+    async def test_add_side_effect_exception(self):
         async def addition(var):
             return var + 1
         mock = AsyncMock(addition, side_effect=Exception('err'))
         with self.assertRaises(Exception):
-            asyncio.run(mock(5))
+            await mock(5)
 
-    def test_add_side_effect_function(self):
+    async def test_add_side_effect_function(self):
         async def addition(var):
             return var + 1
         mock = AsyncMock(side_effect=addition)
-        result = asyncio.run(mock(5))
+        result = await mock(5)
         self.assertEqual(result, 6)
 
-    def test_add_side_effect_iterable(self):
+    async def test_add_side_effect_iterable(self):
         vals = [1, 2, 3]
         mock = AsyncMock(side_effect=vals)
         for item in vals:
-            self.assertEqual(item, asyncio.run(mock()))
-
-        with self.assertRaises(RuntimeError) as e:
-            asyncio.run(mock())
-            self.assertEqual(
-                e.exception,
-                RuntimeError('coroutine raised StopIteration')
-            )
+            self.assertEqual(item, await mock())
+
+        with self.assertRaises(StopAsyncIteration) as e:
+            await mock()
+
+    async def test_return_value_AsyncMock(self):
+        value = AsyncMock(return_value=10)
+        mock = AsyncMock(return_value=value)
+        result = await mock()
+        self.assertIs(result, value)
+
+    async def test_return_value_awaitable(self):
+        fut = asyncio.Future()
+        fut.set_result(None)
+        mock = AsyncMock(return_value=fut)
+        result = await mock()
+        self.assertIsInstance(result, asyncio.Future)
+
+    async def test_side_effect_awaitable_values(self):
+        fut = asyncio.Future()
+        fut.set_result(None)
+
+        mock = AsyncMock(side_effect=[fut])
+        result = await mock()
+        self.assertIsInstance(result, asyncio.Future)
+
+        with self.assertRaises(StopAsyncIteration):
+            await mock()
+
+    async def test_side_effect_is_AsyncMock(self):
+        effect = AsyncMock(return_value=10)
+        mock = AsyncMock(side_effect=effect)
+
+        result = await mock()
+        self.assertEqual(result, 10)
+
+    async def test_wraps_coroutine(self):
+        value = asyncio.Future()
+
+        ran = False
+        async def inner():
+            nonlocal ran
+            ran = True
+            return value
+
+        mock = AsyncMock(wraps=inner)
+        result = await mock()
+        self.assertEqual(result, value)
+        mock.assert_awaited()
+        self.assertTrue(ran)
 
 class AsyncMagicMethods(unittest.TestCase):
     def test_async_magic_methods_return_async_mocks(self):
diff --git a/Misc/NEWS.d/next/Library/2019-11-19-16-28-25.bpo-38857.YPUkU9.rst b/Misc/NEWS.d/next/Library/2019-11-19-16-28-25.bpo-38857.YPUkU9.rst
new file mode 100644
index 0000000000000..f28df2811fb78
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2019-11-19-16-28-25.bpo-38857.YPUkU9.rst
@@ -0,0 +1,4 @@
+AsyncMock fix for return values that are awaitable types.  This also covers
+side_effect iterable values that happend to be awaitable, and wraps
+callables that return an awaitable type. Before these awaitables were being
+awaited instead of being returned as is.
diff --git a/Misc/NEWS.d/next/Library/2019-11-19-16-30-46.bpo-38859.AZUzL8.rst b/Misc/NEWS.d/next/Library/2019-11-19-16-30-46.bpo-38859.AZUzL8.rst
new file mode 100644
index 0000000000000..c059539a1de60
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2019-11-19-16-30-46.bpo-38859.AZUzL8.rst
@@ -0,0 +1,3 @@
+AsyncMock now returns StopAsyncIteration on the exaustion of a side_effects
+iterable. Since PEP-479 its Impossible to raise a StopIteration exception
+from a coroutine.



More information about the Python-checkins mailing list