Add method unittest.TestProgram to add custom arguments

Hi everybody, I would like to add a new method to unittest.TestProgram that would let the user add its own arguments to TestProgram so its behavior can be changed from the command line, and make them accessible from a TestCase. This would allow for more customization when running the tests. Currently there is no way to do this, so users either have to rely on environment variables or parse sys.argv themselves like what is proposed in https://stackoverflow.com/a/46061210. This approach have several disadvantages: - it's more work to check and parse the argument - it's not discoverable from the help - everyone have its own way of doing this Here's two cases where this would be useful for me: - in one project, I have an option to run tests in "record mode". In this mode, all tests are executed but the results are saved to files instead of checking their validity. In subsequent runs, results are compared to those files to make sure they are correct: def _test_payload(obj, filename): import json, os filename = f"test/results/{filename}" if "RECORD_PAYLOADS" in os.environ: with open(filename, "w") as f: json.dump(obj, f, indent=2) else: with open(filename) as f: assert obj == json.load(f) This let us updating tests very quickly and just checking they are correct by looking at the diff. This is very useful for testing an API with many endpoints. As you can see, we use an environment variable to change TestCase behavior but it is awkward to run tests as `RECORD_PAYLOADS=1 python test_api.py -v` instead of a more straightforward `python test_api.py -v --record-payloads`. - in https://bugs.python.org/issue18765 (unittest needs a way to launch pdb.post_mortem or other debug hooks) Gregory P. Smith and Michael Foord propose to make TestCase call pdb.post_mortem() on failure. This also requires a way to change the behavior as this would not be wanted in CI. Just subclassing TestProgram would not be sufficient as the result of the parsing would not be accessible from the TestCase. One argument against making such change is that it would require adding a new attribute to TestCase which would not be backward compatible and which could break existing code but I think this can be manageable if we choose an appropriate name for it. I attached an implementation of this feature (also available at https://github.com/remilapeyre/cpython/tree/unittest-custom-arguments) and look forward for your comments. What do you think about this? --- Lib/unittest/case.py | 7 +- Lib/unittest/loader.py | 73 +++++---- Lib/unittest/main.py | 49 +++++- Lib/unittest/test/test_discovery.py | 225 ++++++++++++++++++++++++---- Lib/unittest/test/test_program.py | 34 ++++- 5 files changed, 331 insertions(+), 57 deletions(-) diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py index a157ae8a14..5698e3a640 100644 --- a/Lib/unittest/case.py +++ b/Lib/unittest/case.py @@ -1,6 +1,7 @@ """Test case implementation""" import sys +import argparse import functools import difflib import logging @@ -416,7 +417,7 @@ class TestCase(object): _class_cleanups = [] - def __init__(self, methodName='runTest'): + def __init__(self, methodName='runTest', command_line_arguments=None): """Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name. @@ -436,6 +437,10 @@ class TestCase(object): self._testMethodDoc = testMethod.__doc__ self._cleanups = [] self._subtest = None + if command_line_arguments is None: + self.command_line_arguments = argparse.Namespace() + else: + self.command_line_arguments = command_line_arguments # Map types to custom assertEqual functions that will compare # instances of said type in more detail to generate a more useful diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py index ba7105e1ad..d5f0a97213 100644 --- a/Lib/unittest/loader.py +++ b/Lib/unittest/loader.py @@ -5,6 +5,7 @@ import re import sys import traceback import types +import itertools import functools import warnings @@ -81,7 +82,7 @@ class TestLoader(object): # avoid infinite re-entrancy. self._loading_packages = set() - def loadTestsFromTestCase(self, testCaseClass): + def loadTestsFromTestCase(self, testCaseClass, command_line_arguments=None): """Return a suite of all test cases contained in testCaseClass""" if issubclass(testCaseClass, suite.TestSuite): raise TypeError("Test cases should not be derived from " @@ -90,12 +91,21 @@ class TestLoader(object): testCaseNames = self.getTestCaseNames(testCaseClass) if not testCaseNames and hasattr(testCaseClass, 'runTest'): testCaseNames = ['runTest'] - loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames)) + + # keep backward compatibility for subclasses that override __init__ + def instanciate_testcase(testCaseClass, testCaseName): + try: + return testCaseClass(testCaseName, command_line_arguments) + except TypeError: + return testCaseClass(testCaseName) + loaded_suite = self.suiteClass( + map(instanciate_testcase, itertools.repeat(testCaseClass), testCaseNames) + ) return loaded_suite # XXX After Python 3.5, remove backward compatibility hacks for # use_load_tests deprecation via *args and **kws. See issue 16662. - def loadTestsFromModule(self, module, *args, pattern=None, **kws): + def loadTestsFromModule(self, module, *args, pattern=None, command_line_arguments=None, **kws): """Return a suite of all test cases contained in the given module""" # This method used to take an undocumented and unofficial # use_load_tests argument. For backward compatibility, we still @@ -121,7 +131,7 @@ class TestLoader(object): for name in dir(module): obj = getattr(module, name) if isinstance(obj, type) and issubclass(obj, case.TestCase): - tests.append(self.loadTestsFromTestCase(obj)) + tests.append(self.loadTestsFromTestCase(obj, command_line_arguments)) load_tests = getattr(module, 'load_tests', None) tests = self.suiteClass(tests) @@ -135,7 +145,7 @@ class TestLoader(object): return error_case return tests - def loadTestsFromName(self, name, module=None): + def loadTestsFromName(self, name, module=None, command_line_arguments=None): """Return a suite of all test cases given a string specifier. The name may resolve either to a module, a test case class, a @@ -188,9 +198,9 @@ class TestLoader(object): return error_case if isinstance(obj, types.ModuleType): - return self.loadTestsFromModule(obj) + return self.loadTestsFromModule(obj, command_line_arguments) elif isinstance(obj, type) and issubclass(obj, case.TestCase): - return self.loadTestsFromTestCase(obj) + return self.loadTestsFromTestCase(obj, command_line_arguments) elif (isinstance(obj, types.FunctionType) and isinstance(parent, type) and issubclass(parent, case.TestCase)): @@ -213,11 +223,11 @@ class TestLoader(object): else: raise TypeError("don't know how to make test from: %s" % obj) - def loadTestsFromNames(self, names, module=None): + def loadTestsFromNames(self, names, module=None, command_line_arguments=None): """Return a suite of all test cases found using the given sequence of string specifiers. See 'loadTestsFromName()'. """ - suites = [self.loadTestsFromName(name, module) for name in names] + suites = [self.loadTestsFromName(name, module, command_line_arguments) for name in names] return self.suiteClass(suites) def getTestCaseNames(self, testCaseClass): @@ -239,7 +249,7 @@ class TestLoader(object): testFnNames.sort(key=functools.cmp_to_key(self.sortTestMethodsUsing)) return testFnNames - def discover(self, start_dir, pattern='test*.py', top_level_dir=None): + def discover(self, start_dir, pattern='test*.py', top_level_dir=None, command_line_arguments=None): """Find and return all test modules from the specified start directory, recursing into subdirectories to find them and return all tests found within them. Only test files that match the pattern will @@ -322,9 +332,12 @@ class TestLoader(object): self._top_level_dir = \ (path.split(the_module.__name__ .replace(".", os.path.sep))[0]) - tests.extend(self._find_tests(path, - pattern, - namespace=True)) + tests.extend(self._find_tests( + path, + pattern, + namespace=True, + command_line_arguments=command_line_arguments + )) elif the_module.__name__ in sys.builtin_module_names: # builtin module raise TypeError('Can not use builtin modules ' @@ -346,7 +359,7 @@ class TestLoader(object): raise ImportError('Start directory is not importable: %r' % start_dir) if not is_namespace: - tests = list(self._find_tests(start_dir, pattern)) + tests = list(self._find_tests(start_dir, pattern, command_line_arguments=command_line_arguments)) return self.suiteClass(tests) def _get_directory_containing_module(self, module_name): @@ -381,7 +394,7 @@ class TestLoader(object): # override this method to use alternative matching strategy return fnmatch(path, pattern) - def _find_tests(self, start_dir, pattern, namespace=False): + def _find_tests(self, start_dir, pattern, namespace=False, command_line_arguments=None): """Used by discovery. Yields test suites it loads.""" # Handle the __init__ in this package name = self._get_name_from_path(start_dir) @@ -391,7 +404,7 @@ class TestLoader(object): # name is in self._loading_packages while we have called into # loadTestsFromModule with name. tests, should_recurse = self._find_test_path( - start_dir, pattern, namespace) + start_dir, pattern, namespace, command_line_arguments) if tests is not None: yield tests if not should_recurse: @@ -403,7 +416,7 @@ class TestLoader(object): for path in paths: full_path = os.path.join(start_dir, path) tests, should_recurse = self._find_test_path( - full_path, pattern, namespace) + full_path, pattern, namespace, command_line_arguments) if tests is not None: yield tests if should_recurse: @@ -411,11 +424,16 @@ class TestLoader(object): name = self._get_name_from_path(full_path) self._loading_packages.add(name) try: - yield from self._find_tests(full_path, pattern, namespace) + yield from self._find_tests( + full_path, + pattern, + namespace, + command_line_arguments=command_line_arguments + ) finally: self._loading_packages.discard(name) - def _find_test_path(self, full_path, pattern, namespace=False): + def _find_test_path(self, full_path, pattern, namespace=False, command_line_arguments=None): """Used by discovery. Loads tests from a single file, or a directories' __init__.py when @@ -457,7 +475,8 @@ class TestLoader(object): "%r. Is this module globally installed?") raise ImportError( msg % (mod_name, module_dir, expected_dir)) - return self.loadTestsFromModule(module, pattern=pattern), False + return self.loadTestsFromModule(module, pattern=pattern, + command_line_arguments=command_line_arguments), False elif os.path.isdir(full_path): if (not namespace and not os.path.isfile(os.path.join(full_path, '__init__.py'))): @@ -480,7 +499,11 @@ class TestLoader(object): # Mark this package as being in load_tests (possibly ;)) self._loading_packages.add(name) try: - tests = self.loadTestsFromModule(package, pattern=pattern) + tests = self.loadTestsFromModule( + package, + pattern=pattern, + command_line_arguments=command_line_arguments + ) if load_tests is not None: # loadTestsFromModule(package) has loaded tests for us. return tests, False @@ -507,11 +530,11 @@ def getTestCaseNames(testCaseClass, prefix, sortUsing=util.three_way_cmp, testNa return _makeLoader(prefix, sortUsing, testNamePatterns=testNamePatterns).getTestCaseNames(testCaseClass) def makeSuite(testCaseClass, prefix='test', sortUsing=util.three_way_cmp, - suiteClass=suite.TestSuite): + suiteClass=suite.TestSuite, command_line_arguments=None): return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase( - testCaseClass) + testCaseClass, command_line_arguments) def findTestCases(module, prefix='test', sortUsing=util.three_way_cmp, - suiteClass=suite.TestSuite): + suiteClass=suite.TestSuite, command_line_arguments=None): return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(\ - module) + module, command_line_arguments=command_line_arguments) diff --git a/Lib/unittest/main.py b/Lib/unittest/main.py index e62469aa2a..7adc558e5f 100644 --- a/Lib/unittest/main.py +++ b/Lib/unittest/main.py @@ -51,6 +51,8 @@ def _convert_select_pattern(pattern): pattern = '*%s*' % pattern return pattern +_options = ('verbosity', 'tb_locals', 'failfast', 'catchbreak', 'buffer', 'tests', + 'testNamePatterns', 'tests', 'start', 'pattern', 'top', 'exit') class TestProgram(object): """A command-line program that runs a set of tests; this is primarily @@ -100,6 +102,37 @@ class TestProgram(object): self.parseArgs(argv) self.runTests() + def __setattr__(self, name, value): + if name in _options: + setattr(self.command_line_arguments, name, value) + else: + super().__setattr__(name, value) + + def __getattribute__(self, name): + if name in _options: + try: + return getattr(self.command_line_arguments, name) + except AttributeError: + pass + + try: + return super().__getattribute__(name) + except AttributeError: + if name == 'command_line_arguments': + namespace = argparse.Namespace() + # preload command_line_arguments with class arguments + # this is useful for subclasses of TestProgram that override __init__ + for name in _options: + try: + value = super().__getattribute__(name) + setattr(namespace, name, value) + except AttributeError: + pass + self.command_line_arguments = namespace + return namespace + else: + raise + def usageExit(self, msg=None): if msg: print(msg) @@ -123,14 +156,14 @@ class TestProgram(object): if len(argv) > 1 and argv[1].lower() == 'discover': self._do_discovery(argv[2:]) return - self._main_parser.parse_args(argv[1:], self) + self._main_parser.parse_args(argv[1:], self.command_line_arguments) if not self.tests: # this allows "python -m unittest -v" to still work for # test discovery. self._do_discovery([]) return else: - self._main_parser.parse_args(argv[1:], self) + self._main_parser.parse_args(argv[1:], self.command_line_arguments) if self.tests: self.testNames = _convert_names(self.tests) @@ -151,7 +184,12 @@ class TestProgram(object): self.testLoader.testNamePatterns = self.testNamePatterns if from_discovery: loader = self.testLoader if Loader is None else Loader() - self.test = loader.discover(self.start, self.pattern, self.top) + self.test = loader.discover( + self.start, + self.pattern, + self.top, + self.command_line_arguments + ) elif self.testNames is None: self.test = self.testLoader.loadTestsFromModule(self.module) else: @@ -196,8 +234,13 @@ class TestProgram(object): help='Only run tests which match the given substring') self.testNamePatterns = [] + self.addCustomArguments(parser) + return parser + def addCustomArguments(self, parser): + pass + def _getMainArgParser(self, parent): parser = argparse.ArgumentParser(parents=[parent]) parser.prog = self.progName diff --git a/Lib/unittest/test/test_discovery.py b/Lib/unittest/test/test_discovery.py index 204043b493..3a2112eb7e 100644 --- a/Lib/unittest/test/test_discovery.py +++ b/Lib/unittest/test/test_discovery.py @@ -6,6 +6,7 @@ import types import pickle from test import support import test.test_importlib.util +from argparse import Namespace import unittest import unittest.mock @@ -19,7 +20,6 @@ class TestableTestProgram(unittest.TestProgram): verbosity = 1 progName = '' testRunner = testLoader = None - def __init__(self): pass @@ -72,9 +72,13 @@ class TestDiscovery(unittest.TestCase): loader._get_module_from_name = lambda path: path + ' module' orig_load_tests = loader.loadTestsFromModule - def loadTestsFromModule(module, pattern=None): + def loadTestsFromModule(module, pattern=None, command_line_arguments=None): # This is where load_tests is called. - base = orig_load_tests(module, pattern=pattern) + base = orig_load_tests( + module, + pattern=pattern, + command_line_arguments=command_line_arguments + ) return base + [module + ' tests'] loader.loadTestsFromModule = loadTestsFromModule loader.suiteClass = lambda thing: thing @@ -118,9 +122,13 @@ class TestDiscovery(unittest.TestCase): loader._get_module_from_name = lambda path: path + ' module' orig_load_tests = loader.loadTestsFromModule - def loadTestsFromModule(module, pattern=None): + def loadTestsFromModule(module, pattern=None, command_line_arguments=None): # This is where load_tests is called. - base = orig_load_tests(module, pattern=pattern) + base = orig_load_tests( + module, + pattern=pattern, + command_line_arguments=command_line_arguments + ) return base + [module + ' tests'] loader.loadTestsFromModule = loadTestsFromModule loader.suiteClass = lambda thing: thing @@ -173,9 +181,13 @@ class TestDiscovery(unittest.TestCase): loader._get_module_from_name = lambda name: Module(name) orig_load_tests = loader.loadTestsFromModule - def loadTestsFromModule(module, pattern=None): + def loadTestsFromModule(module, pattern=None, command_line_arguments=None): # This is where load_tests is called. - base = orig_load_tests(module, pattern=pattern) + base = orig_load_tests( + module, + pattern=pattern, + command_line_arguments=command_line_arguments + ) return base + [module.path + ' module tests'] loader.loadTestsFromModule = loadTestsFromModule loader.suiteClass = lambda thing: thing @@ -247,9 +259,13 @@ class TestDiscovery(unittest.TestCase): loader._get_module_from_name = lambda name: Module(name) orig_load_tests = loader.loadTestsFromModule - def loadTestsFromModule(module, pattern=None): + def loadTestsFromModule(module, pattern=None, command_line_arguments=None): # This is where load_tests is called. - base = orig_load_tests(module, pattern=pattern) + base = orig_load_tests( + module, + pattern=pattern, + command_line_arguments=command_line_arguments + ) return base + [module.path + ' module tests'] loader.loadTestsFromModule = loadTestsFromModule loader.suiteClass = lambda thing: thing @@ -395,7 +411,7 @@ class TestDiscovery(unittest.TestCase): self.addCleanup(restore_isdir) _find_tests_args = [] - def _find_tests(start_dir, pattern, namespace=None): + def _find_tests(start_dir, pattern, namespace=None, command_line_arguments=None): _find_tests_args.append((start_dir, pattern)) return ['tests'] loader._find_tests = _find_tests @@ -633,75 +649,214 @@ class TestDiscovery(unittest.TestCase): class Loader(object): args = [] - def discover(self, start_dir, pattern, top_level_dir): - self.args.append((start_dir, pattern, top_level_dir)) + def discover(self, start_dir, pattern, top_level_dir, command_line_arguments): + self.args.append((start_dir, pattern, top_level_dir, command_line_arguments)) return 'tests' program.testLoader = Loader() program._do_discovery(['-v']) - self.assertEqual(Loader.args, [('.', 'test*.py', None)]) + self.assertEqual(Loader.args, [( + '.', + 'test*.py', + None, + Namespace( + buffer=False, + catchbreak=False, + failfast=False, + pattern='test*.py', + start='.', + tb_locals=False, + testNamePatterns=[], + top=None, + verbosity=2 + ) + )]) def test_command_line_handling_do_discovery_calls_loader(self): program = TestableTestProgram() class Loader(object): args = [] - def discover(self, start_dir, pattern, top_level_dir): - self.args.append((start_dir, pattern, top_level_dir)) + def discover(self, start_dir, pattern, top_level_dir, command_line_arguments): + self.args.append((start_dir, pattern, top_level_dir, command_line_arguments)) return 'tests' program._do_discovery(['-v'], Loader=Loader) self.assertEqual(program.verbosity, 2) self.assertEqual(program.test, 'tests') - self.assertEqual(Loader.args, [('.', 'test*.py', None)]) + self.assertEqual(Loader.args, [('.', 'test*.py', None, Namespace( + buffer=False, + catchbreak=False, + exit=True, + failfast=False, + pattern='test*.py', + start='.', + tb_locals=False, + testNamePatterns=[], + top=None, + verbosity=2 + ))]) Loader.args = [] program = TestableTestProgram() program._do_discovery(['--verbose'], Loader=Loader) self.assertEqual(program.test, 'tests') - self.assertEqual(Loader.args, [('.', 'test*.py', None)]) + self.assertEqual(Loader.args, [('.', 'test*.py', None, Namespace( + buffer=False, + catchbreak=False, + exit=True, + failfast=False, + pattern='test*.py', + start='.', + tb_locals=False, + testNamePatterns=[], + top=None, + verbosity=2 + ))]) Loader.args = [] program = TestableTestProgram() program._do_discovery([], Loader=Loader) self.assertEqual(program.test, 'tests') - self.assertEqual(Loader.args, [('.', 'test*.py', None)]) + self.assertEqual(Loader.args, [('.', 'test*.py', None, Namespace( + buffer=False, + catchbreak=False, + exit=True, + failfast=False, + pattern='test*.py', + start='.', + tb_locals=False, + testNamePatterns=[], + top=None, + verbosity=1 + ))]) Loader.args = [] program = TestableTestProgram() program._do_discovery(['fish'], Loader=Loader) self.assertEqual(program.test, 'tests') - self.assertEqual(Loader.args, [('fish', 'test*.py', None)]) + self.assertEqual(Loader.args, [('fish', 'test*.py', None, Namespace( + buffer=False, + catchbreak=False, + exit=True, + failfast=False, + pattern='test*.py', + start='fish', + tb_locals=False, + testNamePatterns=[], + top=None, + verbosity=1 + ))]) Loader.args = [] program = TestableTestProgram() program._do_discovery(['fish', 'eggs'], Loader=Loader) self.assertEqual(program.test, 'tests') - self.assertEqual(Loader.args, [('fish', 'eggs', None)]) + self.assertEqual(Loader.args, [( + 'fish', + 'eggs', + None, + Namespace( + buffer=False, + catchbreak=False, + exit=True, + failfast=False, + pattern='eggs', + start='fish', + tb_locals=False, + testNamePatterns=[], + top=None, + verbosity=1 + ) + )]) Loader.args = [] program = TestableTestProgram() program._do_discovery(['fish', 'eggs', 'ham'], Loader=Loader) self.assertEqual(program.test, 'tests') - self.assertEqual(Loader.args, [('fish', 'eggs', 'ham')]) + self.assertEqual(Loader.args, [( + 'fish', + 'eggs', + 'ham', + Namespace( + buffer=False, + catchbreak=False, + exit=True, + failfast=False, + pattern='eggs', + start='fish', + tb_locals=False, + testNamePatterns=[], + top='ham', + verbosity=1 + ) + )]) Loader.args = [] program = TestableTestProgram() program._do_discovery(['-s', 'fish'], Loader=Loader) self.assertEqual(program.test, 'tests') - self.assertEqual(Loader.args, [('fish', 'test*.py', None)]) + self.assertEqual(Loader.args, [( + 'fish', + 'test*.py', + None, + Namespace( + buffer=False, + catchbreak=False, + exit=True, + failfast=False, + pattern='test*.py', + start='fish', + tb_locals=False, + testNamePatterns=[], + top=None, + verbosity=1 + ) + )]) Loader.args = [] program = TestableTestProgram() program._do_discovery(['-t', 'fish'], Loader=Loader) self.assertEqual(program.test, 'tests') - self.assertEqual(Loader.args, [('.', 'test*.py', 'fish')]) + self.assertEqual(Loader.args, [( + '.', + 'test*.py', + 'fish', + Namespace( + buffer=False, + catchbreak=False, + exit=True, + failfast=False, + pattern='test*.py', + start='.', + tb_locals=False, + testNamePatterns=[], + top='fish', + verbosity=1 + ) + )]) Loader.args = [] program = TestableTestProgram() program._do_discovery(['-p', 'fish'], Loader=Loader) self.assertEqual(program.test, 'tests') - self.assertEqual(Loader.args, [('.', 'fish', None)]) + self.assertEqual(Loader.args, [( + '.', + 'fish', + None, + Namespace( + buffer=False, + catchbreak=False, + exit=True, + failfast=False, + pattern='fish', + start='.', + tb_locals=False, + testNamePatterns=[], + top=None, + verbosity=1 + ) + )]) self.assertFalse(program.failfast) self.assertFalse(program.catchbreak) @@ -710,7 +865,23 @@ class TestDiscovery(unittest.TestCase): program._do_discovery(['-p', 'eggs', '-s', 'fish', '-v', '-f', '-c'], Loader=Loader) self.assertEqual(program.test, 'tests') - self.assertEqual(Loader.args, [('fish', 'eggs', None)]) + self.assertEqual(Loader.args, [( + 'fish', + 'eggs', + None, + Namespace( + buffer=False, + catchbreak=True, + exit=True, + failfast=True, + pattern='eggs', + start='fish', + tb_locals=False, + testNamePatterns=[], + top=None, + verbosity=2 + ) + )]) self.assertEqual(program.verbosity, 2) self.assertTrue(program.failfast) self.assertTrue(program.catchbreak) @@ -785,7 +956,7 @@ class TestDiscovery(unittest.TestCase): expectedPath = os.path.abspath(os.path.dirname(unittest.test.__file__)) self.wasRun = False - def _find_tests(start_dir, pattern, namespace=None): + def _find_tests(start_dir, pattern, namespace=None, command_line_arguments=None): self.wasRun = True self.assertEqual(start_dir, expectedPath) return tests @@ -833,7 +1004,7 @@ class TestDiscovery(unittest.TestCase): return package _find_tests_args = [] - def _find_tests(start_dir, pattern, namespace=None): + def _find_tests(start_dir, pattern, namespace=None, command_line_arguments=None): _find_tests_args.append((start_dir, pattern)) return ['%s/tests' % start_dir] diff --git a/Lib/unittest/test/test_program.py b/Lib/unittest/test/test_program.py index 4a62ae1b11..a6873089f2 100644 --- a/Lib/unittest/test/test_program.py +++ b/Lib/unittest/test/test_program.py @@ -6,6 +6,7 @@ import subprocess from test import support import unittest import unittest.test +from argparse import Namespace class Test_TestProgram(unittest.TestCase): @@ -17,7 +18,7 @@ class Test_TestProgram(unittest.TestCase): expectedPath = os.path.abspath(os.path.dirname(unittest.test.__file__)) self.wasRun = False - def _find_tests(start_dir, pattern): + def _find_tests(start_dir, pattern, command_line_arguments=None): self.wasRun = True self.assertEqual(start_dir, expectedPath) return tests @@ -437,6 +438,37 @@ class TestCommandLineArgs(unittest.TestCase): self.assertIn('Ran 7 tests', run_unittest(['-k', '*test_warnings.*Warning*', t])) self.assertIn('Ran 1 test', run_unittest(['-k', '*test_warnings.*warning*', t])) + def testCustomCommandLineArguments(self): + class FakeTP(unittest.TestProgram): + def addCustomArguments(self, parser): + parser.add_argument('--foo', action="store_true", help='foo help') + def runTests(self, *args, **kw): pass + + fp = FakeTP(argv=["testprogram"]) + self.assertEqual(fp.command_line_arguments, Namespace( + buffer=False, + catchbreak=False, + exit=True, + failfast=False, + foo=False, + tb_locals=False, + testNamePatterns=[], + tests=[], + verbosity=1 + )) + fp = FakeTP(argv=["testprogram", "--foo"]) + self.assertEqual(fp.command_line_arguments, Namespace( + buffer=False, + catchbreak=False, + exit=True, + failfast=False, + foo=True, + tb_locals=False, + testNamePatterns=[], + tests=[], + verbosity=1 + )) + if __name__ == '__main__': unittest.main() -- 2.20.1
participants (1)
-
Rémi Lapeyre