diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index b61643eb0a16..98b505c9046b 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -147,8 +147,8 @@ def __init__(self): @classmethod def _getOrCreate(cls): """Internal function to get or create global BarrierTaskContext.""" - if cls._taskContext is None: - cls._taskContext = BarrierTaskContext() + if not isinstance(cls._taskContext, BarrierTaskContext): + cls._taskContext = object.__new__(cls) return cls._taskContext @classmethod diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 050c2dd01836..131c51e108ca 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -614,6 +614,21 @@ def context_barrier(x): times = rdd.barrier().mapPartitions(f).map(context_barrier).collect() self.assertTrue(max(times) - min(times) < 1) + def test_barrier_with_python_worker_reuse(self): + """ + Verify that BarrierTaskContext.barrier() with reused python worker. + """ + self.sc._conf.set("spark.python.work.reuse", "true") + rdd = self.sc.parallelize(range(4), 4) + # start a normal job first to start all worker + result = rdd.map(lambda x: x ** 2).collect() + self.assertEqual([0, 1, 4, 9], result) + # make sure `spark.python.work.reuse=true` + self.assertEqual(self.sc._conf.get("spark.python.work.reuse"), "true") + + # worker will be reused in this barrier job + self.test_barrier() + def test_barrier_infos(self): """ Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the