Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,12 @@ private[spark] class TaskSetManager(
private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = {
partitionToIndex.get(partitionId).foreach { index =>
if (!successful(index)) {
if (speculationEnabled) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC in this case no task in this taskSet actually successfully finishes, it's another task attempt from another taskSet for the same stage that succeeded. In stead of changing this code path, I'd suggest we have another flag to show whether any task succeeded in current taskSet, and if no task have succeeded, skip L987.

WDYT @squito ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that is sort of what I was suggesting -- but I was thinking rather than just a flag, maybe we separate out tasksSuccessful into tasksCompletedSuccessfully (from this taskset) and tasksNoLongerNecessary (from any taskset), perhaps with better names. If you just had a flag, you would avoid the exception from the empty heap, but you still might decide to enable speculation prematurely as you really haven't finished enough for SPECULATION_QUANTILE:

if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

speculationEnabled && ! isZombie

taskAttempts(index).headOption.map { info =>
info.markFinished(TaskState.FINISHED, clock.getTimeMillis())
successfulTaskDurations.insert(info.duration)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the normal code path to update task durations?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TaskSetManager#handleSuccessfulTask update successful task durations, and write to successfulTaskDurations.

When there are multiple tasksets for this stage, markPartitionCompletedInAllTaskSets is
accumulate the value of tasksSuccessful.

In this case, when checkSpeculatableTasks is called, the value of tasksSuccessful matches the condition, but successfulTaskDurations is empty.

https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala#L723

  def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]): Unit = {
    val info = taskInfos(tid)
    val index = info.index
    info.markFinished(TaskState.FINISHED, clock.getTimeMillis())
    if (speculationEnabled) {
      successfulTaskDurations.insert(info.duration)
    }
   // ...
   // There may be multiple tasksets for this stage -- we let all of them know that the partition
   // was completed.  This may result in some of the tasksets getting completed.
    sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId)

https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala#L987

override def checkSpeculatableTasks(minTimeToSpeculation: Int): Boolean = {
//...
  if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) {
      val time = clock.getTimeMillis()
      val medianDuration = successfulTaskDurations.median

}
}
tasksSuccessful += 1
successful(index) = true
if (tasksSuccessful == numTasks) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1365,6 +1365,55 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
assert(taskOption4.get.addedJars === addedJarsMidTaskSet)
}

test("[SPARK-24677] MedianHeap should not be empty when speculation is enabled") {
val conf = new SparkConf().set("spark.speculation", "true")
sc = new SparkContext("local", "test", conf)
// Set the speculation multiplier to be 0 so speculative tasks are launched immediately
sc.conf.set("spark.speculation.multiplier", "0.0")
sc.conf.set("spark.speculation.quantile", "0.1")
sc.conf.set("spark.speculation", "true")

sched = new FakeTaskScheduler(sc)
sched.initialize(new FakeSchedulerBackend())

val dagScheduler = new FakeDAGScheduler(sc, sched)
sched.setDAGScheduler(dagScheduler)

val taskSet1 = FakeTask.createTaskSet(10)
val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet1.tasks.map { task =>
task.metrics.internalAccums
}

sched.submitTasks(taskSet1)
sched.resourceOffers(
(0 until 10).map { idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) })

val taskSetManager1 = sched.taskSetManagerForAttempt(0, 0).get

// fail fetch
taskSetManager1.handleFailedTask(
taskSetManager1.taskAttempts.head.head.taskId, TaskState.FAILED,
FetchFailed(null, 0, 0, 0, "fetch failed"))

assert(taskSetManager1.isZombie)
assert(taskSetManager1.runningTasks === 9)

val taskSet2 = FakeTask.createTaskSet(10, stageAttemptId = 1)
sched.submitTasks(taskSet2)
sched.resourceOffers(
(11 until 20).map { idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) })

// Complete the 2 tasks and leave 8 task in running
for (id <- Set(0, 1)) {
taskSetManager1.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id)))
assert(sched.endedTasks(id) === Success)
}

val taskSetManager2 = sched.taskSetManagerForAttempt(0, 1).get
assert(!taskSetManager2.successfulTaskDurations.isEmpty())
taskSetManager2.checkSpeculatableTasks(0)
}

private def createTaskResult(
id: Int,
accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): DirectTaskResult[Int] = {
Expand Down