[Python-checkins] bpo-32622: Implement loop.sendfile() (#5271)

Andrew Svetlov webhook-mailer at python.org
Sat Jan 27 14:22:50 EST 2018


https://github.com/python/cpython/commit/7c684073f951dd891021676ecfd86ffc18b8895e
commit: 7c684073f951dd891021676ecfd86ffc18b8895e
branch: master
author: Andrew Svetlov <andrew.svetlov at gmail.com>
committer: GitHub <noreply at github.com>
date: 2018-01-27T21:22:47+02:00
summary:

bpo-32622: Implement loop.sendfile() (#5271)

files:
A Misc/NEWS.d/next/Library/2018-01-22-18-18-44.bpo-32622.A1D6FP.rst
M Doc/library/asyncio-eventloop.rst
M Lib/asyncio/base_events.py
M Lib/asyncio/constants.py
M Lib/asyncio/events.py
M Lib/asyncio/proactor_events.py
M Lib/asyncio/selector_events.py
M Lib/asyncio/sslproto.py
M Lib/asyncio/windows_events.py
M Lib/test/test_asyncio/test_base_events.py
M Lib/test/test_asyncio/test_events.py
M Modules/overlapped.c

diff --git a/Doc/library/asyncio-eventloop.rst b/Doc/library/asyncio-eventloop.rst
index 834a4e85c2f..fe162236e01 100644
--- a/Doc/library/asyncio-eventloop.rst
+++ b/Doc/library/asyncio-eventloop.rst
@@ -543,6 +543,37 @@ Creating listening connections
    .. versionadded:: 3.5.3
 
 
+File Transferring
+-----------------
+
+.. coroutinemethod:: AbstractEventLoop.sendfile(sock, transport, \
+                                                offset=0, count=None, \
+                                                *, fallback=True)
+
+   Send a *file* to *transport*, return the total number of bytes
+   which were sent.
+
+   The method uses high-performance :meth:`os.sendfile` if available.
+
+   *file* must be a regular file object opened in binary mode.
+
+   *offset* tells from where to start reading the file. If specified,
+   *count* is the total number of bytes to transmit as opposed to
+   sending the file until EOF is reached. File position is updated on
+   return or also in case of error in which case :meth:`file.tell()
+   <io.IOBase.tell>` can be used to figure out the number of bytes
+   which were sent.
+
+   *fallback* set to ``True`` makes asyncio to manually read and send
+   the file when the platform does not support the sendfile syscall
+   (e.g. Windows or SSL socket on Unix).
+
+   Raise :exc:`SendfileNotAvailableError` if the system does not support
+   *sendfile* syscall and *fallback* is ``False``.
+
+   .. versionadded:: 3.7
+
+
 TLS Upgrade
 -----------
 
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py
index 94eb3089e93..f532dc42132 100644
--- a/Lib/asyncio/base_events.py
+++ b/Lib/asyncio/base_events.py
@@ -38,8 +38,10 @@
 from . import coroutines
 from . import events
 from . import futures
+from . import protocols
 from . import sslproto
 from . import tasks
+from . import transports
 from .log import logger
 
 
@@ -155,6 +157,75 @@ def _run_until_complete_cb(fut):
     futures._get_loop(fut).stop()
 
 
+
+class _SendfileFallbackProtocol(protocols.Protocol):
+    def __init__(self, transp):
+        if not isinstance(transp, transports._FlowControlMixin):
+            raise TypeError("transport should be _FlowControlMixin instance")
+        self._transport = transp
+        self._proto = transp.get_protocol()
+        self._should_resume_reading = transp.is_reading()
+        self._should_resume_writing = transp._protocol_paused
+        transp.pause_reading()
+        transp.set_protocol(self)
+        if self._should_resume_writing:
+            self._write_ready_fut = self._transport._loop.create_future()
+        else:
+            self._write_ready_fut = None
+
+    async def drain(self):
+        if self._transport.is_closing():
+            raise ConnectionError("Connection closed by peer")
+        fut = self._write_ready_fut
+        if fut is None:
+            return
+        await fut
+
+    def connection_made(self, transport):
+        raise RuntimeError("Invalid state: "
+                           "connection should have been established already.")
+
+    def connection_lost(self, exc):
+        if self._write_ready_fut is not None:
+            # Never happens if peer disconnects after sending the whole content
+            # Thus disconnection is always an exception from user perspective
+            if exc is None:
+                self._write_ready_fut.set_exception(
+                    ConnectionError("Connection is closed by peer"))
+            else:
+                self._write_ready_fut.set_exception(exc)
+        self._proto.connection_lost(exc)
+
+    def pause_writing(self):
+        if self._write_ready_fut is not None:
+            return
+        self._write_ready_fut = self._transport._loop.create_future()
+
+    def resume_writing(self):
+        if self._write_ready_fut is None:
+            return
+        self._write_ready_fut.set_result(False)
+        self._write_ready_fut = None
+
+    def data_received(self, data):
+        raise RuntimeError("Invalid state: reading should be paused")
+
+    def eof_received(self):
+        raise RuntimeError("Invalid state: reading should be paused")
+
+    async def restore(self):
+        self._transport.set_protocol(self._proto)
+        if self._should_resume_reading:
+            self._transport.resume_reading()
+        if self._write_ready_fut is not None:
+            # Cancel the future.
+            # Basically it has no effect because protocol is switched back,
+            # no code should wait for it anymore.
+            self._write_ready_fut.cancel()
+        if self._should_resume_writing:
+            self._proto.resume_writing()
+
+
 class Server(events.AbstractServer):
 
     def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog,
@@ -926,6 +997,77 @@ def _check_sendfile_params(self, sock, file, offset, count):
 
         return transport, protocol
 
+    async def sendfile(self, transport, file, offset=0, count=None,
+                       *, fallback=True):
+        """Send a file to transport.
+
+        Return the total number of bytes which were sent.
+
+        The method uses high-performance os.sendfile if available.
+
+        file must be a regular file object opened in binary mode.
+
+        offset tells from where to start reading the file. If specified,
+        count is the total number of bytes to transmit as opposed to
+        sending the file until EOF is reached. File position is updated on
+        return or also in case of error in which case file.tell()
+        can be used to figure out the number of bytes
+        which were sent.
+
+        fallback set to True makes asyncio to manually read and send
+        the file when the platform does not support the sendfile syscall
+        (e.g. Windows or SSL socket on Unix).
+
+        Raise SendfileNotAvailableError if the system does not support
+        sendfile syscall and fallback is False.
+        """
+        if transport.is_closing():
+            raise RuntimeError("Transport is closing")
+        mode = getattr(transport, '_sendfile_compatible',
+                       constants._SendfileMode.UNSUPPORTED)
+        if mode is constants._SendfileMode.UNSUPPORTED:
+            raise RuntimeError(
+                f"sendfile is not supported for transport {transport!r}")
+        if mode is constants._SendfileMode.TRY_NATIVE:
+            try:
+                return await self._sendfile_native(transport, file,
+                                                   offset, count)
+            except events.SendfileNotAvailableError as exc:
+                if not fallback:
+                    raise
+        # the mode is FALLBACK or fallback is True
+        return await self._sendfile_fallback(transport, file,
+                                             offset, count)
+
+    async def _sendfile_native(self, transp, file, offset, count):
+        raise events.SendfileNotAvailableError(
+            "sendfile syscall is not supported")
+
+    async def _sendfile_fallback(self, transp, file, offset, count):
+        if offset:
+            file.seek(offset)
+        blocksize = min(count, 16384) if count else 16384
+        buf = bytearray(blocksize)
+        total_sent = 0
+        proto = _SendfileFallbackProtocol(transp)
+        try:
+            while True:
+                if count:
+                    blocksize = min(count - total_sent, blocksize)
+                    if blocksize <= 0:
+                        return total_sent
+                view = memoryview(buf)[:blocksize]
+                read = file.readinto(view)
+                if not read:
+                    return total_sent  # EOF
+                await proto.drain()
+                transp.write(view)
+                total_sent += read
+        finally:
+            if total_sent > 0 and hasattr(file, 'seek'):
+                file.seek(offset + total_sent)
+            await proto.restore()
+
     async def start_tls(self, transport, protocol, sslcontext, *,
                         server_side=False,
                         server_hostname=None,
diff --git a/Lib/asyncio/constants.py b/Lib/asyncio/constants.py
index 0ad974ff2fb..739b0a70c13 100644
--- a/Lib/asyncio/constants.py
+++ b/Lib/asyncio/constants.py
@@ -1,3 +1,5 @@
+import enum
+
 # After the connection is lost, log warnings after this many write()s.
 LOG_THRESHOLD_FOR_CONNLOST_WRITES = 5
 
@@ -11,3 +13,10 @@
 
 # Number of seconds to wait for SSL handshake to complete
 SSL_HANDSHAKE_TIMEOUT = 10.0
+
+# The enum should be here to break circular dependencies between
+# base_events and sslproto
+class _SendfileMode(enum.Enum):
+    UNSUPPORTED = enum.auto()
+    TRY_NATIVE = enum.auto()
+    FALLBACK = enum.auto()
diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py
index 7aa3de02c95..bdefcf62a05 100644
--- a/Lib/asyncio/events.py
+++ b/Lib/asyncio/events.py
@@ -354,6 +354,14 @@ def set_default_executor(self, executor):
         """
         raise NotImplementedError
 
+    async def sendfile(self, transport, file, offset=0, count=None,
+                       *, fallback=True):
+        """Send a file through a transport.
+
+        Return an amount of sent bytes.
+        """
+        raise NotImplementedError
+
     async def start_tls(self, transport, protocol, sslcontext, *,
                         server_side=False,
                         server_hostname=None,
diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py
index ab1285b7999..6d27e532387 100644
--- a/Lib/asyncio/proactor_events.py
+++ b/Lib/asyncio/proactor_events.py
@@ -180,7 +180,12 @@ def _loop_reading(self, fut=None):
                 assert self._read_fut is fut or (self._read_fut is None and
                                                  self._closing)
                 self._read_fut = None
-                data = fut.result()  # deliver data later in "finally" clause
+                if fut.done():
+                    # deliver data later in "finally" clause
+                    data = fut.result()
+                else:
+                    # the future will be replaced by next proactor.recv call
+                    fut.cancel()
 
             if self._closing:
                 # since close() has been called we ignore any read data
@@ -345,6 +350,8 @@ class _ProactorSocketTransport(_ProactorReadPipeTransport,
                                transports.Transport):
     """Transport for connected sockets."""
 
+    _sendfile_compatible = constants._SendfileMode.FALLBACK
+
     def _set_extra(self, sock):
         self._extra['socket'] = sock
 
diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py
index 9446ae6a3bc..5956f2d993e 100644
--- a/Lib/asyncio/selector_events.py
+++ b/Lib/asyncio/selector_events.py
@@ -540,6 +540,20 @@ def _sock_accept(self, fut, registered, sock):
         else:
             fut.set_result((conn, address))
 
+    async def _sendfile_native(self, transp, file, offset, count):
+        del self._transports[transp._sock_fd]
+        resume_reading = transp.is_reading()
+        transp.pause_reading()
+        await transp._make_empty_waiter()
+        try:
+            return await self.sock_sendfile(transp._sock, file, offset, count,
+                                            fallback=False)
+        finally:
+            transp._reset_empty_waiter()
+            if resume_reading:
+                transp.resume_reading()
+            self._transports[transp._sock_fd] = transp
+
     def _process_events(self, event_list):
         for key, mask in event_list:
             fileobj, (reader, writer) = key.fileobj, key.data
@@ -695,12 +709,14 @@ def get_write_buffer_size(self):
 class _SelectorSocketTransport(_SelectorTransport):
 
     _start_tls_compatible = True
+    _sendfile_compatible = constants._SendfileMode.TRY_NATIVE
 
     def __init__(self, loop, sock, protocol, waiter=None,
                  extra=None, server=None):
         super().__init__(loop, sock, protocol, extra, server)
         self._eof = False
         self._paused = False
+        self._empty_waiter = None
 
         # Disable the Nagle algorithm -- small writes will be
         # sent without waiting for the TCP ACK.  This generally
@@ -765,6 +781,8 @@ def write(self, data):
                             f'not {type(data).__name__!r}')
         if self._eof:
             raise RuntimeError('Cannot call write() after write_eof()')
+        if self._empty_waiter is not None:
+            raise RuntimeError('unable to write; sendfile is in progress')
         if not data:
             return
 
@@ -807,12 +825,16 @@ def _write_ready(self):
             self._loop._remove_writer(self._sock_fd)
             self._buffer.clear()
             self._fatal_error(exc, 'Fatal write error on socket transport')
+            if self._empty_waiter is not None:
+                self._empty_waiter.set_exception(exc)
         else:
             if n:
                 del self._buffer[:n]
             self._maybe_resume_protocol()  # May append to buffer.
             if not self._buffer:
                 self._loop._remove_writer(self._sock_fd)
+                if self._empty_waiter is not None:
+                    self._empty_waiter.set_result(None)
                 if self._closing:
                     self._call_connection_lost(None)
                 elif self._eof:
@@ -828,6 +850,23 @@ def write_eof(self):
     def can_write_eof(self):
         return True
 
+    def _call_connection_lost(self, exc):
+        super()._call_connection_lost(exc)
+        if self._empty_waiter is not None:
+            self._empty_waiter.set_exception(
+                ConnectionError("Connection is closed by peer"))
+
+    def _make_empty_waiter(self):
+        if self._empty_waiter is not None:
+            raise RuntimeError("Empty waiter is already set")
+        self._empty_waiter = self._loop.create_future()
+        if not self._buffer:
+            self._empty_waiter.set_result(None)
+        return self._empty_waiter
+
+    def _reset_empty_waiter(self):
+        self._empty_waiter = None
+
 
 class _SelectorDatagramTransport(_SelectorTransport):
 
diff --git a/Lib/asyncio/sslproto.py b/Lib/asyncio/sslproto.py
index 1130bced8ae..863b54313cc 100644
--- a/Lib/asyncio/sslproto.py
+++ b/Lib/asyncio/sslproto.py
@@ -282,6 +282,8 @@ def feed_appdata(self, data, offset=0):
 class _SSLProtocolTransport(transports._FlowControlMixin,
                             transports.Transport):
 
+    _sendfile_compatible = constants._SendfileMode.FALLBACK
+
     def __init__(self, loop, ssl_protocol):
         self._loop = loop
         # SSLProtocol instance
@@ -365,6 +367,11 @@ def get_write_buffer_size(self):
         """Return the current size of the write buffer."""
         return self._ssl_protocol._transport.get_write_buffer_size()
 
+    @property
+    def _protocol_paused(self):
+        # Required for sendfile fallback pause_writing/resume_writing logic
+        return self._ssl_protocol._transport._protocol_paused
+
     def write(self, data):
         """Write some data bytes to the transport.
 
diff --git a/Lib/asyncio/windows_events.py b/Lib/asyncio/windows_events.py
index 890fce8b405..f91fcddb2aa 100644
--- a/Lib/asyncio/windows_events.py
+++ b/Lib/asyncio/windows_events.py
@@ -425,7 +425,8 @@ def finish_recv(trans, key, ov):
             try:
                 return ov.getresult()
             except OSError as exc:
-                if exc.winerror == _overlapped.ERROR_NETNAME_DELETED:
+                if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED,
+                                    _overlapped.ERROR_OPERATION_ABORTED):
                     raise ConnectionResetError(*exc.args)
                 else:
                     raise
@@ -447,7 +448,8 @@ def finish_recv(trans, key, ov):
             try:
                 return ov.getresult()
             except OSError as exc:
-                if exc.winerror == _overlapped.ERROR_NETNAME_DELETED:
+                if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED,
+                                    _overlapped.ERROR_OPERATION_ABORTED):
                     raise ConnectionResetError(*exc.args)
                 else:
                     raise
@@ -466,7 +468,8 @@ def finish_send(trans, key, ov):
             try:
                 return ov.getresult()
             except OSError as exc:
-                if exc.winerror == _overlapped.ERROR_NETNAME_DELETED:
+                if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED,
+                                    _overlapped.ERROR_OPERATION_ABORTED):
                     raise ConnectionResetError(*exc.args)
                 else:
                     raise
diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py
index 6489f50f272..ab6560c70b9 100644
--- a/Lib/test/test_asyncio/test_base_events.py
+++ b/Lib/test/test_asyncio/test_base_events.py
@@ -1788,7 +1788,7 @@ def runner(loop):
             outer_loop.close()
 
 
-class BaseLoopSendfileTests(test_utils.TestCase):
+class BaseLoopSockSendfileTests(test_utils.TestCase):
 
     DATA = b"12345abcde" * 16 * 1024  # 160 KiB
 
@@ -1799,9 +1799,11 @@ def __init__(self, loop):
             self.closed = False
             self.data = bytearray()
             self.fut = loop.create_future()
+            self.transport = None
 
         def connection_made(self, transport):
             self.started = True
+            self.transport = transport
 
         def data_received(self, data):
             self.data.extend(data)
@@ -1809,6 +1811,7 @@ def data_received(self, data):
         def connection_lost(self, exc):
             self.closed = True
             self.fut.set_result(None)
+            self.transport = None
 
         async def wait_closed(self):
             await self.fut
@@ -1853,6 +1856,10 @@ def prepare(self):
         def cleanup():
             server.close()
             self.run_loop(server.wait_closed())
+            sock.close()
+            if proto.transport is not None:
+                proto.transport.close()
+                self.run_loop(proto.wait_closed())
 
         self.addCleanup(cleanup)
 
diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py
index cf217538a06..0981bd6ac91 100644
--- a/Lib/test/test_asyncio/test_events.py
+++ b/Lib/test/test_asyncio/test_events.py
@@ -26,6 +26,7 @@
     import tty
 
 import asyncio
+from asyncio import base_events
 from asyncio import coroutines
 from asyncio import events
 from asyncio import proactor_events
@@ -2090,14 +2091,308 @@ def test_subprocess_shell_invalid_args(self):
             self.loop.run_until_complete(connect(shell=False))
 
 
+class MySendfileProto(MyBaseProto):
+
+    def __init__(self, loop=None, close_after=0):
+        super().__init__(loop)
+        self.data = bytearray()
+        self.close_after = close_after
+
+    def data_received(self, data):
+        self.data.extend(data)
+        super().data_received(data)
+        if self.close_after and self.nbytes >= self.close_after:
+            self.transport.close()
+
+
+class SendfileMixin:
+    # Note: sendfile via SSL transport is equal to sendfile fallback
+
+    DATA = b"12345abcde" * 160 * 1024  # 160 KiB
+
+    @classmethod
+    def setUpClass(cls):
+        with open(support.TESTFN, 'wb') as fp:
+            fp.write(cls.DATA)
+        super().setUpClass()
+
+    @classmethod
+    def tearDownClass(cls):
+        support.unlink(support.TESTFN)
+        super().tearDownClass()
+
+    def setUp(self):
+        self.file = open(support.TESTFN, 'rb')
+        self.addCleanup(self.file.close)
+        super().setUp()
+
+    def run_loop(self, coro):
+        return self.loop.run_until_complete(coro)
+
+    def prepare(self, *, is_ssl=False, close_after=0):
+        port = support.find_unused_port()
+        srv_proto = MySendfileProto(loop=self.loop, close_after=close_after)
+        if is_ssl:
+            srv_ctx = test_utils.simple_server_sslcontext()
+            cli_ctx = test_utils.simple_client_sslcontext()
+        else:
+            srv_ctx = None
+            cli_ctx = None
+        srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        # reduce recv socket buffer size to test on relative small data sets
+        srv_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024)
+        srv_sock.bind((support.HOST, port))
+        server = self.run_loop(self.loop.create_server(
+            lambda: srv_proto, sock=srv_sock, ssl=srv_ctx))
+
+        if is_ssl:
+            server_hostname = support.HOST
+        else:
+            server_hostname = None
+        cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        # reduce send socket buffer size to test on relative small data sets
+        cli_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
+        cli_sock.connect((support.HOST, port))
+        cli_proto = MySendfileProto(loop=self.loop)
+        tr, pr = self.run_loop(self.loop.create_connection(
+            lambda: cli_proto, sock=cli_sock,
+            ssl=cli_ctx, server_hostname=server_hostname))
+
+        def cleanup():
+            srv_proto.transport.close()
+            cli_proto.transport.close()
+            self.run_loop(srv_proto.done)
+            self.run_loop(cli_proto.done)
+
+            server.close()
+            self.run_loop(server.wait_closed())
+
+        self.addCleanup(cleanup)
+        return srv_proto, cli_proto
+
+    @unittest.skipIf(sys.platform == 'win32', "UDP sockets are not supported")
+    def test_sendfile_not_supported(self):
+        tr, pr = self.run_loop(
+            self.loop.create_datagram_endpoint(
+                lambda: MyDatagramProto(loop=self.loop),
+                family=socket.AF_INET))
+        try:
+            with self.assertRaisesRegex(RuntimeError, "not supported"):
+                self.run_loop(
+                    self.loop.sendfile(tr, self.file))
+            self.assertEqual(0, self.file.tell())
+        finally:
+            # don't use self.addCleanup because it produces resource warning
+            tr.close()
+
+    def test_sendfile(self):
+        srv_proto, cli_proto = self.prepare()
+        ret = self.run_loop(
+            self.loop.sendfile(cli_proto.transport, self.file))
+        cli_proto.transport.close()
+        self.run_loop(srv_proto.done)
+        self.assertEqual(ret, len(self.DATA))
+        self.assertEqual(srv_proto.nbytes, len(self.DATA))
+        self.assertEqual(srv_proto.data, self.DATA)
+        self.assertEqual(self.file.tell(), len(self.DATA))
+
+    def test_sendfile_force_fallback(self):
+        srv_proto, cli_proto = self.prepare()
+
+        def sendfile_native(transp, file, offset, count):
+            # to raise SendfileNotAvailableError
+            return base_events.BaseEventLoop._sendfile_native(
+                self.loop, transp, file, offset, count)
+
+        self.loop._sendfile_native = sendfile_native
+
+        ret = self.run_loop(
+            self.loop.sendfile(cli_proto.transport, self.file))
+        cli_proto.transport.close()
+        self.run_loop(srv_proto.done)
+        self.assertEqual(ret, len(self.DATA))
+        self.assertEqual(srv_proto.nbytes, len(self.DATA))
+        self.assertEqual(srv_proto.data, self.DATA)
+        self.assertEqual(self.file.tell(), len(self.DATA))
+
+    def test_sendfile_force_unsupported_native(self):
+        if sys.platform == 'win32':
+            if isinstance(self.loop, asyncio.ProactorEventLoop):
+                self.skipTest("Fails on proactor event loop")
+        srv_proto, cli_proto = self.prepare()
+
+        def sendfile_native(transp, file, offset, count):
+            # to raise SendfileNotAvailableError
+            return base_events.BaseEventLoop._sendfile_native(
+                self.loop, transp, file, offset, count)
+
+        self.loop._sendfile_native = sendfile_native
+
+        with self.assertRaisesRegex(events.SendfileNotAvailableError,
+                                    "not supported"):
+            self.run_loop(
+                self.loop.sendfile(cli_proto.transport, self.file,
+                                   fallback=False))
+
+        cli_proto.transport.close()
+        self.run_loop(srv_proto.done)
+        self.assertEqual(srv_proto.nbytes, 0)
+        self.assertEqual(self.file.tell(), 0)
+
+    def test_sendfile_ssl(self):
+        srv_proto, cli_proto = self.prepare(is_ssl=True)
+        ret = self.run_loop(
+            self.loop.sendfile(cli_proto.transport, self.file))
+        cli_proto.transport.close()
+        self.run_loop(srv_proto.done)
+        self.assertEqual(ret, len(self.DATA))
+        self.assertEqual(srv_proto.nbytes, len(self.DATA))
+        self.assertEqual(srv_proto.data, self.DATA)
+        self.assertEqual(self.file.tell(), len(self.DATA))
+
+    def test_sendfile_for_closing_transp(self):
+        srv_proto, cli_proto = self.prepare()
+        cli_proto.transport.close()
+        with self.assertRaisesRegex(RuntimeError, "is closing"):
+            self.run_loop(self.loop.sendfile(cli_proto.transport, self.file))
+        self.run_loop(srv_proto.done)
+        self.assertEqual(srv_proto.nbytes, 0)
+        self.assertEqual(self.file.tell(), 0)
+
+    def test_sendfile_pre_and_post_data(self):
+        srv_proto, cli_proto = self.prepare()
+        PREFIX = b'zxcvbnm' * 1024
+        SUFFIX = b'0987654321' * 1024
+        cli_proto.transport.write(PREFIX)
+        ret = self.run_loop(
+            self.loop.sendfile(cli_proto.transport, self.file))
+        cli_proto.transport.write(SUFFIX)
+        cli_proto.transport.close()
+        self.run_loop(srv_proto.done)
+        self.assertEqual(ret, len(self.DATA))
+        self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
+        self.assertEqual(self.file.tell(), len(self.DATA))
+
+    def test_sendfile_ssl_pre_and_post_data(self):
+        srv_proto, cli_proto = self.prepare(is_ssl=True)
+        PREFIX = b'zxcvbnm' * 1024
+        SUFFIX = b'0987654321' * 1024
+        cli_proto.transport.write(PREFIX)
+        ret = self.run_loop(
+            self.loop.sendfile(cli_proto.transport, self.file))
+        cli_proto.transport.write(SUFFIX)
+        cli_proto.transport.close()
+        self.run_loop(srv_proto.done)
+        self.assertEqual(ret, len(self.DATA))
+        self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
+        self.assertEqual(self.file.tell(), len(self.DATA))
+
+    def test_sendfile_partial(self):
+        srv_proto, cli_proto = self.prepare()
+        ret = self.run_loop(
+            self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
+        cli_proto.transport.close()
+        self.run_loop(srv_proto.done)
+        self.assertEqual(ret, 100)
+        self.assertEqual(srv_proto.nbytes, 100)
+        self.assertEqual(srv_proto.data, self.DATA[1000:1100])
+        self.assertEqual(self.file.tell(), 1100)
+
+    def test_sendfile_ssl_partial(self):
+        srv_proto, cli_proto = self.prepare(is_ssl=True)
+        ret = self.run_loop(
+            self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
+        cli_proto.transport.close()
+        self.run_loop(srv_proto.done)
+        self.assertEqual(ret, 100)
+        self.assertEqual(srv_proto.nbytes, 100)
+        self.assertEqual(srv_proto.data, self.DATA[1000:1100])
+        self.assertEqual(self.file.tell(), 1100)
+
+    def test_sendfile_close_peer_after_receiving(self):
+        srv_proto, cli_proto = self.prepare(close_after=len(self.DATA))
+        ret = self.run_loop(
+            self.loop.sendfile(cli_proto.transport, self.file))
+        cli_proto.transport.close()
+        self.run_loop(srv_proto.done)
+        self.assertEqual(ret, len(self.DATA))
+        self.assertEqual(srv_proto.nbytes, len(self.DATA))
+        self.assertEqual(srv_proto.data, self.DATA)
+        self.assertEqual(self.file.tell(), len(self.DATA))
+
+    def test_sendfile_ssl_close_peer_after_receiving(self):
+        srv_proto, cli_proto = self.prepare(is_ssl=True,
+                                            close_after=len(self.DATA))
+        ret = self.run_loop(
+            self.loop.sendfile(cli_proto.transport, self.file))
+        self.run_loop(srv_proto.done)
+        self.assertEqual(ret, len(self.DATA))
+        self.assertEqual(srv_proto.nbytes, len(self.DATA))
+        self.assertEqual(srv_proto.data, self.DATA)
+        self.assertEqual(self.file.tell(), len(self.DATA))
+
+    def test_sendfile_close_peer_in_middle_of_receiving(self):
+        srv_proto, cli_proto = self.prepare(close_after=1024)
+        with self.assertRaises(ConnectionError):
+            self.run_loop(
+                self.loop.sendfile(cli_proto.transport, self.file))
+        self.run_loop(srv_proto.done)
+
+        self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
+                        srv_proto.nbytes)
+        self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
+                        self.file.tell())
+
+    def test_sendfile_fallback_close_peer_in_middle_of_receiving(self):
+
+        def sendfile_native(transp, file, offset, count):
+            # to raise SendfileNotAvailableError
+            return base_events.BaseEventLoop._sendfile_native(
+                self.loop, transp, file, offset, count)
+
+        self.loop._sendfile_native = sendfile_native
+
+        srv_proto, cli_proto = self.prepare(close_after=1024)
+        with self.assertRaises(ConnectionError):
+            self.run_loop(
+                self.loop.sendfile(cli_proto.transport, self.file))
+        self.run_loop(srv_proto.done)
+
+        self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
+                        srv_proto.nbytes)
+        self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
+                        self.file.tell())
+
+    @unittest.skipIf(not hasattr(os, 'sendfile'),
+                     "Don't have native sendfile support")
+    def test_sendfile_prevents_bare_write(self):
+        srv_proto, cli_proto = self.prepare()
+        fut = self.loop.create_future()
+
+        async def coro():
+            fut.set_result(None)
+            return await self.loop.sendfile(cli_proto.transport, self.file)
+
+        t = self.loop.create_task(coro())
+        self.run_loop(fut)
+        with self.assertRaisesRegex(RuntimeError,
+                                    "sendfile is in progress"):
+            cli_proto.transport.write(b'data')
+        ret = self.run_loop(t)
+        self.assertEqual(ret, len(self.DATA))
+
+
 if sys.platform == 'win32':
 
-    class SelectEventLoopTests(EventLoopTestsMixin, test_utils.TestCase):
+    class SelectEventLoopTests(EventLoopTestsMixin,
+                               SendfileMixin,
+                               test_utils.TestCase):
 
         def create_event_loop(self):
             return asyncio.SelectorEventLoop()
 
     class ProactorEventLoopTests(EventLoopTestsMixin,
+                                 SendfileMixin,
                                  SubprocessTestsMixin,
                                  test_utils.TestCase):
 
@@ -2125,7 +2420,7 @@ def test_remove_fds_after_closing(self):
 else:
     import selectors
 
-    class UnixEventLoopTestsMixin(EventLoopTestsMixin):
+    class UnixEventLoopTestsMixin(EventLoopTestsMixin, SendfileMixin):
         def setUp(self):
             super().setUp()
             watcher = asyncio.SafeChildWatcher()
@@ -2556,7 +2851,9 @@ def test_not_implemented_async(self):
             with self.assertRaises(NotImplementedError):
                 await loop.sock_accept(f)
             with self.assertRaises(NotImplementedError):
-                await loop.sock_sendfile(f, mock.Mock())
+                await loop.sock_sendfile(f, f)
+            with self.assertRaises(NotImplementedError):
+                await loop.sendfile(f, f)
             with self.assertRaises(NotImplementedError):
                 await loop.connect_read_pipe(f, mock.sentinel.pipe)
             with self.assertRaises(NotImplementedError):
diff --git a/Misc/NEWS.d/next/Library/2018-01-22-18-18-44.bpo-32622.A1D6FP.rst b/Misc/NEWS.d/next/Library/2018-01-22-18-18-44.bpo-32622.A1D6FP.rst
new file mode 100644
index 00000000000..d7433fa3cb1
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2018-01-22-18-18-44.bpo-32622.A1D6FP.rst
@@ -0,0 +1 @@
+Add :meth:`asyncio.AbstractEventLoop.sendfile` method.
diff --git a/Modules/overlapped.c b/Modules/overlapped.c
index e66e8566840..447a337fdd1 100644
--- a/Modules/overlapped.c
+++ b/Modules/overlapped.c
@@ -1436,6 +1436,7 @@ PyInit__overlapped(void)
 
     WINAPI_CONSTANT(F_DWORD,  ERROR_IO_PENDING);
     WINAPI_CONSTANT(F_DWORD,  ERROR_NETNAME_DELETED);
+    WINAPI_CONSTANT(F_DWORD,  ERROR_OPERATION_ABORTED);
     WINAPI_CONSTANT(F_DWORD,  ERROR_SEM_TIMEOUT);
     WINAPI_CONSTANT(F_DWORD,  ERROR_PIPE_BUSY);
     WINAPI_CONSTANT(F_DWORD,  INFINITE);



More information about the Python-checkins mailing list