diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index 5c19c77f37a81..1eed0ff3ee5e8 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -297,11 +297,9 @@ via `sc.setJobGroup` in a separate PVM thread, which also disallows to cancel th later. In order to synchronize PVM threads with JVM threads, you should set `PYSPARK_PIN_THREAD` environment variable -to `true`. This pinned thread mode allows one PVM thread has one corresponding JVM thread. - -However, currently it cannot inherit the local properties from the parent thread although it isolates -each thread with its own local properties. To work around this, you should manually copy and set the -local properties from the parent thread to the child thread when you create another thread in PVM. +to `true`. This pinned thread mode allows one PVM thread has one corresponding JVM thread. With this mode, +`pyspark.InheritableThread` is recommanded to use together for a PVM thread to inherit the interitable attributes + such as local properties in a JVM thread. Note that `PYSPARK_PIN_THREAD` is currently experimental and not recommended for use in production. diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index ee153af18c88c..61e38fdb2a57b 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -42,6 +42,8 @@ A :class:`TaskContext` that provides extra info and tooling for barrier execution. - :class:`BarrierTaskInfo`: Information about a barrier task. + - :class:`InheritableThread`: + A inheritable thread to use in Spark when the pinned thread mode is on. """ from functools import wraps @@ -51,6 +53,7 @@ from pyspark.context import SparkContext from pyspark.rdd import RDD, RDDBarrier from pyspark.files import SparkFiles +from pyspark.util import InheritableThread from pyspark.storagelevel import StorageLevel from pyspark.accumulators import Accumulator, AccumulatorParam from pyspark.broadcast import Broadcast @@ -118,5 +121,5 @@ def wrapper(self, *args, **kwargs): "SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast", "Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer", "StatusTracker", "SparkJobInfo", "SparkStageInfo", "Profiler", "BasicProfiler", "TaskContext", - "RDDBarrier", "BarrierTaskContext", "BarrierTaskInfo", + "RDDBarrier", "BarrierTaskContext", "BarrierTaskInfo", "InheritableThread", ] diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 2e105cc38260d..5ddce9f4584c4 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -1013,8 +1013,10 @@ def setJobGroup(self, groupId, description, interruptOnCancel=False): .. note:: Currently, setting a group ID (set to local properties) with multiple threads does not properly work. Internally threads on PVM and JVM are not synced, and JVM thread can be reused for multiple threads on PVM, which fails to isolate local - properties for each thread on PVM. To work around this, You can use - :meth:`RDD.collectWithJobGroup` for now. + properties for each thread on PVM. + + To avoid this, enable the pinned thread mode by setting ``PYSPARK_PIN_THREAD`` + environment variable to ``true`` and uses :class:`pyspark.InheritableThread`. """ self._jsc.setJobGroup(groupId, description, interruptOnCancel) @@ -1026,8 +1028,10 @@ def setLocalProperty(self, key, value): .. note:: Currently, setting a local property with multiple threads does not properly work. Internally threads on PVM and JVM are not synced, and JVM thread can be reused for multiple threads on PVM, which fails to isolate local properties - for each thread on PVM. To work around this, You can use - :meth:`RDD.collectWithJobGroup`. + for each thread on PVM. + + To avoid this, enable the pinned thread mode by setting ``PYSPARK_PIN_THREAD`` + environment variable to ``true`` and uses :class:`pyspark.InheritableThread`. """ self._jsc.setLocalProperty(key, value) @@ -1045,8 +1049,10 @@ def setJobDescription(self, value): .. note:: Currently, setting a job description (set to local properties) with multiple threads does not properly work. Internally threads on PVM and JVM are not synced, and JVM thread can be reused for multiple threads on PVM, which fails to isolate - local properties for each thread on PVM. To work around this, You can use - :meth:`RDD.collectWithJobGroup` for now. + local properties for each thread on PVM. + + To avoid this, enable the pinned thread mode by setting ``PYSPARK_PIN_THREAD`` + environment variable to ``true`` and uses :class:`pyspark.InheritableThread`. """ self._jsc.setJobDescription(value) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 437b2c446529a..4ee486800f882 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -859,12 +859,18 @@ def collect(self): def collectWithJobGroup(self, groupId, description, interruptOnCancel=False): """ - .. note:: Experimental - When collect rdd, use this method to specify job group. + .. note:: Deprecated in 3.1.0. Use :class:`pyspark.InheritableThread` with + the pinned thread mode enabled. + .. versionadded:: 3.0.0 """ + warnings.warn( + "Deprecated in 3.1, Use pyspark.InheritableThread with " + "the pinned thread mode enabled.", + DeprecationWarning) + with SCCallSiteSync(self.context) as css: sock_info = self.ctx._jvm.PythonRDD.collectAndServeWithJobGroup( self._jrdd.rdd(), groupId, description, interruptOnCancel) diff --git a/python/pyspark/tests/test_pin_thread.py b/python/pyspark/tests/test_pin_thread.py index 657d129fe63bb..50eb8e0ec8b50 100644 --- a/python/pyspark/tests/test_pin_thread.py +++ b/python/pyspark/tests/test_pin_thread.py @@ -20,7 +20,7 @@ import threading import unittest -from pyspark import SparkContext, SparkConf +from pyspark import SparkContext, SparkConf, InheritableThread class PinThreadTests(unittest.TestCase): @@ -143,6 +143,27 @@ def run_job(job_group, index): is_job_cancelled[i], "Thread {i}: Job in group B did not succeeded.".format(i=i)) + def test_inheritable_local_property(self): + self.sc.setLocalProperty("a", "hi") + expected = [] + + def get_inner_local_prop(): + expected.append(self.sc.getLocalProperty("b")) + + def get_outer_local_prop(): + expected.append(self.sc.getLocalProperty("a")) + self.sc.setLocalProperty("b", "hello") + t2 = InheritableThread(target=get_inner_local_prop) + t2.start() + t2.join() + + t1 = InheritableThread(target=get_outer_local_prop) + t1.start() + t1.join() + + self.assertEqual(self.sc.getLocalProperty("b"), None) + self.assertEqual(expected, ["hi", "hello"]) + if __name__ == "__main__": import unittest diff --git a/python/pyspark/util.py b/python/pyspark/util.py index c003586e9c03b..86e5ab5a01585 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -16,10 +16,13 @@ # limitations under the License. # +import threading import re import sys import traceback +from py4j.clientserver import ClientServer + __all__ = [] @@ -114,6 +117,64 @@ def _parse_memory(s): raise ValueError("invalid format: " + s) return int(float(s[:-1]) * units[s[-1].lower()]) + +class InheritableThread(threading.Thread): + """ + Thread that is recommended to be used in PySpark instead of :class:`threading.Thread` + when the pinned thread mode is enabled. The usage of this class is exactly same as + :class:`threading.Thread` but correctly inherits the inheritable properties specific + to JVM thread such as ``InheritableThreadLocal``. + + Also, note that pinned thread mode does not close the connection from Python + to JVM when the thread is finished in the Python side. With this class, Python + garbage-collects the Python thread instance and also closes the connection + which finishes JVM thread correctly. + + When the pinned thread mode is off, this works as :class:`threading.Thread`. + + .. note:: Experimental + + .. versionadded:: 3.1.0 + """ + def __init__(self, target, *args, **kwargs): + from pyspark import SparkContext + + sc = SparkContext._active_spark_context + + if isinstance(sc._gateway, ClientServer): + # Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on. + properties = sc._jsc.sc().getLocalProperties().clone() + self._sc = sc + + def copy_local_properties(*a, **k): + sc._jsc.sc().setLocalProperties(properties) + return target(*a, **k) + + super(InheritableThread, self).__init__( + target=copy_local_properties, *args, **kwargs) + else: + super(InheritableThread, self).__init__(target=target, *args, **kwargs) + + def __del__(self): + from pyspark import SparkContext + + if isinstance(SparkContext._gateway, ClientServer): + thread_connection = self._sc._jvm._gateway_client.thread_connection.connection() + if thread_connection is not None: + connections = self._sc._jvm._gateway_client.deque + + # Reuse the lock for Py4J in PySpark + with SparkContext._lock: + for i in range(len(connections)): + if connections[i] is thread_connection: + connections[i].close() + del connections[i] + break + else: + # Just in case the connection was not closed but removed from the queue. + thread_connection.close() + + if __name__ == "__main__": import doctest (failure_count, test_count) = doctest.testmod()