[Python-checkins] bpo-43669: More test_ssl cleanups (GH-25470)

tiran webhook-mailer at python.org
Mon Apr 19 02:31:37 EDT 2021


https://github.com/python/cpython/commit/d37b74f341c5a215e2fdd5eb4f8c0182f327635c
commit: d37b74f341c5a215e2fdd5eb4f8c0182f327635c
branch: master
author: Christian Heimes <christian at python.org>
committer: tiran <christian at python.org>
date: 2021-04-19T08:31:29+02:00
summary:

bpo-43669: More test_ssl cleanups (GH-25470)

Signed-off-by: Christian Heimes <christian at python.org>

files:
M Lib/test/test_ssl.py

diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index a2c79ff5927f1..ae66c3e7d4a56 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -12,7 +12,6 @@
 import socket
 import select
 import time
-import datetime
 import gc
 import os
 import errno
@@ -39,9 +38,7 @@
 
 PROTOCOLS = sorted(ssl._PROTOCOL_NAMES)
 HOST = socket_helper.HOST
-IS_LIBRESSL = ssl.OPENSSL_VERSION.startswith('LibreSSL')
-IS_OPENSSL_1_1_1 = not IS_LIBRESSL and ssl.OPENSSL_VERSION_INFO >= (1, 1, 1)
-IS_OPENSSL_3_0_0 = not IS_LIBRESSL and ssl.OPENSSL_VERSION_INFO >= (3, 0, 0)
+IS_OPENSSL_3_0_0 = ssl.OPENSSL_VERSION_INFO >= (3, 0, 0)
 PY_SSL_DEFAULT_CIPHERS = sysconfig.get_config_var('PY_SSL_DEFAULT_CIPHERS')
 
 PROTOCOL_TO_TLS_VERSION = {}
@@ -258,31 +255,11 @@ def wrapper(*args, **kw):
     return decorator
 
 
-requires_minimum_version = unittest.skipUnless(
-    hasattr(ssl.SSLContext, 'minimum_version'),
-    "required OpenSSL >= 1.1.0g"
-)
-
-
 def handle_error(prefix):
     exc_format = ' '.join(traceback.format_exception(*sys.exc_info()))
     if support.verbose:
         sys.stdout.write(prefix + exc_format)
 
-def _have_secp_curves():
-    if not ssl.HAS_ECDH:
-        return False
-    ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
-    try:
-        ctx.set_ecdh_curve("secp384r1")
-    except ValueError:
-        return False
-    else:
-        return True
-
-
-HAVE_SECP_CURVES = _have_secp_curves()
-
 
 def utc_offset(): #NOTE: ignore issues like #1647654
     # local time = utc time + utc offset
@@ -290,21 +267,6 @@ def utc_offset(): #NOTE: ignore issues like #1647654
         return -time.altzone  # seconds
     return -time.timezone
 
-def asn1time(cert_time):
-    # Some versions of OpenSSL ignore seconds, see #18207
-    # 0.9.8.i
-    if ssl._OPENSSL_API_VERSION == (0, 9, 8, 9, 15):
-        fmt = "%b %d %H:%M:%S %Y GMT"
-        dt = datetime.datetime.strptime(cert_time, fmt)
-        dt = dt.replace(second=0)
-        cert_time = dt.strftime(fmt)
-        # %d adds leading zero but ASN1_TIME_print() uses leading space
-        if cert_time[4] == "0":
-            cert_time = cert_time[:4] + " " + cert_time[5:]
-
-    return cert_time
-
-needs_sni = unittest.skipUnless(ssl.HAS_SNI, "SNI support needed for this test")
 
 ignore_deprecation = warnings_helper.ignore_warnings(
     category=DeprecationWarning
@@ -365,11 +327,12 @@ def test_constants(self):
         ssl.CERT_REQUIRED
         ssl.OP_CIPHER_SERVER_PREFERENCE
         ssl.OP_SINGLE_DH_USE
-        if ssl.HAS_ECDH:
-            ssl.OP_SINGLE_ECDH_USE
+        ssl.OP_SINGLE_ECDH_USE
         ssl.OP_NO_COMPRESSION
-        self.assertIn(ssl.HAS_SNI, {True, False})
-        self.assertIn(ssl.HAS_ECDH, {True, False})
+        self.assertEqual(ssl.HAS_SNI, True)
+        self.assertEqual(ssl.HAS_ECDH, True)
+        self.assertEqual(ssl.HAS_TLSv1_2, True)
+        self.assertEqual(ssl.HAS_TLSv1_3, True)
         ssl.OP_NO_SSLv2
         ssl.OP_NO_SSLv3
         ssl.OP_NO_TLSv1
@@ -537,8 +500,8 @@ def test_openssl_version(self):
         self.assertIsInstance(t, tuple)
         self.assertIsInstance(s, str)
         # Some sanity checks follow
-        # >= 0.9
-        self.assertGreaterEqual(n, 0x900000)
+        # >= 1.1.1
+        self.assertGreaterEqual(n, 0x10101000)
         # < 4.0
         self.assertLess(n, 0x40000000)
         major, minor, fix, patch, status = t
@@ -552,13 +515,13 @@ def test_openssl_version(self):
         self.assertLessEqual(patch, 63)
         self.assertGreaterEqual(status, 0)
         self.assertLessEqual(status, 15)
-        # Version string as returned by {Open,Libre}SSL, the format might change
-        if IS_LIBRESSL:
-            self.assertTrue(s.startswith("LibreSSL {:d}".format(major)),
-                            (s, t, hex(n)))
-        else:
-            self.assertTrue(s.startswith("OpenSSL {:d}.{:d}.{:d}".format(major, minor, fix)),
-                            (s, t, hex(n)))
+
+        libressl_ver = f"LibreSSL {major:d}"
+        openssl_ver = f"OpenSSL {major:d}.{minor:d}.{fix:d}"
+        self.assertTrue(
+            s.startswith((openssl_ver, libressl_ver)),
+            (s, t, hex(n))
+        )
 
     @support.cpython_only
     def test_refcycle(self):
@@ -1196,8 +1159,6 @@ def test_hostname_checks_common_name(self):
             with self.assertRaises(AttributeError):
                 ctx.hostname_checks_common_name = True
 
-    @requires_minimum_version
-    @unittest.skipIf(IS_LIBRESSL, "see bpo-34001")
     @ignore_deprecation
     def test_min_max_version(self):
         ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
@@ -1523,7 +1484,6 @@ def test_set_ecdh_curve(self):
         self.assertRaises(ValueError, ctx.set_ecdh_curve, "foo")
         self.assertRaises(ValueError, ctx.set_ecdh_curve, b"foo")
 
-    @needs_sni
     def test_sni_callback(self):
         ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
 
@@ -1538,7 +1498,6 @@ def dummycallback(sock, servername, ctx):
         ctx.set_servername_callback(None)
         ctx.set_servername_callback(dummycallback)
 
-    @needs_sni
     def test_sni_callback_refcycle(self):
         # Reference cycles through the servername callback are detected
         # and cleared.
@@ -1578,8 +1537,8 @@ def test_get_ca_certs(self):
                          (('organizationalUnitName', 'http://www.cacert.org'),),
                          (('commonName', 'CA Cert Signing Authority'),),
                          (('emailAddress', 'support at cacert.org'),)),
-              'notAfter': asn1time('Mar 29 12:29:49 2033 GMT'),
-              'notBefore': asn1time('Mar 30 12:29:49 2003 GMT'),
+              'notAfter': 'Mar 29 12:29:49 2033 GMT',
+              'notBefore': 'Mar 30 12:29:49 2003 GMT',
               'serialNumber': '00',
               'crlDistributionPoints': ('https://www.cacert.org/revoke.crl',),
               'subject': ((('organizationName', 'Root CA'),),
@@ -1609,7 +1568,6 @@ def test_load_default_certs(self):
         self.assertRaises(TypeError, ctx.load_default_certs, 'SERVER_AUTH')
 
     @unittest.skipIf(sys.platform == "win32", "not-Windows specific")
-    @unittest.skipIf(IS_LIBRESSL, "LibreSSL doesn't support env vars")
     def test_load_default_certs_env(self):
         ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
         with os_helper.EnvironmentVarGuard() as env:
@@ -2145,7 +2103,6 @@ def test_non_blocking_handshake(self):
     def test_get_server_certificate(self):
         _test_get_server_certificate(self, *self.server_addr, cert=SIGNING_CA)
 
-    @needs_sni
     def test_get_server_certificate_sni(self):
         host, port = self.server_addr
         server_names = []
@@ -2198,7 +2155,6 @@ def test_get_ca_certs_capath(self):
             self.assertTrue(cert)
         self.assertEqual(len(ctx.get_ca_certs()), 1)
 
-    @needs_sni
     def test_context_setget(self):
         # Check that the context of a connected socket can be replaced.
         ctx1 = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
@@ -3863,7 +3819,6 @@ def test_tls1_3(self):
                 })
                 self.assertEqual(s.version(), 'TLSv1.3')
 
-    @requires_minimum_version
     @requires_tls_version('TLSv1_2')
     @requires_tls_version('TLSv1')
     @ignore_deprecation
@@ -3882,7 +3837,6 @@ def test_min_max_version_tlsv1_2(self):
                 s.connect((HOST, server.port))
                 self.assertEqual(s.version(), 'TLSv1.2')
 
-    @requires_minimum_version
     @requires_tls_version('TLSv1_1')
     @ignore_deprecation
     def test_min_max_version_tlsv1_1(self):
@@ -3900,7 +3854,6 @@ def test_min_max_version_tlsv1_1(self):
                 s.connect((HOST, server.port))
                 self.assertEqual(s.version(), 'TLSv1.1')
 
-    @requires_minimum_version
     @requires_tls_version('TLSv1_2')
     @requires_tls_version('TLSv1')
     @ignore_deprecation
@@ -3920,7 +3873,6 @@ def test_min_max_version_mismatch(self):
                     s.connect((HOST, server.port))
                 self.assertIn("alert", str(e.exception))
 
-    @requires_minimum_version
     @requires_tls_version('SSLv3')
     def test_min_max_version_sslv3(self):
         client_context, server_context, hostname = testing_context()
@@ -3935,7 +3887,6 @@ def test_min_max_version_sslv3(self):
                 s.connect((HOST, server.port))
                 self.assertEqual(s.version(), 'SSLv3')
 
-    @unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL")
     def test_default_ecdh_curve(self):
         # Issue #21015: elliptic curve-based Diffie Hellman key exchange
         # should be enabled by default on SSL contexts.
@@ -4050,15 +4001,13 @@ def test_dh_params(self):
         if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts:
             self.fail("Non-DH cipher: " + cipher[0])
 
-    @unittest.skipUnless(HAVE_SECP_CURVES, "needs secp384r1 curve support")
-    @unittest.skipIf(IS_OPENSSL_1_1_1, "TODO: Test doesn't work on 1.1.1")
     def test_ecdh_curve(self):
         # server secp384r1, client auto
         client_context, server_context, hostname = testing_context()
 
         server_context.set_ecdh_curve("secp384r1")
         server_context.set_ciphers("ECDHE:!eNULL:!aNULL")
-        server_context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
+        server_context.minimum_version = ssl.TLSVersion.TLSv1_2
         stats = server_params_test(client_context, server_context,
                                    chatty=True, connectionchatty=True,
                                    sni_name=hostname)
@@ -4067,7 +4016,7 @@ def test_ecdh_curve(self):
         client_context, server_context, hostname = testing_context()
         client_context.set_ecdh_curve("secp384r1")
         server_context.set_ciphers("ECDHE:!eNULL:!aNULL")
-        server_context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
+        server_context.minimum_version = ssl.TLSVersion.TLSv1_2
         stats = server_params_test(client_context, server_context,
                                    chatty=True, connectionchatty=True,
                                    sni_name=hostname)
@@ -4077,13 +4026,11 @@ def test_ecdh_curve(self):
         client_context.set_ecdh_curve("prime256v1")
         server_context.set_ecdh_curve("secp384r1")
         server_context.set_ciphers("ECDHE:!eNULL:!aNULL")
-        server_context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
-        try:
+        server_context.minimum_version = ssl.TLSVersion.TLSv1_2
+        with self.assertRaises(ssl.SSLError):
             server_params_test(client_context, server_context,
                                chatty=True, connectionchatty=True,
                                sni_name=hostname)
-        except ssl.SSLError:
-            self.fail("mismatch curve did not fail")
 
     def test_selected_alpn_protocol(self):
         # selected_alpn_protocol() is None unless ALPN is used.
@@ -4152,7 +4099,6 @@ def check_common_name(self, stats, name):
         cert = stats['peercert']
         self.assertIn((('commonName', name),), cert['subject'])
 
-    @needs_sni
     def test_sni_callback(self):
         calls = []
         server_context, other_context, client_context = self.sni_contexts()
@@ -4193,7 +4139,6 @@ def servername_cb(ssl_sock, server_name, initial_context):
         self.check_common_name(stats, SIGNED_CERTFILE_HOSTNAME)
         self.assertEqual(calls, [])
 
-    @needs_sni
     def test_sni_callback_alert(self):
         # Returning a TLS alert is reflected to the connecting client
         server_context, other_context, client_context = self.sni_contexts()
@@ -4207,7 +4152,6 @@ def cb_returning_alert(ssl_sock, server_name, initial_context):
                                        sni_name='supermessage')
         self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_ACCESS_DENIED')
 
-    @needs_sni
     def test_sni_callback_raising(self):
         # Raising fails the connection with a TLS handshake failure alert.
         server_context, other_context, client_context = self.sni_contexts()
@@ -4226,7 +4170,6 @@ def cb_raising(ssl_sock, server_name, initial_context):
                              'SSLV3_ALERT_HANDSHAKE_FAILURE')
             self.assertEqual(catch.unraisable.exc_type, ZeroDivisionError)
 
-    @needs_sni
     def test_sni_callback_wrong_return_type(self):
         # Returning the wrong return type terminates the TLS connection
         # with an internal error alert.



More information about the Python-checkins mailing list