Skip to content

Commit 391755d

Browse files
rxindavies
authored andcommitted
[SPARK-13465] Add a task failure listener to TaskContext
## What changes were proposed in this pull request? TaskContext supports task completion callback, which gets called regardless of task failures. However, there is no way for the listener to know if there is an error. This patch adds a new listener that gets called when a task fails. ## How was the this patch tested? New unit test case and integration test case covering the code path Author: Reynold Xin <[email protected]> Closes #11340 from rxin/SPARK-13465.
1 parent 0598a2b commit 391755d

File tree

9 files changed

+169
-85
lines changed

9 files changed

+169
-85
lines changed

core/src/main/scala/org/apache/spark/TaskContext.scala

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi
2323
import org.apache.spark.executor.TaskMetrics
2424
import org.apache.spark.memory.TaskMemoryManager
2525
import org.apache.spark.metrics.source.Source
26-
import org.apache.spark.util.TaskCompletionListener
26+
import org.apache.spark.util.{TaskCompletionListener, TaskFailureListener}
2727

2828

2929
object TaskContext {
@@ -106,15 +106,39 @@ abstract class TaskContext extends Serializable {
106106
* Adds a (Java friendly) listener to be executed on task completion.
107107
* This will be called in all situation - success, failure, or cancellation.
108108
* An example use is for HadoopRDD to register a callback to close the input stream.
109+
*
110+
* Exceptions thrown by the listener will result in failure of the task.
109111
*/
110112
def addTaskCompletionListener(listener: TaskCompletionListener): TaskContext
111113

112114
/**
113115
* Adds a listener in the form of a Scala closure to be executed on task completion.
114116
* This will be called in all situations - success, failure, or cancellation.
115117
* An example use is for HadoopRDD to register a callback to close the input stream.
118+
*
119+
* Exceptions thrown by the listener will result in failure of the task.
116120
*/
117-
def addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext
121+
def addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext = {
122+
addTaskCompletionListener(new TaskCompletionListener {
123+
override def onTaskCompletion(context: TaskContext): Unit = f(context)
124+
})
125+
}
126+
127+
/**
128+
* Adds a listener to be executed on task failure.
129+
* Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times.
130+
*/
131+
def addTaskFailureListener(listener: TaskFailureListener): TaskContext
132+
133+
/**
134+
* Adds a listener to be executed on task failure.
135+
* Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times.
136+
*/
137+
def addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext = {
138+
addTaskFailureListener(new TaskFailureListener {
139+
override def onTaskFailure(context: TaskContext, error: Throwable): Unit = f(context, error)
140+
})
141+
}
118142

119143
/**
120144
* The ID of the stage that this task belong to.

core/src/main/scala/org/apache/spark/TaskContextImpl.scala

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.executor.TaskMetrics
2323
import org.apache.spark.memory.TaskMemoryManager
2424
import org.apache.spark.metrics.MetricsSystem
2525
import org.apache.spark.metrics.source.Source
26-
import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}
26+
import org.apache.spark.util._
2727

2828
private[spark] class TaskContextImpl(
2929
val stageId: Int,
@@ -41,9 +41,12 @@ private[spark] class TaskContextImpl(
4141
*/
4242
override val taskMetrics: TaskMetrics = new TaskMetrics(initialAccumulators)
4343

44-
// List of callback functions to execute when the task completes.
44+
/** List of callback functions to execute when the task completes. */
4545
@transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]
4646

47+
/** List of callback functions to execute when the task fails. */
48+
@transient private val onFailureCallbacks = new ArrayBuffer[TaskFailureListener]
49+
4750
// Whether the corresponding task has been killed.
4851
@volatile private var interrupted: Boolean = false
4952

@@ -55,14 +58,30 @@ private[spark] class TaskContextImpl(
5558
this
5659
}
5760

58-
override def addTaskCompletionListener(f: TaskContext => Unit): this.type = {
59-
onCompleteCallbacks += new TaskCompletionListener {
60-
override def onTaskCompletion(context: TaskContext): Unit = f(context)
61-
}
61+
override def addTaskFailureListener(listener: TaskFailureListener): this.type = {
62+
onFailureCallbacks += listener
6263
this
6364
}
6465

65-
/** Marks the task as completed and triggers the listeners. */
66+
/** Marks the task as completed and triggers the failure listeners. */
67+
private[spark] def markTaskFailed(error: Throwable): Unit = {
68+
val errorMsgs = new ArrayBuffer[String](2)
69+
// Process complete callbacks in the reverse order of registration
70+
onFailureCallbacks.reverse.foreach { listener =>
71+
try {
72+
listener.onTaskFailure(this, error)
73+
} catch {
74+
case e: Throwable =>
75+
errorMsgs += e.getMessage
76+
logError("Error in TaskFailureListener", e)
77+
}
78+
}
79+
if (errorMsgs.nonEmpty) {
80+
throw new TaskCompletionListenerException(errorMsgs, Option(error))
81+
}
82+
}
83+
84+
/** Marks the task as completed and triggers the completion listeners. */
6685
private[spark] def markTaskCompleted(): Unit = {
6786
completed = true
6887
val errorMsgs = new ArrayBuffer[String](2)

core/src/main/scala/org/apache/spark/scheduler/Task.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,12 @@ private[spark] abstract class Task[T](
8080
}
8181
try {
8282
runTask(context)
83+
} catch { case e: Throwable =>
84+
// Catch all errors; run task failure callbacks, and rethrow the exception.
85+
context.markTaskFailed(e)
86+
throw e
8387
} finally {
88+
// Call the task completion callbacks.
8489
context.markTaskCompleted()
8590
try {
8691
Utils.tryLogNonFatalError {

core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala

Lines changed: 0 additions & 34 deletions
This file was deleted.

core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala renamed to core/src/main/scala/org/apache/spark/util/taskListeners.scala

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,40 @@ import org.apache.spark.annotation.DeveloperApi
2929
*/
3030
@DeveloperApi
3131
trait TaskCompletionListener extends EventListener {
32-
def onTaskCompletion(context: TaskContext)
32+
def onTaskCompletion(context: TaskContext): Unit
33+
}
34+
35+
36+
/**
37+
* :: DeveloperApi ::
38+
*
39+
* Listener providing a callback function to invoke when a task's execution encounters an error.
40+
* Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times.
41+
*/
42+
@DeveloperApi
43+
trait TaskFailureListener extends EventListener {
44+
def onTaskFailure(context: TaskContext, error: Throwable): Unit
45+
}
46+
47+
48+
/**
49+
* Exception thrown when there is an exception in executing the callback in TaskCompletionListener.
50+
*/
51+
private[spark]
52+
class TaskCompletionListenerException(
53+
errorMessages: Seq[String],
54+
val previousError: Option[Throwable] = None)
55+
extends RuntimeException {
56+
57+
override def getMessage: String = {
58+
if (errorMessages.size == 1) {
59+
errorMessages.head
60+
} else {
61+
errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n")
62+
} +
63+
previousError.map { e =>
64+
"\n\nPrevious exception in task: " + e.getMessage + "\n" +
65+
e.getStackTrace.mkString("\t", "\n\t", "")
66+
}.getOrElse("")
67+
}
3368
}

core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java

Lines changed: 0 additions & 39 deletions
This file was deleted.

core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
package test.org.apache.spark;
1919

2020
import org.apache.spark.TaskContext;
21+
import org.apache.spark.util.TaskCompletionListener;
22+
import org.apache.spark.util.TaskFailureListener;
2123

2224
/**
2325
* Something to make sure that TaskContext can be used in Java.
@@ -32,10 +34,38 @@ public static void test() {
3234
tc.isRunningLocally();
3335

3436
tc.addTaskCompletionListener(new JavaTaskCompletionListenerImpl());
37+
tc.addTaskFailureListener(new JavaTaskFailureListenerImpl());
3538

3639
tc.attemptNumber();
3740
tc.partitionId();
3841
tc.stageId();
3942
tc.taskAttemptId();
4043
}
44+
45+
/**
46+
* A simple implementation of TaskCompletionListener that makes sure TaskCompletionListener and
47+
* TaskContext is Java friendly.
48+
*/
49+
static class JavaTaskCompletionListenerImpl implements TaskCompletionListener {
50+
@Override
51+
public void onTaskCompletion(TaskContext context) {
52+
context.isCompleted();
53+
context.isInterrupted();
54+
context.stageId();
55+
context.partitionId();
56+
context.isRunningLocally();
57+
context.addTaskCompletionListener(this);
58+
}
59+
}
60+
61+
/**
62+
* A simple implementation of TaskCompletionListener that makes sure TaskCompletionListener and
63+
* TaskContext is Java friendly.
64+
*/
65+
static class JavaTaskFailureListenerImpl implements TaskFailureListener {
66+
@Override
67+
public void onTaskFailure(TaskContext context, Throwable error) {
68+
}
69+
}
70+
4171
}

core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.memory.TaskMemoryManager
2727
import org.apache.spark.metrics.source.JvmSource
2828
import org.apache.spark.network.util.JavaUtils
2929
import org.apache.spark.rdd.RDD
30-
import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}
30+
import org.apache.spark.util._
3131

3232
class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext {
3333

@@ -66,6 +66,26 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
6666
assert(TaskContextSuite.completed === true)
6767
}
6868

69+
test("calls TaskFailureListeners after failure") {
70+
TaskContextSuite.lastError = null
71+
sc = new SparkContext("local", "test")
72+
val rdd = new RDD[String](sc, List()) {
73+
override def getPartitions = Array[Partition](StubPartition(0))
74+
override def compute(split: Partition, context: TaskContext) = {
75+
context.addTaskFailureListener((context, error) => TaskContextSuite.lastError = error)
76+
sys.error("damn error")
77+
}
78+
}
79+
val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
80+
val func = (c: TaskContext, i: Iterator[String]) => i.next()
81+
val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func))))
82+
val task = new ResultTask[String, String](0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0)
83+
intercept[RuntimeException] {
84+
task.run(0, 0, null)
85+
}
86+
assert(TaskContextSuite.lastError.getMessage == "damn error")
87+
}
88+
6989
test("all TaskCompletionListeners should be called even if some fail") {
7090
val context = TaskContext.empty()
7191
val listener = mock(classOf[TaskCompletionListener])
@@ -80,6 +100,26 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
80100
verify(listener, times(1)).onTaskCompletion(any())
81101
}
82102

103+
test("all TaskFailureListeners should be called even if some fail") {
104+
val context = TaskContext.empty()
105+
val listener = mock(classOf[TaskFailureListener])
106+
context.addTaskFailureListener((_, _) => throw new Exception("exception in listener1"))
107+
context.addTaskFailureListener(listener)
108+
context.addTaskFailureListener((_, _) => throw new Exception("exception in listener3"))
109+
110+
val e = intercept[TaskCompletionListenerException] {
111+
context.markTaskFailed(new Exception("exception in task"))
112+
}
113+
114+
// Make sure listener 2 was called.
115+
verify(listener, times(1)).onTaskFailure(any(), any())
116+
117+
// also need to check failure in TaskFailureListener does not mask earlier exception
118+
assert(e.getMessage.contains("exception in listener1"))
119+
assert(e.getMessage.contains("exception in listener3"))
120+
assert(e.getMessage.contains("exception in task"))
121+
}
122+
83123
test("TaskContext.attemptNumber should return attempt number, not task id (SPARK-4014)") {
84124
sc = new SparkContext("local[1,2]", "test") // use maxRetries = 2 because we test failed tasks
85125
// Check that attemptIds are 0 for all tasks' initial attempts
@@ -153,6 +193,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
153193

154194
private object TaskContextSuite {
155195
@volatile var completed = false
196+
197+
@volatile var lastError: Throwable = _
156198
}
157199

158200
private case class StubPartition(index: Int) extends Partition

project/MimaExcludes.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,9 @@ object MimaExcludes {
271271
) ++ Seq(
272272
// SPARK-13220 Deprecate yarn-client and yarn-cluster mode
273273
ProblemFilters.exclude[MissingMethodProblem](
274-
"org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler")
274+
"org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler"),
275+
// SPARK-13465 TaskContext.
276+
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.addTaskFailureListener")
275277
) ++ Seq (
276278
// SPARK-7729 Executor which has been killed should also be displayed on Executor Tab
277279
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorSummary.this")

0 commit comments

Comments
 (0)