Skip to content
Prev Previous commit
Next Next commit
Cleaned up the code to create new shuffle map stages
  • Loading branch information
kayousterhout committed Jun 15, 2016
commit 1468b91cc3ca493bc18e16525fe39444616f2bb2
62 changes: 22 additions & 40 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -287,46 +287,16 @@ class DAGScheduler(
// We are going to register ancestor shuffle dependencies
getAncestorShuffleDependencies(shuffleDep.rdd).foreach { dep =>
if (!shuffleIdToMapStage.contains(dep.shuffleId)) {
shuffleIdToMapStage(dep.shuffleId) = newOrUsedShuffleStage(dep, firstJobId)
shuffleIdToMapStage(dep.shuffleId) = newShuffleMapStage(dep, firstJobId)
}
}
// Then register current shuffleDep
val stage = newOrUsedShuffleStage(shuffleDep, firstJobId)
val stage = newShuffleMapStage(shuffleDep, firstJobId)
shuffleIdToMapStage(shuffleDep.shuffleId) = stage
stage
}
}

/**
* Helper function to eliminate some code re-use when creating new stages.
*/
private def getParentStagesAndId(rdd: RDD[_], firstJobId: Int): (List[Stage], Int) = {
val parentStages = getParentStages(rdd, firstJobId)
val id = nextStageId.getAndIncrement()
(parentStages, id)
}

/**
* Create a ShuffleMapStage as part of the (re)-creation of a shuffle map stage in
* newOrUsedShuffleStage. The stage will be associated with the provided firstJobId.
* Production of shuffle map stages should always use newOrUsedShuffleStage, not
* newShuffleMapStage directly.
*/
private def newShuffleMapStage(
rdd: RDD[_],
numTasks: Int,
shuffleDep: ShuffleDependency[_, _, _],
firstJobId: Int,
callSite: CallSite): ShuffleMapStage = {
val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, firstJobId)
val stage: ShuffleMapStage = new ShuffleMapStage(id, rdd, numTasks, parentStages,
firstJobId, callSite, shuffleDep)

stageIdToStage(id) = stage
updateJobIdStageIdMaps(firstJobId, stage)
stage
}

/**
* Create a ResultStage associated with the provided jobId.
*/
Expand All @@ -336,26 +306,38 @@ class DAGScheduler(
partitions: Array[Int],
jobId: Int,
callSite: CallSite): ResultStage = {
val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId)
val stage = new ResultStage(id, rdd, func, partitions, parentStages, jobId, callSite)
val id = nextStageId.getAndIncrement()
val stage = new ResultStage(
id, rdd, func, partitions, getParentStages(rdd, jobId), jobId, callSite)
stageIdToStage(id) = stage
updateJobIdStageIdMaps(jobId, stage)
stage
}

/**
* Create a shuffle map Stage for the given RDD. The stage will also be associated with the
* provided firstJobId. If a stage for the shuffleId existed previously so that the shuffleId is
* present in the MapOutputTracker, then the number and location of available outputs are
* recovered from the MapOutputTracker
* Creates a shuffle map stage to generate the given ShuffleDependency. If the given
* ShuffleDependency has already been generated by a past stage, the new shuffle map
* stage will copy output locations from the previous stage, so that tasks won't be launched
* to generate data that already exists (the MapOutputTracker is used to determine what
* previously-generated data is still available).
*
* The newly-created stage will be associated with the provided firstJobId.
*/
private def newOrUsedShuffleStage(
private def newShuffleMapStage(
shuffleDep: ShuffleDependency[_, _, _],
firstJobId: Int): ShuffleMapStage = {
val rdd = shuffleDep.rdd
val numTasks = rdd.partitions.length
val stage = newShuffleMapStage(rdd, numTasks, shuffleDep, firstJobId, rdd.creationSite)
val id = nextStageId.getAndIncrement()
val stage = new ShuffleMapStage(
id, rdd, numTasks, getParentStages(rdd, firstJobId), firstJobId, rdd.creationSite, shuffleDep)

stageIdToStage(id) = stage
updateJobIdStageIdMaps(firstJobId, stage)
if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
// A previous stage generated partitions for this shuffle, so for each output that's still
// available, copy information about that output location to the new stage (so we don't
// unnecessarily re-compute that data).
val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId)
val locs = MapOutputTracker.deserializeMapStatuses(serLocs)
(0 until locs.length).foreach { i =>
Expand Down