[Python-checkins] cpython: asyncio: Change write buffer use to avoid O(N**2). Make write()/sendto() accept

guido.van.rossum python-checkins at python.org
Wed Nov 27 23:12:55 CET 2013


http://hg.python.org/cpython/rev/80e0040d910c
changeset:   87617:80e0040d910c
user:        Guido van Rossum <guido at python.org>
date:        Wed Nov 27 14:12:48 2013 -0800
summary:
  asyncio: Change write buffer use to avoid O(N**2). Make write()/sendto() accept bytearray/memoryview too. Change some asserts with proper exceptions.

files:
  Lib/asyncio/selector_events.py                |   82 ++-
  Lib/test/test_asyncio/test_selector_events.py |  203 +++++++--
  2 files changed, 207 insertions(+), 78 deletions(-)


diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py
--- a/Lib/asyncio/selector_events.py
+++ b/Lib/asyncio/selector_events.py
@@ -340,6 +340,8 @@
 
     max_size = 256 * 1024  # Buffer size passed to recv().
 
+    _buffer_factory = bytearray  # Constructs initial value for self._buffer.
+
     def __init__(self, loop, sock, protocol, extra, server=None):
         super().__init__(extra)
         self._extra['socket'] = sock
@@ -354,7 +356,7 @@
         self._sock_fd = sock.fileno()
         self._protocol = protocol
         self._server = server
-        self._buffer = collections.deque()
+        self._buffer = self._buffer_factory()
         self._conn_lost = 0  # Set when call to connection_lost scheduled.
         self._closing = False  # Set when close() called.
         self._protocol_paused = False
@@ -433,12 +435,14 @@
                 high = 4*low
         if low is None:
             low = high // 4
-        assert 0 <= low <= high, repr((low, high))
+        if not high >= low >= 0:
+            raise ValueError('high (%r) must be >= low (%r) must be >= 0' %
+                             (high, low))
         self._high_water = high
         self._low_water = low
 
     def get_write_buffer_size(self):
-        return sum(len(data) for data in self._buffer)
+        return len(self._buffer)
 
 
 class _SelectorSocketTransport(_SelectorTransport):
@@ -455,13 +459,16 @@
             self._loop.call_soon(waiter.set_result, None)
 
     def pause_reading(self):
-        assert not self._closing, 'Cannot pause_reading() when closing'
-        assert not self._paused, 'Already paused'
+        if self._closing:
+            raise RuntimeError('Cannot pause_reading() when closing')
+        if self._paused:
+            raise RuntimeError('Already paused')
         self._paused = True
         self._loop.remove_reader(self._sock_fd)
 
     def resume_reading(self):
-        assert self._paused, 'Not paused'
+        if not self._paused:
+            raise RuntimeError('Not paused')
         self._paused = False
         if self._closing:
             return
@@ -488,8 +495,11 @@
                     self.close()
 
     def write(self, data):
-        assert isinstance(data, bytes), repr(type(data))
-        assert not self._eof, 'Cannot call write() after write_eof()'
+        if not isinstance(data, (bytes, bytearray, memoryview)):
+            raise TypeError('data argument must be byte-ish (%r)',
+                            type(data))
+        if self._eof:
+            raise RuntimeError('Cannot call write() after write_eof()')
         if not data:
             return
 
@@ -516,25 +526,23 @@
             self._loop.add_writer(self._sock_fd, self._write_ready)
 
         # Add it to the buffer.
-        self._buffer.append(data)
+        self._buffer.extend(data)
         self._maybe_pause_protocol()
 
     def _write_ready(self):
-        data = b''.join(self._buffer)
-        assert data, 'Data should not be empty'
+        assert self._buffer, 'Data should not be empty'
 
-        self._buffer.clear()  # Optimistically; may have to put it back later.
         try:
-            n = self._sock.send(data)
+            n = self._sock.send(self._buffer)
         except (BlockingIOError, InterruptedError):
-            self._buffer.append(data)  # Still need to write this.
+            pass
         except Exception as exc:
             self._loop.remove_writer(self._sock_fd)
+            self._buffer.clear()
             self._fatal_error(exc)
         else:
-            data = data[n:]
-            if data:
-                self._buffer.append(data)  # Still need to write this.
+            if n:
+                del self._buffer[:n]
             self._maybe_resume_protocol()  # May append to buffer.
             if not self._buffer:
                 self._loop.remove_writer(self._sock_fd)
@@ -556,6 +564,8 @@
 
 class _SelectorSslTransport(_SelectorTransport):
 
+    _buffer_factory = bytearray
+
     def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None,
                  server_side=False, server_hostname=None,
                  extra=None, server=None):
@@ -661,13 +671,16 @@
         # accept more data for the buffer and eventually the app will
         # call resume_reading() again, and things will flow again.
 
-        assert not self._closing, 'Cannot pause_reading() when closing'
-        assert not self._paused, 'Already paused'
+        if self._closing:
+            raise RuntimeError('Cannot pause_reading() when closing')
+        if self._paused:
+            raise RuntimeError('Already paused')
         self._paused = True
         self._loop.remove_reader(self._sock_fd)
 
     def resume_reading(self):
-        assert self._paused, 'Not paused'
+        if not self._paused:
+            raise ('Not paused')
         self._paused = False
         if self._closing:
             return
@@ -712,10 +725,8 @@
                 self._loop.add_reader(self._sock_fd, self._read_ready)
 
         if self._buffer:
-            data = b''.join(self._buffer)
-            self._buffer.clear()
             try:
-                n = self._sock.send(data)
+                n = self._sock.send(self._buffer)
             except (BlockingIOError, InterruptedError,
                     ssl.SSLWantWriteError):
                 n = 0
@@ -725,11 +736,12 @@
                 self._write_wants_read = True
             except Exception as exc:
                 self._loop.remove_writer(self._sock_fd)
+                self._buffer.clear()
                 self._fatal_error(exc)
                 return
 
-            if n < len(data):
-                self._buffer.append(data[n:])
+            if n:
+                del self._buffer[:n]
 
         self._maybe_resume_protocol()  # May append to buffer.
 
@@ -739,7 +751,9 @@
                 self._call_connection_lost(None)
 
     def write(self, data):
-        assert isinstance(data, bytes), repr(type(data))
+        if not isinstance(data, (bytes, bytearray, memoryview)):
+            raise TypeError('data argument must be byte-ish (%r)',
+                            type(data))
         if not data:
             return
 
@@ -753,7 +767,7 @@
             self._loop.add_writer(self._sock_fd, self._write_ready)
 
         # Add it to the buffer.
-        self._buffer.append(data)
+        self._buffer.extend(data)
         self._maybe_pause_protocol()
 
     def can_write_eof(self):
@@ -762,6 +776,8 @@
 
 class _SelectorDatagramTransport(_SelectorTransport):
 
+    _buffer_factory = collections.deque
+
     def __init__(self, loop, sock, protocol, address=None, extra=None):
         super().__init__(loop, sock, protocol, extra)
         self._address = address
@@ -784,12 +800,15 @@
             self._protocol.datagram_received(data, addr)
 
     def sendto(self, data, addr=None):
-        assert isinstance(data, bytes), repr(type(data))
+        if not isinstance(data, (bytes, bytearray, memoryview)):
+            raise TypeError('data argument must be byte-ish (%r)',
+                            type(data))
         if not data:
             return
 
-        if self._address:
-            assert addr in (None, self._address)
+        if self._address and addr not in (None, self._address):
+            raise ValueError('Invalid address: must be None or %s' %
+                             (self._address,))
 
         if self._conn_lost and self._address:
             if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
@@ -814,7 +833,8 @@
                 self._fatal_error(exc)
                 return
 
-        self._buffer.append((data, addr))
+        # Ensure that what we buffer is immutable.
+        self._buffer.append((bytes(data), addr))
         self._maybe_pause_protocol()
 
     def _sendto_ready(self):
diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py
--- a/Lib/test/test_asyncio/test_selector_events.py
+++ b/Lib/test/test_asyncio/test_selector_events.py
@@ -32,6 +32,10 @@
         self._internal_fds += 1
 
 
+def list_to_buffer(l=()):
+    return bytearray().join(l)
+
+
 class BaseSelectorEventLoopTests(unittest.TestCase):
 
     def setUp(self):
@@ -613,7 +617,7 @@
 
     def test_close_write_buffer(self):
         tr = _SelectorTransport(self.loop, self.sock, self.protocol, None)
-        tr._buffer.append(b'data')
+        tr._buffer.extend(b'data')
         tr.close()
 
         self.assertFalse(self.loop.readers)
@@ -622,13 +626,13 @@
 
     def test_force_close(self):
         tr = _SelectorTransport(self.loop, self.sock, self.protocol, None)
-        tr._buffer.append(b'1')
+        tr._buffer.extend(b'1')
         self.loop.add_reader(7, unittest.mock.sentinel)
         self.loop.add_writer(7, unittest.mock.sentinel)
         tr._force_close(None)
 
         self.assertTrue(tr._closing)
-        self.assertEqual(tr._buffer, collections.deque())
+        self.assertEqual(tr._buffer, list_to_buffer())
         self.assertFalse(self.loop.readers)
         self.assertFalse(self.loop.writers)
 
@@ -783,21 +787,40 @@
         transport.write(data)
         self.sock.send.assert_called_with(data)
 
+    def test_write_bytearray(self):
+        data = bytearray(b'data')
+        self.sock.send.return_value = len(data)
+
+        transport = _SelectorSocketTransport(
+            self.loop, self.sock, self.protocol)
+        transport.write(data)
+        self.sock.send.assert_called_with(data)
+        self.assertEqual(data, bytearray(b'data'))  # Hasn't been mutated.
+
+    def test_write_memoryview(self):
+        data = memoryview(b'data')
+        self.sock.send.return_value = len(data)
+
+        transport = _SelectorSocketTransport(
+            self.loop, self.sock, self.protocol)
+        transport.write(data)
+        self.sock.send.assert_called_with(data)
+
     def test_write_no_data(self):
         transport = _SelectorSocketTransport(
             self.loop, self.sock, self.protocol)
-        transport._buffer.append(b'data')
+        transport._buffer.extend(b'data')
         transport.write(b'')
         self.assertFalse(self.sock.send.called)
-        self.assertEqual(collections.deque([b'data']), transport._buffer)
+        self.assertEqual(list_to_buffer([b'data']), transport._buffer)
 
     def test_write_buffer(self):
         transport = _SelectorSocketTransport(
             self.loop, self.sock, self.protocol)
-        transport._buffer.append(b'data1')
+        transport._buffer.extend(b'data1')
         transport.write(b'data2')
         self.assertFalse(self.sock.send.called)
-        self.assertEqual(collections.deque([b'data1', b'data2']),
+        self.assertEqual(list_to_buffer([b'data1', b'data2']),
                          transport._buffer)
 
     def test_write_partial(self):
@@ -809,7 +832,30 @@
         transport.write(data)
 
         self.loop.assert_writer(7, transport._write_ready)
-        self.assertEqual(collections.deque([b'ta']), transport._buffer)
+        self.assertEqual(list_to_buffer([b'ta']), transport._buffer)
+
+    def test_write_partial_bytearray(self):
+        data = bytearray(b'data')
+        self.sock.send.return_value = 2
+
+        transport = _SelectorSocketTransport(
+            self.loop, self.sock, self.protocol)
+        transport.write(data)
+
+        self.loop.assert_writer(7, transport._write_ready)
+        self.assertEqual(list_to_buffer([b'ta']), transport._buffer)
+        self.assertEqual(data, bytearray(b'data'))  # Hasn't been mutated.
+
+    def test_write_partial_memoryview(self):
+        data = memoryview(b'data')
+        self.sock.send.return_value = 2
+
+        transport = _SelectorSocketTransport(
+            self.loop, self.sock, self.protocol)
+        transport.write(data)
+
+        self.loop.assert_writer(7, transport._write_ready)
+        self.assertEqual(list_to_buffer([b'ta']), transport._buffer)
 
     def test_write_partial_none(self):
         data = b'data'
@@ -821,7 +867,7 @@
         transport.write(data)
 
         self.loop.assert_writer(7, transport._write_ready)
-        self.assertEqual(collections.deque([b'data']), transport._buffer)
+        self.assertEqual(list_to_buffer([b'data']), transport._buffer)
 
     def test_write_tryagain(self):
         self.sock.send.side_effect = BlockingIOError
@@ -832,7 +878,7 @@
         transport.write(data)
 
         self.loop.assert_writer(7, transport._write_ready)
-        self.assertEqual(collections.deque([b'data']), transport._buffer)
+        self.assertEqual(list_to_buffer([b'data']), transport._buffer)
 
     @unittest.mock.patch('asyncio.selector_events.logger')
     def test_write_exception(self, m_log):
@@ -859,7 +905,7 @@
     def test_write_str(self):
         transport = _SelectorSocketTransport(
             self.loop, self.sock, self.protocol)
-        self.assertRaises(AssertionError, transport.write, 'str')
+        self.assertRaises(TypeError, transport.write, 'str')
 
     def test_write_closing(self):
         transport = _SelectorSocketTransport(
@@ -875,11 +921,10 @@
 
         transport = _SelectorSocketTransport(
             self.loop, self.sock, self.protocol)
-        transport._buffer.append(data)
+        transport._buffer.extend(data)
         self.loop.add_writer(7, transport._write_ready)
         transport._write_ready()
         self.assertTrue(self.sock.send.called)
-        self.assertEqual(self.sock.send.call_args[0], (data,))
         self.assertFalse(self.loop.writers)
 
     def test_write_ready_closing(self):
@@ -889,10 +934,10 @@
         transport = _SelectorSocketTransport(
             self.loop, self.sock, self.protocol)
         transport._closing = True
-        transport._buffer.append(data)
+        transport._buffer.extend(data)
         self.loop.add_writer(7, transport._write_ready)
         transport._write_ready()
-        self.sock.send.assert_called_with(data)
+        self.assertTrue(self.sock.send.called)
         self.assertFalse(self.loop.writers)
         self.sock.close.assert_called_with()
         self.protocol.connection_lost.assert_called_with(None)
@@ -900,6 +945,7 @@
     def test_write_ready_no_data(self):
         transport = _SelectorSocketTransport(
             self.loop, self.sock, self.protocol)
+        # This is an internal error.
         self.assertRaises(AssertionError, transport._write_ready)
 
     def test_write_ready_partial(self):
@@ -908,11 +954,11 @@
 
         transport = _SelectorSocketTransport(
             self.loop, self.sock, self.protocol)
-        transport._buffer.append(data)
+        transport._buffer.extend(data)
         self.loop.add_writer(7, transport._write_ready)
         transport._write_ready()
         self.loop.assert_writer(7, transport._write_ready)
-        self.assertEqual(collections.deque([b'ta']), transport._buffer)
+        self.assertEqual(list_to_buffer([b'ta']), transport._buffer)
 
     def test_write_ready_partial_none(self):
         data = b'data'
@@ -920,23 +966,23 @@
 
         transport = _SelectorSocketTransport(
             self.loop, self.sock, self.protocol)
-        transport._buffer.append(data)
+        transport._buffer.extend(data)
         self.loop.add_writer(7, transport._write_ready)
         transport._write_ready()
         self.loop.assert_writer(7, transport._write_ready)
-        self.assertEqual(collections.deque([b'data']), transport._buffer)
+        self.assertEqual(list_to_buffer([b'data']), transport._buffer)
 
     def test_write_ready_tryagain(self):
         self.sock.send.side_effect = BlockingIOError
 
         transport = _SelectorSocketTransport(
             self.loop, self.sock, self.protocol)
-        transport._buffer = collections.deque([b'data1', b'data2'])
+        transport._buffer = list_to_buffer([b'data1', b'data2'])
         self.loop.add_writer(7, transport._write_ready)
         transport._write_ready()
 
         self.loop.assert_writer(7, transport._write_ready)
-        self.assertEqual(collections.deque([b'data1data2']), transport._buffer)
+        self.assertEqual(list_to_buffer([b'data1data2']), transport._buffer)
 
     def test_write_ready_exception(self):
         err = self.sock.send.side_effect = OSError()
@@ -944,7 +990,7 @@
         transport = _SelectorSocketTransport(
             self.loop, self.sock, self.protocol)
         transport._fatal_error = unittest.mock.Mock()
-        transport._buffer.append(b'data')
+        transport._buffer.extend(b'data')
         transport._write_ready()
         transport._fatal_error.assert_called_with(err)
 
@@ -956,7 +1002,7 @@
         transport = _SelectorSocketTransport(
             self.loop, self.sock, self.protocol)
         transport.close()
-        transport._buffer.append(b'data')
+        transport._buffer.extend(b'data')
         transport._write_ready()
         remove_writer.assert_called_with(self.sock_fd)
 
@@ -976,12 +1022,12 @@
         self.sock.send.side_effect = BlockingIOError
         tr.write(b'data')
         tr.write_eof()
-        self.assertEqual(tr._buffer, collections.deque([b'data']))
+        self.assertEqual(tr._buffer, list_to_buffer([b'data']))
         self.assertTrue(tr._eof)
         self.assertFalse(self.sock.shutdown.called)
         self.sock.send.side_effect = lambda _: 4
         tr._write_ready()
-        self.sock.send.assert_called_with(b'data')
+        self.assertTrue(self.sock.send.called)
         self.sock.shutdown.assert_called_with(socket.SHUT_WR)
         tr.close()
 
@@ -1065,15 +1111,34 @@
         self.assertFalse(tr._paused)
         self.loop.assert_reader(1, tr._read_ready)
 
+    def test_write(self):
+        transport = self._make_one()
+        transport.write(b'data')
+        self.assertEqual(list_to_buffer([b'data']), transport._buffer)
+
+    def test_write_bytearray(self):
+        transport = self._make_one()
+        data = bytearray(b'data')
+        transport.write(data)
+        self.assertEqual(list_to_buffer([b'data']), transport._buffer)
+        self.assertEqual(data, bytearray(b'data'))  # Hasn't been mutated.
+        self.assertIsNot(data, transport._buffer)  # Hasn't been incorporated.
+
+    def test_write_memoryview(self):
+        transport = self._make_one()
+        data = memoryview(b'data')
+        transport.write(data)
+        self.assertEqual(list_to_buffer([b'data']), transport._buffer)
+
     def test_write_no_data(self):
         transport = self._make_one()
-        transport._buffer.append(b'data')
+        transport._buffer.extend(b'data')
         transport.write(b'')
-        self.assertEqual(collections.deque([b'data']), transport._buffer)
+        self.assertEqual(list_to_buffer([b'data']), transport._buffer)
 
     def test_write_str(self):
         transport = self._make_one()
-        self.assertRaises(AssertionError, transport.write, 'str')
+        self.assertRaises(TypeError, transport.write, 'str')
 
     def test_write_closing(self):
         transport = self._make_one()
@@ -1087,7 +1152,7 @@
         transport = self._make_one()
         transport._conn_lost = 1
         transport.write(b'data')
-        self.assertEqual(transport._buffer, collections.deque())
+        self.assertEqual(transport._buffer, list_to_buffer())
         transport.write(b'data')
         transport.write(b'data')
         transport.write(b'data')
@@ -1107,7 +1172,7 @@
         transport = self._make_one()
         transport._write_wants_read = True
         transport._write_ready = unittest.mock.Mock()
-        transport._buffer.append(b'data')
+        transport._buffer.extend(b'data')
         transport._read_ready()
 
         self.assertFalse(transport._write_wants_read)
@@ -1168,31 +1233,31 @@
     def test_write_ready_send(self):
         self.sslsock.send.return_value = 4
         transport = self._make_one()
-        transport._buffer = collections.deque([b'data'])
+        transport._buffer = list_to_buffer([b'data'])
         transport._write_ready()
-        self.assertEqual(collections.deque(), transport._buffer)
+        self.assertEqual(list_to_buffer(), transport._buffer)
         self.assertTrue(self.sslsock.send.called)
 
     def test_write_ready_send_none(self):
         self.sslsock.send.return_value = 0
         transport = self._make_one()
-        transport._buffer = collections.deque([b'data1', b'data2'])
+        transport._buffer = list_to_buffer([b'data1', b'data2'])
         transport._write_ready()
         self.assertTrue(self.sslsock.send.called)
-        self.assertEqual(collections.deque([b'data1data2']), transport._buffer)
+        self.assertEqual(list_to_buffer([b'data1data2']), transport._buffer)
 
     def test_write_ready_send_partial(self):
         self.sslsock.send.return_value = 2
         transport = self._make_one()
-        transport._buffer = collections.deque([b'data1', b'data2'])
+        transport._buffer = list_to_buffer([b'data1', b'data2'])
         transport._write_ready()
         self.assertTrue(self.sslsock.send.called)
-        self.assertEqual(collections.deque([b'ta1data2']), transport._buffer)
+        self.assertEqual(list_to_buffer([b'ta1data2']), transport._buffer)
 
     def test_write_ready_send_closing_partial(self):
         self.sslsock.send.return_value = 2
         transport = self._make_one()
-        transport._buffer = collections.deque([b'data1', b'data2'])
+        transport._buffer = list_to_buffer([b'data1', b'data2'])
         transport._write_ready()
         self.assertTrue(self.sslsock.send.called)
         self.assertFalse(self.sslsock.close.called)
@@ -1201,7 +1266,7 @@
         self.sslsock.send.return_value = 4
         transport = self._make_one()
         transport.close()
-        transport._buffer = collections.deque([b'data'])
+        transport._buffer = list_to_buffer([b'data'])
         transport._write_ready()
         self.assertFalse(self.loop.writers)
         self.protocol.connection_lost.assert_called_with(None)
@@ -1210,26 +1275,26 @@
         self.sslsock.send.return_value = 4
         transport = self._make_one()
         transport.close()
-        transport._buffer = collections.deque()
+        transport._buffer = list_to_buffer()
         transport._write_ready()
         self.assertFalse(self.loop.writers)
         self.protocol.connection_lost.assert_called_with(None)
 
     def test_write_ready_send_retry(self):
         transport = self._make_one()
-        transport._buffer = collections.deque([b'data'])
+        transport._buffer = list_to_buffer([b'data'])
 
         self.sslsock.send.side_effect = ssl.SSLWantWriteError
         transport._write_ready()
-        self.assertEqual(collections.deque([b'data']), transport._buffer)
+        self.assertEqual(list_to_buffer([b'data']), transport._buffer)
 
         self.sslsock.send.side_effect = BlockingIOError()
         transport._write_ready()
-        self.assertEqual(collections.deque([b'data']), transport._buffer)
+        self.assertEqual(list_to_buffer([b'data']), transport._buffer)
 
     def test_write_ready_send_read(self):
         transport = self._make_one()
-        transport._buffer = collections.deque([b'data'])
+        transport._buffer = list_to_buffer([b'data'])
 
         self.loop.remove_writer = unittest.mock.Mock()
         self.sslsock.send.side_effect = ssl.SSLWantReadError
@@ -1242,11 +1307,11 @@
         err = self.sslsock.send.side_effect = OSError()
 
         transport = self._make_one()
-        transport._buffer = collections.deque([b'data'])
+        transport._buffer = list_to_buffer([b'data'])
         transport._fatal_error = unittest.mock.Mock()
         transport._write_ready()
         transport._fatal_error.assert_called_with(err)
-        self.assertEqual(collections.deque(), transport._buffer)
+        self.assertEqual(list_to_buffer(), transport._buffer)
 
     def test_write_ready_read_wants_write(self):
         self.loop.add_reader = unittest.mock.Mock()
@@ -1355,6 +1420,24 @@
         self.assertEqual(
             self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234)))
 
+    def test_sendto_bytearray(self):
+        data = bytearray(b'data')
+        transport = _SelectorDatagramTransport(
+            self.loop, self.sock, self.protocol)
+        transport.sendto(data, ('0.0.0.0', 1234))
+        self.assertTrue(self.sock.sendto.called)
+        self.assertEqual(
+            self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234)))
+
+    def test_sendto_memoryview(self):
+        data = memoryview(b'data')
+        transport = _SelectorDatagramTransport(
+            self.loop, self.sock, self.protocol)
+        transport.sendto(data, ('0.0.0.0', 1234))
+        self.assertTrue(self.sock.sendto.called)
+        self.assertEqual(
+            self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234)))
+
     def test_sendto_no_data(self):
         transport = _SelectorDatagramTransport(
             self.loop, self.sock, self.protocol)
@@ -1375,6 +1458,32 @@
              (b'data2', ('0.0.0.0', 12345))],
             list(transport._buffer))
 
+    def test_sendto_buffer_bytearray(self):
+        data2 = bytearray(b'data2')
+        transport = _SelectorDatagramTransport(
+            self.loop, self.sock, self.protocol)
+        transport._buffer.append((b'data1', ('0.0.0.0', 12345)))
+        transport.sendto(data2, ('0.0.0.0', 12345))
+        self.assertFalse(self.sock.sendto.called)
+        self.assertEqual(
+            [(b'data1', ('0.0.0.0', 12345)),
+             (b'data2', ('0.0.0.0', 12345))],
+            list(transport._buffer))
+        self.assertIsInstance(transport._buffer[1][0], bytes)
+
+    def test_sendto_buffer_memoryview(self):
+        data2 = memoryview(b'data2')
+        transport = _SelectorDatagramTransport(
+            self.loop, self.sock, self.protocol)
+        transport._buffer.append((b'data1', ('0.0.0.0', 12345)))
+        transport.sendto(data2, ('0.0.0.0', 12345))
+        self.assertFalse(self.sock.sendto.called)
+        self.assertEqual(
+            [(b'data1', ('0.0.0.0', 12345)),
+             (b'data2', ('0.0.0.0', 12345))],
+            list(transport._buffer))
+        self.assertIsInstance(transport._buffer[1][0], bytes)
+
     def test_sendto_tryagain(self):
         data = b'data'
 
@@ -1439,13 +1548,13 @@
     def test_sendto_str(self):
         transport = _SelectorDatagramTransport(
             self.loop, self.sock, self.protocol)
-        self.assertRaises(AssertionError, transport.sendto, 'str', ())
+        self.assertRaises(TypeError, transport.sendto, 'str', ())
 
     def test_sendto_connected_addr(self):
         transport = _SelectorDatagramTransport(
             self.loop, self.sock, self.protocol, ('0.0.0.0', 1))
         self.assertRaises(
-            AssertionError, transport.sendto, b'str', ('0.0.0.0', 2))
+            ValueError, transport.sendto, b'str', ('0.0.0.0', 2))
 
     def test_sendto_closing(self):
         transport = _SelectorDatagramTransport(

-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list