From 0cb2cf6e9ece66861073c31b579b595a9de5ce81 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Wed, 7 Nov 2018 18:01:54 +0800 Subject: [PATCH 1/2] fix BarrierTaskContext while python worker reuse --- python/pyspark/taskcontext.py | 11 ++++++++++- python/pyspark/tests.py | 12 ++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index b61643eb0a16..091d45f0605c 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -144,10 +144,19 @@ def __init__(self): """Construct a BarrierTaskContext, use get instead""" pass + def __new__(cls): + """ + Rewrite __new__ method to BarrierTaskContext for _getOrCreate called when _taskContext + is not instance of BarrierTaskContext. + """ + if not isinstance(cls._taskContext, BarrierTaskContext): + cls._taskContext = object.__new__(cls) + return cls._taskContext + @classmethod def _getOrCreate(cls): """Internal function to get or create global BarrierTaskContext.""" - if cls._taskContext is None: + if not isinstance(cls._taskContext, BarrierTaskContext): cls._taskContext = BarrierTaskContext() return cls._taskContext diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 050c2dd01836..1cf9b14ee74c 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -614,6 +614,18 @@ def context_barrier(x): times = rdd.barrier().mapPartitions(f).map(context_barrier).collect() self.assertTrue(max(times) - min(times) < 1) + def test_barrier_with_python_worker_reuse(self): + """ + Verify that BarrierTaskContext.barrier() with reused python worker. + """ + rdd = self.sc.parallelize(range(4), 4) + # start a normal job first to start all worker + result = rdd.map(lambda x: x ** 2).collect() + self.assertEqual([0, 1, 4, 9], result) + + # worker will be reused in this barrier job + self.test_barrier() + def test_barrier_infos(self): """ Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the From 02555b8fbdf85c3f2b5a92420479c168e14b573c Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Mon, 12 Nov 2018 11:36:41 +0800 Subject: [PATCH 2/2] UT enhance and address comment --- python/pyspark/taskcontext.py | 11 +---------- python/pyspark/tests.py | 3 +++ 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index 091d45f0605c..98b505c9046b 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -144,20 +144,11 @@ def __init__(self): """Construct a BarrierTaskContext, use get instead""" pass - def __new__(cls): - """ - Rewrite __new__ method to BarrierTaskContext for _getOrCreate called when _taskContext - is not instance of BarrierTaskContext. - """ - if not isinstance(cls._taskContext, BarrierTaskContext): - cls._taskContext = object.__new__(cls) - return cls._taskContext - @classmethod def _getOrCreate(cls): """Internal function to get or create global BarrierTaskContext.""" if not isinstance(cls._taskContext, BarrierTaskContext): - cls._taskContext = BarrierTaskContext() + cls._taskContext = object.__new__(cls) return cls._taskContext @classmethod diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 1cf9b14ee74c..131c51e108ca 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -618,10 +618,13 @@ def test_barrier_with_python_worker_reuse(self): """ Verify that BarrierTaskContext.barrier() with reused python worker. """ + self.sc._conf.set("spark.python.work.reuse", "true") rdd = self.sc.parallelize(range(4), 4) # start a normal job first to start all worker result = rdd.map(lambda x: x ** 2).collect() self.assertEqual([0, 1, 4, 9], result) + # make sure `spark.python.work.reuse=true` + self.assertEqual(self.sc._conf.get("spark.python.work.reuse"), "true") # worker will be reused in this barrier job self.test_barrier()