File tree Expand file tree Collapse file tree 2 files changed +16
-11
lines changed
Expand file tree Collapse file tree 2 files changed +16
-11
lines changed Original file line number Diff line number Diff line change @@ -493,9 +493,13 @@ def getStart(split):
493493 return start0 + int ((split * size / numSlices )) * step
494494
495495 def f (split , iterator ):
496- # it's an empty iterator here but we need this line for triggering the logic of
497- # checking END_OF_DATA_SECTION during load iterator in runtime, thus make sure
498- # worker reuse takes effect. See more details in SPARK-26549.
496+ # it's an empty iterator here but we need this line for triggering the
497+ # logic of signal handling in FramedSerializer.load_stream, for instance,
498+ # SpecialLengths.END_OF_DATA_SECTION in _read_with_length. Since
499+ # FramedSerializer.load_stream produces a generator, the control should
500+ # at least be in that function once. Here we do it by explicitly converting
501+ # the empty iterator to a list, thus make sure worker reuse takes effect.
502+ # See more details in SPARK-26549.
499503 assert len (list (iterator )) == 0
500504 return xrange (getStart (split ), getStart (split + 1 ), step )
501505
Original file line number Diff line number Diff line change 2222
2323from py4j .protocol import Py4JJavaError
2424
25- from pyspark .testing .utils import ReusedPySparkTestCase , QuietTest
25+ from pyspark .testing .utils import ReusedPySparkTestCase , PySparkTestCase , QuietTest
2626
2727if sys .version_info [0 ] >= 3 :
2828 xrange = range
@@ -144,14 +144,15 @@ def test_with_different_versions_of_python(self):
144144 finally :
145145 self .sc .pythonVer = version
146146
147+
148+ class WorkerReuseTest (PySparkTestCase ):
149+
147150 def test_reuse_worker_of_parallelize_xrange (self ):
148- def get_worker_pid (input_rdd ):
149- return input_rdd .map (lambda x : os .getpid ()).collect ()
150- rdd = self .sc .parallelize (xrange (20 ), 20 )
151- worker_pids = get_worker_pid (rdd )
152- pids = get_worker_pid (rdd )
153- for pid in pids :
154- self .assertTrue (pid in worker_pids )
151+ rdd = self .sc .parallelize (xrange (20 ), 8 )
152+ previous_pids = rdd .map (lambda x : os .getpid ()).collect ()
153+ current_pids = rdd .map (lambda x : os .getpid ()).collect ()
154+ for pid in current_pids :
155+ self .assertTrue (pid in previous_pids )
155156
156157
157158if __name__ == "__main__" :
You can’t perform that action at this time.
0 commit comments