Skip to content
152 changes: 71 additions & 81 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,13 @@ class DAGScheduler(

private[scheduler] val jobIdToStageIds = new HashMap[Int, HashSet[Int]]
private[scheduler] val stageIdToStage = new HashMap[Int, Stage]
private[scheduler] val shuffleToMapStage = new HashMap[Int, ShuffleMapStage]
/**
* Mapping from shuffle dependency ID to the ShuffleMapStage that will generate the data for
* that dependency. Only includes stages that are part of currently running job (when the job(s)
* that require the shuffle stage complete, the mapping will be removed, and the only record of
* the shuffle data will be in the MapOutputTracker).
*/
private[scheduler] val shuffleIdToMapStage = new HashMap[Int, ShuffleMapStage]
private[scheduler] val jobIdToActiveJob = new HashMap[Int, ActiveJob]

// Stages we need to run whose parents aren't done
Expand Down Expand Up @@ -276,86 +282,55 @@ class DAGScheduler(
}

/**
* Get or create a shuffle map stage for the given shuffle dependency's map side.
* Gets a shuffle map stage if one exists in shuffleIdToMapStage. Otherwise, if the
* shuffle map stage doesn't already exist, this method will create the shuffle map stage in
* addition to any missing ancestor shuffle map stages.
*/
private def getShuffleMapStage(
private def getOrCreateShuffleMapStage(
shuffleDep: ShuffleDependency[_, _, _],
firstJobId: Int): ShuffleMapStage = {
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
shuffleIdToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) =>
stage

case None =>
// We are going to register ancestor shuffle dependencies
getAncestorShuffleDependencies(shuffleDep.rdd).foreach { dep =>
if (!shuffleToMapStage.contains(dep.shuffleId)) {
shuffleToMapStage(dep.shuffleId) = newOrUsedShuffleStage(dep, firstJobId)
// Create stages for all missing ancestor shuffle dependencies.
getMissingAncestorShuffleDependencies(shuffleDep.rdd).foreach { dep =>
// Even though getMissingAncestorShuffleDependencies only returns shuffle dependencies
// that were not already in shuffleIdToMapStage, it's possible that by the time we
// get to a particular dependency in the foreach loop, it's been added to
// shuffleIdToMapStage by the stage creation process for an earlier dependency. See
// SPARK-13902 for more information.
if (!shuffleIdToMapStage.contains(dep.shuffleId)) {
createShuffleMapStage(dep, firstJobId)
}
}
// Then register current shuffleDep
val stage = newOrUsedShuffleStage(shuffleDep, firstJobId)
shuffleToMapStage(shuffleDep.shuffleId) = stage
stage
// Finally, create a stage for the given shuffle dependency.
createShuffleMapStage(shuffleDep, firstJobId)
}
}

/**
* Helper function to eliminate some code re-use when creating new stages.
* Creates a ShuffleMapStage that generates the given shuffle dependency's partitions. If a
* previously run stage generated the same shuffle data, this function will copy the output
* locations that are still available from the previous shuffle to avoid unnecessarily
* regenerating data.
*/
private def getParentStagesAndId(rdd: RDD[_], firstJobId: Int): (List[Stage], Int) = {
val parentStages = getParentStages(rdd, firstJobId)
def createShuffleMapStage(shuffleDep: ShuffleDependency[_, _, _], jobId: Int): ShuffleMapStage = {
val rdd = shuffleDep.rdd
val numTasks = rdd.partitions.length
val parents = getOrCreateParentStages(rdd, jobId)
val id = nextStageId.getAndIncrement()
(parentStages, id)
}

/**
* Create a ShuffleMapStage as part of the (re)-creation of a shuffle map stage in
* newOrUsedShuffleStage. The stage will be associated with the provided firstJobId.
* Production of shuffle map stages should always use newOrUsedShuffleStage, not
* newShuffleMapStage directly.
*/
private def newShuffleMapStage(
rdd: RDD[_],
numTasks: Int,
shuffleDep: ShuffleDependency[_, _, _],
firstJobId: Int,
callSite: CallSite): ShuffleMapStage = {
val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, firstJobId)
val stage: ShuffleMapStage = new ShuffleMapStage(id, rdd, numTasks, parentStages,
firstJobId, callSite, shuffleDep)

stageIdToStage(id) = stage
updateJobIdStageIdMaps(firstJobId, stage)
stage
}
val stage = new ShuffleMapStage(id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep)

/**
* Create a ResultStage associated with the provided jobId.
*/
private def newResultStage(
rdd: RDD[_],
func: (TaskContext, Iterator[_]) => _,
partitions: Array[Int],
jobId: Int,
callSite: CallSite): ResultStage = {
val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId)
val stage = new ResultStage(id, rdd, func, partitions, parentStages, jobId, callSite)
stageIdToStage(id) = stage
shuffleIdToMapStage(shuffleDep.shuffleId) = stage
updateJobIdStageIdMaps(jobId, stage)
stage
}

/**
* Create a shuffle map Stage for the given RDD. The stage will also be associated with the
* provided firstJobId. If a stage for the shuffleId existed previously so that the shuffleId is
* present in the MapOutputTracker, then the number and location of available outputs are
* recovered from the MapOutputTracker
*/
private def newOrUsedShuffleStage(
shuffleDep: ShuffleDependency[_, _, _],
firstJobId: Int): ShuffleMapStage = {
val rdd = shuffleDep.rdd
val numTasks = rdd.partitions.length
val stage = newShuffleMapStage(rdd, numTasks, shuffleDep, firstJobId, rdd.creationSite)
if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
// A previously run stage generated partitions for this shuffle, so for each output
// that's still available, copy information about that output location to the new stage
// (so we don't unnecessarily re-compute that data).
val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId)
val locs = MapOutputTracker.deserializeMapStatuses(serLocs)
(0 until locs.length).foreach { i =>
Expand All @@ -373,18 +348,36 @@ class DAGScheduler(
stage
}

/**
* Create a ResultStage associated with the provided jobId.
*/
private def createResultStage(
rdd: RDD[_],
func: (TaskContext, Iterator[_]) => _,
partitions: Array[Int],
jobId: Int,
callSite: CallSite): ResultStage = {
val parents = getOrCreateParentStages(rdd, jobId)
val id = nextStageId.getAndIncrement()
val stage = new ResultStage(id, rdd, func, partitions, parents, jobId, callSite)
stageIdToStage(id) = stage
updateJobIdStageIdMaps(jobId, stage)
stage
}

/**
* Get or create the list of parent stages for a given RDD. The new Stages will be created with
* the provided firstJobId.
*/
private def getParentStages(rdd: RDD[_], firstJobId: Int): List[Stage] = {
private def getOrCreateParentStages(rdd: RDD[_], firstJobId: Int): List[Stage] = {
getShuffleDependencies(rdd).map { shuffleDep =>
getShuffleMapStage(shuffleDep, firstJobId)
getOrCreateShuffleMapStage(shuffleDep, firstJobId)
}.toList
}

/** Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet */
private def getAncestorShuffleDependencies(rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = {
private def getMissingAncestorShuffleDependencies(
rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = {
val ancestors = new Stack[ShuffleDependency[_, _, _]]
val visited = new HashSet[RDD[_]]
// We are manually maintaining a stack here to prevent StackOverflowError
Expand All @@ -396,7 +389,7 @@ class DAGScheduler(
if (!visited(toVisit)) {
visited += toVisit
getShuffleDependencies(toVisit).foreach { shuffleDep =>
if (!shuffleToMapStage.contains(shuffleDep.shuffleId)) {
if (!shuffleIdToMapStage.contains(shuffleDep.shuffleId)) {
ancestors.push(shuffleDep)
waitingForVisit.push(shuffleDep.rdd)
} // Otherwise, the dependency and its ancestors have already been registered.
Expand Down Expand Up @@ -453,7 +446,7 @@ class DAGScheduler(
for (dep <- rdd.dependencies) {
dep match {
case shufDep: ShuffleDependency[_, _, _] =>
val mapStage = getShuffleMapStage(shufDep, stage.firstJobId)
val mapStage = getOrCreateShuffleMapStage(shufDep, stage.firstJobId)
if (!mapStage.isAvailable) {
missing += mapStage
}
Expand Down Expand Up @@ -482,8 +475,7 @@ class DAGScheduler(
val s = stages.head
s.jobIds += jobId
jobIdToStageIds.getOrElseUpdate(jobId, new HashSet[Int]()) += s.id
val parents: List[Stage] = getParentStages(s.rdd, jobId)
val parentsWithoutThisJobId = parents.filter { ! _.jobIds.contains(jobId) }
val parentsWithoutThisJobId = s.parents.filter { ! _.jobIds.contains(jobId) }
updateJobIdStageIdMapsList(parentsWithoutThisJobId ++ stages.tail)
}
}
Expand Down Expand Up @@ -516,8 +508,8 @@ class DAGScheduler(
logDebug("Removing running stage %d".format(stageId))
runningStages -= stage
}
for ((k, v) <- shuffleToMapStage.find(_._2 == stage)) {
shuffleToMapStage.remove(k)
for ((k, v) <- shuffleIdToMapStage.find(_._2 == stage)) {
shuffleIdToMapStage.remove(k)
}
if (waitingStages.contains(stage)) {
logDebug("Removing stage %d from waiting set.".format(stageId))
Expand Down Expand Up @@ -843,7 +835,7 @@ class DAGScheduler(
try {
// New stage creation may throw an exception if, for example, jobs are run on a
// HadoopRDD whose underlying HDFS files have been deleted.
finalStage = newResultStage(finalRDD, func, partitions, jobId, callSite)
finalStage = createResultStage(finalRDD, func, partitions, jobId, callSite)
} catch {
case e: Exception =>
logWarning("Creating new stage failed due to exception - job: " + jobId, e)
Expand Down Expand Up @@ -881,7 +873,7 @@ class DAGScheduler(
try {
// New stage creation may throw an exception if, for example, jobs are run on a
// HadoopRDD whose underlying HDFS files have been deleted.
finalStage = getShuffleMapStage(dependency, jobId)
finalStage = getOrCreateShuffleMapStage(dependency, jobId)
} catch {
case e: Exception =>
logWarning("Creating new stage failed due to exception - job: " + jobId, e)
Expand Down Expand Up @@ -966,7 +958,6 @@ class DAGScheduler(
case s: ShuffleMapStage =>
partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap
case s: ResultStage =>
val job = s.activeJob.get
partitionsToCompute.map { id =>
val p = s.partitions(id)
(id, getPreferredLocs(stage.rdd, p))
Expand Down Expand Up @@ -1028,7 +1019,6 @@ class DAGScheduler(
}

case stage: ResultStage =>
val job = stage.activeJob.get
partitionsToCompute.map { id =>
val p: Int = stage.partitions(id)
val part = stage.rdd.partitions(p)
Expand Down Expand Up @@ -1244,7 +1234,7 @@ class DAGScheduler(

case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) =>
val failedStage = stageIdToStage(task.stageId)
val mapStage = shuffleToMapStage(shuffleId)
val mapStage = shuffleIdToMapStage(shuffleId)

if (failedStage.latestInfo.attemptId != task.stageAttemptId) {
logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" +
Expand Down Expand Up @@ -1334,14 +1324,14 @@ class DAGScheduler(

if (!env.blockManager.externalShuffleServiceEnabled || fetchFailed) {
// TODO: This will be really slow if we keep accumulating shuffle map stages
for ((shuffleId, stage) <- shuffleToMapStage) {
for ((shuffleId, stage) <- shuffleIdToMapStage) {
stage.removeOutputsOnExecutor(execId)
mapOutputTracker.registerMapOutputs(
shuffleId,
stage.outputLocInMapOutputTrackerFormat(),
changeEpoch = true)
}
if (shuffleToMapStage.isEmpty) {
if (shuffleIdToMapStage.isEmpty) {
mapOutputTracker.incrementEpoch()
}
clearCacheLocs()
Expand Down Expand Up @@ -1496,7 +1486,7 @@ class DAGScheduler(
for (dep <- rdd.dependencies) {
dep match {
case shufDep: ShuffleDependency[_, _, _] =>
val mapStage = getShuffleMapStage(shufDep, stage.firstJobId)
val mapStage = getOrCreateShuffleMapStage(shufDep, stage.firstJobId)
if (!mapStage.isAvailable) {
waitingForVisit.push(mapStage.rdd)
} // Otherwise there's no need to follow the dependency back
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,12 +360,12 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou

submit(rddD, Array(0))

assert(scheduler.shuffleToMapStage.size === 3)
assert(scheduler.shuffleIdToMapStage.size === 3)
assert(scheduler.activeJobs.size === 1)

val mapStageA = scheduler.shuffleToMapStage(s_A)
val mapStageB = scheduler.shuffleToMapStage(s_B)
val mapStageC = scheduler.shuffleToMapStage(s_C)
val mapStageA = scheduler.shuffleIdToMapStage(s_A)
val mapStageB = scheduler.shuffleIdToMapStage(s_B)
val mapStageC = scheduler.shuffleIdToMapStage(s_C)
val finalStage = scheduler.activeJobs.head.finalStage

assert(mapStageA.parents.isEmpty)
Expand Down Expand Up @@ -2072,7 +2072,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
assert(scheduler.jobIdToStageIds.isEmpty)
assert(scheduler.stageIdToStage.isEmpty)
assert(scheduler.runningStages.isEmpty)
assert(scheduler.shuffleToMapStage.isEmpty)
assert(scheduler.shuffleIdToMapStage.isEmpty)
assert(scheduler.waitingStages.isEmpty)
assert(scheduler.outputCommitCoordinator.isEmpty)
}
Expand Down