[Python-checkins] cpython: asyncio, Tulip issue 126: call_soon(), call_soon_threadsafe(), call_later(),

victor.stinner python-checkins at python.org
Tue Feb 11 11:37:20 CET 2014


http://hg.python.org/cpython/rev/3ba4742a6fde
changeset:   89145:3ba4742a6fde
user:        Victor Stinner <victor.stinner at gmail.com>
date:        Tue Feb 11 11:34:30 2014 +0100
summary:
  asyncio, Tulip issue 126: call_soon(), call_soon_threadsafe(), call_later(),
call_at() and run_in_executor() now raise a TypeError if the callback is a
coroutine function.

files:
  Lib/asyncio/base_events.py                    |   6 +++
  Lib/asyncio/test_utils.py                     |   5 ++-
  Lib/test/test_asyncio/test_base_events.py     |  18 ++++++++++
  Lib/test/test_asyncio/test_proactor_events.py |   2 +-
  Lib/test/test_asyncio/test_selector_events.py |   9 ++--
  Lib/test/test_asyncio/test_tasks.py           |  12 ++---
  6 files changed, 39 insertions(+), 13 deletions(-)


diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py
--- a/Lib/asyncio/base_events.py
+++ b/Lib/asyncio/base_events.py
@@ -227,6 +227,8 @@
 
     def call_at(self, when, callback, *args):
         """Like call_later(), but uses an absolute time."""
+        if tasks.iscoroutinefunction(callback):
+            raise TypeError("coroutines cannot be used with call_at()")
         timer = events.TimerHandle(when, callback, args)
         heapq.heappush(self._scheduled, timer)
         return timer
@@ -241,6 +243,8 @@
         Any positional arguments after the callback will be passed to
         the callback when it is called.
         """
+        if tasks.iscoroutinefunction(callback):
+            raise TypeError("coroutines cannot be used with call_soon()")
         handle = events.Handle(callback, args)
         self._ready.append(handle)
         return handle
@@ -252,6 +256,8 @@
         return handle
 
     def run_in_executor(self, executor, callback, *args):
+        if tasks.iscoroutinefunction(callback):
+            raise TypeError("coroutines cannot be used with run_in_executor()")
         if isinstance(callback, events.Handle):
             assert not args
             assert not isinstance(callback, events.TimerHandle)
diff --git a/Lib/asyncio/test_utils.py b/Lib/asyncio/test_utils.py
--- a/Lib/asyncio/test_utils.py
+++ b/Lib/asyncio/test_utils.py
@@ -135,7 +135,7 @@
         if name.startswith('__') and name.endswith('__'):
             # skip magic names
             continue
-        dct[name] = unittest.mock.Mock(return_value=None)
+        dct[name] = MockCallback(return_value=None)
     return type('TestProtocol', (base,) + base.__bases__, dct)()
 
 
@@ -274,3 +274,6 @@
 
     def _write_to_self(self):
         pass
+
+def MockCallback(**kwargs):
+    return unittest.mock.Mock(spec=['__call__'], **kwargs)
diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py
--- a/Lib/test/test_asyncio/test_base_events.py
+++ b/Lib/test/test_asyncio/test_base_events.py
@@ -567,6 +567,7 @@
 
         m_socket.getaddrinfo.return_value = [
             (2, 1, 6, '', ('127.0.0.1', 10100))]
+        m_socket.getaddrinfo._is_coroutine = False
         m_sock = m_socket.socket.return_value = unittest.mock.Mock()
         m_sock.bind.side_effect = Err
 
@@ -577,6 +578,7 @@
     @unittest.mock.patch('asyncio.base_events.socket')
     def test_create_datagram_endpoint_no_addrinfo(self, m_socket):
         m_socket.getaddrinfo.return_value = []
+        m_socket.getaddrinfo._is_coroutine = False
 
         coro = self.loop.create_datagram_endpoint(
             MyDatagramProto, local_addr=('localhost', 0))
@@ -681,6 +683,22 @@
                                                 unittest.mock.ANY,
                                                 MyProto, sock, None, None)
 
+    def test_call_coroutine(self):
+        @asyncio.coroutine
+        def coroutine_function():
+            pass
+
+        with self.assertRaises(TypeError):
+            self.loop.call_soon(coroutine_function)
+        with self.assertRaises(TypeError):
+            self.loop.call_soon_threadsafe(coroutine_function)
+        with self.assertRaises(TypeError):
+            self.loop.call_later(60, coroutine_function)
+        with self.assertRaises(TypeError):
+            self.loop.call_at(self.loop.time() + 60, coroutine_function)
+        with self.assertRaises(TypeError):
+            self.loop.run_in_executor(None, coroutine_function)
+
 
 if __name__ == '__main__':
     unittest.main()
diff --git a/Lib/test/test_asyncio/test_proactor_events.py b/Lib/test/test_asyncio/test_proactor_events.py
--- a/Lib/test/test_asyncio/test_proactor_events.py
+++ b/Lib/test/test_asyncio/test_proactor_events.py
@@ -402,7 +402,7 @@
             NotImplementedError, BaseProactorEventLoop, self.proactor)
 
     def test_make_socket_transport(self):
-        tr = self.loop._make_socket_transport(self.sock, unittest.mock.Mock())
+        tr = self.loop._make_socket_transport(self.sock, asyncio.Protocol())
         self.assertIsInstance(tr, _ProactorSocketTransport)
 
     def test_loop_self_reading(self):
diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py
--- a/Lib/test/test_asyncio/test_selector_events.py
+++ b/Lib/test/test_asyncio/test_selector_events.py
@@ -44,8 +44,8 @@
     def test_make_socket_transport(self):
         m = unittest.mock.Mock()
         self.loop.add_reader = unittest.mock.Mock()
-        self.assertIsInstance(
-            self.loop._make_socket_transport(m, m), _SelectorSocketTransport)
+        transport = self.loop._make_socket_transport(m, asyncio.Protocol())
+        self.assertIsInstance(transport, _SelectorSocketTransport)
 
     @unittest.skipIf(ssl is None, 'No ssl module')
     def test_make_ssl_transport(self):
@@ -54,8 +54,9 @@
         self.loop.add_writer = unittest.mock.Mock()
         self.loop.remove_reader = unittest.mock.Mock()
         self.loop.remove_writer = unittest.mock.Mock()
-        self.assertIsInstance(
-            self.loop._make_ssl_transport(m, m, m, m), _SelectorSslTransport)
+        waiter = asyncio.Future(loop=self.loop)
+        transport = self.loop._make_ssl_transport(m, asyncio.Protocol(), m, waiter)
+        self.assertIsInstance(transport, _SelectorSslTransport)
 
     @unittest.mock.patch('asyncio.selector_events.ssl', None)
     def test_make_ssl_transport_without_ssl_error(self):
diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py
--- a/Lib/test/test_asyncio/test_tasks.py
+++ b/Lib/test/test_asyncio/test_tasks.py
@@ -2,8 +2,6 @@
 
 import gc
 import unittest
-import unittest.mock
-from unittest.mock import Mock
 
 import asyncio
 from asyncio import test_utils
@@ -1358,7 +1356,7 @@
     def _check_success(self, **kwargs):
         a, b, c = [asyncio.Future(loop=self.one_loop) for i in range(3)]
         fut = asyncio.gather(*self.wrap_futures(a, b, c), **kwargs)
-        cb = Mock()
+        cb = test_utils.MockCallback()
         fut.add_done_callback(cb)
         b.set_result(1)
         a.set_result(2)
@@ -1380,7 +1378,7 @@
     def test_one_exception(self):
         a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)]
         fut = asyncio.gather(*self.wrap_futures(a, b, c, d, e))
-        cb = Mock()
+        cb = test_utils.MockCallback()
         fut.add_done_callback(cb)
         exc = ZeroDivisionError()
         a.set_result(1)
@@ -1399,7 +1397,7 @@
         a, b, c, d = [asyncio.Future(loop=self.one_loop) for i in range(4)]
         fut = asyncio.gather(*self.wrap_futures(a, b, c, d),
                              return_exceptions=True)
-        cb = Mock()
+        cb = test_utils.MockCallback()
         fut.add_done_callback(cb)
         exc = ZeroDivisionError()
         exc2 = RuntimeError()
@@ -1460,7 +1458,7 @@
     def test_one_cancellation(self):
         a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)]
         fut = asyncio.gather(a, b, c, d, e)
-        cb = Mock()
+        cb = test_utils.MockCallback()
         fut.add_done_callback(cb)
         a.set_result(1)
         b.cancel()
@@ -1479,7 +1477,7 @@
         a, b, c, d, e, f = [asyncio.Future(loop=self.one_loop)
                             for i in range(6)]
         fut = asyncio.gather(a, b, c, d, e, f, return_exceptions=True)
-        cb = Mock()
+        cb = test_utils.MockCallback()
         fut.add_done_callback(cb)
         a.set_result(1)
         zde = ZeroDivisionError()

-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list