[Python-checkins] cpython: asyncio: Refactor ssl transport ready loop (Nikolay Kim).

guido.van.rossum python-checkins at python.org
Fri Nov 1 22:25:51 CET 2013


http://hg.python.org/cpython/rev/10a518b26ed1
changeset:   86828:10a518b26ed1
user:        Guido van Rossum <guido at dropbox.com>
date:        Fri Nov 01 14:18:02 2013 -0700
summary:
  asyncio: Refactor ssl transport ready loop (Nikolay Kim).

files:
  Lib/asyncio/selector_events.py                |   94 +++---
  Lib/test/test_asyncio/test_selector_events.py |  134 ++++++---
  2 files changed, 136 insertions(+), 92 deletions(-)


diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py
--- a/Lib/asyncio/selector_events.py
+++ b/Lib/asyncio/selector_events.py
@@ -286,7 +286,7 @@
                 err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
                 if err != 0:
                     # Jump to the except clause below.
-                    raise OSError(err, 'Connect call failed')
+                    raise OSError(err, 'Connect call failed %s' % (address,))
         except (BlockingIOError, InterruptedError):
             self.add_writer(fd, self._sock_connect, fut, True, sock, address)
         except Exception as exc:
@@ -413,7 +413,7 @@
             try:
                 self._protocol.pause_writing()
             except Exception:
-                tulip_log.exception('pause_writing() failed')
+                logger.exception('pause_writing() failed')
 
     def _maybe_resume_protocol(self):
         if (self._protocol_paused and
@@ -422,7 +422,7 @@
             try:
                 self._protocol.resume_writing()
             except Exception:
-                tulip_log.exception('resume_writing() failed')
+                logger.exception('resume_writing() failed')
 
     def set_write_buffer_limits(self, high=None, low=None):
         if high is None:
@@ -635,15 +635,16 @@
                            compression=self._sock.compression(),
                            )
 
-        self._loop.add_reader(self._sock_fd, self._on_ready)
-        self._loop.add_writer(self._sock_fd, self._on_ready)
+        self._read_wants_write = False
+        self._write_wants_read = False
+        self._loop.add_reader(self._sock_fd, self._read_ready)
         self._loop.call_soon(self._protocol.connection_made, self)
         if self._waiter is not None:
             self._loop.call_soon(self._waiter.set_result, None)
 
     def pause_reading(self):
         # XXX This is a bit icky, given the comment at the top of
-        # _on_ready().  Is it possible to evoke a deadlock?  I don't
+        # _read_ready().  Is it possible to evoke a deadlock?  I don't
         # know, although it doesn't look like it; write() will still
         # accept more data for the buffer and eventually the app will
         # call resume_reading() again, and things will flow again.
@@ -658,41 +659,55 @@
         self._paused = False
         if self._closing:
             return
-        self._loop.add_reader(self._sock_fd, self._on_ready)
+        self._loop.add_reader(self._sock_fd, self._read_ready)
 
-    def _on_ready(self):
-        # Because of renegotiations (?), there's no difference between
-        # readable and writable.  We just try both.  XXX This may be
-        # incorrect; we probably need to keep state about what we
-        # should do next.
+    def _read_ready(self):
+        if self._write_wants_read:
+            self._write_wants_read = False
+            self._write_ready()
 
-        # First try reading.
-        if not self._closing and not self._paused:
-            try:
-                data = self._sock.recv(self.max_size)
-            except (BlockingIOError, InterruptedError,
-                    ssl.SSLWantReadError, ssl.SSLWantWriteError):
-                pass
-            except Exception as exc:
-                self._fatal_error(exc)
+            if self._buffer:
+                self._loop.add_writer(self._sock_fd, self._write_ready)
+
+        try:
+            data = self._sock.recv(self.max_size)
+        except (BlockingIOError, InterruptedError, ssl.SSLWantReadError):
+            pass
+        except ssl.SSLWantWriteError:
+            self._read_wants_write = True
+            self._loop.remove_reader(self._sock_fd)
+            self._loop.add_writer(self._sock_fd, self._write_ready)
+        except Exception as exc:
+            self._fatal_error(exc)
+        else:
+            if data:
+                self._protocol.data_received(data)
             else:
-                if data:
-                    self._protocol.data_received(data)
-                else:
-                    try:
-                        self._protocol.eof_received()
-                    finally:
-                        self.close()
+                try:
+                    self._protocol.eof_received()
+                finally:
+                    self.close()
 
-        # Now try writing, if there's anything to write.
+    def _write_ready(self):
+        if self._read_wants_write:
+            self._read_wants_write = False
+            self._read_ready()
+
+            if not (self._paused or self._closing):
+                self._loop.add_reader(self._sock_fd, self._read_ready)
+
         if self._buffer:
             data = b''.join(self._buffer)
             self._buffer.clear()
             try:
                 n = self._sock.send(data)
             except (BlockingIOError, InterruptedError,
-                    ssl.SSLWantReadError, ssl.SSLWantWriteError):
+                    ssl.SSLWantWriteError):
                 n = 0
+            except ssl.SSLWantReadError:
+                n = 0
+                self._loop.remove_writer(self._sock_fd)
+                self._write_wants_read = True
             except Exception as exc:
                 self._loop.remove_writer(self._sock_fd)
                 self._fatal_error(exc)
@@ -701,11 +716,12 @@
             if n < len(data):
                 self._buffer.append(data[n:])
 
-            self._maybe_resume_protocol()  # May append to buffer.
+        self._maybe_resume_protocol()  # May append to buffer.
 
-        if self._closing and not self._buffer:
+        if not self._buffer:
             self._loop.remove_writer(self._sock_fd)
-            self._call_connection_lost(None)
+            if self._closing:
+                self._call_connection_lost(None)
 
     def write(self, data):
         assert isinstance(data, bytes), repr(type(data))
@@ -718,20 +734,16 @@
             self._conn_lost += 1
             return
 
-        # We could optimize, but the callback can do this for now.
+        if not self._buffer:
+            self._loop.add_writer(self._sock_fd, self._write_ready)
+
+        # Add it to the buffer.
         self._buffer.append(data)
         self._maybe_pause_protocol()
 
     def can_write_eof(self):
         return False
 
-    def close(self):
-        if self._closing:
-            return
-        self._closing = True
-        self._conn_lost += 1
-        self._loop.remove_reader(self._sock_fd)
-
 
 class _SelectorDatagramTransport(_SelectorTransport):
 
diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py
--- a/Lib/test/test_asyncio/test_selector_events.py
+++ b/Lib/test/test_asyncio/test_selector_events.py
@@ -1003,8 +1003,7 @@
             self.loop, self.sock, self.protocol, self.sslcontext,
             waiter=waiter)
         self.assertTrue(self.sslsock.do_handshake.called)
-        self.loop.assert_reader(1, tr._on_ready)
-        self.loop.assert_writer(1, tr._on_ready)
+        self.loop.assert_reader(1, tr._read_ready)
         test_utils.run_briefly(self.loop)
         self.assertIsNone(waiter.result())
 
@@ -1047,13 +1046,13 @@
     def test_pause_resume_reading(self):
         tr = self._make_one()
         self.assertFalse(tr._paused)
-        self.loop.assert_reader(1, tr._on_ready)
+        self.loop.assert_reader(1, tr._read_ready)
         tr.pause_reading()
         self.assertTrue(tr._paused)
         self.assertFalse(1 in self.loop.readers)
         tr.resume_reading()
         self.assertFalse(tr._paused)
-        self.loop.assert_reader(1, tr._on_ready)
+        self.loop.assert_reader(1, tr._read_ready)
 
     def test_write_no_data(self):
         transport = self._make_one()
@@ -1084,140 +1083,173 @@
         transport.write(b'data')
         m_log.warning.assert_called_with('socket.send() raised exception.')
 
-    def test_on_ready_recv(self):
+    def test_read_ready_recv(self):
         self.sslsock.recv.return_value = b'data'
         transport = self._make_one()
-        transport._on_ready()
+        transport._read_ready()
         self.assertTrue(self.sslsock.recv.called)
         self.assertEqual((b'data',), self.protocol.data_received.call_args[0])
 
-    def test_on_ready_recv_eof(self):
+    def test_read_ready_write_wants_read(self):
+        self.loop.add_writer = unittest.mock.Mock()
+        self.sslsock.recv.side_effect = BlockingIOError
+        transport = self._make_one()
+        transport._write_wants_read = True
+        transport._write_ready = unittest.mock.Mock()
+        transport._buffer.append(b'data')
+        transport._read_ready()
+
+        self.assertFalse(transport._write_wants_read)
+        transport._write_ready.assert_called_with()
+        self.loop.add_writer.assert_called_with(
+            transport._sock_fd, transport._write_ready)
+
+    def test_read_ready_recv_eof(self):
         self.sslsock.recv.return_value = b''
         transport = self._make_one()
         transport.close = unittest.mock.Mock()
-        transport._on_ready()
+        transport._read_ready()
         transport.close.assert_called_with()
         self.protocol.eof_received.assert_called_with()
 
-    def test_on_ready_recv_conn_reset(self):
+    def test_read_ready_recv_conn_reset(self):
         err = self.sslsock.recv.side_effect = ConnectionResetError()
         transport = self._make_one()
         transport._force_close = unittest.mock.Mock()
-        transport._on_ready()
+        transport._read_ready()
         transport._force_close.assert_called_with(err)
 
-    def test_on_ready_recv_retry(self):
+    def test_read_ready_recv_retry(self):
         self.sslsock.recv.side_effect = ssl.SSLWantReadError
         transport = self._make_one()
-        transport._on_ready()
+        transport._read_ready()
         self.assertTrue(self.sslsock.recv.called)
         self.assertFalse(self.protocol.data_received.called)
 
-        self.sslsock.recv.side_effect = ssl.SSLWantWriteError
-        transport._on_ready()
-        self.assertFalse(self.protocol.data_received.called)
-
         self.sslsock.recv.side_effect = BlockingIOError
-        transport._on_ready()
+        transport._read_ready()
         self.assertFalse(self.protocol.data_received.called)
 
         self.sslsock.recv.side_effect = InterruptedError
-        transport._on_ready()
+        transport._read_ready()
         self.assertFalse(self.protocol.data_received.called)
 
-    def test_on_ready_recv_exc(self):
+    def test_read_ready_recv_write(self):
+        self.loop.remove_reader = unittest.mock.Mock()
+        self.loop.add_writer = unittest.mock.Mock()
+        self.sslsock.recv.side_effect = ssl.SSLWantWriteError
+        transport = self._make_one()
+        transport._read_ready()
+        self.assertFalse(self.protocol.data_received.called)
+        self.assertTrue(transport._read_wants_write)
+
+        self.loop.remove_reader.assert_called_with(transport._sock_fd)
+        self.loop.add_writer.assert_called_with(
+            transport._sock_fd, transport._write_ready)
+
+    def test_read_ready_recv_exc(self):
         err = self.sslsock.recv.side_effect = OSError()
         transport = self._make_one()
         transport._fatal_error = unittest.mock.Mock()
-        transport._on_ready()
+        transport._read_ready()
         transport._fatal_error.assert_called_with(err)
 
-    def test_on_ready_send(self):
-        self.sslsock.recv.side_effect = ssl.SSLWantReadError
+    def test_write_ready_send(self):
         self.sslsock.send.return_value = 4
         transport = self._make_one()
         transport._buffer = collections.deque([b'data'])
-        transport._on_ready()
+        transport._write_ready()
         self.assertEqual(collections.deque(), transport._buffer)
         self.assertTrue(self.sslsock.send.called)
 
-    def test_on_ready_send_none(self):
-        self.sslsock.recv.side_effect = ssl.SSLWantReadError
+    def test_write_ready_send_none(self):
         self.sslsock.send.return_value = 0
         transport = self._make_one()
         transport._buffer = collections.deque([b'data1', b'data2'])
-        transport._on_ready()
+        transport._write_ready()
         self.assertTrue(self.sslsock.send.called)
         self.assertEqual(collections.deque([b'data1data2']), transport._buffer)
 
-    def test_on_ready_send_partial(self):
-        self.sslsock.recv.side_effect = ssl.SSLWantReadError
+    def test_write_ready_send_partial(self):
         self.sslsock.send.return_value = 2
         transport = self._make_one()
         transport._buffer = collections.deque([b'data1', b'data2'])
-        transport._on_ready()
+        transport._write_ready()
         self.assertTrue(self.sslsock.send.called)
         self.assertEqual(collections.deque([b'ta1data2']), transport._buffer)
 
-    def test_on_ready_send_closing_partial(self):
-        self.sslsock.recv.side_effect = ssl.SSLWantReadError
+    def test_write_ready_send_closing_partial(self):
         self.sslsock.send.return_value = 2
         transport = self._make_one()
         transport._buffer = collections.deque([b'data1', b'data2'])
-        transport._on_ready()
+        transport._write_ready()
         self.assertTrue(self.sslsock.send.called)
         self.assertFalse(self.sslsock.close.called)
 
-    def test_on_ready_send_closing(self):
-        self.sslsock.recv.side_effect = ssl.SSLWantReadError
+    def test_write_ready_send_closing(self):
         self.sslsock.send.return_value = 4
         transport = self._make_one()
         transport.close()
         transport._buffer = collections.deque([b'data'])
-        transport._on_ready()
+        transport._write_ready()
         self.assertFalse(self.loop.writers)
         self.protocol.connection_lost.assert_called_with(None)
 
-    def test_on_ready_send_closing_empty_buffer(self):
-        self.sslsock.recv.side_effect = ssl.SSLWantReadError
+    def test_write_ready_send_closing_empty_buffer(self):
         self.sslsock.send.return_value = 4
         transport = self._make_one()
         transport.close()
         transport._buffer = collections.deque()
-        transport._on_ready()
+        transport._write_ready()
         self.assertFalse(self.loop.writers)
         self.protocol.connection_lost.assert_called_with(None)
 
-    def test_on_ready_send_retry(self):
-        self.sslsock.recv.side_effect = ssl.SSLWantReadError
-
+    def test_write_ready_send_retry(self):
         transport = self._make_one()
         transport._buffer = collections.deque([b'data'])
 
-        self.sslsock.send.side_effect = ssl.SSLWantReadError
-        transport._on_ready()
-        self.assertTrue(self.sslsock.send.called)
-        self.assertEqual(collections.deque([b'data']), transport._buffer)
-
         self.sslsock.send.side_effect = ssl.SSLWantWriteError
-        transport._on_ready()
+        transport._write_ready()
         self.assertEqual(collections.deque([b'data']), transport._buffer)
 
         self.sslsock.send.side_effect = BlockingIOError()
-        transport._on_ready()
+        transport._write_ready()
         self.assertEqual(collections.deque([b'data']), transport._buffer)
 
-    def test_on_ready_send_exc(self):
-        self.sslsock.recv.side_effect = ssl.SSLWantReadError
+    def test_write_ready_send_read(self):
+        transport = self._make_one()
+        transport._buffer = collections.deque([b'data'])
+
+        self.loop.remove_writer = unittest.mock.Mock()
+        self.sslsock.send.side_effect = ssl.SSLWantReadError
+        transport._write_ready()
+        self.assertFalse(self.protocol.data_received.called)
+        self.assertTrue(transport._write_wants_read)
+        self.loop.remove_writer.assert_called_with(transport._sock_fd)
+
+    def test_write_ready_send_exc(self):
         err = self.sslsock.send.side_effect = OSError()
 
         transport = self._make_one()
         transport._buffer = collections.deque([b'data'])
         transport._fatal_error = unittest.mock.Mock()
-        transport._on_ready()
+        transport._write_ready()
         transport._fatal_error.assert_called_with(err)
         self.assertEqual(collections.deque(), transport._buffer)
 
+    def test_write_ready_read_wants_write(self):
+        self.loop.add_reader = unittest.mock.Mock()
+        self.sslsock.send.side_effect = BlockingIOError
+        transport = self._make_one()
+        transport._read_wants_write = True
+        transport._read_ready = unittest.mock.Mock()
+        transport._write_ready()
+
+        self.assertFalse(transport._read_wants_write)
+        transport._read_ready.assert_called_with()
+        self.loop.add_reader.assert_called_with(
+            transport._sock_fd, transport._read_ready)
+
     def test_write_eof(self):
         tr = self._make_one()
         self.assertFalse(tr.can_write_eof())

-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list