[Python-checkins] Extract sendfile tests into a separate test file (#9757)

Andrew Svetlov webhook-mailer at python.org
Tue Oct 9 00:53:01 EDT 2018


https://github.com/python/cpython/commit/2b2758d0b30f4ed7d37319d6c18552eccbc8e7b7
commit: 2b2758d0b30f4ed7d37319d6c18552eccbc8e7b7
branch: master
author: Andrew Svetlov <andrew.svetlov at gmail.com>
committer: GitHub <noreply at github.com>
date: 2018-10-09T07:52:57+03:00
summary:

Extract sendfile tests into a separate test file (#9757)

files:
A Lib/test/test_asyncio/test_sendfile.py
M Lib/test/test_asyncio/test_events.py

diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py
index 607c1955ac58..b76cfb75cce2 100644
--- a/Lib/test/test_asyncio/test_events.py
+++ b/Lib/test/test_asyncio/test_events.py
@@ -15,7 +15,6 @@
     ssl = None
 import subprocess
 import sys
-import tempfile
 import threading
 import time
 import errno
@@ -1987,461 +1986,15 @@ def test_subprocess_shell_invalid_args(self):
             self.loop.run_until_complete(connect(shell=False))
 
 
-class SendfileBase:
-
-    DATA = b"SendfileBaseData" * (1024 * 8)  # 128 KiB
-
-    # Reduce socket buffer size to test on relative small data sets.
-    BUF_SIZE = 4 * 1024   # 4 KiB
-
-    @classmethod
-    def setUpClass(cls):
-        with open(support.TESTFN, 'wb') as fp:
-            fp.write(cls.DATA)
-        super().setUpClass()
-
-    @classmethod
-    def tearDownClass(cls):
-        support.unlink(support.TESTFN)
-        super().tearDownClass()
-
-    def setUp(self):
-        self.file = open(support.TESTFN, 'rb')
-        self.addCleanup(self.file.close)
-        super().setUp()
-
-    def run_loop(self, coro):
-        return self.loop.run_until_complete(coro)
-
-
-class SockSendfileMixin(SendfileBase):
-
-    class MyProto(asyncio.Protocol):
-
-        def __init__(self, loop):
-            self.started = False
-            self.closed = False
-            self.data = bytearray()
-            self.fut = loop.create_future()
-            self.transport = None
-
-        def connection_made(self, transport):
-            self.started = True
-            self.transport = transport
-
-        def data_received(self, data):
-            self.data.extend(data)
-
-        def connection_lost(self, exc):
-            self.closed = True
-            self.fut.set_result(None)
-
-        async def wait_closed(self):
-            await self.fut
-
-    @classmethod
-    def setUpClass(cls):
-        cls.__old_bufsize = constants.SENDFILE_FALLBACK_READBUFFER_SIZE
-        constants.SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 16
-        super().setUpClass()
-
-    @classmethod
-    def tearDownClass(cls):
-        constants.SENDFILE_FALLBACK_READBUFFER_SIZE = cls.__old_bufsize
-        super().tearDownClass()
-
-    def make_socket(self, cleanup=True):
-        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-        sock.setblocking(False)
-        if cleanup:
-            self.addCleanup(sock.close)
-        return sock
-
-    def reduce_receive_buffer_size(self, sock):
-        # Reduce receive socket buffer size to test on relative
-        # small data sets.
-        sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self.BUF_SIZE)
-
-    def reduce_send_buffer_size(self, sock, transport=None):
-        # Reduce send socket buffer size to test on relative small data sets.
-
-        # On macOS, SO_SNDBUF is reset by connect(). So this method
-        # should be called after the socket is connected.
-        sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self.BUF_SIZE)
-
-        if transport is not None:
-            transport.set_write_buffer_limits(high=self.BUF_SIZE)
-
-    def prepare_socksendfile(self):
-        proto = self.MyProto(self.loop)
-        port = support.find_unused_port()
-        srv_sock = self.make_socket(cleanup=False)
-        srv_sock.bind((support.HOST, port))
-        server = self.run_loop(self.loop.create_server(
-            lambda: proto, sock=srv_sock))
-        self.reduce_receive_buffer_size(srv_sock)
-
-        sock = self.make_socket()
-        self.run_loop(self.loop.sock_connect(sock, ('127.0.0.1', port)))
-        self.reduce_send_buffer_size(sock)
-
-        def cleanup():
-            if proto.transport is not None:
-                # can be None if the task was cancelled before
-                # connection_made callback
-                proto.transport.close()
-                self.run_loop(proto.wait_closed())
-
-            server.close()
-            self.run_loop(server.wait_closed())
-
-        self.addCleanup(cleanup)
-
-        return sock, proto
-
-    def test_sock_sendfile_success(self):
-        sock, proto = self.prepare_socksendfile()
-        ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
-        sock.close()
-        self.run_loop(proto.wait_closed())
-
-        self.assertEqual(ret, len(self.DATA))
-        self.assertEqual(proto.data, self.DATA)
-        self.assertEqual(self.file.tell(), len(self.DATA))
-
-    def test_sock_sendfile_with_offset_and_count(self):
-        sock, proto = self.prepare_socksendfile()
-        ret = self.run_loop(self.loop.sock_sendfile(sock, self.file,
-                                                    1000, 2000))
-        sock.close()
-        self.run_loop(proto.wait_closed())
-
-        self.assertEqual(proto.data, self.DATA[1000:3000])
-        self.assertEqual(self.file.tell(), 3000)
-        self.assertEqual(ret, 2000)
-
-    def test_sock_sendfile_zero_size(self):
-        sock, proto = self.prepare_socksendfile()
-        with tempfile.TemporaryFile() as f:
-            ret = self.run_loop(self.loop.sock_sendfile(sock, f,
-                                                        0, None))
-        sock.close()
-        self.run_loop(proto.wait_closed())
-
-        self.assertEqual(ret, 0)
-        self.assertEqual(self.file.tell(), 0)
-
-    def test_sock_sendfile_mix_with_regular_send(self):
-        buf = b"mix_regular_send" * (4 * 1024)  # 64 KiB
-        sock, proto = self.prepare_socksendfile()
-        self.run_loop(self.loop.sock_sendall(sock, buf))
-        ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
-        self.run_loop(self.loop.sock_sendall(sock, buf))
-        sock.close()
-        self.run_loop(proto.wait_closed())
-
-        self.assertEqual(ret, len(self.DATA))
-        expected = buf + self.DATA + buf
-        self.assertEqual(proto.data, expected)
-        self.assertEqual(self.file.tell(), len(self.DATA))
-
-
-class SendfileMixin(SendfileBase):
-
-    class MySendfileProto(MyBaseProto):
-
-        def __init__(self, loop=None, close_after=0):
-            super().__init__(loop)
-            self.data = bytearray()
-            self.close_after = close_after
-
-        def data_received(self, data):
-            self.data.extend(data)
-            super().data_received(data)
-            if self.close_after and self.nbytes >= self.close_after:
-                self.transport.close()
-
-
-    # Note: sendfile via SSL transport is equal to sendfile fallback
-
-    def prepare_sendfile(self, *, is_ssl=False, close_after=0):
-        port = support.find_unused_port()
-        srv_proto = self.MySendfileProto(loop=self.loop,
-                                         close_after=close_after)
-        if is_ssl:
-            if not ssl:
-                self.skipTest("No ssl module")
-            srv_ctx = test_utils.simple_server_sslcontext()
-            cli_ctx = test_utils.simple_client_sslcontext()
-        else:
-            srv_ctx = None
-            cli_ctx = None
-        srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-        srv_sock.bind((support.HOST, port))
-        server = self.run_loop(self.loop.create_server(
-            lambda: srv_proto, sock=srv_sock, ssl=srv_ctx))
-        self.reduce_receive_buffer_size(srv_sock)
-
-        if is_ssl:
-            server_hostname = support.HOST
-        else:
-            server_hostname = None
-        cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-        cli_sock.connect((support.HOST, port))
-
-        cli_proto = self.MySendfileProto(loop=self.loop)
-        tr, pr = self.run_loop(self.loop.create_connection(
-            lambda: cli_proto, sock=cli_sock,
-            ssl=cli_ctx, server_hostname=server_hostname))
-        self.reduce_send_buffer_size(cli_sock, transport=tr)
-
-        def cleanup():
-            srv_proto.transport.close()
-            cli_proto.transport.close()
-            self.run_loop(srv_proto.done)
-            self.run_loop(cli_proto.done)
-
-            server.close()
-            self.run_loop(server.wait_closed())
-
-        self.addCleanup(cleanup)
-        return srv_proto, cli_proto
-
-    @unittest.skipIf(sys.platform == 'win32', "UDP sockets are not supported")
-    def test_sendfile_not_supported(self):
-        tr, pr = self.run_loop(
-            self.loop.create_datagram_endpoint(
-                lambda: MyDatagramProto(loop=self.loop),
-                family=socket.AF_INET))
-        try:
-            with self.assertRaisesRegex(RuntimeError, "not supported"):
-                self.run_loop(
-                    self.loop.sendfile(tr, self.file))
-            self.assertEqual(0, self.file.tell())
-        finally:
-            # don't use self.addCleanup because it produces resource warning
-            tr.close()
-
-    def test_sendfile(self):
-        srv_proto, cli_proto = self.prepare_sendfile()
-        ret = self.run_loop(
-            self.loop.sendfile(cli_proto.transport, self.file))
-        cli_proto.transport.close()
-        self.run_loop(srv_proto.done)
-        self.assertEqual(ret, len(self.DATA))
-        self.assertEqual(srv_proto.nbytes, len(self.DATA))
-        self.assertEqual(srv_proto.data, self.DATA)
-        self.assertEqual(self.file.tell(), len(self.DATA))
-
-    def test_sendfile_force_fallback(self):
-        srv_proto, cli_proto = self.prepare_sendfile()
-
-        def sendfile_native(transp, file, offset, count):
-            # to raise SendfileNotAvailableError
-            return base_events.BaseEventLoop._sendfile_native(
-                self.loop, transp, file, offset, count)
-
-        self.loop._sendfile_native = sendfile_native
-
-        ret = self.run_loop(
-            self.loop.sendfile(cli_proto.transport, self.file))
-        cli_proto.transport.close()
-        self.run_loop(srv_proto.done)
-        self.assertEqual(ret, len(self.DATA))
-        self.assertEqual(srv_proto.nbytes, len(self.DATA))
-        self.assertEqual(srv_proto.data, self.DATA)
-        self.assertEqual(self.file.tell(), len(self.DATA))
-
-    def test_sendfile_force_unsupported_native(self):
-        if sys.platform == 'win32':
-            if isinstance(self.loop, asyncio.ProactorEventLoop):
-                self.skipTest("Fails on proactor event loop")
-        srv_proto, cli_proto = self.prepare_sendfile()
-
-        def sendfile_native(transp, file, offset, count):
-            # to raise SendfileNotAvailableError
-            return base_events.BaseEventLoop._sendfile_native(
-                self.loop, transp, file, offset, count)
-
-        self.loop._sendfile_native = sendfile_native
-
-        with self.assertRaisesRegex(asyncio.SendfileNotAvailableError,
-                                    "not supported"):
-            self.run_loop(
-                self.loop.sendfile(cli_proto.transport, self.file,
-                                   fallback=False))
-
-        cli_proto.transport.close()
-        self.run_loop(srv_proto.done)
-        self.assertEqual(srv_proto.nbytes, 0)
-        self.assertEqual(self.file.tell(), 0)
-
-    def test_sendfile_ssl(self):
-        srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
-        ret = self.run_loop(
-            self.loop.sendfile(cli_proto.transport, self.file))
-        cli_proto.transport.close()
-        self.run_loop(srv_proto.done)
-        self.assertEqual(ret, len(self.DATA))
-        self.assertEqual(srv_proto.nbytes, len(self.DATA))
-        self.assertEqual(srv_proto.data, self.DATA)
-        self.assertEqual(self.file.tell(), len(self.DATA))
-
-    def test_sendfile_for_closing_transp(self):
-        srv_proto, cli_proto = self.prepare_sendfile()
-        cli_proto.transport.close()
-        with self.assertRaisesRegex(RuntimeError, "is closing"):
-            self.run_loop(self.loop.sendfile(cli_proto.transport, self.file))
-        self.run_loop(srv_proto.done)
-        self.assertEqual(srv_proto.nbytes, 0)
-        self.assertEqual(self.file.tell(), 0)
-
-    def test_sendfile_pre_and_post_data(self):
-        srv_proto, cli_proto = self.prepare_sendfile()
-        PREFIX = b'PREFIX__' * 1024  # 8 KiB
-        SUFFIX = b'--SUFFIX' * 1024  # 8 KiB
-        cli_proto.transport.write(PREFIX)
-        ret = self.run_loop(
-            self.loop.sendfile(cli_proto.transport, self.file))
-        cli_proto.transport.write(SUFFIX)
-        cli_proto.transport.close()
-        self.run_loop(srv_proto.done)
-        self.assertEqual(ret, len(self.DATA))
-        self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
-        self.assertEqual(self.file.tell(), len(self.DATA))
-
-    def test_sendfile_ssl_pre_and_post_data(self):
-        srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
-        PREFIX = b'zxcvbnm' * 1024
-        SUFFIX = b'0987654321' * 1024
-        cli_proto.transport.write(PREFIX)
-        ret = self.run_loop(
-            self.loop.sendfile(cli_proto.transport, self.file))
-        cli_proto.transport.write(SUFFIX)
-        cli_proto.transport.close()
-        self.run_loop(srv_proto.done)
-        self.assertEqual(ret, len(self.DATA))
-        self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
-        self.assertEqual(self.file.tell(), len(self.DATA))
-
-    def test_sendfile_partial(self):
-        srv_proto, cli_proto = self.prepare_sendfile()
-        ret = self.run_loop(
-            self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
-        cli_proto.transport.close()
-        self.run_loop(srv_proto.done)
-        self.assertEqual(ret, 100)
-        self.assertEqual(srv_proto.nbytes, 100)
-        self.assertEqual(srv_proto.data, self.DATA[1000:1100])
-        self.assertEqual(self.file.tell(), 1100)
-
-    def test_sendfile_ssl_partial(self):
-        srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
-        ret = self.run_loop(
-            self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
-        cli_proto.transport.close()
-        self.run_loop(srv_proto.done)
-        self.assertEqual(ret, 100)
-        self.assertEqual(srv_proto.nbytes, 100)
-        self.assertEqual(srv_proto.data, self.DATA[1000:1100])
-        self.assertEqual(self.file.tell(), 1100)
-
-    def test_sendfile_close_peer_after_receiving(self):
-        srv_proto, cli_proto = self.prepare_sendfile(
-            close_after=len(self.DATA))
-        ret = self.run_loop(
-            self.loop.sendfile(cli_proto.transport, self.file))
-        cli_proto.transport.close()
-        self.run_loop(srv_proto.done)
-        self.assertEqual(ret, len(self.DATA))
-        self.assertEqual(srv_proto.nbytes, len(self.DATA))
-        self.assertEqual(srv_proto.data, self.DATA)
-        self.assertEqual(self.file.tell(), len(self.DATA))
-
-    def test_sendfile_ssl_close_peer_after_receiving(self):
-        srv_proto, cli_proto = self.prepare_sendfile(
-            is_ssl=True, close_after=len(self.DATA))
-        ret = self.run_loop(
-            self.loop.sendfile(cli_proto.transport, self.file))
-        self.run_loop(srv_proto.done)
-        self.assertEqual(ret, len(self.DATA))
-        self.assertEqual(srv_proto.nbytes, len(self.DATA))
-        self.assertEqual(srv_proto.data, self.DATA)
-        self.assertEqual(self.file.tell(), len(self.DATA))
-
-    def test_sendfile_close_peer_in_the_middle_of_receiving(self):
-        srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
-        with self.assertRaises(ConnectionError):
-            self.run_loop(
-                self.loop.sendfile(cli_proto.transport, self.file))
-        self.run_loop(srv_proto.done)
-
-        self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
-                        srv_proto.nbytes)
-        self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
-                        self.file.tell())
-        self.assertTrue(cli_proto.transport.is_closing())
-
-    def test_sendfile_fallback_close_peer_in_the_middle_of_receiving(self):
-
-        def sendfile_native(transp, file, offset, count):
-            # to raise SendfileNotAvailableError
-            return base_events.BaseEventLoop._sendfile_native(
-                self.loop, transp, file, offset, count)
-
-        self.loop._sendfile_native = sendfile_native
-
-        srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
-        with self.assertRaises(ConnectionError):
-            self.run_loop(
-                self.loop.sendfile(cli_proto.transport, self.file))
-        self.run_loop(srv_proto.done)
-
-        self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
-                        srv_proto.nbytes)
-        self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
-                        self.file.tell())
-
-    @unittest.skipIf(not hasattr(os, 'sendfile'),
-                     "Don't have native sendfile support")
-    def test_sendfile_prevents_bare_write(self):
-        srv_proto, cli_proto = self.prepare_sendfile()
-        fut = self.loop.create_future()
-
-        async def coro():
-            fut.set_result(None)
-            return await self.loop.sendfile(cli_proto.transport, self.file)
-
-        t = self.loop.create_task(coro())
-        self.run_loop(fut)
-        with self.assertRaisesRegex(RuntimeError,
-                                    "sendfile is in progress"):
-            cli_proto.transport.write(b'data')
-        ret = self.run_loop(t)
-        self.assertEqual(ret, len(self.DATA))
-
-    def test_sendfile_no_fallback_for_fallback_transport(self):
-        transport = mock.Mock()
-        transport.is_closing.side_effect = lambda: False
-        transport._sendfile_compatible = constants._SendfileMode.FALLBACK
-        with self.assertRaisesRegex(RuntimeError, 'fallback is disabled'):
-            self.loop.run_until_complete(
-                self.loop.sendfile(transport, None, fallback=False))
-
-
 if sys.platform == 'win32':
 
     class SelectEventLoopTests(EventLoopTestsMixin,
-                               SendfileMixin,
-                               SockSendfileMixin,
                                test_utils.TestCase):
 
         def create_event_loop(self):
             return asyncio.SelectorEventLoop()
 
     class ProactorEventLoopTests(EventLoopTestsMixin,
-                                 SendfileMixin,
-                                 SockSendfileMixin,
                                  SubprocessTestsMixin,
                                  test_utils.TestCase):
 
@@ -2469,9 +2022,7 @@ def test_remove_fds_after_closing(self):
 else:
     import selectors
 
-    class UnixEventLoopTestsMixin(EventLoopTestsMixin,
-                                  SendfileMixin,
-                                  SockSendfileMixin):
+    class UnixEventLoopTestsMixin(EventLoopTestsMixin):
         def setUp(self):
             super().setUp()
             watcher = asyncio.SafeChildWatcher()
diff --git a/Lib/test/test_asyncio/test_sendfile.py b/Lib/test/test_asyncio/test_sendfile.py
new file mode 100644
index 000000000000..26e44a3348a5
--- /dev/null
+++ b/Lib/test/test_asyncio/test_sendfile.py
@@ -0,0 +1,550 @@
+"""Tests for sendfile functionality."""
+
+import asyncio
+import os
+import socket
+import sys
+import tempfile
+import unittest
+from asyncio import base_events
+from asyncio import constants
+from unittest import mock
+from test import support
+from test.test_asyncio import utils as test_utils
+
+try:
+    import ssl
+except ImportError:
+    ssl = None
+
+
+class MySendfileProto(asyncio.Protocol):
+
+    def __init__(self, loop=None, close_after=0):
+        self.transport = None
+        self.state = 'INITIAL'
+        self.nbytes = 0
+        if loop is not None:
+            self.connected = loop.create_future()
+            self.done = loop.create_future()
+        self.data = bytearray()
+        self.close_after = close_after
+
+    def connection_made(self, transport):
+        self.transport = transport
+        assert self.state == 'INITIAL', self.state
+        self.state = 'CONNECTED'
+        if self.connected:
+            self.connected.set_result(None)
+
+    def eof_received(self):
+        assert self.state == 'CONNECTED', self.state
+        self.state = 'EOF'
+
+    def connection_lost(self, exc):
+        assert self.state in ('CONNECTED', 'EOF'), self.state
+        self.state = 'CLOSED'
+        if self.done:
+            self.done.set_result(None)
+
+    def data_received(self, data):
+        assert self.state == 'CONNECTED', self.state
+        self.nbytes += len(data)
+        self.data.extend(data)
+        super().data_received(data)
+        if self.close_after and self.nbytes >= self.close_after:
+            self.transport.close()
+
+
+class MyProto(asyncio.Protocol):
+
+    def __init__(self, loop):
+        self.started = False
+        self.closed = False
+        self.data = bytearray()
+        self.fut = loop.create_future()
+        self.transport = None
+
+    def connection_made(self, transport):
+        self.started = True
+        self.transport = transport
+
+    def data_received(self, data):
+        self.data.extend(data)
+
+    def connection_lost(self, exc):
+        self.closed = True
+        self.fut.set_result(None)
+
+    async def wait_closed(self):
+        await self.fut
+
+
+class SendfileBase:
+
+    DATA = b"SendfileBaseData" * (1024 * 8)  # 128 KiB
+
+    # Reduce socket buffer size to test on relative small data sets.
+    BUF_SIZE = 4 * 1024   # 4 KiB
+
+    def create_event_loop(self):
+        raise NotImplementedError
+
+    @classmethod
+    def setUpClass(cls):
+        with open(support.TESTFN, 'wb') as fp:
+            fp.write(cls.DATA)
+        super().setUpClass()
+
+    @classmethod
+    def tearDownClass(cls):
+        support.unlink(support.TESTFN)
+        super().tearDownClass()
+
+    def setUp(self):
+        self.file = open(support.TESTFN, 'rb')
+        self.addCleanup(self.file.close)
+        self.loop = self.create_event_loop()
+        self.set_event_loop(self.loop)
+        super().setUp()
+
+    def tearDown(self):
+        # just in case if we have transport close callbacks
+        if not self.loop.is_closed():
+            test_utils.run_briefly(self.loop)
+
+        self.doCleanups()
+        support.gc_collect()
+        super().tearDown()
+
+    def run_loop(self, coro):
+        return self.loop.run_until_complete(coro)
+
+
+class SockSendfileMixin(SendfileBase):
+
+    @classmethod
+    def setUpClass(cls):
+        cls.__old_bufsize = constants.SENDFILE_FALLBACK_READBUFFER_SIZE
+        constants.SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 16
+        super().setUpClass()
+
+    @classmethod
+    def tearDownClass(cls):
+        constants.SENDFILE_FALLBACK_READBUFFER_SIZE = cls.__old_bufsize
+        super().tearDownClass()
+
+    def make_socket(self, cleanup=True):
+        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.setblocking(False)
+        if cleanup:
+            self.addCleanup(sock.close)
+        return sock
+
+    def reduce_receive_buffer_size(self, sock):
+        # Reduce receive socket buffer size to test on relative
+        # small data sets.
+        sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self.BUF_SIZE)
+
+    def reduce_send_buffer_size(self, sock, transport=None):
+        # Reduce send socket buffer size to test on relative small data sets.
+
+        # On macOS, SO_SNDBUF is reset by connect(). So this method
+        # should be called after the socket is connected.
+        sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self.BUF_SIZE)
+
+        if transport is not None:
+            transport.set_write_buffer_limits(high=self.BUF_SIZE)
+
+    def prepare_socksendfile(self):
+        proto = MyProto(self.loop)
+        port = support.find_unused_port()
+        srv_sock = self.make_socket(cleanup=False)
+        srv_sock.bind((support.HOST, port))
+        server = self.run_loop(self.loop.create_server(
+            lambda: proto, sock=srv_sock))
+        self.reduce_receive_buffer_size(srv_sock)
+
+        sock = self.make_socket()
+        self.run_loop(self.loop.sock_connect(sock, ('127.0.0.1', port)))
+        self.reduce_send_buffer_size(sock)
+
+        def cleanup():
+            if proto.transport is not None:
+                # can be None if the task was cancelled before
+                # connection_made callback
+                proto.transport.close()
+                self.run_loop(proto.wait_closed())
+
+            server.close()
+            self.run_loop(server.wait_closed())
+
+        self.addCleanup(cleanup)
+
+        return sock, proto
+
+    def test_sock_sendfile_success(self):
+        sock, proto = self.prepare_socksendfile()
+        ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
+        sock.close()
+        self.run_loop(proto.wait_closed())
+
+        self.assertEqual(ret, len(self.DATA))
+        self.assertEqual(proto.data, self.DATA)
+        self.assertEqual(self.file.tell(), len(self.DATA))
+
+    def test_sock_sendfile_with_offset_and_count(self):
+        sock, proto = self.prepare_socksendfile()
+        ret = self.run_loop(self.loop.sock_sendfile(sock, self.file,
+                                                    1000, 2000))
+        sock.close()
+        self.run_loop(proto.wait_closed())
+
+        self.assertEqual(proto.data, self.DATA[1000:3000])
+        self.assertEqual(self.file.tell(), 3000)
+        self.assertEqual(ret, 2000)
+
+    def test_sock_sendfile_zero_size(self):
+        sock, proto = self.prepare_socksendfile()
+        with tempfile.TemporaryFile() as f:
+            ret = self.run_loop(self.loop.sock_sendfile(sock, f,
+                                                        0, None))
+        sock.close()
+        self.run_loop(proto.wait_closed())
+
+        self.assertEqual(ret, 0)
+        self.assertEqual(self.file.tell(), 0)
+
+    def test_sock_sendfile_mix_with_regular_send(self):
+        buf = b"mix_regular_send" * (4 * 1024)  # 64 KiB
+        sock, proto = self.prepare_socksendfile()
+        self.run_loop(self.loop.sock_sendall(sock, buf))
+        ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
+        self.run_loop(self.loop.sock_sendall(sock, buf))
+        sock.close()
+        self.run_loop(proto.wait_closed())
+
+        self.assertEqual(ret, len(self.DATA))
+        expected = buf + self.DATA + buf
+        self.assertEqual(proto.data, expected)
+        self.assertEqual(self.file.tell(), len(self.DATA))
+
+
+class SendfileMixin(SendfileBase):
+
+    # Note: sendfile via SSL transport is equal to sendfile fallback
+
+    def prepare_sendfile(self, *, is_ssl=False, close_after=0):
+        port = support.find_unused_port()
+        srv_proto = MySendfileProto(loop=self.loop,
+                                    close_after=close_after)
+        if is_ssl:
+            if not ssl:
+                self.skipTest("No ssl module")
+            srv_ctx = test_utils.simple_server_sslcontext()
+            cli_ctx = test_utils.simple_client_sslcontext()
+        else:
+            srv_ctx = None
+            cli_ctx = None
+        srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        srv_sock.bind((support.HOST, port))
+        server = self.run_loop(self.loop.create_server(
+            lambda: srv_proto, sock=srv_sock, ssl=srv_ctx))
+        self.reduce_receive_buffer_size(srv_sock)
+
+        if is_ssl:
+            server_hostname = support.HOST
+        else:
+            server_hostname = None
+        cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        cli_sock.connect((support.HOST, port))
+
+        cli_proto = MySendfileProto(loop=self.loop)
+        tr, pr = self.run_loop(self.loop.create_connection(
+            lambda: cli_proto, sock=cli_sock,
+            ssl=cli_ctx, server_hostname=server_hostname))
+        self.reduce_send_buffer_size(cli_sock, transport=tr)
+
+        def cleanup():
+            srv_proto.transport.close()
+            cli_proto.transport.close()
+            self.run_loop(srv_proto.done)
+            self.run_loop(cli_proto.done)
+
+            server.close()
+            self.run_loop(server.wait_closed())
+
+        self.addCleanup(cleanup)
+        return srv_proto, cli_proto
+
+    @unittest.skipIf(sys.platform == 'win32', "UDP sockets are not supported")
+    def test_sendfile_not_supported(self):
+        tr, pr = self.run_loop(
+            self.loop.create_datagram_endpoint(
+                asyncio.DatagramProtocol,
+                family=socket.AF_INET))
+        try:
+            with self.assertRaisesRegex(RuntimeError, "not supported"):
+                self.run_loop(
+                    self.loop.sendfile(tr, self.file))
+            self.assertEqual(0, self.file.tell())
+        finally:
+            # don't use self.addCleanup because it produces resource warning
+            tr.close()
+
+    def test_sendfile(self):
+        srv_proto, cli_proto = self.prepare_sendfile()
+        ret = self.run_loop(
+            self.loop.sendfile(cli_proto.transport, self.file))
+        cli_proto.transport.close()
+        self.run_loop(srv_proto.done)
+        self.assertEqual(ret, len(self.DATA))
+        self.assertEqual(srv_proto.nbytes, len(self.DATA))
+        self.assertEqual(srv_proto.data, self.DATA)
+        self.assertEqual(self.file.tell(), len(self.DATA))
+
+    def test_sendfile_force_fallback(self):
+        srv_proto, cli_proto = self.prepare_sendfile()
+
+        def sendfile_native(transp, file, offset, count):
+            # to raise SendfileNotAvailableError
+            return base_events.BaseEventLoop._sendfile_native(
+                self.loop, transp, file, offset, count)
+
+        self.loop._sendfile_native = sendfile_native
+
+        ret = self.run_loop(
+            self.loop.sendfile(cli_proto.transport, self.file))
+        cli_proto.transport.close()
+        self.run_loop(srv_proto.done)
+        self.assertEqual(ret, len(self.DATA))
+        self.assertEqual(srv_proto.nbytes, len(self.DATA))
+        self.assertEqual(srv_proto.data, self.DATA)
+        self.assertEqual(self.file.tell(), len(self.DATA))
+
+    def test_sendfile_force_unsupported_native(self):
+        if sys.platform == 'win32':
+            if isinstance(self.loop, asyncio.ProactorEventLoop):
+                self.skipTest("Fails on proactor event loop")
+        srv_proto, cli_proto = self.prepare_sendfile()
+
+        def sendfile_native(transp, file, offset, count):
+            # to raise SendfileNotAvailableError
+            return base_events.BaseEventLoop._sendfile_native(
+                self.loop, transp, file, offset, count)
+
+        self.loop._sendfile_native = sendfile_native
+
+        with self.assertRaisesRegex(asyncio.SendfileNotAvailableError,
+                                    "not supported"):
+            self.run_loop(
+                self.loop.sendfile(cli_proto.transport, self.file,
+                                   fallback=False))
+
+        cli_proto.transport.close()
+        self.run_loop(srv_proto.done)
+        self.assertEqual(srv_proto.nbytes, 0)
+        self.assertEqual(self.file.tell(), 0)
+
+    def test_sendfile_ssl(self):
+        srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
+        ret = self.run_loop(
+            self.loop.sendfile(cli_proto.transport, self.file))
+        cli_proto.transport.close()
+        self.run_loop(srv_proto.done)
+        self.assertEqual(ret, len(self.DATA))
+        self.assertEqual(srv_proto.nbytes, len(self.DATA))
+        self.assertEqual(srv_proto.data, self.DATA)
+        self.assertEqual(self.file.tell(), len(self.DATA))
+
+    def test_sendfile_for_closing_transp(self):
+        srv_proto, cli_proto = self.prepare_sendfile()
+        cli_proto.transport.close()
+        with self.assertRaisesRegex(RuntimeError, "is closing"):
+            self.run_loop(self.loop.sendfile(cli_proto.transport, self.file))
+        self.run_loop(srv_proto.done)
+        self.assertEqual(srv_proto.nbytes, 0)
+        self.assertEqual(self.file.tell(), 0)
+
+    def test_sendfile_pre_and_post_data(self):
+        srv_proto, cli_proto = self.prepare_sendfile()
+        PREFIX = b'PREFIX__' * 1024  # 8 KiB
+        SUFFIX = b'--SUFFIX' * 1024  # 8 KiB
+        cli_proto.transport.write(PREFIX)
+        ret = self.run_loop(
+            self.loop.sendfile(cli_proto.transport, self.file))
+        cli_proto.transport.write(SUFFIX)
+        cli_proto.transport.close()
+        self.run_loop(srv_proto.done)
+        self.assertEqual(ret, len(self.DATA))
+        self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
+        self.assertEqual(self.file.tell(), len(self.DATA))
+
+    def test_sendfile_ssl_pre_and_post_data(self):
+        srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
+        PREFIX = b'zxcvbnm' * 1024
+        SUFFIX = b'0987654321' * 1024
+        cli_proto.transport.write(PREFIX)
+        ret = self.run_loop(
+            self.loop.sendfile(cli_proto.transport, self.file))
+        cli_proto.transport.write(SUFFIX)
+        cli_proto.transport.close()
+        self.run_loop(srv_proto.done)
+        self.assertEqual(ret, len(self.DATA))
+        self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
+        self.assertEqual(self.file.tell(), len(self.DATA))
+
+    def test_sendfile_partial(self):
+        srv_proto, cli_proto = self.prepare_sendfile()
+        ret = self.run_loop(
+            self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
+        cli_proto.transport.close()
+        self.run_loop(srv_proto.done)
+        self.assertEqual(ret, 100)
+        self.assertEqual(srv_proto.nbytes, 100)
+        self.assertEqual(srv_proto.data, self.DATA[1000:1100])
+        self.assertEqual(self.file.tell(), 1100)
+
+    def test_sendfile_ssl_partial(self):
+        srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
+        ret = self.run_loop(
+            self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
+        cli_proto.transport.close()
+        self.run_loop(srv_proto.done)
+        self.assertEqual(ret, 100)
+        self.assertEqual(srv_proto.nbytes, 100)
+        self.assertEqual(srv_proto.data, self.DATA[1000:1100])
+        self.assertEqual(self.file.tell(), 1100)
+
+    def test_sendfile_close_peer_after_receiving(self):
+        srv_proto, cli_proto = self.prepare_sendfile(
+            close_after=len(self.DATA))
+        ret = self.run_loop(
+            self.loop.sendfile(cli_proto.transport, self.file))
+        cli_proto.transport.close()
+        self.run_loop(srv_proto.done)
+        self.assertEqual(ret, len(self.DATA))
+        self.assertEqual(srv_proto.nbytes, len(self.DATA))
+        self.assertEqual(srv_proto.data, self.DATA)
+        self.assertEqual(self.file.tell(), len(self.DATA))
+
+    def test_sendfile_ssl_close_peer_after_receiving(self):
+        srv_proto, cli_proto = self.prepare_sendfile(
+            is_ssl=True, close_after=len(self.DATA))
+        ret = self.run_loop(
+            self.loop.sendfile(cli_proto.transport, self.file))
+        self.run_loop(srv_proto.done)
+        self.assertEqual(ret, len(self.DATA))
+        self.assertEqual(srv_proto.nbytes, len(self.DATA))
+        self.assertEqual(srv_proto.data, self.DATA)
+        self.assertEqual(self.file.tell(), len(self.DATA))
+
+    def test_sendfile_close_peer_in_the_middle_of_receiving(self):
+        srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
+        with self.assertRaises(ConnectionError):
+            self.run_loop(
+                self.loop.sendfile(cli_proto.transport, self.file))
+        self.run_loop(srv_proto.done)
+
+        self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
+                        srv_proto.nbytes)
+        self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
+                        self.file.tell())
+        self.assertTrue(cli_proto.transport.is_closing())
+
+    def test_sendfile_fallback_close_peer_in_the_middle_of_receiving(self):
+
+        def sendfile_native(transp, file, offset, count):
+            # to raise SendfileNotAvailableError
+            return base_events.BaseEventLoop._sendfile_native(
+                self.loop, transp, file, offset, count)
+
+        self.loop._sendfile_native = sendfile_native
+
+        srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
+        with self.assertRaises(ConnectionError):
+            self.run_loop(
+                self.loop.sendfile(cli_proto.transport, self.file))
+        self.run_loop(srv_proto.done)
+
+        self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
+                        srv_proto.nbytes)
+        self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
+                        self.file.tell())
+
+    @unittest.skipIf(not hasattr(os, 'sendfile'),
+                     "Don't have native sendfile support")
+    def test_sendfile_prevents_bare_write(self):
+        srv_proto, cli_proto = self.prepare_sendfile()
+        fut = self.loop.create_future()
+
+        async def coro():
+            fut.set_result(None)
+            return await self.loop.sendfile(cli_proto.transport, self.file)
+
+        t = self.loop.create_task(coro())
+        self.run_loop(fut)
+        with self.assertRaisesRegex(RuntimeError,
+                                    "sendfile is in progress"):
+            cli_proto.transport.write(b'data')
+        ret = self.run_loop(t)
+        self.assertEqual(ret, len(self.DATA))
+
+    def test_sendfile_no_fallback_for_fallback_transport(self):
+        transport = mock.Mock()
+        transport.is_closing.side_effect = lambda: False
+        transport._sendfile_compatible = constants._SendfileMode.FALLBACK
+        with self.assertRaisesRegex(RuntimeError, 'fallback is disabled'):
+            self.loop.run_until_complete(
+                self.loop.sendfile(transport, None, fallback=False))
+
+
+class SendfileTestsBase(SendfileMixin, SockSendfileMixin):
+    pass
+
+
+if sys.platform == 'win32':
+
+    class SelectEventLoopTests(SendfileTestsBase,
+                               test_utils.TestCase):
+
+        def create_event_loop(self):
+            return asyncio.SelectorEventLoop()
+
+    class ProactorEventLoopTests(SendfileTestsBase,
+                                 test_utils.TestCase):
+
+        def create_event_loop(self):
+            return asyncio.ProactorEventLoop()
+
+else:
+    import selectors
+
+    if hasattr(selectors, 'KqueueSelector'):
+        class KqueueEventLoopTests(SendfileTestsBase,
+                                   test_utils.TestCase):
+
+            def create_event_loop(self):
+                return asyncio.SelectorEventLoop(
+                    selectors.KqueueSelector())
+
+    if hasattr(selectors, 'EpollSelector'):
+        class EPollEventLoopTests(SendfileTestsBase,
+                                  test_utils.TestCase):
+
+            def create_event_loop(self):
+                return asyncio.SelectorEventLoop(selectors.EpollSelector())
+
+    if hasattr(selectors, 'PollSelector'):
+        class PollEventLoopTests(SendfileTestsBase,
+                                 test_utils.TestCase):
+
+            def create_event_loop(self):
+                return asyncio.SelectorEventLoop(selectors.PollSelector())
+
+    # Should always exist.
+    class SelectEventLoopTests(SendfileTestsBase,
+                               test_utils.TestCase):
+
+        def create_event_loop(self):
+            return asyncio.SelectorEventLoop(selectors.SelectSelector())



More information about the Python-checkins mailing list