[Python-ideas] Add method unittest.TestProgram to add custom arguments

Rémi Lapeyre remi.lapeyre at henki.fr
Tue Mar 12 10:16:18 EDT 2019


Hi everybody,

I would like to add a new method to unittest.TestProgram that would let
the user add its own arguments to TestProgram so its behavior can be
changed from the command line, and make them accessible from a TestCase.

This would allow for more customization when running the tests.

Currently there is no way to do this, so users either have to rely on
environment variables or parse sys.argv themselves like what is proposed
in https://stackoverflow.com/a/46061210.

This approach have several disadvantages:
  - it's more work to check and parse the argument
  - it's not discoverable from the help
  - everyone have its own way of doing this

Here's two cases where this would be useful for me:
  - in one project, I have an option to run tests in "record mode". In
  this mode, all tests are executed but the results are saved to files
  instead of checking their validity. In subsequent runs, results are
  compared to those files to make sure they are correct:

    def _test_payload(obj, filename):
        import json, os

        filename = f"test/results/{filename}"
        if "RECORD_PAYLOADS" in os.environ:
            with open(filename, "w") as f:
                json.dump(obj, f, indent=2)
        else:
            with open(filename) as f:
                assert obj == json.load(f)

  This let us updating tests very quickly and just checking they are
  correct by looking at the diff. This is very useful for testing an
  API with many endpoints.

  As you can see, we use an environment variable to change TestCase
  behavior but it is awkward to run tests as
  `RECORD_PAYLOADS=1 python test_api.py -v` instead of a more
  straightforward `python test_api.py -v --record-payloads`.

  - in https://bugs.python.org/issue18765 (unittest needs a way to launch pdb.post_mortem or other debug hooks)
  Gregory P. Smith and Michael Foord propose to make TestCase call
  pdb.post_mortem() on failure. This also requires a way to change the
  behavior as this would not be wanted in CI.

Just subclassing TestProgram would not be sufficient as the result of the parsing
would not be accessible from the TestCase.

One argument against making such change is that it would require adding
a new attribute to TestCase which would not be backward compatible and
which could break existing code but I think this can be manageable if we
choose an appropriate name for it.

I attached an implementation of this feature (also available at https://github.com/remilapeyre/cpython/tree/unittest-custom-arguments)
and look forward for your comments.

What do you think about this?

---
 Lib/unittest/case.py                |   7 +-
 Lib/unittest/loader.py              |  73 +++++----
 Lib/unittest/main.py                |  49 +++++-
 Lib/unittest/test/test_discovery.py | 225 ++++++++++++++++++++++++----
 Lib/unittest/test/test_program.py   |  34 ++++-
 5 files changed, 331 insertions(+), 57 deletions(-)

diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py
index a157ae8a14..5698e3a640 100644
--- a/Lib/unittest/case.py
+++ b/Lib/unittest/case.py
@@ -1,6 +1,7 @@
 """Test case implementation"""

 import sys
+import argparse
 import functools
 import difflib
 import logging
@@ -416,7 +417,7 @@ class TestCase(object):

     _class_cleanups = []

-    def __init__(self, methodName='runTest'):
+    def __init__(self, methodName='runTest', command_line_arguments=None):
         """Create an instance of the class that will use the named test
            method when executed. Raises a ValueError if the instance does
            not have a method with the specified name.
@@ -436,6 +437,10 @@ class TestCase(object):
             self._testMethodDoc = testMethod.__doc__
         self._cleanups = []
         self._subtest = None
+        if command_line_arguments is None:
+            self.command_line_arguments = argparse.Namespace()
+        else:
+            self.command_line_arguments = command_line_arguments

         # Map types to custom assertEqual functions that will compare
         # instances of said type in more detail to generate a more useful
diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py
index ba7105e1ad..d5f0a97213 100644
--- a/Lib/unittest/loader.py
+++ b/Lib/unittest/loader.py
@@ -5,6 +5,7 @@ import re
 import sys
 import traceback
 import types
+import itertools
 import functools
 import warnings

@@ -81,7 +82,7 @@ class TestLoader(object):
         # avoid infinite re-entrancy.
         self._loading_packages = set()

-    def loadTestsFromTestCase(self, testCaseClass):
+    def loadTestsFromTestCase(self, testCaseClass, command_line_arguments=None):
         """Return a suite of all test cases contained in testCaseClass"""
         if issubclass(testCaseClass, suite.TestSuite):
             raise TypeError("Test cases should not be derived from "
@@ -90,12 +91,21 @@ class TestLoader(object):
         testCaseNames = self.getTestCaseNames(testCaseClass)
         if not testCaseNames and hasattr(testCaseClass, 'runTest'):
             testCaseNames = ['runTest']
-        loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
+
+        # keep backward compatibility for subclasses that override __init__
+        def instanciate_testcase(testCaseClass, testCaseName):
+            try:
+                return testCaseClass(testCaseName, command_line_arguments)
+            except TypeError:
+                return testCaseClass(testCaseName)
+        loaded_suite = self.suiteClass(
+            map(instanciate_testcase, itertools.repeat(testCaseClass), testCaseNames)
+        )
         return loaded_suite

     # XXX After Python 3.5, remove backward compatibility hacks for
     # use_load_tests deprecation via *args and **kws.  See issue 16662.
-    def loadTestsFromModule(self, module, *args, pattern=None, **kws):
+    def loadTestsFromModule(self, module, *args, pattern=None, command_line_arguments=None, **kws):
         """Return a suite of all test cases contained in the given module"""
         # This method used to take an undocumented and unofficial
         # use_load_tests argument.  For backward compatibility, we still
@@ -121,7 +131,7 @@ class TestLoader(object):
         for name in dir(module):
             obj = getattr(module, name)
             if isinstance(obj, type) and issubclass(obj, case.TestCase):
-                tests.append(self.loadTestsFromTestCase(obj))
+                tests.append(self.loadTestsFromTestCase(obj, command_line_arguments))

         load_tests = getattr(module, 'load_tests', None)
         tests = self.suiteClass(tests)
@@ -135,7 +145,7 @@ class TestLoader(object):
                 return error_case
         return tests

-    def loadTestsFromName(self, name, module=None):
+    def loadTestsFromName(self, name, module=None, command_line_arguments=None):
         """Return a suite of all test cases given a string specifier.

         The name may resolve either to a module, a test case class, a
@@ -188,9 +198,9 @@ class TestLoader(object):
                     return error_case

         if isinstance(obj, types.ModuleType):
-            return self.loadTestsFromModule(obj)
+            return self.loadTestsFromModule(obj, command_line_arguments)
         elif isinstance(obj, type) and issubclass(obj, case.TestCase):
-            return self.loadTestsFromTestCase(obj)
+            return self.loadTestsFromTestCase(obj, command_line_arguments)
         elif (isinstance(obj, types.FunctionType) and
               isinstance(parent, type) and
               issubclass(parent, case.TestCase)):
@@ -213,11 +223,11 @@ class TestLoader(object):
         else:
             raise TypeError("don't know how to make test from: %s" % obj)

-    def loadTestsFromNames(self, names, module=None):
+    def loadTestsFromNames(self, names, module=None, command_line_arguments=None):
         """Return a suite of all test cases found using the given sequence
         of string specifiers. See 'loadTestsFromName()'.
         """
-        suites = [self.loadTestsFromName(name, module) for name in names]
+        suites = [self.loadTestsFromName(name, module, command_line_arguments) for name in names]
         return self.suiteClass(suites)

     def getTestCaseNames(self, testCaseClass):
@@ -239,7 +249,7 @@ class TestLoader(object):
             testFnNames.sort(key=functools.cmp_to_key(self.sortTestMethodsUsing))
         return testFnNames

-    def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
+    def discover(self, start_dir, pattern='test*.py', top_level_dir=None, command_line_arguments=None):
         """Find and return all test modules from the specified start
         directory, recursing into subdirectories to find them and return all
         tests found within them. Only test files that match the pattern will
@@ -322,9 +332,12 @@ class TestLoader(object):
                                 self._top_level_dir = \
                                     (path.split(the_module.__name__
                                          .replace(".", os.path.sep))[0])
-                                tests.extend(self._find_tests(path,
-                                                              pattern,
-                                                              namespace=True))
+                                tests.extend(self._find_tests(
+                                    path,
+                                    pattern,
+                                    namespace=True,
+                                    command_line_arguments=command_line_arguments
+                                ))
                     elif the_module.__name__ in sys.builtin_module_names:
                         # builtin module
                         raise TypeError('Can not use builtin modules '
@@ -346,7 +359,7 @@ class TestLoader(object):
             raise ImportError('Start directory is not importable: %r' % start_dir)

         if not is_namespace:
-            tests = list(self._find_tests(start_dir, pattern))
+            tests = list(self._find_tests(start_dir, pattern, command_line_arguments=command_line_arguments))
         return self.suiteClass(tests)

     def _get_directory_containing_module(self, module_name):
@@ -381,7 +394,7 @@ class TestLoader(object):
         # override this method to use alternative matching strategy
         return fnmatch(path, pattern)

-    def _find_tests(self, start_dir, pattern, namespace=False):
+    def _find_tests(self, start_dir, pattern, namespace=False, command_line_arguments=None):
         """Used by discovery. Yields test suites it loads."""
         # Handle the __init__ in this package
         name = self._get_name_from_path(start_dir)
@@ -391,7 +404,7 @@ class TestLoader(object):
             # name is in self._loading_packages while we have called into
             # loadTestsFromModule with name.
             tests, should_recurse = self._find_test_path(
-                start_dir, pattern, namespace)
+                start_dir, pattern, namespace, command_line_arguments)
             if tests is not None:
                 yield tests
             if not should_recurse:
@@ -403,7 +416,7 @@ class TestLoader(object):
         for path in paths:
             full_path = os.path.join(start_dir, path)
             tests, should_recurse = self._find_test_path(
-                full_path, pattern, namespace)
+                full_path, pattern, namespace, command_line_arguments)
             if tests is not None:
                 yield tests
             if should_recurse:
@@ -411,11 +424,16 @@ class TestLoader(object):
                 name = self._get_name_from_path(full_path)
                 self._loading_packages.add(name)
                 try:
-                    yield from self._find_tests(full_path, pattern, namespace)
+                    yield from self._find_tests(
+                        full_path,
+                        pattern,
+                        namespace,
+                        command_line_arguments=command_line_arguments
+                    )
                 finally:
                     self._loading_packages.discard(name)

-    def _find_test_path(self, full_path, pattern, namespace=False):
+    def _find_test_path(self, full_path, pattern, namespace=False, command_line_arguments=None):
         """Used by discovery.

         Loads tests from a single file, or a directories' __init__.py when
@@ -457,7 +475,8 @@ class TestLoader(object):
                            "%r. Is this module globally installed?")
                     raise ImportError(
                         msg % (mod_name, module_dir, expected_dir))
-                return self.loadTestsFromModule(module, pattern=pattern), False
+                return self.loadTestsFromModule(module, pattern=pattern,
+                         command_line_arguments=command_line_arguments), False
         elif os.path.isdir(full_path):
             if (not namespace and
                 not os.path.isfile(os.path.join(full_path, '__init__.py'))):
@@ -480,7 +499,11 @@ class TestLoader(object):
                 # Mark this package as being in load_tests (possibly ;))
                 self._loading_packages.add(name)
                 try:
-                    tests = self.loadTestsFromModule(package, pattern=pattern)
+                    tests = self.loadTestsFromModule(
+                        package,
+                        pattern=pattern,
+                        command_line_arguments=command_line_arguments
+                    )
                     if load_tests is not None:
                         # loadTestsFromModule(package) has loaded tests for us.
                         return tests, False
@@ -507,11 +530,11 @@ def getTestCaseNames(testCaseClass, prefix, sortUsing=util.three_way_cmp, testNa
     return _makeLoader(prefix, sortUsing, testNamePatterns=testNamePatterns).getTestCaseNames(testCaseClass)

 def makeSuite(testCaseClass, prefix='test', sortUsing=util.three_way_cmp,
-              suiteClass=suite.TestSuite):
+              suiteClass=suite.TestSuite, command_line_arguments=None):
     return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(
-        testCaseClass)
+        testCaseClass, command_line_arguments)

 def findTestCases(module, prefix='test', sortUsing=util.three_way_cmp,
-                  suiteClass=suite.TestSuite):
+                  suiteClass=suite.TestSuite, command_line_arguments=None):
     return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(\
-        module)
+        module, command_line_arguments=command_line_arguments)
diff --git a/Lib/unittest/main.py b/Lib/unittest/main.py
index e62469aa2a..7adc558e5f 100644
--- a/Lib/unittest/main.py
+++ b/Lib/unittest/main.py
@@ -51,6 +51,8 @@ def _convert_select_pattern(pattern):
         pattern = '*%s*' % pattern
     return pattern

+_options = ('verbosity', 'tb_locals', 'failfast', 'catchbreak', 'buffer', 'tests',
+            'testNamePatterns', 'tests', 'start', 'pattern', 'top', 'exit')

 class TestProgram(object):
     """A command-line program that runs a set of tests; this is primarily
@@ -100,6 +102,37 @@ class TestProgram(object):
         self.parseArgs(argv)
         self.runTests()

+    def __setattr__(self, name, value):
+        if name in _options:
+            setattr(self.command_line_arguments, name, value)
+        else:
+            super().__setattr__(name, value)
+
+    def __getattribute__(self, name):
+        if name in _options:
+            try:
+                return getattr(self.command_line_arguments, name)
+            except AttributeError:
+                pass
+
+        try:
+            return super().__getattribute__(name)
+        except AttributeError:
+            if name == 'command_line_arguments':
+                namespace = argparse.Namespace()
+                # preload command_line_arguments with class arguments
+                # this is useful for subclasses of TestProgram that override __init__
+                for name in _options:
+                    try:
+                        value = super().__getattribute__(name)
+                        setattr(namespace, name, value)
+                    except AttributeError:
+                        pass
+                self.command_line_arguments = namespace
+                return namespace
+            else:
+                raise
+
     def usageExit(self, msg=None):
         if msg:
             print(msg)
@@ -123,14 +156,14 @@ class TestProgram(object):
             if len(argv) > 1 and argv[1].lower() == 'discover':
                 self._do_discovery(argv[2:])
                 return
-            self._main_parser.parse_args(argv[1:], self)
+            self._main_parser.parse_args(argv[1:], self.command_line_arguments)
             if not self.tests:
                 # this allows "python -m unittest -v" to still work for
                 # test discovery.
                 self._do_discovery([])
                 return
         else:
-            self._main_parser.parse_args(argv[1:], self)
+            self._main_parser.parse_args(argv[1:], self.command_line_arguments)

         if self.tests:
             self.testNames = _convert_names(self.tests)
@@ -151,7 +184,12 @@ class TestProgram(object):
             self.testLoader.testNamePatterns = self.testNamePatterns
         if from_discovery:
             loader = self.testLoader if Loader is None else Loader()
-            self.test = loader.discover(self.start, self.pattern, self.top)
+            self.test = loader.discover(
+                self.start,
+                self.pattern,
+                self.top,
+                self.command_line_arguments
+            )
         elif self.testNames is None:
             self.test = self.testLoader.loadTestsFromModule(self.module)
         else:
@@ -196,8 +234,13 @@ class TestProgram(object):
                                 help='Only run tests which match the given substring')
             self.testNamePatterns = []

+        self.addCustomArguments(parser)
+
         return parser

+    def addCustomArguments(self, parser):
+        pass
+
     def _getMainArgParser(self, parent):
         parser = argparse.ArgumentParser(parents=[parent])
         parser.prog = self.progName
diff --git a/Lib/unittest/test/test_discovery.py b/Lib/unittest/test/test_discovery.py
index 204043b493..3a2112eb7e 100644
--- a/Lib/unittest/test/test_discovery.py
+++ b/Lib/unittest/test/test_discovery.py
@@ -6,6 +6,7 @@ import types
 import pickle
 from test import support
 import test.test_importlib.util
+from argparse import Namespace

 import unittest
 import unittest.mock
@@ -19,7 +20,6 @@ class TestableTestProgram(unittest.TestProgram):
     verbosity = 1
     progName = ''
     testRunner = testLoader = None
-
     def __init__(self):
         pass

@@ -72,9 +72,13 @@ class TestDiscovery(unittest.TestCase):

         loader._get_module_from_name = lambda path: path + ' module'
         orig_load_tests = loader.loadTestsFromModule
-        def loadTestsFromModule(module, pattern=None):
+        def loadTestsFromModule(module, pattern=None, command_line_arguments=None):
             # This is where load_tests is called.
-            base = orig_load_tests(module, pattern=pattern)
+            base = orig_load_tests(
+                module,
+                pattern=pattern,
+                command_line_arguments=command_line_arguments
+            )
             return base + [module + ' tests']
         loader.loadTestsFromModule = loadTestsFromModule
         loader.suiteClass = lambda thing: thing
@@ -118,9 +122,13 @@ class TestDiscovery(unittest.TestCase):

         loader._get_module_from_name = lambda path: path + ' module'
         orig_load_tests = loader.loadTestsFromModule
-        def loadTestsFromModule(module, pattern=None):
+        def loadTestsFromModule(module, pattern=None, command_line_arguments=None):
             # This is where load_tests is called.
-            base = orig_load_tests(module, pattern=pattern)
+            base = orig_load_tests(
+                module,
+                pattern=pattern,
+                command_line_arguments=command_line_arguments
+            )
             return base + [module + ' tests']
         loader.loadTestsFromModule = loadTestsFromModule
         loader.suiteClass = lambda thing: thing
@@ -173,9 +181,13 @@ class TestDiscovery(unittest.TestCase):

         loader._get_module_from_name = lambda name: Module(name)
         orig_load_tests = loader.loadTestsFromModule
-        def loadTestsFromModule(module, pattern=None):
+        def loadTestsFromModule(module, pattern=None, command_line_arguments=None):
             # This is where load_tests is called.
-            base = orig_load_tests(module, pattern=pattern)
+            base = orig_load_tests(
+                module,
+                pattern=pattern,
+                command_line_arguments=command_line_arguments
+            )
             return base + [module.path + ' module tests']
         loader.loadTestsFromModule = loadTestsFromModule
         loader.suiteClass = lambda thing: thing
@@ -247,9 +259,13 @@ class TestDiscovery(unittest.TestCase):

         loader._get_module_from_name = lambda name: Module(name)
         orig_load_tests = loader.loadTestsFromModule
-        def loadTestsFromModule(module, pattern=None):
+        def loadTestsFromModule(module, pattern=None, command_line_arguments=None):
             # This is where load_tests is called.
-            base = orig_load_tests(module, pattern=pattern)
+            base = orig_load_tests(
+                module,
+                pattern=pattern,
+                command_line_arguments=command_line_arguments
+            )
             return base + [module.path + ' module tests']
         loader.loadTestsFromModule = loadTestsFromModule
         loader.suiteClass = lambda thing: thing
@@ -395,7 +411,7 @@ class TestDiscovery(unittest.TestCase):
         self.addCleanup(restore_isdir)

         _find_tests_args = []
-        def _find_tests(start_dir, pattern, namespace=None):
+        def _find_tests(start_dir, pattern, namespace=None, command_line_arguments=None):
             _find_tests_args.append((start_dir, pattern))
             return ['tests']
         loader._find_tests = _find_tests
@@ -633,75 +649,214 @@ class TestDiscovery(unittest.TestCase):

         class Loader(object):
             args = []
-            def discover(self, start_dir, pattern, top_level_dir):
-                self.args.append((start_dir, pattern, top_level_dir))
+            def discover(self, start_dir, pattern, top_level_dir, command_line_arguments):
+                self.args.append((start_dir, pattern, top_level_dir, command_line_arguments))
                 return 'tests'

         program.testLoader = Loader()
         program._do_discovery(['-v'])
-        self.assertEqual(Loader.args, [('.', 'test*.py', None)])
+        self.assertEqual(Loader.args, [(
+            '.',
+            'test*.py',
+            None,
+            Namespace(
+                buffer=False,
+                catchbreak=False,
+                failfast=False,
+                pattern='test*.py',
+                start='.',
+                tb_locals=False,
+                testNamePatterns=[],
+                top=None,
+                verbosity=2
+            )
+        )])

     def test_command_line_handling_do_discovery_calls_loader(self):
         program = TestableTestProgram()

         class Loader(object):
             args = []
-            def discover(self, start_dir, pattern, top_level_dir):
-                self.args.append((start_dir, pattern, top_level_dir))
+            def discover(self, start_dir, pattern, top_level_dir, command_line_arguments):
+                self.args.append((start_dir, pattern, top_level_dir, command_line_arguments))
                 return 'tests'

         program._do_discovery(['-v'], Loader=Loader)
         self.assertEqual(program.verbosity, 2)
         self.assertEqual(program.test, 'tests')
-        self.assertEqual(Loader.args, [('.', 'test*.py', None)])
+        self.assertEqual(Loader.args, [('.', 'test*.py', None, Namespace(
+                buffer=False,
+                catchbreak=False,
+                exit=True,
+                failfast=False,
+                pattern='test*.py',
+                start='.',
+                tb_locals=False,
+                testNamePatterns=[],
+                top=None,
+                verbosity=2
+            ))])

         Loader.args = []
         program = TestableTestProgram()
         program._do_discovery(['--verbose'], Loader=Loader)
         self.assertEqual(program.test, 'tests')
-        self.assertEqual(Loader.args, [('.', 'test*.py', None)])
+        self.assertEqual(Loader.args, [('.', 'test*.py', None, Namespace(
+            buffer=False,
+            catchbreak=False,
+            exit=True,
+            failfast=False,
+            pattern='test*.py',
+            start='.',
+            tb_locals=False,
+            testNamePatterns=[],
+            top=None,
+            verbosity=2
+        ))])

         Loader.args = []
         program = TestableTestProgram()
         program._do_discovery([], Loader=Loader)
         self.assertEqual(program.test, 'tests')
-        self.assertEqual(Loader.args, [('.', 'test*.py', None)])
+        self.assertEqual(Loader.args, [('.', 'test*.py', None, Namespace(
+            buffer=False,
+            catchbreak=False,
+            exit=True,
+            failfast=False,
+            pattern='test*.py',
+            start='.',
+            tb_locals=False,
+            testNamePatterns=[],
+            top=None,
+            verbosity=1
+        ))])

         Loader.args = []
         program = TestableTestProgram()
         program._do_discovery(['fish'], Loader=Loader)
         self.assertEqual(program.test, 'tests')
-        self.assertEqual(Loader.args, [('fish', 'test*.py', None)])
+        self.assertEqual(Loader.args, [('fish', 'test*.py', None, Namespace(
+            buffer=False,
+            catchbreak=False,
+            exit=True,
+            failfast=False,
+            pattern='test*.py',
+            start='fish',
+            tb_locals=False,
+            testNamePatterns=[],
+            top=None,
+            verbosity=1
+        ))])

         Loader.args = []
         program = TestableTestProgram()
         program._do_discovery(['fish', 'eggs'], Loader=Loader)
         self.assertEqual(program.test, 'tests')
-        self.assertEqual(Loader.args, [('fish', 'eggs', None)])
+        self.assertEqual(Loader.args, [(
+            'fish',
+            'eggs',
+            None,
+            Namespace(
+                buffer=False,
+                catchbreak=False,
+                exit=True,
+                failfast=False,
+                pattern='eggs',
+                start='fish',
+                tb_locals=False,
+                testNamePatterns=[],
+                top=None,
+                verbosity=1
+            )
+        )])

         Loader.args = []
         program = TestableTestProgram()
         program._do_discovery(['fish', 'eggs', 'ham'], Loader=Loader)
         self.assertEqual(program.test, 'tests')
-        self.assertEqual(Loader.args, [('fish', 'eggs', 'ham')])
+        self.assertEqual(Loader.args, [(
+            'fish',
+            'eggs',
+            'ham',
+            Namespace(
+                buffer=False,
+                catchbreak=False,
+                exit=True,
+                failfast=False,
+                pattern='eggs',
+                start='fish',
+                tb_locals=False,
+                testNamePatterns=[],
+                top='ham',
+                verbosity=1
+            )
+        )])

         Loader.args = []
         program = TestableTestProgram()
         program._do_discovery(['-s', 'fish'], Loader=Loader)
         self.assertEqual(program.test, 'tests')
-        self.assertEqual(Loader.args, [('fish', 'test*.py', None)])
+        self.assertEqual(Loader.args, [(
+            'fish',
+            'test*.py',
+            None,
+            Namespace(
+                buffer=False,
+                catchbreak=False,
+                exit=True,
+                failfast=False,
+                pattern='test*.py',
+                start='fish',
+                tb_locals=False,
+                testNamePatterns=[],
+                top=None,
+                verbosity=1
+            )
+        )])

         Loader.args = []
         program = TestableTestProgram()
         program._do_discovery(['-t', 'fish'], Loader=Loader)
         self.assertEqual(program.test, 'tests')
-        self.assertEqual(Loader.args, [('.', 'test*.py', 'fish')])
+        self.assertEqual(Loader.args, [(
+            '.',
+            'test*.py',
+            'fish',
+            Namespace(
+                buffer=False,
+                catchbreak=False,
+                exit=True,
+                failfast=False,
+                pattern='test*.py',
+                start='.',
+                tb_locals=False,
+                testNamePatterns=[],
+                top='fish',
+                verbosity=1
+            )
+        )])

         Loader.args = []
         program = TestableTestProgram()
         program._do_discovery(['-p', 'fish'], Loader=Loader)
         self.assertEqual(program.test, 'tests')
-        self.assertEqual(Loader.args, [('.', 'fish', None)])
+        self.assertEqual(Loader.args, [(
+            '.',
+            'fish',
+            None,
+            Namespace(
+                buffer=False,
+                catchbreak=False,
+                exit=True,
+                failfast=False,
+                pattern='fish',
+                start='.',
+                tb_locals=False,
+                testNamePatterns=[],
+                top=None,
+                verbosity=1
+            )
+        )])
         self.assertFalse(program.failfast)
         self.assertFalse(program.catchbreak)

@@ -710,7 +865,23 @@ class TestDiscovery(unittest.TestCase):
         program._do_discovery(['-p', 'eggs', '-s', 'fish', '-v', '-f', '-c'],
                               Loader=Loader)
         self.assertEqual(program.test, 'tests')
-        self.assertEqual(Loader.args, [('fish', 'eggs', None)])
+        self.assertEqual(Loader.args, [(
+            'fish',
+            'eggs',
+            None,
+            Namespace(
+                buffer=False,
+                catchbreak=True,
+                exit=True,
+                failfast=True,
+                pattern='eggs',
+                start='fish',
+                tb_locals=False,
+                testNamePatterns=[],
+                top=None,
+                verbosity=2
+            )
+        )])
         self.assertEqual(program.verbosity, 2)
         self.assertTrue(program.failfast)
         self.assertTrue(program.catchbreak)
@@ -785,7 +956,7 @@ class TestDiscovery(unittest.TestCase):
         expectedPath = os.path.abspath(os.path.dirname(unittest.test.__file__))

         self.wasRun = False
-        def _find_tests(start_dir, pattern, namespace=None):
+        def _find_tests(start_dir, pattern, namespace=None, command_line_arguments=None):
             self.wasRun = True
             self.assertEqual(start_dir, expectedPath)
             return tests
@@ -833,7 +1004,7 @@ class TestDiscovery(unittest.TestCase):
             return package

         _find_tests_args = []
-        def _find_tests(start_dir, pattern, namespace=None):
+        def _find_tests(start_dir, pattern, namespace=None, command_line_arguments=None):
             _find_tests_args.append((start_dir, pattern))
             return ['%s/tests' % start_dir]

diff --git a/Lib/unittest/test/test_program.py b/Lib/unittest/test/test_program.py
index 4a62ae1b11..a6873089f2 100644
--- a/Lib/unittest/test/test_program.py
+++ b/Lib/unittest/test/test_program.py
@@ -6,6 +6,7 @@ import subprocess
 from test import support
 import unittest
 import unittest.test
+from argparse import Namespace


 class Test_TestProgram(unittest.TestCase):
@@ -17,7 +18,7 @@ class Test_TestProgram(unittest.TestCase):
         expectedPath = os.path.abspath(os.path.dirname(unittest.test.__file__))

         self.wasRun = False
-        def _find_tests(start_dir, pattern):
+        def _find_tests(start_dir, pattern, command_line_arguments=None):
             self.wasRun = True
             self.assertEqual(start_dir, expectedPath)
             return tests
@@ -437,6 +438,37 @@ class TestCommandLineArgs(unittest.TestCase):
         self.assertIn('Ran 7 tests', run_unittest(['-k', '*test_warnings.*Warning*', t]))
         self.assertIn('Ran 1 test', run_unittest(['-k', '*test_warnings.*warning*', t]))

+    def testCustomCommandLineArguments(self):
+        class FakeTP(unittest.TestProgram):
+            def addCustomArguments(self, parser):
+                parser.add_argument('--foo', action="store_true", help='foo help')
+            def runTests(self, *args, **kw): pass
+
+        fp = FakeTP(argv=["testprogram"])
+        self.assertEqual(fp.command_line_arguments, Namespace(
+            buffer=False,
+            catchbreak=False,
+            exit=True,
+            failfast=False,
+            foo=False,
+            tb_locals=False,
+            testNamePatterns=[],
+            tests=[],
+            verbosity=1
+        ))
+        fp = FakeTP(argv=["testprogram", "--foo"])
+        self.assertEqual(fp.command_line_arguments, Namespace(
+            buffer=False,
+            catchbreak=False,
+            exit=True,
+            failfast=False,
+            foo=True,
+            tb_locals=False,
+            testNamePatterns=[],
+            tests=[],
+            verbosity=1
+        ))
+

 if __name__ == '__main__':
     unittest.main()
--
2.20.1



More information about the Python-ideas mailing list