[Python-checkins] r80598 - in python/branches/py3k: Lib/test/test_ssl.py

antoine.pitrou python-checkins at python.org
Wed Apr 28 23:37:09 CEST 2010


Author: antoine.pitrou
Date: Wed Apr 28 23:37:09 2010
New Revision: 80598

Log:
Merged revisions 80596 via svnmerge from 
svn+ssh://pythondev@svn.python.org/python/trunk

........
  r80596 | antoine.pitrou | 2010-04-28 23:11:01 +0200 (mer., 28 avril 2010) | 3 lines
  
  Fix style issues in test_ssl
........


Modified:
   python/branches/py3k/   (props changed)
   python/branches/py3k/Lib/test/test_ssl.py

Modified: python/branches/py3k/Lib/test/test_ssl.py
==============================================================================
--- python/branches/py3k/Lib/test/test_ssl.py	(original)
+++ python/branches/py3k/Lib/test/test_ssl.py	Wed Apr 28 23:37:09 2010
@@ -36,7 +36,7 @@
 
 class BasicTests(unittest.TestCase):
 
-    def testSSLconnect(self):
+    def test_connect(self):
         if not support.is_resource_enabled('network'):
             return
         s = ssl.wrap_socket(socket.socket(socket.AF_INET),
@@ -57,7 +57,7 @@
         finally:
             s.close()
 
-    def testCrucialConstants(self):
+    def test_constants(self):
         ssl.PROTOCOL_SSLv2
         ssl.PROTOCOL_SSLv23
         ssl.PROTOCOL_SSLv3
@@ -66,7 +66,7 @@
         ssl.CERT_OPTIONAL
         ssl.CERT_REQUIRED
 
-    def testRAND(self):
+    def test_random(self):
         v = ssl.RAND_status()
         if support.verbose:
             sys.stdout.write("\n RAND_status is %d (%s)\n"
@@ -80,7 +80,7 @@
             print("didn't raise TypeError")
         ssl.RAND_add("this is a random string", 75.0)
 
-    def testParseCert(self):
+    def test_parse_cert(self):
         # note that this uses an 'unofficial' function in _ssl.c,
         # provided solely for this test, to exercise the certificate
         # parsing code
@@ -88,9 +88,9 @@
         if support.verbose:
             sys.stdout.write("\n" + pprint.pformat(p) + "\n")
 
-    def testDERtoPEM(self):
-
-        pem = open(SVN_PYTHON_ORG_ROOT_CERT, 'r').read()
+    def test_DER_to_PEM(self):
+        with open(SVN_PYTHON_ORG_ROOT_CERT, 'r') as f:
+            pem = f.read()
         d1 = ssl.PEM_cert_to_DER_cert(pem)
         p2 = ssl.DER_cert_to_PEM_cert(d1)
         d2 = ssl.PEM_cert_to_DER_cert(p2)
@@ -166,7 +166,7 @@
 
 class NetworkedTests(unittest.TestCase):
 
-    def testConnect(self):
+    def test_connect(self):
         s = ssl.wrap_socket(socket.socket(socket.AF_INET),
                             cert_reqs=ssl.CERT_NONE)
         s.connect(("svn.python.org", 443))
@@ -213,7 +213,7 @@
             os.read(fd, 0)
         self.assertEqual(e.exception.errno, errno.EBADF)
 
-    def testNonBlockingHandshake(self):
+    def test_non_blocking_handshake(self):
         s = socket.socket(socket.AF_INET)
         s.connect(("svn.python.org", 443))
         s.setblocking(False)
@@ -237,8 +237,7 @@
         if support.verbose:
             sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count)
 
-    def testFetchServerCert(self):
-
+    def test_get_server_certificate(self):
         pem = ssl.get_server_certificate(("svn.python.org", 443))
         if not pem:
             self.fail("No server certificate on svn.python.org:443!")
@@ -289,7 +288,6 @@
 except ImportError:
     _have_threads = False
 else:
-
     _have_threads = True
 
     class ThreadedEchoServer(threading.Thread):
@@ -310,7 +308,7 @@
                 threading.Thread.__init__(self)
                 self.daemon = True
 
-            def wrap_conn (self):
+            def wrap_conn(self):
                 try:
                     self.sslconn = ssl.wrap_socket(self.sock, server_side=True,
                                                    certfile=self.server.certificate,
@@ -359,7 +357,7 @@
                 else:
                     self.sock.close()
 
-            def run (self):
+            def run(self):
                 self.running = True
                 if not self.server.starttls_server:
                     if not self.wrap_conn():
@@ -367,28 +365,28 @@
                 while self.running:
                     try:
                         msg = self.read()
-                        amsg = (msg and str(msg, 'ASCII', 'strict')) or ''
-                        if not msg:
+                        stripped = msg.strip()
+                        if not stripped:
                             # eof, so quit this handler
                             self.running = False
                             self.close()
-                        elif amsg.strip() == 'over':
+                        elif stripped == b'over':
                             if support.verbose and self.server.connectionchatty:
                                 sys.stdout.write(" server: client closed connection\n")
                             self.close()
                             return
                         elif (self.server.starttls_server and
-                              amsg.strip() == 'STARTTLS'):
+                              stripped == 'STARTTLS'):
                             if support.verbose and self.server.connectionchatty:
                                 sys.stdout.write(" server: read STARTTLS from client, sending OK...\n")
-                            self.write("OK\n".encode("ASCII", "strict"))
+                            self.write(b"OK\n")
                             if not self.wrap_conn():
                                 return
                         elif (self.server.starttls_server and self.sslconn
-                              and amsg.strip() == 'ENDTLS'):
+                              and stripped == 'ENDTLS'):
                             if support.verbose and self.server.connectionchatty:
                                 sys.stdout.write(" server: read ENDTLS from client, sending OK...\n")
-                            self.write("OK\n".encode("ASCII", "strict"))
+                            self.write(b"OK\n")
                             self.sock = self.sslconn.unwrap()
                             self.sslconn = None
                             if support.verbose and self.server.connectionchatty:
@@ -397,9 +395,9 @@
                             if (support.verbose and
                                 self.server.connectionchatty):
                                 ctype = (self.sslconn and "encrypted") or "unencrypted"
-                                sys.stdout.write(" server: read %s (%s), sending back %s (%s)...\n"
-                                                 % (repr(msg), ctype, repr(msg.lower()), ctype))
-                            self.write(amsg.lower().encode('ASCII', 'strict'))
+                                sys.stdout.write(" server: read %r (%s), sending back %r (%s)...\n"
+                                                 % (msg, ctype, msg.lower(), ctype))
+                            self.write(msg.lower())
                     except socket.error:
                         if self.server.chatty:
                             handle_error("Test server failure:\n")
@@ -432,11 +430,11 @@
             threading.Thread.__init__(self)
             self.daemon = True
 
-        def start (self, flag=None):
+        def start(self, flag=None):
             self.flag = flag
             threading.Thread.start(self)
 
-        def run (self):
+        def run(self):
             self.sock.settimeout(0.05)
             self.sock.listen(5)
             self.active = True
@@ -457,7 +455,7 @@
                     self.stop()
             self.sock.close()
 
-        def stop (self):
+        def stop(self):
             self.active = False
 
     class OurHTTPSServer(threading.Thread):
@@ -467,12 +465,9 @@
         class HTTPSServer(HTTPServer):
 
             def __init__(self, server_address, RequestHandlerClass, certfile):
-
                 HTTPServer.__init__(self, server_address, RequestHandlerClass)
                 # we assume the certfile contains both private key and certificate
                 self.certfile = certfile
-                self.active = False
-                self.active_lock = threading.Lock()
                 self.allow_reuse_address = True
 
             def __str__(self):
@@ -481,7 +476,7 @@
                          self.server_name,
                          self.server_port))
 
-            def get_request (self):
+            def get_request(self):
                 # override this to wrap socket with SSL
                 sock, addr = self.socket.accept()
                 sslconn = ssl.wrap_socket(sock, server_side=True,
@@ -489,7 +484,6 @@
                 return sslconn, addr
 
         class RootedHTTPRequestHandler(SimpleHTTPRequestHandler):
-
             # need to override translate_path to get a known root,
             # instead of using os.curdir, since the test could be
             # run from anywhere
@@ -520,7 +514,6 @@
                 return path
 
             def log_message(self, format, *args):
-
                 # we override this to suppress logging unless "verbose"
 
                 if support.verbose:
@@ -534,7 +527,6 @@
 
         def __init__(self, certfile):
             self.flag = None
-            self.active = False
             self.RootedHTTPRequestHandler.root = os.path.split(CERTFILE)[0]
             self.server = self.HTTPSServer(
                 (HOST, 0), self.RootedHTTPRequestHandler, certfile)
@@ -545,19 +537,16 @@
         def __str__(self):
             return "<%s %s>" % (self.__class__.__name__, self.server)
 
-        def start (self, flag=None):
+        def start(self, flag=None):
             self.flag = flag
             threading.Thread.start(self)
 
-        def run (self):
-            self.active = True
+        def run(self):
             if self.flag:
                 self.flag.set()
             self.server.serve_forever(0.05)
-            self.active = False
 
-        def stop (self):
-            self.active = False
+        def stop(self):
             self.server.shutdown()
 
 
@@ -609,7 +598,7 @@
                         if not data:
                             self.close()
                         else:
-                            self.send(str(data, 'ASCII', 'strict').lower().encode('ASCII', 'strict'))
+                            self.send(data.lower())
 
                 def handle_close(self):
                     self.close()
@@ -650,7 +639,7 @@
             self.flag = flag
             threading.Thread.start(self)
 
-        def run (self):
+        def run(self):
             self.active = True
             if self.flag:
                 self.flag.set()
@@ -660,11 +649,15 @@
                 except:
                     pass
 
-        def stop (self):
+        def stop(self):
             self.active = False
             self.server.close()
 
-    def badCertTest (certfile):
+    def bad_cert_test(certfile):
+        """
+        Launch a server with CERT_REQUIRED, and check that trying to
+        connect to it with the given client certificate fails.
+        """
         server = ThreadedEchoServer(CERTFILE,
                                     certreqs=ssl.CERT_REQUIRED,
                                     cacerts=CERTFILE, chatty=False,
@@ -684,7 +677,7 @@
                 if support.verbose:
                     sys.stdout.write("\nSSLError is %s\n" % x.args[1])
             except socket.error as x:
-                if test_support.verbose:
+                if support.verbose:
                     sys.stdout.write("\nsocket.error is %s\n" % x[1])
             else:
                 self.fail("Use of invalid cert should have failed!")
@@ -692,11 +685,13 @@
             server.stop()
             server.join()
 
-    def serverParamsTest (certfile, protocol, certreqs, cacertsfile,
-                          client_certfile, client_protocol=None,
-                          indata="FOO\n",
-                          ciphers=None, chatty=False, connectionchatty=False):
-
+    def server_params_test(certfile, protocol, certreqs, cacertsfile,
+                           client_certfile, client_protocol=None, indata=b"FOO\n",
+                           ciphers=None, chatty=True, connectionchatty=False):
+        """
+        Launch a server, connect a client to it and try various reads
+        and writes.
+        """
         server = ThreadedEchoServer(certfile,
                                     certreqs=certreqs,
                                     ssl_version=protocol,
@@ -719,24 +714,22 @@
                                 cert_reqs=certreqs,
                                 ssl_version=client_protocol)
             s.connect((HOST, server.port))
-            bindata = indata.encode('ASCII', 'strict')
-            for arg in [bindata, bytearray(bindata), memoryview(bindata)]:
+            for arg in [indata, bytearray(indata), memoryview(indata)]:
                 if connectionchatty:
                     if support.verbose:
                         sys.stdout.write(
-                            " client:  sending %s...\n" % (repr(indata)))
+                            " client:  sending %r...\n" % indata)
                 s.write(arg)
                 outdata = s.read()
                 if connectionchatty:
                     if support.verbose:
-                        sys.stdout.write(" client:  read %s\n" % repr(outdata))
-                outdata = str(outdata, 'ASCII', 'strict')
+                        sys.stdout.write(" client:  read %r\n" % outdata)
                 if outdata != indata.lower():
                     self.fail(
-                        "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
-                        % (repr(outdata[:min(len(outdata),20)]), len(outdata),
-                           repr(indata[:min(len(indata),20)].lower()), len(indata)))
-            s.write("over\n".encode("ASCII", "strict"))
+                        "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n"
+                        % (outdata[:20], len(outdata),
+                           indata[:20].lower(), len(indata)))
+            s.write(b"over\n")
             if connectionchatty:
                 if support.verbose:
                     sys.stdout.write(" client:  closing connection.\n")
@@ -745,22 +738,19 @@
             server.stop()
             server.join()
 
-    def tryProtocolCombo (server_protocol,
-                          client_protocol,
-                          expectedToWork,
-                          certsreqs=None):
-
+    def try_protocol_combo(server_protocol,
+                           client_protocol,
+                           expect_success,
+                           certsreqs=None):
         if certsreqs is None:
             certsreqs = ssl.CERT_NONE
-
-        if certsreqs == ssl.CERT_NONE:
-            certtype = "CERT_NONE"
-        elif certsreqs == ssl.CERT_OPTIONAL:
-            certtype = "CERT_OPTIONAL"
-        elif certsreqs == ssl.CERT_REQUIRED:
-            certtype = "CERT_REQUIRED"
+        certtype = {
+            ssl.CERT_NONE: "CERT_NONE",
+            ssl.CERT_OPTIONAL: "CERT_OPTIONAL",
+            ssl.CERT_REQUIRED: "CERT_REQUIRED",
+        }[certsreqs]
         if support.verbose:
-            formatstr = (expectedToWork and " %s->%s %s\n") or " {%s->%s} %s\n"
+            formatstr = (expect_success and " %s->%s %s\n") or " {%s->%s} %s\n"
             sys.stdout.write(formatstr %
                              (ssl.get_protocol_name(client_protocol),
                               ssl.get_protocol_name(server_protocol),
@@ -769,20 +759,20 @@
             # NOTE: we must enable "ALL" ciphers, otherwise an SSLv23 client
             # will send an SSLv3 hello (rather than SSLv2) starting from
             # OpenSSL 1.0.0 (see issue #8322).
-            serverParamsTest(CERTFILE, server_protocol, certsreqs,
-                             CERTFILE, CERTFILE, client_protocol,
-                             ciphers="ALL",
-                             chatty=False, connectionchatty=False)
+            server_params_test(CERTFILE, server_protocol, certsreqs,
+                               CERTFILE, CERTFILE, client_protocol,
+                               ciphers="ALL", chatty=False,
+                               connectionchatty=False)
         # Protocol mismatch can result in either an SSLError, or a
         # "Connection reset by peer" error.
         except ssl.SSLError:
-            if expectedToWork:
+            if expect_success:
                 raise
         except socket.error as e:
-            if expectedToWork or e.errno != errno.ECONNRESET:
+            if expect_success or e.errno != errno.ECONNRESET:
                 raise
         else:
-            if not expectedToWork:
+            if not expect_success:
                 self.fail(
                     "Client protocol %s succeeded with server protocol %s!"
                     % (ssl.get_protocol_name(client_protocol),
@@ -791,16 +781,15 @@
 
     class ThreadedTests(unittest.TestCase):
 
-        def testEcho (self):
-
+        def test_echo(self):
+            """Basic test of an SSL client connecting to a server"""
             if support.verbose:
                 sys.stdout.write("\n")
-            serverParamsTest(CERTFILE, ssl.PROTOCOL_TLSv1, ssl.CERT_NONE,
-                             CERTFILE, CERTFILE, ssl.PROTOCOL_TLSv1,
-                             chatty=True, connectionchatty=True)
-
-        def testReadCert(self):
+            server_params_test(CERTFILE, ssl.PROTOCOL_TLSv1, ssl.CERT_NONE,
+                               CERTFILE, CERTFILE, ssl.PROTOCOL_TLSv1,
+                               chatty=True, connectionchatty=True)
 
+        def test_getpeercert(self):
             if support.verbose:
                 sys.stdout.write("\n")
             s2 = socket.socket()
@@ -840,23 +829,30 @@
                 server.stop()
                 server.join()
 
-        def testNULLcert(self):
-            badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir,
-                                     "nullcert.pem"))
-        def testMalformedCert(self):
-            badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir,
-                                     "badcert.pem"))
-        def testWrongCert(self):
-            badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir,
-                                     "wrongcert.pem"))
-        def testMalformedKey(self):
-            badCertTest(os.path.join(os.path.dirname(__file__) or os.curdir,
-                                     "badkey.pem"))
-
-        def testRudeShutdown(self):
-
+        def test_empty_cert(self):
+            """Connecting with an empty cert file"""
+            bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir,
+                                      "nullcert.pem"))
+        def test_malformed_cert(self):
+            """Connecting with a badly formatted certificate (syntax error)"""
+            bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir,
+                                       "badcert.pem"))
+        def test_nonexisting_cert(self):
+            """Connecting with a non-existing cert file"""
+            bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir,
+                                       "wrongcert.pem"))
+        def test_malformed_key(self):
+            """Connecting with a badly formatted key (syntax error)"""
+            bad_cert_test(os.path.join(os.path.dirname(__file__) or os.curdir,
+                                       "badkey.pem"))
+
+        def test_rude_shutdown(self):
+            """A brutal shutdown of an SSL server should raise an IOError
+            in the client when attempting handshake.
+            """
             listener_ready = threading.Event()
             listener_gone = threading.Event()
+
             s = socket.socket()
             port = support.bind_port(s, HOST)
 
@@ -890,62 +886,66 @@
             finally:
                 t.join()
 
-        def testProtocolSSL2(self):
+        def test_protocol_sslv2(self):
+            """Connecting to an SSLv2 server with various client options"""
             if support.verbose:
                 sys.stdout.write("\n")
-            tryProtocolCombo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True)
-            tryProtocolCombo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL)
-            tryProtocolCombo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED)
-            tryProtocolCombo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, True)
-            tryProtocolCombo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False)
-            tryProtocolCombo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False)
+            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True)
+            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL)
+            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED)
+            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, True)
+            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False)
+            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False)
 
-        def testProtocolSSL23(self):
+        def test_protocol_sslv23(self):
+            """Connecting to an SSLv23 server with various client options"""
             if support.verbose:
                 sys.stdout.write("\n")
             try:
-                tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv2, True)
+                try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv2, True)
             except (ssl.SSLError, socket.error) as x:
                 # this fails on some older versions of OpenSSL (0.9.7l, for instance)
                 if support.verbose:
                     sys.stdout.write(
                         " SSL2 client to SSL23 server test unexpectedly failed:\n %s\n"
                         % str(x))
-            tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True)
-            tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True)
-            tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True)
-
-            tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL)
-            tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL)
-            tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL)
-
-            tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED)
-            tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED)
-            tryProtocolCombo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED)
+            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True)
+            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True)
+            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True)
+
+            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL)
+            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL)
+            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL)
+
+            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED)
+            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED)
+            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED)
 
-        def testProtocolSSL3(self):
+        def test_protocol_sslv3(self):
+            """Connecting to an SSLv3 server with various client options"""
             if support.verbose:
                 sys.stdout.write("\n")
-            tryProtocolCombo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True)
-            tryProtocolCombo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL)
-            tryProtocolCombo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED)
-            tryProtocolCombo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False)
-            tryProtocolCombo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False)
-            tryProtocolCombo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False)
+            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True)
+            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL)
+            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED)
+            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False)
+            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False)
+            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False)
 
-        def testProtocolTLS1(self):
+        def test_protocol_tlsv1(self):
+            """Connecting to a TLSv1 server with various client options"""
             if support.verbose:
                 sys.stdout.write("\n")
-            tryProtocolCombo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True)
-            tryProtocolCombo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL)
-            tryProtocolCombo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED)
-            tryProtocolCombo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False)
-            tryProtocolCombo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False)
-            tryProtocolCombo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv23, False)
-
-        def testSTARTTLS (self):
-
-            msgs = ("msg 1", "MSG 2", "STARTTLS", "MSG 3", "msg 4", "ENDTLS", "msg 5", "msg 6")
+            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True)
+            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL)
+            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED)
+            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False)
+            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False)
+            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv23, False)
+
+        def test_starttls(self):
+            """Switching from clear text to encrypted and back again."""
+            msgs = (b"msg 1", b"MSG 2", b"STARTTLS", b"MSG 3", b"msg 4", b"ENDTLS", b"msg 5", b"msg 6")
 
             server = ThreadedEchoServer(CERTFILE,
                                         ssl_version=ssl.PROTOCOL_TLSv1,
@@ -965,45 +965,42 @@
                 if support.verbose:
                     sys.stdout.write("\n")
                 for indata in msgs:
-                    msg = indata.encode('ASCII', 'replace')
                     if support.verbose:
                         sys.stdout.write(
-                            " client:  sending %s...\n" % repr(msg))
+                            " client:  sending %r...\n" % indata)
                     if wrapped:
-                        conn.write(msg)
+                        conn.write(indata)
                         outdata = conn.read()
                     else:
-                        s.send(msg)
+                        s.send(indata)
                         outdata = s.recv(1024)
-                    if (indata == "STARTTLS" and
-                        str(outdata, 'ASCII', 'replace').strip().lower().startswith("ok")):
+                    msg = outdata.strip().lower()
+                    if indata == b"STARTTLS" and msg.startswith(b"ok"):
+                        # STARTTLS ok, switch to secure mode
                         if support.verbose:
-                            msg = str(outdata, 'ASCII', 'replace')
                             sys.stdout.write(
-                                " client:  read %s from server, starting TLS...\n"
-                                % repr(msg))
+                                " client:  read %r from server, starting TLS...\n"
+                                % msg)
                         conn = ssl.wrap_socket(s, ssl_version=ssl.PROTOCOL_TLSv1)
                         wrapped = True
-                    elif (indata == "ENDTLS" and
-                          str(outdata, 'ASCII', 'replace').strip().lower().startswith("ok")):
+                    elif indata == b"ENDTLS" and msg.startswith(b"ok"):
+                        # ENDTLS ok, switch back to clear text
                         if support.verbose:
-                            msg = str(outdata, 'ASCII', 'replace')
                             sys.stdout.write(
-                                " client:  read %s from server, ending TLS...\n"
-                                % repr(msg))
+                                " client:  read %r from server, ending TLS...\n"
+                                % msg)
                         s = conn.unwrap()
                         wrapped = False
                     else:
                         if support.verbose:
-                            msg = str(outdata, 'ASCII', 'replace')
                             sys.stdout.write(
-                                " client:  read %s from server\n" % repr(msg))
+                                " client:  read %r from server\n" % msg)
                 if support.verbose:
                     sys.stdout.write(" client:  closing connection.\n")
                 if wrapped:
-                    conn.write("over\n".encode("ASCII", "strict"))
+                    conn.write(b"over\n")
                 else:
-                    s.send("over\n".encode("ASCII", "strict"))
+                    s.send(b"over\n")
                 if wrapped:
                     conn.close()
                 else:
@@ -1012,8 +1009,8 @@
                 server.stop()
                 server.join()
 
-        def testSocketServer(self):
-
+        def test_socketserver(self):
+            """Using a SocketServer to create and manage SSL connections."""
             server = OurHTTPSServer(CERTFILE)
             flag = threading.Event()
             server.start(flag)
@@ -1023,7 +1020,8 @@
             try:
                 if support.verbose:
                     sys.stdout.write('\n')
-                d1 = open(CERTFILE, 'rb').read()
+                with open(CERTFILE, 'rb') as f:
+                    d1 = f.read()
                 d2 = ''
                 # now fetch the same data from the HTTPS server
                 url = 'https://%s:%d/%s' % (
@@ -1046,12 +1044,14 @@
                     sys.stdout.write('joining thread\n')
                 server.join()
 
-        def testAsyncoreServer(self):
+        def test_asyncore_server(self):
+            """Check the example asyncore integration."""
+            indata = "TEST MESSAGE of mixed case\n"
 
             if support.verbose:
                 sys.stdout.write("\n")
 
-            indata="FOO\n"
+            indata = b"FOO\n"
             server = AsyncoreEchoServer(CERTFILE)
             flag = threading.Event()
             server.start(flag)
@@ -1063,18 +1063,17 @@
                 s.connect(('127.0.0.1', server.port))
                 if support.verbose:
                     sys.stdout.write(
-                        " client:  sending %s...\n" % (repr(indata)))
-                s.write(indata.encode('ASCII', 'strict'))
+                        " client:  sending %r...\n" % indata)
+                s.write(indata)
                 outdata = s.read()
                 if support.verbose:
-                    sys.stdout.write(" client:  read %s\n" % repr(outdata))
-                outdata = str(outdata, 'ASCII', 'strict')
+                    sys.stdout.write(" client:  read %r\n" % outdata)
                 if outdata != indata.lower():
                     self.fail(
-                        "bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
-                        % (outdata[:min(len(outdata),20)], len(outdata),
-                           indata[:min(len(indata),20)].lower(), len(indata)))
-                s.write("over\n".encode("ASCII", "strict"))
+                        "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n"
+                        % (outdata[:20], len(outdata),
+                           indata[:20].lower(), len(indata)))
+                s.write(b"over\n")
                 if support.verbose:
                     sys.stdout.write(" client:  closing connection.\n")
                 s.close()
@@ -1082,8 +1081,8 @@
                 server.stop()
                 server.join()
 
-        def testAllRecvAndSendMethods(self):
-
+        def test_recv_send(self):
+            """Test recv(), send() and friends."""
             if support.verbose:
                 sys.stdout.write("\n")
 
@@ -1132,19 +1131,18 @@
                 data_prefix = "PREFIX_"
 
                 for meth_name, send_meth, expect_success, args in send_methods:
-                    indata = data_prefix + meth_name
+                    indata = (data_prefix + meth_name).encode('ascii')
                     try:
-                        send_meth(indata.encode('ASCII', 'strict'), *args)
+                        send_meth(indata, *args)
                         outdata = s.read()
-                        outdata = str(outdata, 'ASCII', 'strict')
                         if outdata != indata.lower():
                             self.fail(
                                 "While sending with <<{name:s}>> bad data "
-                                "<<{outdata:s}>> ({nout:d}) received; "
-                                "expected <<{indata:s}>> ({nin:d})\n".format(
-                                    name=meth_name, outdata=repr(outdata[:20]),
+                                "<<{outdata:r}>> ({nout:d}) received; "
+                                "expected <<{indata:r}>> ({nin:d})\n".format(
+                                    name=meth_name, outdata=outdata[:20],
                                     nout=len(outdata),
-                                    indata=repr(indata[:20]), nin=len(indata)
+                                    indata=indata[:20], nin=len(indata)
                                 )
                             )
                     except ValueError as e:
@@ -1162,19 +1160,18 @@
                             )
 
                 for meth_name, recv_meth, expect_success, args in recv_methods:
-                    indata = data_prefix + meth_name
+                    indata = (data_prefix + meth_name).encode('ascii')
                     try:
-                        s.send(indata.encode('ASCII', 'strict'))
+                        s.send(indata)
                         outdata = recv_meth(*args)
-                        outdata = str(outdata, 'ASCII', 'strict')
                         if outdata != indata.lower():
                             self.fail(
                                 "While receiving with <<{name:s}>> bad data "
-                                "<<{outdata:s}>> ({nout:d}) received; "
-                                "expected <<{indata:s}>> ({nin:d})\n".format(
-                                    name=meth_name, outdata=repr(outdata[:20]),
+                                "<<{outdata:r}>> ({nout:d}) received; "
+                                "expected <<{indata:r}>> ({nin:d})\n".format(
+                                    name=meth_name, outdata=outdata[:20],
                                     nout=len(outdata),
-                                    indata=repr(indata[:20]), nin=len(indata)
+                                    indata=indata[:20], nin=len(indata)
                                 )
                             )
                     except ValueError as e:
@@ -1193,7 +1190,7 @@
                         # consume data
                         s.read()
 
-                s.write("over\n".encode("ASCII", "strict"))
+                s.write(b"over\n")
                 s.close()
             finally:
                 server.stop()
@@ -1272,10 +1269,11 @@
         if thread_info and support.is_resource_enabled('network'):
             tests.append(ThreadedTests)
 
-    support.run_unittest(*tests)
-
-    if _have_threads:
-        support.threading_cleanup(*thread_info)
+    try:
+        support.run_unittest(*tests)
+    finally:
+        if _have_threads:
+            support.threading_cleanup(*thread_info)
 
 if __name__ == "__main__":
     test_main()


More information about the Python-checkins mailing list