Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -697,9 +697,12 @@ private[spark] class TaskSchedulerImpl(
* do not also submit those same tasks. That also means that a task completion from an earlier
* attempt can lead to the entire stage getting marked as successful.
*/
private[scheduler] def markPartitionCompletedInAllTaskSets(stageId: Int, partitionId: Int) = {
private[scheduler] def markPartitionCompletedInAllTaskSets(
stageId: Int,
partitionId: Int,
taskInfo: TaskInfo) = {
taskSetsByStageIdAndAttempt.getOrElse(stageId, Map()).values.foreach { tsm =>
tsm.markPartitionCompleted(partitionId)
tsm.markPartitionCompleted(partitionId, taskInfo)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ private[spark] class TaskSetManager(
}
// There may be multiple tasksets for this stage -- we let all of them know that the partition
// was completed. This may result in some of the tasksets getting completed.
sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId)
sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId, info)
// This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the
// "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not
// "deserialize" the value when holding a lock to avoid blocking other threads. So we call
Expand All @@ -769,9 +769,12 @@ private[spark] class TaskSetManager(
maybeFinishTaskSet()
}

private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = {
private[scheduler] def markPartitionCompleted(partitionId: Int, taskInfo: TaskInfo): Unit = {
partitionToIndex.get(partitionId).foreach { index =>
if (!successful(index)) {
if (speculationEnabled && !isZombie) {
successfulTaskDurations.insert(taskInfo.duration)
}
tasksSuccessful += 1
successful(index) = true
if (tasksSuccessful == numTasks) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1365,6 +1365,55 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
assert(taskOption4.get.addedJars === addedJarsMidTaskSet)
}

test("[SPARK-24677] Avoid NoSuchElementException from MedianHeap") {
val conf = new SparkConf().set("spark.speculation", "true")
sc = new SparkContext("local", "test", conf)
// Set the speculation multiplier to be 0 so speculative tasks are launched immediately
sc.conf.set("spark.speculation.multiplier", "0.0")
sc.conf.set("spark.speculation.quantile", "0.1")
sc.conf.set("spark.speculation", "true")

sched = new FakeTaskScheduler(sc)
sched.initialize(new FakeSchedulerBackend())

val dagScheduler = new FakeDAGScheduler(sc, sched)
sched.setDAGScheduler(dagScheduler)

val taskSet1 = FakeTask.createTaskSet(10)
val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet1.tasks.map { task =>
task.metrics.internalAccums
}

sched.submitTasks(taskSet1)
sched.resourceOffers(
(0 until 10).map { idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) })

val taskSetManager1 = sched.taskSetManagerForAttempt(0, 0).get

// fail fetch
taskSetManager1.handleFailedTask(
taskSetManager1.taskAttempts.head.head.taskId, TaskState.FAILED,
FetchFailed(null, 0, 0, 0, "fetch failed"))

assert(taskSetManager1.isZombie)
assert(taskSetManager1.runningTasks === 9)

val taskSet2 = FakeTask.createTaskSet(10, stageAttemptId = 1)
sched.submitTasks(taskSet2)
sched.resourceOffers(
(11 until 20).map { idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) })

// Complete the 2 tasks and leave 8 task in running
for (id <- Set(0, 1)) {
taskSetManager1.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id)))
assert(sched.endedTasks(id) === Success)
}

val taskSetManager2 = sched.taskSetManagerForAttempt(0, 1).get
assert(!taskSetManager2.successfulTaskDurations.isEmpty())
taskSetManager2.checkSpeculatableTasks(0)
}

private def createTaskResult(
id: Int,
accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): DirectTaskResult[Int] = {
Expand Down