diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 71a219a4f341..85b6893245a2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -855,7 +855,7 @@ class DAGScheduler( private def submitMissingTasks(stage: Stage, jobId: Int) { logDebug("submitMissingTasks(" + stage + ")") // Get our pending tasks and remember them in our pendingTasks entry - stage.pendingTasks.clear() + stage.pendingPartitions.clear() // First figure out the indexes of partition ids to compute. val partitionsToCompute: Seq[Int] = { @@ -938,8 +938,8 @@ class DAGScheduler( if (tasks.size > 0) { logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") - stage.pendingTasks ++= tasks - logDebug("New pending tasks: " + stage.pendingTasks) + stage.pendingPartitions ++= tasks.map(_.partitionId) + logDebug("New pending partitions: " + stage.pendingPartitions) taskScheduler.submitTasks(new TaskSet( tasks.toArray, stage.id, stage.latestInfo.attemptId, stage.firstJobId, properties)) stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) @@ -1027,7 +1027,7 @@ class DAGScheduler( case Success => listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType, event.reason, event.taskInfo, event.taskMetrics)) - stage.pendingTasks -= task + stage.pendingPartitions -= task.partitionId task match { case rt: ResultTask[_, _] => // Cast to ResultStage here because it's part of the ResultTask @@ -1073,7 +1073,7 @@ class DAGScheduler( shuffleStage.addOutputLoc(smt.partitionId, status) } - if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) { + if (runningStages.contains(shuffleStage) && shuffleStage.pendingPartitions.isEmpty) { markStageAsFinished(shuffleStage) logInfo("looking for newly runnable stages") logInfo("running: " + runningStages) @@ -1126,7 +1126,7 @@ class DAGScheduler( case Resubmitted => logInfo("Resubmitted " + task + ", so marking it as still running") - stage.pendingTasks += task + stage.pendingPartitions += task.partitionId case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) => val failedStage = stageIdToStage(task.stageId) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index b86724de2cb7..f1c37fd82cfc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -60,7 +60,7 @@ private[spark] abstract class Stage( /** Set of jobs that this stage belongs to. */ val jobIds = new HashSet[Int] - var pendingTasks = new HashSet[Task[_]] + var pendingPartitions = new HashSet[Int] /** The ID to use for the next new attempt for this stage. */ private var nextAttemptId: Int = 0 diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 82455b0426a5..5b79cdaf94b0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -487,8 +487,8 @@ private[spark] class TaskSetManager( // a good proxy to task serialization time. // val timeTaken = clock.getTime() - startTime val taskName = s"task ${info.id} in stage ${taskSet.id}" - logInfo("Starting %s (TID %d, %s, %s, %d bytes)".format( - taskName, taskId, host, taskLocality, serializedTask.limit)) + logInfo(s"Starting $taskName (TID $taskId, $host, ${task.partitionId}," + + s"$taskLocality, ${serializedTask.limit} bytes)") sched.dagScheduler.taskStarted(task, info) return Some(new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId, diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 86728cb2b62a..f389ce869dec 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import org.apache.spark.shuffle.MetadataFetchFailedException + import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap, Map} import scala.language.reflectiveCalls import scala.util.control.NonFatal @@ -699,6 +701,7 @@ class DAGSchedulerSuite runEvent(ExecutorLost("exec-hostA")) val newEpoch = mapOutputTracker.getEpoch assert(newEpoch > oldEpoch) + val taskSet = taskSets(0) // should be ignored for being too old runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", @@ -739,6 +742,88 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + test("verify not submit next stage while not have registered mapStatus") { + val firstRDD = new MyRDD(sc, 3, Nil) + val firstShuffleDep = new ShuffleDependency(firstRDD, null) + val firstShuffleId = firstShuffleDep.shuffleId + val shuffleMapRdd = new MyRDD(sc, 3, List(firstShuffleDep)) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) + submit(reduceRdd, Array(0)) + + // things start out smoothly, stage 0 completes with no issues + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostB", shuffleMapRdd.partitions.size)), + (Success, makeMapStatus("hostB", shuffleMapRdd.partitions.size)), + (Success, makeMapStatus("hostA", shuffleMapRdd.partitions.size)) + )) + + // then one executor dies, and a task fails in stage 1 + runEvent(ExecutorLost("exec-hostA")) + runEvent(CompletionEvent(taskSets(1).tasks(0), + FetchFailed(null, firstShuffleId, 2, 0, "Fetch failed"), + null, null, createFakeTaskInfo(), null)) + + // so we resubmit stage 0, which completes happily + Thread.sleep(1000) + val stage0Resubmit = taskSets(2) + assert(stage0Resubmit.stageId == 0) + assert(stage0Resubmit.stageAttemptId === 1) + val task = stage0Resubmit.tasks(0) + assert(task.partitionId === 2) + runEvent(CompletionEvent(task, Success, + makeMapStatus("hostC", shuffleMapRdd.partitions.size), null, createFakeTaskInfo(), null)) + + // now here is where things get tricky : we will now have a task set representing + // the second attempt for stage 1, but we *also* have some tasks for the first attempt for + // stage 1 still going + val stage1Resubmit = taskSets(3) + assert(stage1Resubmit.stageId == 1) + assert(stage1Resubmit.stageAttemptId === 1) + assert(stage1Resubmit.tasks.length === 3) + + // we'll have some tasks finish from the first attempt, and some finish from the second attempt, + // so that we actually have all stage outputs, though no attempt has completed all its + // tasks + runEvent(CompletionEvent(taskSets(3).tasks(0), Success, + makeMapStatus("hostC", reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent(taskSets(3).tasks(1), Success, + makeMapStatus("hostC", reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) + // late task finish from the first attempt + runEvent(CompletionEvent(taskSets(1).tasks(2), Success, + makeMapStatus("hostB", reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) + + // What should happen now is that we submit stage 2. However, we might not see an error + // b/c of DAGScheduler's error handling (it tends to swallow errors and just log them). But + // we can check some conditions. + // Note that the really important thing here is not so much that we submit stage 2 *immediately* + // but that we don't end up with some error from these interleaved completions. It would also + // be OK (though sub-optimal) if stage 2 simply waited until the resubmission of stage 1 had + // all its tasks complete + + // check that we have all the map output for stage 0 (it should have been there even before + // the last round of completions from stage 1, but just to double check it hasn't been messed + // up) + (0 until 3).foreach { reduceIdx => + val arr = mapOutputTracker.getServerStatuses(0, reduceIdx) + assert(arr != null) + assert(arr.nonEmpty) + } + + // and check we have all the map output for stage 1 + (0 until 1).foreach { reduceIdx => + val arr = mapOutputTracker.getServerStatuses(1, reduceIdx) + assert(arr != null) + assert(arr.nonEmpty) + } + + // and check that stage 2 has been submitted + assert(taskSets.size == 5) + val stage2TaskSet = taskSets(4) + assert(stage2TaskSet.stageId == 2) + assert(stage2TaskSet.stageAttemptId == 0) + } + /** * Makes sure that failures of stage used by multiple jobs are correctly handled. * @@ -749,7 +834,7 @@ class DAGSchedulerSuite * | \ | * | \ | * | \ | - * reduceRdd1 reduceRdd2 + * reduceRdd1 reduceRddi2 * * We start both shuffleMapRdds and then fail shuffleMapRdd1. As a result, the job listeners for * reduceRdd1 and reduceRdd2 should both be informed that the job failed. shuffleMapRDD2 should