From 2f371d7f876631db224f59acd314b128567bd70e Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Sun, 6 Jan 2019 03:28:23 +0800 Subject: [PATCH 1/3] Fix for python worker reuse take no effect --- python/pyspark/tests/test_worker.py | 9 +++++++++ python/pyspark/worker.py | 7 ++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/python/pyspark/tests/test_worker.py b/python/pyspark/tests/test_worker.py index a33b77d98341..9a338df3d957 100644 --- a/python/pyspark/tests/test_worker.py +++ b/python/pyspark/tests/test_worker.py @@ -144,6 +144,15 @@ def test_with_different_versions_of_python(self): finally: self.sc.pythonVer = version + def test_reuse_worker(self): + def get_worker_pid(input_rdd): + return input_rdd.map(lambda x: os.getpid()).collect() + rdd = self.sc.parallelize(range(20), 20) + worker_pids = get_worker_pid(rdd) + pids = get_worker_pid(rdd) + for pid in pids: + self.assertTrue(pid in worker_pids) + if __name__ == "__main__": import unittest diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index bf007b0c62d8..01e54b402828 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -446,7 +446,12 @@ def process(): pickleSer._write_with_length((aid, accum._value), outfile) # check end of stream - if read_int(infile) == SpecialLengths.END_OF_STREAM: + res = read_int(infile) + if sys.version >= '3' and res == SpecialLengths.END_OF_DATA_SECTION: + # skip the END_OF_DATA_SECTION for Python3, otherwise the worker reuse will take + # no effect, see SPARK-26549 for more details. + res = read_int(infile) + if res == SpecialLengths.END_OF_STREAM: write_int(SpecialLengths.END_OF_STREAM, outfile) else: # write a different value to tell JVM to not reuse this worker From ab451e5b4e152450e3fda7ef677deef52bf359a1 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Mon, 7 Jan 2019 21:59:25 +0800 Subject: [PATCH 2/3] Simplify approach just for parallelize xrange --- python/pyspark/context.py | 4 ++++ python/pyspark/tests/test_worker.py | 4 ++-- python/pyspark/worker.py | 7 +------ 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 6137ed25a0dd..9c13c18781d1 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -493,6 +493,10 @@ def getStart(split): return start0 + int((split * size / numSlices)) * step def f(split, iterator): + # it's an empty iterator here but we need this line for triggering the logic of + # checking END_OF_DATA_SECTION during load iterator in runtime, thus make sure + # worker reuse takes effect. See more details in SPARK-26549. + assert len(list(iterator)) == 0 return xrange(getStart(split), getStart(split + 1), step) return self.parallelize([], numSlices).mapPartitionsWithIndex(f) diff --git a/python/pyspark/tests/test_worker.py b/python/pyspark/tests/test_worker.py index 9a338df3d957..71fcc189cb10 100644 --- a/python/pyspark/tests/test_worker.py +++ b/python/pyspark/tests/test_worker.py @@ -144,10 +144,10 @@ def test_with_different_versions_of_python(self): finally: self.sc.pythonVer = version - def test_reuse_worker(self): + def test_reuse_worker_of_parallelize_xrange(self): def get_worker_pid(input_rdd): return input_rdd.map(lambda x: os.getpid()).collect() - rdd = self.sc.parallelize(range(20), 20) + rdd = self.sc.parallelize(xrange(20), 20) worker_pids = get_worker_pid(rdd) pids = get_worker_pid(rdd) for pid in pids: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 01e54b402828..bf007b0c62d8 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -446,12 +446,7 @@ def process(): pickleSer._write_with_length((aid, accum._value), outfile) # check end of stream - res = read_int(infile) - if sys.version >= '3' and res == SpecialLengths.END_OF_DATA_SECTION: - # skip the END_OF_DATA_SECTION for Python3, otherwise the worker reuse will take - # no effect, see SPARK-26549 for more details. - res = read_int(infile) - if res == SpecialLengths.END_OF_STREAM: + if read_int(infile) == SpecialLengths.END_OF_STREAM: write_int(SpecialLengths.END_OF_STREAM, outfile) else: # write a different value to tell JVM to not reuse this worker From 4868e82256c08679e081dd9e92d5454056686de8 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Wed, 9 Jan 2019 11:13:47 +0800 Subject: [PATCH 3/3] Address comments --- python/pyspark/context.py | 10 +++++++--- python/pyspark/tests/test_worker.py | 17 +++++++++-------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 9c13c18781d1..180a3e882dab 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -493,9 +493,13 @@ def getStart(split): return start0 + int((split * size / numSlices)) * step def f(split, iterator): - # it's an empty iterator here but we need this line for triggering the logic of - # checking END_OF_DATA_SECTION during load iterator in runtime, thus make sure - # worker reuse takes effect. See more details in SPARK-26549. + # it's an empty iterator here but we need this line for triggering the + # logic of signal handling in FramedSerializer.load_stream, for instance, + # SpecialLengths.END_OF_DATA_SECTION in _read_with_length. Since + # FramedSerializer.load_stream produces a generator, the control should + # at least be in that function once. Here we do it by explicitly converting + # the empty iterator to a list, thus make sure worker reuse takes effect. + # See more details in SPARK-26549. assert len(list(iterator)) == 0 return xrange(getStart(split), getStart(split + 1), step) diff --git a/python/pyspark/tests/test_worker.py b/python/pyspark/tests/test_worker.py index 71fcc189cb10..a4f108f18e17 100644 --- a/python/pyspark/tests/test_worker.py +++ b/python/pyspark/tests/test_worker.py @@ -22,7 +22,7 @@ from py4j.protocol import Py4JJavaError -from pyspark.testing.utils import ReusedPySparkTestCase, QuietTest +from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest if sys.version_info[0] >= 3: xrange = range @@ -144,14 +144,15 @@ def test_with_different_versions_of_python(self): finally: self.sc.pythonVer = version + +class WorkerReuseTest(PySparkTestCase): + def test_reuse_worker_of_parallelize_xrange(self): - def get_worker_pid(input_rdd): - return input_rdd.map(lambda x: os.getpid()).collect() - rdd = self.sc.parallelize(xrange(20), 20) - worker_pids = get_worker_pid(rdd) - pids = get_worker_pid(rdd) - for pid in pids: - self.assertTrue(pid in worker_pids) + rdd = self.sc.parallelize(xrange(20), 8) + previous_pids = rdd.map(lambda x: os.getpid()).collect() + current_pids = rdd.map(lambda x: os.getpid()).collect() + for pid in current_pids: + self.assertTrue(pid in previous_pids) if __name__ == "__main__":