Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1433,17 +1433,18 @@ class DAGScheduler(
val failedStage = stageIdToStage(task.stageId)
logInfo(s"Marking $failedStage (${failedStage.name}) as failed due to a barrier task " +
"failed.")
val message = s"Stage failed because barrier task $task finished unsuccessfully. " +
val message = s"Stage failed because barrier task $task finished unsuccessfully.\n" +
failure.toErrorString
try {
// cancelTasks will fail if a SchedulerBackend does not implement killTask
taskScheduler.cancelTasks(stageId, interruptThread = false)
// killAllTaskAttempts will fail if a SchedulerBackend does not implement killTask.
val reason = s"Task $task from barrier stage $failedStage (${failedStage.name}) failed."
taskScheduler.killAllTaskAttempts(stageId, interruptThread = false, reason)
} catch {
case e: UnsupportedOperationException =>
// Cannot continue with barrier stage if failed to cancel zombie barrier tasks.
// TODO SPARK-24877 leave the zombie tasks and ignore their completion events.
logWarning(s"Could not cancel tasks for stage $stageId", e)
abortStage(failedStage, "Could not cancel zombie barrier tasks for stage " +
logWarning(s"Could not kill all tasks for stage $stageId", e)
abortStage(failedStage, "Could not kill zombie barrier tasks for stage " +
s"$failedStage (${failedStage.name})", Some(e))
}
markStageAsFinished(failedStage, Some(message))
Expand All @@ -1457,7 +1458,8 @@ class DAGScheduler(

if (shouldAbortStage) {
val abortMessage = if (disallowStageRetryForTest) {
"Barrier stage will not retry stage due to testing config"
"Barrier stage will not retry stage due to testing config. Most recent failure " +
s"reason: $message"
} else {
s"""$failedStage (${failedStage.name})
|has failed the maximum allowable number of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,22 @@ private[spark] trait TaskScheduler {
// Submit a sequence of tasks to run.
def submitTasks(taskSet: TaskSet): Unit

// Cancel a stage.
// Kill all the tasks in a stage and fail the stage and all the jobs that depend on the stage.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it guaranteed to work for any backend like YARN, Mesos, K8s?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated comment to note that if the backend doesn't support kill a task then the method shall throw UnsupportedOperationException.

// Throw UnsupportedOperationException if the backend doesn't support kill tasks.
def cancelTasks(stageId: Int, interruptThread: Boolean): Unit

/**
* Kills a task attempt.
* Throw UnsupportedOperationException if the backend doesn't support kill a task.
*
* @return Whether the task was successfully killed.
*/
def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean

// Kill all the running task attempts in a stage.
// Throw UnsupportedOperationException if the backend doesn't support kill tasks.
def killAllTaskAttempts(stageId: Int, interruptThread: Boolean, reason: String): Unit

// Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called.
def setDAGScheduler(dagScheduler: DAGScheduler): Unit

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,18 +222,11 @@ private[spark] class TaskSchedulerImpl(

override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
logInfo("Cancelling stage " + stageId)
// Kill all running tasks for the stage.
killAllTaskAttempts(stageId, interruptThread, reason = "Stage cancelled")
// Cancel all attempts for the stage.
taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts =>
attempts.foreach { case (_, tsm) =>
// There are two possible cases here:
// 1. The task set manager has been created and some tasks have been scheduled.
// In this case, send a kill signal to the executors to kill the task and then abort
// the stage.
// 2. The task set manager has been created but no tasks have been scheduled. In this case,
// simply abort the stage.
tsm.runningTasksSet.foreach { tid =>
taskIdToExecutorId.get(tid).foreach(execId =>
backend.killTask(tid, execId, interruptThread, reason = "Stage cancelled"))
}
tsm.abort("Stage %s cancelled".format(stageId))
logInfo("Stage %d was cancelled".format(stageId))
}
Expand All @@ -252,6 +245,27 @@ private[spark] class TaskSchedulerImpl(
}
}

override def killAllTaskAttempts(
stageId: Int,
interruptThread: Boolean,
reason: String): Unit = synchronized {
logInfo(s"Killing all running tasks in stage $stageId: $reason")
taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is some dup code and we dropped the useful comments from cancelTasks. It would be great if we move the common code here with comment and let cancelTasks call this method.

attempts.foreach { case (_, tsm) =>
// There are two possible cases here:
// 1. The task set manager has been created and some tasks have been scheduled.
// In this case, send a kill signal to the executors to kill the task.
// 2. The task set manager has been created but no tasks have been scheduled. In this case,
// simply continue.
tsm.runningTasksSet.foreach { tid =>
taskIdToExecutorId.get(tid).foreach { execId =>
backend.killTask(tid, execId, interruptThread, reason)
}
}
}
}
}

/**
* Called to indicate that all task attempts (including speculated tasks) associated with the
* given TaskSetManager have completed, so state associated with the TaskSetManager should be
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
}
override def killTaskAttempt(
taskId: Long, interruptThread: Boolean, reason: String): Boolean = false
override def killAllTaskAttempts(
stageId: Int, interruptThread: Boolean, reason: String): Unit = {}
override def setDAGScheduler(dagScheduler: DAGScheduler) = {}
override def defaultParallelism() = 2
override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}
Expand Down Expand Up @@ -629,6 +631,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
taskId: Long, interruptThread: Boolean, reason: String): Boolean = {
throw new UnsupportedOperationException
}
override def killAllTaskAttempts(
stageId: Int, interruptThread: Boolean, reason: String): Unit = {
throw new UnsupportedOperationException
}
override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {}
override def defaultParallelism(): Int = 2
override def executorHeartbeatReceived(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ private class DummyTaskScheduler extends TaskScheduler {
override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = {}
override def killTaskAttempt(
taskId: Long, interruptThread: Boolean, reason: String): Boolean = false
override def killAllTaskAttempts(
stageId: Int, interruptThread: Boolean, reason: String): Unit = {}
override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {}
override def defaultParallelism(): Int = 2
override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1055,4 +1055,66 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten
assert(3 === taskDescriptions.length)
}

test("cancelTasks shall kill all the running tasks and fail the stage") {
val taskScheduler = setupScheduler()

taskScheduler.initialize(new FakeSchedulerBackend {
override def killTask(
taskId: Long,
executorId: String,
interruptThread: Boolean,
reason: String): Unit = {
// Since we only submit one stage attempt, the following call is sufficient to mark the
// task as killed.
taskScheduler.taskSetManagerForAttempt(0, 0).get.runningTasksSet.remove(taskId)
}
})

val attempt1 = FakeTask.createTaskSet(10, 0)
taskScheduler.submitTasks(attempt1)

val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 1),
new WorkerOffer("executor1", "host1", 1))
val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten
assert(2 === taskDescriptions.length)
val tsm = taskScheduler.taskSetManagerForAttempt(0, 0).get
assert(2 === tsm.runningTasks)

taskScheduler.cancelTasks(0, false)
assert(0 === tsm.runningTasks)
assert(tsm.isZombie)
assert(taskScheduler.taskSetManagerForAttempt(0, 0).isEmpty)
}

test("killAllTaskAttempts shall kill all the running tasks and not fail the stage") {
val taskScheduler = setupScheduler()

taskScheduler.initialize(new FakeSchedulerBackend {
override def killTask(
taskId: Long,
executorId: String,
interruptThread: Boolean,
reason: String): Unit = {
// Since we only submit one stage attempt, the following call is sufficient to mark the
// task as killed.
taskScheduler.taskSetManagerForAttempt(0, 0).get.runningTasksSet.remove(taskId)
}
})

val attempt1 = FakeTask.createTaskSet(10, 0)
taskScheduler.submitTasks(attempt1)

val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 1),
new WorkerOffer("executor1", "host1", 1))
val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten
assert(2 === taskDescriptions.length)
val tsm = taskScheduler.taskSetManagerForAttempt(0, 0).get
assert(2 === tsm.runningTasks)

taskScheduler.killAllTaskAttempts(0, false, "test")
assert(0 === tsm.runningTasks)
assert(!tsm.isZombie)
assert(taskScheduler.taskSetManagerForAttempt(0, 0).isDefined)
}
}