diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index fd115fd2cb8eb..ed781be299b71 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -103,6 +103,11 @@ abstract class TaskContext extends Serializable { * This will be called in all situations - success, failure, or cancellation. Adding a listener * to an already completed task will result in that listener being called immediately. * + * Two listeners registered in the same thread will be invoked in reverse order of registration if + * the task completes after both are registered. There are no ordering guarantees for listeners + * registered in different threads, or for listeners registered after the task completes. + * Listeners are guaranteed to execute sequentially. + * * An example use is for HadoopRDD to register a callback to close the input stream. * * Exceptions thrown by the listener will result in failure of the task. diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 7d909a56774cc..cb7f4304d07cb 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -17,7 +17,7 @@ package org.apache.spark -import java.util.Properties +import java.util.{Properties, Stack} import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ @@ -39,9 +39,9 @@ import org.apache.spark.util._ * A small note on thread safety. The interrupted & fetchFailed fields are volatile, this makes * sure that updates are always visible across threads. The complete & failed flags and their * callbacks are protected by locking on the context instance. For instance, this ensures - * that you cannot add a completion listener in one thread while we are completing (and calling - * the completion listeners) in another thread. Other state is immutable, however the exposed - * `TaskMetrics` & `MetricsSystem` objects are not thread safe. + * that you cannot add a completion listener in one thread while we are completing in another + * thread. Other state is immutable, however the exposed `TaskMetrics` & `MetricsSystem` objects are + * not thread safe. */ private[spark] class TaskContextImpl( override val stageId: Int, @@ -59,11 +59,23 @@ private[spark] class TaskContextImpl( extends TaskContext with Logging { - /** List of callback functions to execute when the task completes. */ - @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener] + /** + * List of callback functions to execute when the task completes. + * + * Using a stack causes us to process listeners in reverse order of registration. As listeners are + * invoked, they are popped from the stack. + */ + @transient private val onCompleteCallbacks = new Stack[TaskCompletionListener] /** List of callback functions to execute when the task fails. */ - @transient private val onFailureCallbacks = new ArrayBuffer[TaskFailureListener] + @transient private val onFailureCallbacks = new Stack[TaskFailureListener] + + /** + * The thread currently executing task completion or failure listeners, if any. + * + * `invokeListeners()` uses this to ensure listeners are called sequentially. + */ + @transient private var listenerInvocationThread: Option[Thread] = None // If defined, the corresponding task has been killed and this option contains the reason. @volatile private var reasonIfKilled: Option[String] = None @@ -71,35 +83,36 @@ private[spark] class TaskContextImpl( // Whether the task has completed. private var completed: Boolean = false - // Whether the task has failed. - private var failed: Boolean = false - - // Throwable that caused the task to fail - private var failure: Throwable = _ + // If defined, the task has failed and this option contains the Throwable that caused the task to + // fail. + private var failureCauseOpt: Option[Throwable] = None // If there was a fetch failure in the task, we store it here, to make sure user-code doesn't // hide the exception. See SPARK-19276 @volatile private var _fetchFailedException: Option[FetchFailedException] = None - @GuardedBy("this") - override def addTaskCompletionListener(listener: TaskCompletionListener) - : this.type = synchronized { - if (completed) { - listener.onTaskCompletion(this) - } else { - onCompleteCallbacks += listener + override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { + val needToCallListener = synchronized { + // If there is already a thread invoking listeners, adding the new listener to + // `onCompleteCallbacks` will cause that thread to execute the new listener, and the call to + // `invokeTaskCompletionListeners()` below will be a no-op. + // + // If there is no such thread, the call to `invokeTaskCompletionListeners()` below will + // execute all listeners, including the new listener. + onCompleteCallbacks.push(listener) + completed + } + if (needToCallListener) { + invokeTaskCompletionListeners(None) } this } - @GuardedBy("this") - override def addTaskFailureListener(listener: TaskFailureListener) - : this.type = synchronized { - if (failed) { - listener.onTaskFailure(this, failure) - } else { - onFailureCallbacks += listener - } + override def addTaskFailureListener(listener: TaskFailureListener): this.type = { + synchronized { + onFailureCallbacks.push(listener) + failureCauseOpt + }.foreach(invokeTaskFailureListeners) this } @@ -107,33 +120,80 @@ private[spark] class TaskContextImpl( resources.asJava } - @GuardedBy("this") - private[spark] override def markTaskFailed(error: Throwable): Unit = synchronized { - if (failed) return - failed = true - failure = error - invokeListeners(onFailureCallbacks.toSeq, "TaskFailureListener", Option(error)) { - _.onTaskFailure(this, error) + private[spark] override def markTaskFailed(error: Throwable): Unit = { + synchronized { + if (failureCauseOpt.isDefined) return + failureCauseOpt = Some(error) } + invokeTaskFailureListeners(error) } - @GuardedBy("this") - private[spark] override def markTaskCompleted(error: Option[Throwable]): Unit = synchronized { - if (completed) return - completed = true - invokeListeners(onCompleteCallbacks.toSeq, "TaskCompletionListener", error) { + private[spark] override def markTaskCompleted(error: Option[Throwable]): Unit = { + synchronized { + if (completed) return + completed = true + } + invokeTaskCompletionListeners(error) + } + + private def invokeTaskCompletionListeners(error: Option[Throwable]): Unit = { + // It is safe to access the reference to `onCompleteCallbacks` without holding the TaskContext + // lock. `invokeListeners()` acquires the lock before accessing the contents. + invokeListeners(onCompleteCallbacks, "TaskCompletionListener", error) { _.onTaskCompletion(this) } } + private def invokeTaskFailureListeners(error: Throwable): Unit = { + // It is safe to access the reference to `onFailureCallbacks` without holding the TaskContext + // lock. `invokeListeners()` acquires the lock before accessing the contents. + invokeListeners(onFailureCallbacks, "TaskFailureListener", Option(error)) { + _.onTaskFailure(this, error) + } + } + private def invokeListeners[T]( - listeners: Seq[T], + listeners: Stack[T], name: String, error: Option[Throwable])( callback: T => Unit): Unit = { + // This method is subject to two constraints: + // + // 1. Listeners must be run sequentially to uphold the guarantee provided by the TaskContext + // API. + // + // 2. Listeners may spawn threads that call methods on this TaskContext. To avoid deadlock, we + // cannot call listeners while holding the TaskContext lock. + // + // We meet these constraints by ensuring there is at most one thread invoking listeners at any + // point in time. + synchronized { + if (listenerInvocationThread.nonEmpty) { + // If another thread is already invoking listeners, do nothing. + return + } else { + // If no other thread is invoking listeners, register this thread as the listener invocation + // thread. This prevents other threads from invoking listeners until this thread is + // deregistered. + listenerInvocationThread = Some(Thread.currentThread()) + } + } + + def getNextListenerOrDeregisterThread(): Option[T] = synchronized { + if (listeners.empty()) { + // We have executed all listeners that have been added so far. Deregister this thread as the + // callback invocation thread. + listenerInvocationThread = None + None + } else { + Some(listeners.pop()) + } + } + val errorMsgs = new ArrayBuffer[String](2) - // Process callbacks in the reverse order of registration - listeners.reverse.foreach { listener => + var listenerOption: Option[T] = None + while ({listenerOption = getNextListenerOrDeregisterThread(); listenerOption.nonEmpty}) { + val listener = listenerOption.get try { callback(listener) } catch { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index c50a8b9a78b1d..bbe55cb3ba9f4 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -239,10 +239,20 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( /** Contains the throwable thrown while writing the parent iterator to the Python process. */ def exception: Option[Throwable] = Option(_exception) - /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */ + /** + * Terminates the writer thread and waits for it to exit, ignoring any exceptions that may occur + * due to cleanup. + */ def shutdownOnTaskCompletion(): Unit = { assert(context.isCompleted) this.interrupt() + // Task completion listeners that run after this method returns may invalidate + // `inputIterator`. For example, when `inputIterator` was generated by the off-heap vectorized + // reader, a task completion listener will free the underlying off-heap buffers. If the writer + // thread is still running when `inputIterator` is invalidated, it can cause a use-after-free + // bug that crashes the executor (SPARK-33277). Therefore this method must wait for the writer + // thread to exit before returning. + this.join() } /** diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 2200b5b175119..693841d843f0b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -18,6 +18,9 @@ package org.apache.spark.scheduler import java.util.Properties +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.mutable.ArrayBuffer import org.mockito.ArgumentMatchers.any import org.mockito.Mockito._ @@ -334,6 +337,124 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark assert(e.getMessage.contains("exception in task")) } + test("listener registers another listener (reentrancy)") { + val context = TaskContext.empty() + var invocations = 0 + val simpleListener = new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = { + invocations += 1 + } + } + + // Create a listener that registers another listener. + val reentrantListener = new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = { + context.addTaskCompletionListener(simpleListener) + invocations += 1 + } + } + context.addTaskCompletionListener(reentrantListener) + + // Ensure the listener can execute without encountering deadlock. + assert(invocations == 0) + context.markTaskCompleted(None) + assert(invocations == 2) + } + + test("listener registers another listener using a second thread") { + val context = TaskContext.empty() + val invocations = new AtomicInteger(0) + val simpleListener = new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = { + invocations.getAndIncrement() + } + } + + // Create a listener that registers another listener using a second thread. + val multithreadedListener = new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = { + val thread = new Thread(new Runnable { + override def run(): Unit = { + context.addTaskCompletionListener(simpleListener) + } + }) + thread.start() + invocations.getAndIncrement() + thread.join() + } + } + context.addTaskCompletionListener(multithreadedListener) + + // Ensure the listener can execute without encountering deadlock. + assert(invocations.get() == 0) + context.markTaskCompleted(None) + assert(invocations.get() == 2) + } + + test("listeners registered from different threads are called sequentially") { + val context = TaskContext.empty() + val invocations = new AtomicInteger(0) + val numRunningListeners = new AtomicInteger(0) + + // Create a listener that will throw if more than one instance is running at the same time. + val registerExclusiveListener = new Runnable { + override def run(): Unit = { + context.addTaskCompletionListener(new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = { + if (numRunningListeners.getAndIncrement() != 0) throw new Exception() + Thread.sleep(100) + if (numRunningListeners.decrementAndGet() != 0) throw new Exception() + invocations.getAndIncrement() + } + }) + } + } + + // Register it multiple times from different threads before and after the task completes. + assert(invocations.get() == 0) + assert(numRunningListeners.get() == 0) + val thread1 = new Thread(registerExclusiveListener) + val thread2 = new Thread(registerExclusiveListener) + thread1.start() + thread2.start() + thread1.join() + thread2.join() + assert(invocations.get() == 0) + context.markTaskCompleted(None) + assert(invocations.get() == 2) + val thread3 = new Thread(registerExclusiveListener) + val thread4 = new Thread(registerExclusiveListener) + thread3.start() + thread4.start() + thread3.join() + thread4.join() + assert(invocations.get() == 4) + assert(numRunningListeners.get() == 0) + } + + test("listeners registered from same thread are called in reverse order") { + val context = TaskContext.empty() + val invocationOrder = ArrayBuffer.empty[String] + + // Create listeners that log an id to `invocationOrder` when they are invoked. + def makeLoggingListener(id: String): TaskCompletionListener = new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = { + invocationOrder += id + } + } + context.addTaskCompletionListener(makeLoggingListener("A")) + context.addTaskCompletionListener(makeLoggingListener("B")) + context.addTaskCompletionListener(makeLoggingListener("C")) + + // Ensure the listeners are called in reverse order of registration, except when they are called + // after the task is complete. + assert(invocationOrder === Seq.empty) + context.markTaskCompleted(None) + assert(invocationOrder === Seq("C", "B", "A")) + context.addTaskCompletionListener(makeLoggingListener("D")) + assert(invocationOrder === Seq("C", "B", "A", "D")) + } + } private object TaskContextSuite { diff --git a/python/pyspark/sql/tests/test_pandas_map.py b/python/pyspark/sql/tests/test_pandas_map.py index e8f92de417dda..a5e07b4e53ef9 100644 --- a/python/pyspark/sql/tests/test_pandas_map.py +++ b/python/pyspark/sql/tests/test_pandas_map.py @@ -15,9 +15,12 @@ # limitations under the License. # import os +import shutil +import tempfile import time import unittest +from pyspark.sql import Row from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message @@ -120,6 +123,24 @@ def test_self_join(self): expected = df1.join(df1).collect() self.assertEqual(sorted(actual), sorted(expected)) + # SPARK-33277 + def test_map_in_pandas_with_column_vector(self): + path = tempfile.mkdtemp() + shutil.rmtree(path) + + try: + self.spark.range(0, 200000, 1, 1).write.parquet(path) + + def func(iterator): + for pdf in iterator: + yield pd.DataFrame({'id': [0] * len(pdf)}) + + for offheap in ["true", "false"]: + with self.sql_conf({"spark.sql.columnVector.offheap.enabled": offheap}): + self.assertEquals( + self.spark.read.parquet(path).mapInPandas(func, 'id long').head(), Row(0)) + finally: + shutil.rmtree(path) if __name__ == "__main__": from pyspark.sql.tests.test_pandas_map import * # noqa: F401 diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index 2eb2dec00106e..a170f5532ee98 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -1152,6 +1152,25 @@ def test_datasource_with_udf(self): finally: shutil.rmtree(path) + # SPARK-33277 + def test_pandas_udf_with_column_vector(self): + path = tempfile.mkdtemp() + shutil.rmtree(path) + + try: + self.spark.range(0, 200000, 1, 1).write.parquet(path) + + @pandas_udf(LongType()) + def udf(x): + return pd.Series([0] * len(x)) + + for offheap in ["true", "false"]: + with self.sql_conf({"spark.sql.columnVector.offheap.enabled": offheap}): + self.assertEquals( + self.spark.read.parquet(path).select(udf('id')).head(), Row(0)) + finally: + shutil.rmtree(path) + if __name__ == "__main__": from pyspark.sql.tests.test_pandas_udf_scalar import * # noqa: F401 diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index 98d193f94dedf..be1e917172a46 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -703,6 +703,26 @@ def f(e): self.assertEqual(result.collect(), [Row(c1=Row(_1=1.0, _2=1.0), c2=Row(_1=1, _2=1), c3=1.0, c4=1)]) + # SPARK-33277 + def test_udf_with_column_vector(self): + path = tempfile.mkdtemp() + shutil.rmtree(path) + + try: + self.spark.range(0, 100000, 1, 1).write.parquet(path) + + def f(x): + return 0 + + fUdf = udf(f, LongType()) + + for offheap in ["true", "false"]: + with self.sql_conf({"spark.sql.columnVector.offheap.enabled": offheap}): + self.assertEquals( + self.spark.read.parquet(path).select(fUdf('id')).head(), Row(0)) + finally: + shutil.rmtree(path) + class UDFInitializationTests(unittest.TestCase): def tearDown(self):