Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion Lib/unittest/case.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test case implementation"""

import sys
import argparse
import functools
import difflib
import logging
Expand Down Expand Up @@ -429,7 +430,7 @@ class TestCase(object):

_class_cleanups = []

def __init__(self, methodName='runTest'):
def __init__(self, methodName='runTest', command_line_arguments=None):
"""Create an instance of the class that will use the named test
method when executed. Raises a ValueError if the instance does
not have a method with the specified name.
Expand All @@ -449,6 +450,10 @@ def __init__(self, methodName='runTest'):
self._testMethodDoc = testMethod.__doc__
self._cleanups = []
self._subtest = None
if command_line_arguments is None:
self.command_line_arguments = argparse.Namespace()
else:
self.command_line_arguments = command_line_arguments

# Map types to custom assertEqual functions that will compare
# instances of said type in more detail to generate a more useful
Expand Down
73 changes: 48 additions & 25 deletions Lib/unittest/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys
import traceback
import types
import itertools
import functools
import warnings

Expand Down Expand Up @@ -81,7 +82,7 @@ def __init__(self):
# avoid infinite re-entrancy.
self._loading_packages = set()

def loadTestsFromTestCase(self, testCaseClass):
def loadTestsFromTestCase(self, testCaseClass, command_line_arguments=None):
"""Return a suite of all test cases contained in testCaseClass"""
if issubclass(testCaseClass, suite.TestSuite):
raise TypeError("Test cases should not be derived from "
Expand All @@ -90,12 +91,21 @@ def loadTestsFromTestCase(self, testCaseClass):
testCaseNames = self.getTestCaseNames(testCaseClass)
if not testCaseNames and hasattr(testCaseClass, 'runTest'):
testCaseNames = ['runTest']
loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))

# keep backward compatibility for subclasses that override __init__
def instanciate_testcase(testCaseClass, testCaseName):
try:
return testCaseClass(testCaseName, command_line_arguments)
except TypeError:
return testCaseClass(testCaseName)
loaded_suite = self.suiteClass(
map(instanciate_testcase, itertools.repeat(testCaseClass), testCaseNames)
)
return loaded_suite

# XXX After Python 3.5, remove backward compatibility hacks for
# use_load_tests deprecation via *args and **kws. See issue 16662.
def loadTestsFromModule(self, module, *args, pattern=None, **kws):
def loadTestsFromModule(self, module, *args, pattern=None, command_line_arguments=None, **kws):
"""Return a suite of all test cases contained in the given module"""
# This method used to take an undocumented and unofficial
# use_load_tests argument. For backward compatibility, we still
Expand All @@ -121,7 +131,7 @@ def loadTestsFromModule(self, module, *args, pattern=None, **kws):
for name in dir(module):
obj = getattr(module, name)
if isinstance(obj, type) and issubclass(obj, case.TestCase):
tests.append(self.loadTestsFromTestCase(obj))
tests.append(self.loadTestsFromTestCase(obj, command_line_arguments))

load_tests = getattr(module, 'load_tests', None)
tests = self.suiteClass(tests)
Expand All @@ -135,7 +145,7 @@ def loadTestsFromModule(self, module, *args, pattern=None, **kws):
return error_case
return tests

def loadTestsFromName(self, name, module=None):
def loadTestsFromName(self, name, module=None, command_line_arguments=None):
"""Return a suite of all test cases given a string specifier.

The name may resolve either to a module, a test case class, a
Expand Down Expand Up @@ -188,9 +198,9 @@ def loadTestsFromName(self, name, module=None):
return error_case

if isinstance(obj, types.ModuleType):
return self.loadTestsFromModule(obj)
return self.loadTestsFromModule(obj, command_line_arguments)
elif isinstance(obj, type) and issubclass(obj, case.TestCase):
return self.loadTestsFromTestCase(obj)
return self.loadTestsFromTestCase(obj, command_line_arguments)
elif (isinstance(obj, types.FunctionType) and
isinstance(parent, type) and
issubclass(parent, case.TestCase)):
Expand All @@ -213,11 +223,11 @@ def loadTestsFromName(self, name, module=None):
else:
raise TypeError("don't know how to make test from: %s" % obj)

def loadTestsFromNames(self, names, module=None):
def loadTestsFromNames(self, names, module=None, command_line_arguments=None):
"""Return a suite of all test cases found using the given sequence
of string specifiers. See 'loadTestsFromName()'.
"""
suites = [self.loadTestsFromName(name, module) for name in names]
suites = [self.loadTestsFromName(name, module, command_line_arguments) for name in names]
return self.suiteClass(suites)

def getTestCaseNames(self, testCaseClass):
Expand All @@ -239,7 +249,7 @@ def shouldIncludeMethod(attrname):
testFnNames.sort(key=functools.cmp_to_key(self.sortTestMethodsUsing))
return testFnNames

def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
def discover(self, start_dir, pattern='test*.py', top_level_dir=None, command_line_arguments=None):
"""Find and return all test modules from the specified start
directory, recursing into subdirectories to find them and return all
tests found within them. Only test files that match the pattern will
Expand Down Expand Up @@ -322,9 +332,12 @@ def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
self._top_level_dir = \
(path.split(the_module.__name__
.replace(".", os.path.sep))[0])
tests.extend(self._find_tests(path,
pattern,
namespace=True))
tests.extend(self._find_tests(
path,
pattern,
namespace=True,
command_line_arguments=command_line_arguments
))
elif the_module.__name__ in sys.builtin_module_names:
# builtin module
raise TypeError('Can not use builtin modules '
Expand All @@ -346,7 +359,7 @@ def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
raise ImportError('Start directory is not importable: %r' % start_dir)

if not is_namespace:
tests = list(self._find_tests(start_dir, pattern))
tests = list(self._find_tests(start_dir, pattern, command_line_arguments=command_line_arguments))
return self.suiteClass(tests)

def _get_directory_containing_module(self, module_name):
Expand Down Expand Up @@ -381,7 +394,7 @@ def _match_path(self, path, full_path, pattern):
# override this method to use alternative matching strategy
return fnmatch(path, pattern)

def _find_tests(self, start_dir, pattern, namespace=False):
def _find_tests(self, start_dir, pattern, namespace=False, command_line_arguments=None):
"""Used by discovery. Yields test suites it loads."""
# Handle the __init__ in this package
name = self._get_name_from_path(start_dir)
Expand All @@ -391,7 +404,7 @@ def _find_tests(self, start_dir, pattern, namespace=False):
# name is in self._loading_packages while we have called into
# loadTestsFromModule with name.
tests, should_recurse = self._find_test_path(
start_dir, pattern, namespace)
start_dir, pattern, namespace, command_line_arguments)
if tests is not None:
yield tests
if not should_recurse:
Expand All @@ -403,19 +416,24 @@ def _find_tests(self, start_dir, pattern, namespace=False):
for path in paths:
full_path = os.path.join(start_dir, path)
tests, should_recurse = self._find_test_path(
full_path, pattern, namespace)
full_path, pattern, namespace, command_line_arguments)
if tests is not None:
yield tests
if should_recurse:
# we found a package that didn't use load_tests.
name = self._get_name_from_path(full_path)
self._loading_packages.add(name)
try:
yield from self._find_tests(full_path, pattern, namespace)
yield from self._find_tests(
full_path,
pattern,
namespace,
command_line_arguments=command_line_arguments
)
finally:
self._loading_packages.discard(name)

def _find_test_path(self, full_path, pattern, namespace=False):
def _find_test_path(self, full_path, pattern, namespace=False, command_line_arguments=None):
"""Used by discovery.

Loads tests from a single file, or a directories' __init__.py when
Expand Down Expand Up @@ -457,7 +475,8 @@ def _find_test_path(self, full_path, pattern, namespace=False):
"%r. Is this module globally installed?")
raise ImportError(
msg % (mod_name, module_dir, expected_dir))
return self.loadTestsFromModule(module, pattern=pattern), False
return self.loadTestsFromModule(module, pattern=pattern,
command_line_arguments=command_line_arguments), False
elif os.path.isdir(full_path):
if (not namespace and
not os.path.isfile(os.path.join(full_path, '__init__.py'))):
Expand All @@ -480,7 +499,11 @@ def _find_test_path(self, full_path, pattern, namespace=False):
# Mark this package as being in load_tests (possibly ;))
self._loading_packages.add(name)
try:
tests = self.loadTestsFromModule(package, pattern=pattern)
tests = self.loadTestsFromModule(
package,
pattern=pattern,
command_line_arguments=command_line_arguments
)
if load_tests is not None:
# loadTestsFromModule(package) has loaded tests for us.
return tests, False
Expand All @@ -507,11 +530,11 @@ def getTestCaseNames(testCaseClass, prefix, sortUsing=util.three_way_cmp, testNa
return _makeLoader(prefix, sortUsing, testNamePatterns=testNamePatterns).getTestCaseNames(testCaseClass)

def makeSuite(testCaseClass, prefix='test', sortUsing=util.three_way_cmp,
suiteClass=suite.TestSuite):
suiteClass=suite.TestSuite, command_line_arguments=None):
return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(
testCaseClass)
testCaseClass, command_line_arguments)

def findTestCases(module, prefix='test', sortUsing=util.three_way_cmp,
suiteClass=suite.TestSuite):
suiteClass=suite.TestSuite, command_line_arguments=None):
return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(\
module)
module, command_line_arguments=command_line_arguments)
49 changes: 46 additions & 3 deletions Lib/unittest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def _convert_select_pattern(pattern):
pattern = '*%s*' % pattern
return pattern

_options = ('verbosity', 'tb_locals', 'failfast', 'catchbreak', 'buffer', 'tests',
'testNamePatterns', 'tests', 'start', 'pattern', 'top', 'exit')

class TestProgram(object):
"""A command-line program that runs a set of tests; this is primarily
Expand Down Expand Up @@ -100,6 +102,37 @@ def __init__(self, module='__main__', defaultTest=None, argv=None,
self.parseArgs(argv)
self.runTests()

def __setattr__(self, name, value):
if name in _options:
setattr(self.command_line_arguments, name, value)
else:
super().__setattr__(name, value)

def __getattribute__(self, name):
if name in _options:
try:
return getattr(self.command_line_arguments, name)
except AttributeError:
pass

try:
return super().__getattribute__(name)
except AttributeError:
if name == 'command_line_arguments':
namespace = argparse.Namespace()
# preload command_line_arguments with class arguments
# this is useful for subclasses of TestProgram that override __init__
for name in _options:
try:
value = super().__getattribute__(name)
setattr(namespace, name, value)
except AttributeError:
pass
self.command_line_arguments = namespace
return namespace
else:
raise

def usageExit(self, msg=None):
if msg:
print(msg)
Expand All @@ -123,14 +156,14 @@ def parseArgs(self, argv):
if len(argv) > 1 and argv[1].lower() == 'discover':
self._do_discovery(argv[2:])
return
self._main_parser.parse_args(argv[1:], self)
self._main_parser.parse_args(argv[1:], self.command_line_arguments)
if not self.tests:
# this allows "python -m unittest -v" to still work for
# test discovery.
self._do_discovery([])
return
else:
self._main_parser.parse_args(argv[1:], self)
self._main_parser.parse_args(argv[1:], self.command_line_arguments)

if self.tests:
self.testNames = _convert_names(self.tests)
Expand All @@ -151,7 +184,12 @@ def createTests(self, from_discovery=False, Loader=None):
self.testLoader.testNamePatterns = self.testNamePatterns
if from_discovery:
loader = self.testLoader if Loader is None else Loader()
self.test = loader.discover(self.start, self.pattern, self.top)
self.test = loader.discover(
self.start,
self.pattern,
self.top,
self.command_line_arguments
)
elif self.testNames is None:
self.test = self.testLoader.loadTestsFromModule(self.module)
else:
Expand Down Expand Up @@ -196,8 +234,13 @@ def _getParentArgParser(self):
help='Only run tests which match the given substring')
self.testNamePatterns = []

self.addCustomArguments(parser)

return parser

def addCustomArguments(self, parser):
pass

def _getMainArgParser(self, parent):
parser = argparse.ArgumentParser(parents=[parent])
parser.prog = self.progName
Expand Down
Loading