Skip to content
Prev Previous commit
Next Next commit
Refactor
  • Loading branch information
HyukjinKwon committed Jul 13, 2020
commit 64241512d0a1c848bdcd66556d2522b2cf824ed0
40 changes: 31 additions & 9 deletions dev/run-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,14 @@ def setup_test_environ(environ):
os.environ[k] = v


def determine_modules_to_test(changed_modules):
def determine_modules_to_test(changed_modules, deduplicated=True):
"""
Given a set of modules that have changed, compute the transitive closure of those modules'
dependent modules in order to determine the set of modules that should be tested.

Returns a topologically-sorted list of modules (ties are broken by sorting on module names).
If ``deduplicated`` is disabled, the modules are returned without tacking the deduplication
by dependencies into account.

>>> [x.name for x in determine_modules_to_test([modules.root])]
['root']
Expand All @@ -122,11 +124,32 @@ def determine_modules_to_test(changed_modules):
... # doctest: +NORMALIZE_WHITESPACE
['sql', 'avro', 'hive', 'mllib', 'sql-kafka-0-10', 'examples', 'hive-thriftserver',
'pyspark-sql', 'repl', 'sparkr', 'pyspark-mllib', 'pyspark-ml']
>>> sorted([x.name for x in determine_modules_to_test(
... [modules.sparkr, modules.pyspark_sql], deduplicated=False)])
... # doctest: +NORMALIZE_WHITESPACE
['avro', 'examples', 'hive', 'hive-thriftserver', 'mllib', 'pyspark-ml',
'pyspark-mllib', 'pyspark-sql', 'repl', 'sparkr', 'sql', 'sql-kafka-0-10']
>>> sorted([x.name for x in determine_modules_to_test(
... [modules.sql, modules.core], deduplicated=False)])
... # doctest: +NORMALIZE_WHITESPACE
['avro', 'catalyst', 'core', 'examples', 'graphx', 'hive', 'hive-thriftserver',
'mllib', 'mllib-local', 'pyspark-core', 'pyspark-ml', 'pyspark-mllib',
'pyspark-resource', 'pyspark-sql', 'pyspark-streaming', 'repl', 'root',
'sparkr', 'sql', 'sql-kafka-0-10', 'streaming', 'streaming-kafka-0-10',
'streaming-kinesis-asl']
"""
modules_to_test = set()
for module in changed_modules:
modules_to_test = modules_to_test.union(determine_modules_to_test(module.dependent_modules))
modules_to_test = modules_to_test.union(
determine_modules_to_test(module.dependent_modules, deduplicated))
modules_to_test = modules_to_test.union(set(changed_modules))

if not deduplicated:
return modules_to_test

# If we need to run all of the tests, then we should short-circuit and return 'root'
if modules.root in modules_to_test:
return [modules.root]
return toposort_flatten(
{m: set(m.dependencies).intersection(modules_to_test) for m in modules_to_test}, sort=True)

Expand Down Expand Up @@ -640,11 +663,14 @@ def main():
os.environ["GITHUB_SHA"], target_ref=os.environ["GITHUB_PREV_SHA"])
print("changed_files : %s" % changed_files)
new_test_modules = list(set(determine_modules_to_test(
determine_modules_for_files(changed_files))).intersection(test_modules))
determine_modules_for_files(changed_files)), deduplicated=False
).intersection(test_modules))

if modules.root not in new_test_modules:
# If there is root, we should also test the modules.
# If there is no root, then just use the modules as passed initially.
# If root module does not exist, only test the intersected modules.
# If root module is found, just run the modules as specified initially.
test_modules = new_test_modules

print("test_modules : %s" % test_modules)

changed_modules = test_modules
Expand All @@ -667,10 +693,6 @@ def main():
if not test_modules:
test_modules = determine_modules_to_test(changed_modules)

# If we need to run all of the tests, then we should short-circuit and return 'root'
if modules.root in test_modules:
test_modules = [modules.root]

str_excluded_tags = opts.excluded_tags
str_included_tags = opts.included_tags
if str_excluded_tags:
Expand Down