Skip to content

Commit 4868e82

Browse files
committed
Address comments
1 parent ab451e5 commit 4868e82

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

python/pyspark/context.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -493,9 +493,13 @@ def getStart(split):
493493
return start0 + int((split * size / numSlices)) * step
494494

495495
def f(split, iterator):
496-
# it's an empty iterator here but we need this line for triggering the logic of
497-
# checking END_OF_DATA_SECTION during load iterator in runtime, thus make sure
498-
# worker reuse takes effect. See more details in SPARK-26549.
496+
# it's an empty iterator here but we need this line for triggering the
497+
# logic of signal handling in FramedSerializer.load_stream, for instance,
498+
# SpecialLengths.END_OF_DATA_SECTION in _read_with_length. Since
499+
# FramedSerializer.load_stream produces a generator, the control should
500+
# at least be in that function once. Here we do it by explicitly converting
501+
# the empty iterator to a list, thus make sure worker reuse takes effect.
502+
# See more details in SPARK-26549.
499503
assert len(list(iterator)) == 0
500504
return xrange(getStart(split), getStart(split + 1), step)
501505

python/pyspark/tests/test_worker.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from py4j.protocol import Py4JJavaError
2424

25-
from pyspark.testing.utils import ReusedPySparkTestCase, QuietTest
25+
from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest
2626

2727
if sys.version_info[0] >= 3:
2828
xrange = range
@@ -144,14 +144,15 @@ def test_with_different_versions_of_python(self):
144144
finally:
145145
self.sc.pythonVer = version
146146

147+
148+
class WorkerReuseTest(PySparkTestCase):
149+
147150
def test_reuse_worker_of_parallelize_xrange(self):
148-
def get_worker_pid(input_rdd):
149-
return input_rdd.map(lambda x: os.getpid()).collect()
150-
rdd = self.sc.parallelize(xrange(20), 20)
151-
worker_pids = get_worker_pid(rdd)
152-
pids = get_worker_pid(rdd)
153-
for pid in pids:
154-
self.assertTrue(pid in worker_pids)
151+
rdd = self.sc.parallelize(xrange(20), 8)
152+
previous_pids = rdd.map(lambda x: os.getpid()).collect()
153+
current_pids = rdd.map(lambda x: os.getpid()).collect()
154+
for pid in current_pids:
155+
self.assertTrue(pid in previous_pids)
155156

156157

157158
if __name__ == "__main__":

0 commit comments

Comments
 (0)