[Python-checkins] cpython (merge 3.4 -> default): (Merge 3.4) asyncio, Tulip issue 205: Fix a race condition in

victor.stinner python-checkins at python.org
Sun Aug 31 15:10:23 CEST 2014


http://hg.python.org/cpython/rev/28cbbe2ce104
changeset:   92280:28cbbe2ce104
parent:      92278:9138d60db0e4
parent:      92279:ad67f66a5f3c
user:        Victor Stinner <victor.stinner at gmail.com>
date:        Sun Aug 31 15:08:21 2014 +0200
summary:
  (Merge 3.4) asyncio, Tulip issue 205: Fix a race condition in
BaseSelectorEventLoop.sock_connect()

There is a race condition in create_connection() used with wait_for() to have a
timeout. sock_connect() registers the file descriptor of the socket to be
notified of write event (if connect() raises BlockingIOError). When
create_connection() is cancelled with a TimeoutError, sock_connect() coroutine
gets the exception, but it doesn't unregister the file descriptor for write
event. create_connection() gets the TimeoutError and closes the socket.

If you call again create_connection(), the new socket will likely gets the same
file descriptor, which is still registered in the selector. When sock_connect()
calls add_writer(), it tries to modify the entry instead of creating a new one.

This issue was originally reported in the Trollius project, but the bug comes
from Tulip in fact (Trollius is based on Tulip):
https://bitbucket.org/enovance/trollius/issue/15/after-timeouterror-on-wait_for

This change fixes the race condition. It also makes sock_connect() more
reliable (and portable) is sock.connect() raises an InterruptedError.

files:
  Lib/asyncio/selector_events.py                |  44 ++++-
  Lib/test/test_asyncio/test_selector_events.py |  74 +++++++--
  2 files changed, 83 insertions(+), 35 deletions(-)


diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py
--- a/Lib/asyncio/selector_events.py
+++ b/Lib/asyncio/selector_events.py
@@ -8,6 +8,7 @@
 
 import collections
 import errno
+import functools
 import socket
 try:
     import ssl
@@ -345,26 +346,43 @@
         except ValueError as err:
             fut.set_exception(err)
         else:
-            self._sock_connect(fut, False, sock, address)
+            self._sock_connect(fut, sock, address)
         return fut
 
-    def _sock_connect(self, fut, registered, sock, address):
+    def _sock_connect(self, fut, sock, address):
         fd = sock.fileno()
-        if registered:
-            self.remove_writer(fd)
+        try:
+            while True:
+                try:
+                    sock.connect(address)
+                except InterruptedError:
+                    continue
+                else:
+                    break
+        except BlockingIOError:
+            fut.add_done_callback(functools.partial(self._sock_connect_done,
+                                                    sock))
+            self.add_writer(fd, self._sock_connect_cb, fut, sock, address)
+        except Exception as exc:
+            fut.set_exception(exc)
+        else:
+            fut.set_result(None)
+
+    def _sock_connect_done(self, sock, fut):
+        self.remove_writer(sock.fileno())
+
+    def _sock_connect_cb(self, fut, sock, address):
         if fut.cancelled():
             return
+
         try:
-            if not registered:
-                # First time around.
-                sock.connect(address)
-            else:
-                err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
-                if err != 0:
-                    # Jump to the except clause below.
-                    raise OSError(err, 'Connect call failed %s' % (address,))
+            err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
+            if err != 0:
+                # Jump to any except clause below.
+                raise OSError(err, 'Connect call failed %s' % (address,))
         except (BlockingIOError, InterruptedError):
-            self.add_writer(fd, self._sock_connect, fut, True, sock, address)
+            # socket is still registered, the callback will be retried later
+            pass
         except Exception as exc:
             fut.set_exception(exc)
         else:
diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py
--- a/Lib/test/test_asyncio/test_selector_events.py
+++ b/Lib/test/test_asyncio/test_selector_events.py
@@ -40,8 +40,9 @@
 class BaseSelectorEventLoopTests(test_utils.TestCase):
 
     def setUp(self):
-        selector = mock.Mock()
-        self.loop = TestBaseSelectorEventLoop(selector)
+        self.selector = mock.Mock()
+        self.selector.select.return_value = []
+        self.loop = TestBaseSelectorEventLoop(self.selector)
         self.set_event_loop(self.loop, cleanup=False)
 
     def test_make_socket_transport(self):
@@ -303,63 +304,92 @@
         f = self.loop.sock_connect(sock, ('127.0.0.1', 8080))
         self.assertIsInstance(f, asyncio.Future)
         self.assertEqual(
-            (f, False, sock, ('127.0.0.1', 8080)),
+            (f, sock, ('127.0.0.1', 8080)),
             self.loop._sock_connect.call_args[0])
 
+    def test_sock_connect_timeout(self):
+        # Tulip issue #205: sock_connect() must unregister the socket on
+        # timeout error
+
+        # prepare mocks
+        self.loop.add_writer = mock.Mock()
+        self.loop.remove_writer = mock.Mock()
+        sock = test_utils.mock_nonblocking_socket()
+        sock.connect.side_effect = BlockingIOError
+
+        # first call to sock_connect() registers the socket
+        fut = self.loop.sock_connect(sock, ('127.0.0.1', 80))
+        self.assertTrue(sock.connect.called)
+        self.assertTrue(self.loop.add_writer.called)
+        self.assertEqual(len(fut._callbacks), 1)
+
+        # on timeout, the socket must be unregistered
+        sock.connect.reset_mock()
+        fut.set_exception(asyncio.TimeoutError)
+        with self.assertRaises(asyncio.TimeoutError):
+            self.loop.run_until_complete(fut)
+        self.assertTrue(self.loop.remove_writer.called)
+
     def test__sock_connect(self):
         f = asyncio.Future(loop=self.loop)
 
         sock = mock.Mock()
         sock.fileno.return_value = 10
 
-        self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080))
+        self.loop._sock_connect(f, sock, ('127.0.0.1', 8080))
         self.assertTrue(f.done())
         self.assertIsNone(f.result())
         self.assertTrue(sock.connect.called)
 
-    def test__sock_connect_canceled_fut(self):
+    def test__sock_connect_cb_cancelled_fut(self):
         sock = mock.Mock()
+        self.loop.remove_writer = mock.Mock()
 
         f = asyncio.Future(loop=self.loop)
         f.cancel()
 
-        self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080))
-        self.assertFalse(sock.connect.called)
+        self.loop._sock_connect_cb(f, sock, ('127.0.0.1', 8080))
+        self.assertFalse(sock.getsockopt.called)
 
-    def test__sock_connect_unregister(self):
+    def test__sock_connect_writer(self):
+        # check that the fd is registered and then unregistered
+        self.loop._process_events = mock.Mock()
+        self.loop.add_writer = mock.Mock()
+        self.loop.remove_writer = mock.Mock()
+
         sock = mock.Mock()
         sock.fileno.return_value = 10
+        sock.connect.side_effect = BlockingIOError
+        sock.getsockopt.return_value = 0
+        address = ('127.0.0.1', 8080)
 
         f = asyncio.Future(loop=self.loop)
-        f.cancel()
+        self.loop._sock_connect(f, sock, address)
+        self.assertTrue(self.loop.add_writer.called)
+        self.assertEqual(10, self.loop.add_writer.call_args[0][0])
 
-        self.loop.remove_writer = mock.Mock()
-        self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080))
+        self.loop._sock_connect_cb(f, sock, address)
+        # need to run the event loop to execute _sock_connect_done() callback
+        self.loop.run_until_complete(f)
         self.assertEqual((10,), self.loop.remove_writer.call_args[0])
 
-    def test__sock_connect_tryagain(self):
+    def test__sock_connect_cb_tryagain(self):
         f = asyncio.Future(loop=self.loop)
         sock = mock.Mock()
         sock.fileno.return_value = 10
         sock.getsockopt.return_value = errno.EAGAIN
 
-        self.loop.add_writer = mock.Mock()
-        self.loop.remove_writer = mock.Mock()
+        # check that the exception is handled
+        self.loop._sock_connect_cb(f, sock, ('127.0.0.1', 8080))
 
-        self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080))
-        self.assertEqual(
-            (10, self.loop._sock_connect, f,
-             True, sock, ('127.0.0.1', 8080)),
-            self.loop.add_writer.call_args[0])
-
-    def test__sock_connect_exception(self):
+    def test__sock_connect_cb_exception(self):
         f = asyncio.Future(loop=self.loop)
         sock = mock.Mock()
         sock.fileno.return_value = 10
         sock.getsockopt.return_value = errno.ENOTCONN
 
         self.loop.remove_writer = mock.Mock()
-        self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080))
+        self.loop._sock_connect_cb(f, sock, ('127.0.0.1', 8080))
         self.assertIsInstance(f.exception(), OSError)
 
     def test_sock_accept(self):

-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list