diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index af558d6e5b474..190c8ea1e5d67 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.Source -import org.apache.spark.util.TaskCompletionListener +import org.apache.spark.util.{TaskCompletionListener, TaskFailureListener} object TaskContext { @@ -108,6 +108,8 @@ abstract class TaskContext extends Serializable { * Adds a (Java friendly) listener to be executed on task completion. * This will be called in all situation - success, failure, or cancellation. * An example use is for HadoopRDD to register a callback to close the input stream. + * + * Exceptions thrown by the listener will result in failure of the task. */ def addTaskCompletionListener(listener: TaskCompletionListener): TaskContext @@ -115,8 +117,30 @@ abstract class TaskContext extends Serializable { * Adds a listener in the form of a Scala closure to be executed on task completion. * This will be called in all situations - success, failure, or cancellation. * An example use is for HadoopRDD to register a callback to close the input stream. + * + * Exceptions thrown by the listener will result in failure of the task. + */ + def addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext = { + addTaskCompletionListener(new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = f(context) + }) + } + + /** + * Adds a listener to be executed on task failure. + * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times. */ - def addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext + def addTaskFailureListener(listener: TaskFailureListener): TaskContext + + /** + * Adds a listener to be executed on task failure. + * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times. + */ + def addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext = { + addTaskFailureListener(new TaskFailureListener { + override def onTaskFailure(context: TaskContext, error: Throwable): Unit = f(context, error) + }) + } /** * Adds a callback function to be executed on task completion. An example use diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index f0ae83a9341bd..123658721316f 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -23,7 +23,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.metrics.source.Source -import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} +import org.apache.spark.util._ private[spark] class TaskContextImpl( val stageId: Int, @@ -41,9 +41,12 @@ private[spark] class TaskContextImpl( // For backwards-compatibility; this method is now deprecated as of 1.3.0. override def attemptId(): Long = taskAttemptId - // List of callback functions to execute when the task completes. + /** List of callback functions to execute when the task completes. */ @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener] + /** List of callback functions to execute when the task fails. */ + @transient private val onFailureCallbacks = new ArrayBuffer[TaskFailureListener] + // Whether the corresponding task has been killed. @volatile private var interrupted: Boolean = false @@ -55,10 +58,8 @@ private[spark] class TaskContextImpl( this } - override def addTaskCompletionListener(f: TaskContext => Unit): this.type = { - onCompleteCallbacks += new TaskCompletionListener { - override def onTaskCompletion(context: TaskContext): Unit = f(context) - } + override def addTaskFailureListener(listener: TaskFailureListener): this.type = { + onFailureCallbacks += listener this } @@ -69,7 +70,25 @@ private[spark] class TaskContextImpl( } } - /** Marks the task as completed and triggers the listeners. */ + /** Marks the task as completed and triggers the failure listeners. */ + private[spark] def markTaskFailed(error: Throwable): Unit = { + val errorMsgs = new ArrayBuffer[String](2) + // Process complete callbacks in the reverse order of registration + onFailureCallbacks.reverse.foreach { listener => + try { + listener.onTaskFailure(this, error) + } catch { + case e: Throwable => + errorMsgs += e.getMessage + logError("Error in TaskFailureListener", e) + } + } + if (errorMsgs.nonEmpty) { + throw new TaskCompletionListenerException(errorMsgs, Option(error)) + } + } + + /** Marks the task as completed and triggers the completion listeners. */ private[spark] def markTaskCompleted(): Unit = { completed = true val errorMsgs = new ArrayBuffer[String](2) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 304f723e4924e..17304ea19204b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -87,7 +87,12 @@ private[spark] abstract class Task[T]( } try { (runTask(context), context.collectAccumulators()) + } catch { case e: Throwable => + // Catch all errors; run task failure callbacks, and rethrow the exception. + context.markTaskFailed(e) + throw e } finally { + // Call the task completion callbacks. context.markTaskCompleted() try { Utils.tryLogNonFatalError { diff --git a/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala b/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala deleted file mode 100644 index f64e069cd1724..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -/** - * Exception thrown when there is an exception in - * executing the callback in TaskCompletionListener. - */ -private[spark] -class TaskCompletionListenerException(errorMessages: Seq[String]) extends Exception { - - override def getMessage: String = { - if (errorMessages.size == 1) { - errorMessages.head - } else { - errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n") - } - } -} diff --git a/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala b/core/src/main/scala/org/apache/spark/util/taskListeners.scala similarity index 51% rename from core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala rename to core/src/main/scala/org/apache/spark/util/taskListeners.scala index c1b8bf052c0ca..1be31e88ab68e 100644 --- a/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala +++ b/core/src/main/scala/org/apache/spark/util/taskListeners.scala @@ -29,5 +29,40 @@ import org.apache.spark.annotation.DeveloperApi */ @DeveloperApi trait TaskCompletionListener extends EventListener { - def onTaskCompletion(context: TaskContext) + def onTaskCompletion(context: TaskContext): Unit +} + + +/** + * :: DeveloperApi :: + * + * Listener providing a callback function to invoke when a task's execution encounters an error. + * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times. + */ +@DeveloperApi +trait TaskFailureListener extends EventListener { + def onTaskFailure(context: TaskContext, error: Throwable): Unit +} + + +/** + * Exception thrown when there is an exception in executing the callback in TaskCompletionListener. + */ +private[spark] +class TaskCompletionListenerException( + errorMessages: Seq[String], + val previousError: Option[Throwable] = None) + extends RuntimeException { + + override def getMessage: String = { + if (errorMessages.size == 1) { + errorMessages.head + } else { + errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n") + } + + previousError.map { e => + "\n\nPrevious exception in task: " + e.getMessage + "\n" + + e.getStackTrace.mkString("\t", "\n\t", "") + }.getOrElse("") + } } diff --git a/core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java b/core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java deleted file mode 100644 index e38bc38949d7c..0000000000000 --- a/core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package test.org.apache.spark; - -import org.apache.spark.TaskContext; -import org.apache.spark.util.TaskCompletionListener; - - -/** - * A simple implementation of TaskCompletionListener that makes sure TaskCompletionListener and - * TaskContext is Java friendly. - */ -public class JavaTaskCompletionListenerImpl implements TaskCompletionListener { - - @Override - public void onTaskCompletion(TaskContext context) { - context.isCompleted(); - context.isInterrupted(); - context.stageId(); - context.partitionId(); - context.isRunningLocally(); - context.addTaskCompletionListener(this); - } -} diff --git a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java index 4a918f725dc91..f914081d7d5b2 100644 --- a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java +++ b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java @@ -18,6 +18,8 @@ package test.org.apache.spark; import org.apache.spark.TaskContext; +import org.apache.spark.util.TaskCompletionListener; +import org.apache.spark.util.TaskFailureListener; /** * Something to make sure that TaskContext can be used in Java. @@ -32,10 +34,38 @@ public static void test() { tc.isRunningLocally(); tc.addTaskCompletionListener(new JavaTaskCompletionListenerImpl()); + tc.addTaskFailureListener(new JavaTaskFailureListenerImpl()); tc.attemptNumber(); tc.partitionId(); tc.stageId(); tc.taskAttemptId(); } + + /** + * A simple implementation of TaskCompletionListener that makes sure TaskCompletionListener and + * TaskContext is Java friendly. + */ + static class JavaTaskCompletionListenerImpl implements TaskCompletionListener { + @Override + public void onTaskCompletion(TaskContext context) { + context.isCompleted(); + context.isInterrupted(); + context.stageId(); + context.partitionId(); + context.isRunningLocally(); + context.addTaskCompletionListener(this); + } + } + + /** + * A simple implementation of TaskCompletionListener that makes sure TaskCompletionListener and + * TaskContext is Java friendly. + */ + static class JavaTaskFailureListenerImpl implements TaskFailureListener { + @Override + public void onTaskFailure(TaskContext context, Throwable error) { + } + } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 450ab7b9fe92b..9df605f7db0bf 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -17,16 +17,15 @@ package org.apache.spark.scheduler -import org.mockito.Mockito._ import org.mockito.Matchers.any - +import org.mockito.Mockito._ import org.scalatest.BeforeAndAfter import org.apache.spark._ -import org.apache.spark.rdd.RDD -import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} import org.apache.spark.metrics.source.JvmSource - +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.util._ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { @@ -66,6 +65,27 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark assert(TaskContextSuite.completed === true) } + test("calls TaskFailureListeners after failure") { + TaskContextSuite.lastError = null + sc = new SparkContext("local", "test") + val rdd = new RDD[String](sc, List()) { + override def getPartitions = Array[Partition](StubPartition(0)) + override def compute(split: Partition, context: TaskContext) = { + context.addTaskFailureListener((context, error) => TaskContextSuite.lastError = error) + sys.error("damn error") + } + } + val closureSerializer = SparkEnv.get.closureSerializer.newInstance() + val func = (c: TaskContext, i: Iterator[String]) => i.next() + val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func)))) + val task = new ResultTask[String, String]( + 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, Seq.empty) + intercept[RuntimeException] { + task.run(0, 0, null) + } + assert(TaskContextSuite.lastError.getMessage == "damn error") + } + test("all TaskCompletionListeners should be called even if some fail") { val context = TaskContext.empty() val listener = mock(classOf[TaskCompletionListener]) @@ -80,6 +100,26 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark verify(listener, times(1)).onTaskCompletion(any()) } + test("all TaskFailureListeners should be called even if some fail") { + val context = TaskContext.empty() + val listener = mock(classOf[TaskFailureListener]) + context.addTaskFailureListener((_, _) => throw new Exception("exception in listener1")) + context.addTaskFailureListener(listener) + context.addTaskFailureListener((_, _) => throw new Exception("exception in listener3")) + + val e = intercept[TaskCompletionListenerException] { + context.markTaskFailed(new Exception("exception in task")) + } + + // Make sure listener 2 was called. + verify(listener, times(1)).onTaskFailure(any(), any()) + + // also need to check failure in TaskFailureListener does not mask earlier exception + assert(e.getMessage.contains("exception in listener1")) + assert(e.getMessage.contains("exception in listener3")) + assert(e.getMessage.contains("exception in task")) + } + test("TaskContext.attemptNumber should return attempt number, not task id (SPARK-4014)") { sc = new SparkContext("local[1,2]", "test") // use maxRetries = 2 because we test failed tasks // Check that attemptIds are 0 for all tasks' initial attempts @@ -110,6 +150,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark private object TaskContextSuite { @volatile var completed = false + + @volatile var lastError: Throwable = _ } private case class StubPartition(index: Int) extends Partition diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 08b4a2349ac4b..8b95909179036 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -158,7 +158,9 @@ object MimaExcludes { ) ++ Seq( // SPARK-3580 Add getNumPartitions method to JavaRDD ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.getNumPartitions") + "org.apache.spark.api.java.JavaRDDLike.getNumPartitions"), + // SPARK-13465 TaskContext. + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.addTaskFailureListener") ) ++ Seq( // SPARK-12591 Register OpenHashMapBasedStateMap for Kryo ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.serializer.KryoInputDataInputBridge"),