Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Initial attempt at parallelizing Python test execution
  • Loading branch information
JoshRosen committed Jun 28, 2015
commit af4cef40fa48feee32389e2a7fa8f96c1f82fd91
85 changes: 60 additions & 25 deletions python/run-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,19 @@
#

from __future__ import print_function
import logging
from optparse import OptionParser
import os
import re
import subprocess
import sys
import tempfile
from threading import Thread, Lock
import time
if sys.version < '3':
import Queue
else:
import queue as Queue


# Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module
Expand All @@ -43,34 +50,43 @@ def print_red(text):


LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log")
LOG_FILE_LOCK = Lock()
LOGGER = logging.getLogger()


def run_individual_python_test(test_name, pyspark_python):
env = {'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)}
print(" Running test: %s ..." % test_name, end='')
LOGGER.info("Starting test(%s): %s" % (pyspark_python, test_name))
start_time = time.time()
with open(LOG_FILE, 'a') as log_file:
retcode = subprocess.call(
[os.path.join(SPARK_HOME, "bin/pyspark"), test_name],
stderr=log_file, stdout=log_file, env=env)
per_test_output = tempfile.TemporaryFile()
retcode = subprocess.Popen(
[os.path.join(SPARK_HOME, "bin/pyspark"), test_name],
stderr=per_test_output, stdout=per_test_output, env=env).wait()
duration = time.time() - start_time
with LOG_FILE_LOCK:
with open(LOG_FILE, 'ab') as log_file:
per_test_output.seek(0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: only need to print the log when the test fails.

log_file.writelines(per_test_output.readlines())
per_test_output.close()
# Exit on the first failure.
if retcode != 0:
with open(LOG_FILE, 'r') as log_file:
for line in log_file:
if not re.match('[0-9]+', line):
print(line, end='')
print_red("\nHad test failures in %s; see logs." % test_name)
exit(-1)
print_red("\nHad test failures in %s with %s; see logs." % (test_name, pyspark_python))
# Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if
# this code is invoked from a thread other than the main thread.
os._exit(-1)
else:
print("ok (%is)" % duration)
LOGGER.info("Finished test(%s): %s (%is)" % (pyspark_python, test_name, duration))


def get_default_python_executables():
python_execs = [x for x in ["python2.6", "python3.4", "pypy"] if which(x)]
if "python2.6" not in python_execs:
print("WARNING: Not testing against `python2.6` because it could not be found; falling"
" back to `python` instead")
LOGGER.warning("Not testing against `python2.6` because it could not be found; falling"
" back to `python` instead")
python_execs.insert(0, "python")
return python_execs

Expand All @@ -88,6 +104,10 @@ def parse_opts():
default=",".join(sorted(python_modules.keys())),
help="A comma-separated list of Python modules to test (default: %default)"
)
parser.add_option(
"-p", "--parallelism", type="int", default=4,
help="The number of suites to test in parallel (default %default)"
)

(opts, args) = parser.parse_args()
if args:
Expand All @@ -96,8 +116,9 @@ def parse_opts():


def main():
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, format="%(message)s")
opts = parse_opts()
print("Running PySpark tests. Output is in python/%s" % LOG_FILE)
LOGGER.info("Running PySpark tests. Output is in python/%s" % LOG_FILE)
if os.path.exists(LOG_FILE):
os.remove(LOG_FILE)
python_execs = opts.python_executables.split(',')
Expand All @@ -108,24 +129,38 @@ def main():
else:
print("Error: unrecognized module %s" % module_name)
sys.exit(-1)
print("Will test against the following Python executables: %s" % python_execs)
print("Will test the following Python modules: %s" % [x.name for x in modules_to_test])
LOGGER.info("Will test against the following Python executables: %s" % python_execs)
LOGGER.info("Will test the following Python modules: %s" % [x.name for x in modules_to_test])

start_time = time.time()
task_queue = Queue.Queue()
for python_exec in python_execs:
python_implementation = subprocess.check_output(
[python_exec, "-c", "import platform; print(platform.python_implementation())"],
universal_newlines=True).strip()
print("Testing with `%s`: " % python_exec, end='')
subprocess.call([python_exec, "--version"])

for module in modules_to_test:
if python_implementation not in module.blacklisted_python_implementations:
print("Running %s tests ..." % module.name)
for test_goal in module.python_test_goals:
run_individual_python_test(test_goal, python_exec)
for test_goal in module.python_test_goals:
task_queue.put((python_exec, test_goal))

def process_queue(task_queue):
while True:
try:
(python_exec, test_goal) = task_queue.get_nowait()
except Queue.Empty:
break
try:
run_individual_python_test(test_goal, python_exec)
finally:
task_queue.task_done()

start_time = time.time()
for _ in range(opts.parallelism):
worker = Thread(target=process_queue, args=(task_queue,))
worker.daemon = True
worker.start()
try:
task_queue.join()
except (KeyboardInterrupt, SystemExit):
print_red("Exiting due to interrupt")
sys.exit(-1)
total_duration = time.time() - start_time
print("Tests passed in %i seconds" % total_duration)
LOGGER.info("Tests passed in %i seconds" % total_duration)


if __name__ == "__main__":
Expand Down