[Python-checkins] bpo-33654: Support protocol type switching in SSLTransport.set_protocol() (#7194)

Andrew Svetlov webhook-mailer at python.org
Tue May 29 05:02:52 EDT 2018


https://github.com/python/cpython/commit/2179022d94937d7b0600b0dc192ca6fa5f53d830
commit: 2179022d94937d7b0600b0dc192ca6fa5f53d830
branch: master
author: Yury Selivanov <yury at magic.io>
committer: Andrew Svetlov <andrew.svetlov at gmail.com>
date: 2018-05-29T12:02:40+03:00
summary:

bpo-33654: Support protocol type switching in SSLTransport.set_protocol() (#7194)

files:
A Misc/NEWS.d/next/Library/2018-05-29-01-13-39.bpo-33654.sa81Si.rst
M Lib/asyncio/sslproto.py
M Lib/test/test_asyncio/test_sslproto.py

diff --git a/Lib/asyncio/sslproto.py b/Lib/asyncio/sslproto.py
index ab43e93b28bc..a6d382ecd3de 100644
--- a/Lib/asyncio/sslproto.py
+++ b/Lib/asyncio/sslproto.py
@@ -295,7 +295,7 @@ def get_extra_info(self, name, default=None):
         return self._ssl_protocol._get_extra_info(name, default)
 
     def set_protocol(self, protocol):
-        self._ssl_protocol._app_protocol = protocol
+        self._ssl_protocol._set_app_protocol(protocol)
 
     def get_protocol(self):
         return self._ssl_protocol._app_protocol
@@ -440,9 +440,7 @@ def __init__(self, loop, app_protocol, sslcontext, waiter,
 
         self._waiter = waiter
         self._loop = loop
-        self._app_protocol = app_protocol
-        self._app_protocol_is_buffer = \
-            isinstance(app_protocol, protocols.BufferedProtocol)
+        self._set_app_protocol(app_protocol)
         self._app_transport = _SSLProtocolTransport(self._loop, self)
         # _SSLPipe instance (None until the connection is made)
         self._sslpipe = None
@@ -454,6 +452,11 @@ def __init__(self, loop, app_protocol, sslcontext, waiter,
         self._call_connection_made = call_connection_made
         self._ssl_handshake_timeout = ssl_handshake_timeout
 
+    def _set_app_protocol(self, app_protocol):
+        self._app_protocol = app_protocol
+        self._app_protocol_is_buffer = \
+            isinstance(app_protocol, protocols.BufferedProtocol)
+
     def _wakeup_waiter(self, exc=None):
         if self._waiter is None:
             return
diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py
index 1b2f9d2a3a2a..fa9cbd56ed42 100644
--- a/Lib/test/test_asyncio/test_sslproto.py
+++ b/Lib/test/test_asyncio/test_sslproto.py
@@ -302,6 +302,7 @@ def test_start_tls_client_buf_proto_1(self):
 
         server_context = test_utils.simple_server_sslcontext()
         client_context = test_utils.simple_client_sslcontext()
+        client_con_made_calls = 0
 
         def serve(sock):
             sock.settimeout(self.TIMEOUT)
@@ -315,20 +316,21 @@ def serve(sock):
             data = sock.recv_all(len(HELLO_MSG))
             self.assertEqual(len(data), len(HELLO_MSG))
 
+            sock.sendall(b'2')
+            data = sock.recv_all(len(HELLO_MSG))
+            self.assertEqual(len(data), len(HELLO_MSG))
+
             sock.shutdown(socket.SHUT_RDWR)
             sock.close()
 
-        class ClientProto(asyncio.BufferedProtocol):
-            def __init__(self, on_data, on_eof):
+        class ClientProtoFirst(asyncio.BufferedProtocol):
+            def __init__(self, on_data):
                 self.on_data = on_data
-                self.on_eof = on_eof
-                self.con_made_cnt = 0
                 self.buf = bytearray(1)
 
-            def connection_made(proto, tr):
-                proto.con_made_cnt += 1
-                # Ensure connection_made gets called only once.
-                self.assertEqual(proto.con_made_cnt, 1)
+            def connection_made(self, tr):
+                nonlocal client_con_made_calls
+                client_con_made_calls += 1
 
             def get_buffer(self, sizehint):
                 return self.buf
@@ -337,27 +339,50 @@ def buffer_updated(self, nsize):
                 assert nsize == 1
                 self.on_data.set_result(bytes(self.buf[:nsize]))
 
+        class ClientProtoSecond(asyncio.Protocol):
+            def __init__(self, on_data, on_eof):
+                self.on_data = on_data
+                self.on_eof = on_eof
+                self.con_made_cnt = 0
+
+            def connection_made(self, tr):
+                nonlocal client_con_made_calls
+                client_con_made_calls += 1
+
+            def data_received(self, data):
+                self.on_data.set_result(data)
+
             def eof_received(self):
                 self.on_eof.set_result(True)
 
         async def client(addr):
             await asyncio.sleep(0.5, loop=self.loop)
 
-            on_data = self.loop.create_future()
+            on_data1 = self.loop.create_future()
+            on_data2 = self.loop.create_future()
             on_eof = self.loop.create_future()
 
             tr, proto = await self.loop.create_connection(
-                lambda: ClientProto(on_data, on_eof), *addr)
+                lambda: ClientProtoFirst(on_data1), *addr)
 
             tr.write(HELLO_MSG)
             new_tr = await self.loop.start_tls(tr, proto, client_context)
 
-            self.assertEqual(await on_data, b'O')
+            self.assertEqual(await on_data1, b'O')
+            new_tr.write(HELLO_MSG)
+
+            new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
+            self.assertEqual(await on_data2, b'2')
             new_tr.write(HELLO_MSG)
             await on_eof
 
             new_tr.close()
 
+            # connection_made() should be called only once -- when
+            # we establish connection for the first time. Start TLS
+            # doesn't call connection_made() on application protocols.
+            self.assertEqual(client_con_made_calls, 1)
+
         with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
             self.loop.run_until_complete(
                 asyncio.wait_for(client(srv.addr),
diff --git a/Misc/NEWS.d/next/Library/2018-05-29-01-13-39.bpo-33654.sa81Si.rst b/Misc/NEWS.d/next/Library/2018-05-29-01-13-39.bpo-33654.sa81Si.rst
new file mode 100644
index 000000000000..39e8e615d8c4
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2018-05-29-01-13-39.bpo-33654.sa81Si.rst
@@ -0,0 +1 @@
+Support protocol type switching in SSLTransport.set_protocol().



More information about the Python-checkins mailing list