[Python-checkins] cpython: Make parameterized tests in email less hackish.

r.david.murray python-checkins at python.org
Thu May 31 03:54:12 CEST 2012


http://hg.python.org/cpython/rev/e6a33938b03f
changeset:   77251:e6a33938b03f
user:        R David Murray <rdmurray at bitdance.com>
date:        Wed May 30 21:53:40 2012 -0400
summary:
  Make parameterized tests in email less hackish.

Or perhaps more hackish, depending on your perspective.  But at least this
way it is now possible to run the individual tests using the unittest CLI.

files:
  Lib/test/test_email/__init__.py            |  79 ++++++++++
  Lib/test/test_email/test_generator.py      |  41 +----
  Lib/test/test_email/test_headerregistry.py |  26 +--
  Lib/test/test_email/test_pickleable.py     |  69 +++-----
  4 files changed, 122 insertions(+), 93 deletions(-)


diff --git a/Lib/test/test_email/__init__.py b/Lib/test/test_email/__init__.py
--- a/Lib/test/test_email/__init__.py
+++ b/Lib/test/test_email/__init__.py
@@ -71,3 +71,82 @@
         for i in range(len(actual)):
             self.assertIsInstance(actual[i], expected[i],
                                     'item {}'.format(i))
+
+
+# Metaclass to allow for parameterized tests
+class Parameterized(type):
+
+    """Provide a test method parameterization facility.
+
+    Parameters are specified as the value of a class attribute that ends with
+    the string '_params'.  Call the portion before '_params' the prefix.  Then
+    a method to be parameterized must have the same prefix, the string
+    '_as_', and an arbitrary suffix.
+
+    The value of the _params attribute may be either a dictionary or a list.
+    The values in the dictionary and the elements of the list may either be
+    single values, or a list.  If single values, they are turned into single
+    element tuples.  However derived, the resulting sequence is passed via
+    *args to the parameterized test function.
+
+    In a _params dictioanry, the keys become part of the name of the generated
+    tests.  In a _params list, the values in the list are converted into a
+    string by joining the string values of the elements of the tuple by '_' and
+    converting any blanks into '_'s, and this become part of the name.  The
+    full name of a generated test is the portion of the _params name before the
+    '_params' portion, plus an '_', plus the name derived as explained above.
+
+    For example, if we have:
+
+        count_params = range(2)
+
+        def count_as_foo_arg(self, foo):
+            self.assertEqual(foo+1, myfunc(foo))
+
+    we will get parameterized test methods named:
+        test_foo_arg_0
+        test_foo_arg_1
+        test_foo_arg_2
+
+    Or we could have:
+
+        example_params = {'foo': ('bar', 1), 'bing': ('bang', 2)}
+
+        def example_as_myfunc_input(self, name, count):
+            self.assertEqual(name+str(count), myfunc(name, count))
+
+    and get:
+        test_myfunc_input_foo
+        test_myfunc_input_bing
+
+    Note: if and only if the generated test name is a valid identifier can it
+    be used to select the test individually from the unittest command line.
+
+    """
+
+    def __new__(meta, classname, bases, classdict):
+        paramdicts = {}
+        for name, attr in classdict.items():
+            if name.endswith('_params'):
+                if not hasattr(attr, 'keys'):
+                    d = {}
+                    for x in attr:
+                        if not hasattr(x, '__iter__'):
+                            x = (x,)
+                        n = '_'.join(str(v) for v in x).replace(' ', '_')
+                        d[n] = x
+                    attr = d
+                paramdicts[name[:-7] + '_as_'] = attr
+        testfuncs = {}
+        for name, attr in classdict.items():
+            for paramsname, paramsdict in paramdicts.items():
+                if name.startswith(paramsname):
+                    testnameroot = 'test_' + name[len(paramsname):]
+                    for paramname, params in paramsdict.items():
+                        test = (lambda self, name=name, params=params:
+                                        getattr(self, name)(*params))
+                        testname = testnameroot + '_' + paramname
+                        test.__name__ = testname
+                        testfuncs[testname] = test
+        classdict.update(testfuncs)
+        return super().__new__(meta, classname, bases, classdict)
diff --git a/Lib/test/test_email/test_generator.py b/Lib/test/test_email/test_generator.py
--- a/Lib/test/test_email/test_generator.py
+++ b/Lib/test/test_email/test_generator.py
@@ -4,10 +4,10 @@
 from email import message_from_string, message_from_bytes
 from email.generator import Generator, BytesGenerator
 from email import policy
-from test.test_email import TestEmailBase
+from test.test_email import TestEmailBase, Parameterized
 
 
-class TestGeneratorBase:
+class TestGeneratorBase(metaclass=Parameterized):
 
     policy = policy.default
 
@@ -80,31 +80,23 @@
               "\n"
               "None\n")
 
-    def _test_maxheaderlen_parameter(self, n):
+    length_params = [n for n in refold_long_expected]
+
+    def length_as_maxheaderlen_parameter(self, n):
         msg = self.msgmaker(self.typ(self.refold_long_expected[0]))
         s = self.ioclass()
         g = self.genclass(s, maxheaderlen=n, policy=self.policy)
         g.flatten(msg)
         self.assertEqual(s.getvalue(), self.typ(self.refold_long_expected[n]))
 
-    for n in refold_long_expected:
-        locals()['test_maxheaderlen_parameter_' + str(n)] = (
-            lambda self, n=n:
-                self._test_maxheaderlen_parameter(n))
-
-    def _test_max_line_length_policy(self, n):
+    def length_as_max_line_length_policy(self, n):
         msg = self.msgmaker(self.typ(self.refold_long_expected[0]))
         s = self.ioclass()
         g = self.genclass(s, policy=self.policy.clone(max_line_length=n))
         g.flatten(msg)
         self.assertEqual(s.getvalue(), self.typ(self.refold_long_expected[n]))
 
-    for n in refold_long_expected:
-        locals()['test_max_line_length_policy' + str(n)] = (
-            lambda self, n=n:
-                self._test_max_line_length_policy(n))
-
-    def _test_maxheaderlen_parm_overrides_policy(self, n):
+    def length_as_maxheaderlen_parm_overrides_policy(self, n):
         msg = self.msgmaker(self.typ(self.refold_long_expected[0]))
         s = self.ioclass()
         g = self.genclass(s, maxheaderlen=n,
@@ -112,12 +104,7 @@
         g.flatten(msg)
         self.assertEqual(s.getvalue(), self.typ(self.refold_long_expected[n]))
 
-    for n in refold_long_expected:
-        locals()['test_maxheaderlen_parm_overrides_policy' + str(n)] = (
-            lambda self, n=n:
-                self._test_maxheaderlen_parm_overrides_policy(n))
-
-    def _test_refold_none_does_not_fold(self, n):
+    def length_as_max_line_length_with_refold_none_does_not_fold(self, n):
         msg = self.msgmaker(self.typ(self.refold_long_expected[0]))
         s = self.ioclass()
         g = self.genclass(s, policy=self.policy.clone(refold_source='none',
@@ -125,12 +112,7 @@
         g.flatten(msg)
         self.assertEqual(s.getvalue(), self.typ(self.refold_long_expected[0]))
 
-    for n in refold_long_expected:
-        locals()['test_refold_none_does_not_fold' + str(n)] = (
-            lambda self, n=n:
-                self._test_refold_none_does_not_fold(n))
-
-    def _test_refold_all(self, n):
+    def length_as_max_line_length_with_refold_all_folds(self, n):
         msg = self.msgmaker(self.typ(self.refold_long_expected[0]))
         s = self.ioclass()
         g = self.genclass(s, policy=self.policy.clone(refold_source='all',
@@ -138,11 +120,6 @@
         g.flatten(msg)
         self.assertEqual(s.getvalue(), self.typ(self.refold_all_expected[n]))
 
-    for n in refold_long_expected:
-        locals()['test_refold_all' + str(n)] = (
-            lambda self, n=n:
-                self._test_refold_all(n))
-
     def test_crlf_control_via_policy(self):
         source = "Subject: test\r\n\r\ntest body\r\n"
         expected = source
diff --git a/Lib/test/test_email/test_headerregistry.py b/Lib/test/test_email/test_headerregistry.py
--- a/Lib/test/test_email/test_headerregistry.py
+++ b/Lib/test/test_email/test_headerregistry.py
@@ -4,7 +4,7 @@
 from email import errors
 from email import policy
 from email.message import Message
-from test.test_email import TestEmailBase
+from test.test_email import TestEmailBase, Parameterized
 from email import headerregistry
 from email.headerregistry import Address, Group
 
@@ -175,9 +175,9 @@
         self.assertEqual(m['Date'].datetime, self.dt)
 
 
-class TestAddressHeader(TestHeaderBase):
+class TestAddressHeader(TestHeaderBase, metaclass=Parameterized):
 
-    examples = {
+    example_params = {
 
         'empty':
             ('<>',
@@ -305,8 +305,8 @@
         # trailing comments, which aren't currently handled.  comments in
         # general are not handled yet.
 
-    def _test_single_addr(self, source, defects, decoded, display_name,
-                          addr_spec, username, domain, comment):
+    def example_as_address(self, source, defects, decoded, display_name,
+                           addr_spec, username, domain, comment):
         h = self.make_header('sender', source)
         self.assertEqual(h, decoded)
         self.assertDefectsEqual(h.defects, defects)
@@ -322,13 +322,8 @@
         # XXX: we have no comment support yet.
         #self.assertEqual(a.comment, comment)
 
-    for name in examples:
-        locals()['test_'+name] = (
-            lambda self, name=name:
-                self._test_single_addr(*self.examples[name]))
-
-    def _test_group_single_addr(self, source, defects, decoded, display_name,
-                                addr_spec, username, domain, comment):
+    def example_as_group(self, source, defects, decoded, display_name,
+                         addr_spec, username, domain, comment):
         source = 'foo: {};'.format(source)
         gdecoded = 'foo: {};'.format(decoded) if decoded else 'foo:;'
         h = self.make_header('to', source)
@@ -344,11 +339,6 @@
         self.assertEqual(a.username, username)
         self.assertEqual(a.domain, domain)
 
-    for name in examples:
-        locals()['test_group_'+name] = (
-            lambda self, name=name:
-                self._test_group_single_addr(*self.examples[name]))
-
     def test_simple_address_list(self):
         value = ('Fred <dinsdale at python.org>, foo at example.com, '
                     '"Harry W. Hastings" <hasty at example.com>')
@@ -366,7 +356,7 @@
             'Harry W. Hastings')
 
     def test_complex_address_list(self):
-        examples = list(self.examples.values())
+        examples = list(self.example_params.values())
         source = ('dummy list:;, another: (empty);,' +
                  ', '.join([x[0] for x in examples[:4]]) + ', ' +
                  r'"A \"list\"": ' +
diff --git a/Lib/test/test_email/test_pickleable.py b/Lib/test/test_email/test_pickleable.py
--- a/Lib/test/test_email/test_pickleable.py
+++ b/Lib/test/test_email/test_pickleable.py
@@ -6,83 +6,66 @@
 import email.message
 from email import policy
 from email.headerregistry import HeaderRegistry
-from test.test_email import TestEmailBase
+from test.test_email import TestEmailBase, Parameterized
 
-class TestPickleCopyHeader(TestEmailBase):
+class TestPickleCopyHeader(TestEmailBase, metaclass=Parameterized):
 
     header_factory = HeaderRegistry()
 
     unstructured = header_factory('subject', 'this is a test')
 
-    def _test_deepcopy(self, name, value):
+    header_params = {
+        'subject': ('subject', 'this is a test'),
+        'from':    ('from',    'frodo at mordor.net'),
+        'to':      ('to',      'a: k at b.com, y at z.com;, j at f.com'),
+        'date':    ('date',    'Tue, 29 May 2012 09:24:26 +1000'),
+        }
+
+    def header_as_deepcopy(self, name, value):
         header = self.header_factory(name, value)
         h = copy.deepcopy(header)
         self.assertEqual(str(h), str(header))
 
-    def _test_pickle(self, name, value):
+    def header_as_pickle(self, name, value):
         header = self.header_factory(name, value)
         p = pickle.dumps(header)
         h = pickle.loads(p)
         self.assertEqual(str(h), str(header))
 
-    headers = (
-        ('subject', 'this is a test'),
-        ('from',    'frodo at mordor.net'),
-        ('to',      'a: k at b.com, y at z.com;, j at f.com'),
-        ('date',    'Tue, 29 May 2012 09:24:26 +1000'),
-        )
 
-    for header in headers:
-        locals()['test_deepcopy_'+header[0]] = (
-            lambda self, header=header:
-                self._test_deepcopy(*header))
+class TestPickleCopyMessage(TestEmailBase, metaclass=Parameterized):
 
-    for header in headers:
-        locals()['test_pickle_'+header[0]] = (
-            lambda self, header=header:
-                self._test_pickle(*header))
-
-
-class TestPickleCopyMessage(TestEmailBase):
-
-    msgs = {}
+    # Message objects are a sequence, so we have to make them a one-tuple in
+    # msg_params so they get passed to the parameterized test method as a
+    # single argument instead of as a list of headers.
+    msg_params = {}
 
     # Note: there will be no custom header objects in the parsed message.
-    msgs['parsed'] = email.message_from_string(textwrap.dedent("""\
+    msg_params['parsed'] = (email.message_from_string(textwrap.dedent("""\
         Date: Tue, 29 May 2012 09:24:26 +1000
         From: frodo at mordor.net
         To: bilbo at underhill.org
         Subject: help
 
         I think I forgot the ring.
-        """), policy=policy.default)
+        """), policy=policy.default),)
 
-    msgs['created'] = email.message.Message(policy=policy.default)
-    msgs['created']['Date'] = 'Tue, 29 May 2012 09:24:26 +1000'
-    msgs['created']['From'] = 'frodo at mordor.net'
-    msgs['created']['To'] = 'bilbo at underhill.org'
-    msgs['created']['Subject'] = 'help'
-    msgs['created'].set_payload('I think I forgot the ring.')
+    msg_params['created'] = (email.message.Message(policy=policy.default),)
+    msg_params['created'][0]['Date'] = 'Tue, 29 May 2012 09:24:26 +1000'
+    msg_params['created'][0]['From'] = 'frodo at mordor.net'
+    msg_params['created'][0]['To'] = 'bilbo at underhill.org'
+    msg_params['created'][0]['Subject'] = 'help'
+    msg_params['created'][0].set_payload('I think I forgot the ring.')
 
-    def _test_deepcopy(self, msg):
+    def msg_as_deepcopy(self, msg):
         msg2 = copy.deepcopy(msg)
         self.assertEqual(msg2.as_string(), msg.as_string())
 
-    def _test_pickle(self, msg):
+    def msg_as_pickle(self, msg):
         p = pickle.dumps(msg)
         msg2 = pickle.loads(p)
         self.assertEqual(msg2.as_string(), msg.as_string())
 
-    for name, msg in msgs.items():
-        locals()['test_deepcopy_'+name] = (
-            lambda self, msg=msg:
-                self._test_deepcopy(msg))
-
-    for name, msg in msgs.items():
-        locals()['test_pickle_'+name] = (
-            lambda self, msg=msg:
-                self._test_pickle(msg))
-
 
 if __name__ == '__main__':
     unittest.main()

-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list