Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 75 additions & 8 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.io.{NotSerializableException, PrintWriter, StringWriter}
import java.util.Properties
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack}
import scala.concurrent.Await
import scala.concurrent.duration._
import scala.language.postfixOps
Expand Down Expand Up @@ -195,11 +195,15 @@ class DAGScheduler(
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
// We are going to register ancestor shuffle dependencies
registerShuffleDependencies(shuffleDep, jobId)
// Then register current shuffleDep
val stage =
newOrUsedStage(
shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId,
shuffleDep.rdd.creationSite)
shuffleToMapStage(shuffleDep.shuffleId) = stage

stage
}
}
Expand Down Expand Up @@ -265,6 +269,9 @@ class DAGScheduler(
private def getParentStages(rdd: RDD[_], jobId: Int): List[Stage] = {
val parents = new HashSet[Stage]
val visited = new HashSet[RDD[_]]
// We are manually maintaining a stack here to prevent StackOverflowError
// caused by recursively visiting
val waitingForVisit = new Stack[RDD[_]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a comment that we are manually maintaining a stack to prevent StackOverflowError

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Same on the other methods that use a stack)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I add a commit for that.

def visit(r: RDD[_]) {
if (!visited(r)) {
visited += r
Expand All @@ -275,18 +282,69 @@ class DAGScheduler(
case shufDep: ShuffleDependency[_, _, _] =>
parents += getShuffleMapStage(shufDep, jobId)
case _ =>
visit(dep.rdd)
waitingForVisit.push(dep.rdd)
}
}
}
}
visit(rdd)
waitingForVisit.push(rdd)
while (!waitingForVisit.isEmpty) {
visit(waitingForVisit.pop())
}
parents.toList
}

// Find ancestor missing shuffle dependencies and register into shuffleToMapStage
private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], jobId: Int) = {
val parentsWithNoMapStage = getAncestorShuffleDependencies(shuffleDep.rdd)
while (!parentsWithNoMapStage.isEmpty) {
val currentShufDep = parentsWithNoMapStage.pop()
val stage =
newOrUsedStage(
currentShufDep.rdd, currentShufDep.rdd.partitions.size, currentShufDep, jobId,
currentShufDep.rdd.creationSite)
shuffleToMapStage(currentShufDep.shuffleId) = stage
}
}

// Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet
private def getAncestorShuffleDependencies(rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = {
val parents = new Stack[ShuffleDependency[_, _, _]]
val visited = new HashSet[RDD[_]]
// We are manually maintaining a stack here to prevent StackOverflowError
// caused by recursively visiting
val waitingForVisit = new Stack[RDD[_]]
def visit(r: RDD[_]) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why define a function here? seems like this is only used once? why not just inline it in the while?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I let the codes as the function getParentShuffleDependencies because it contains multiple indents and so put it under the case statement would not be readable. I can make it as inline if this is an issue.

if (!visited(r)) {
visited += r
for (dep <- r.dependencies) {
dep match {
case shufDep: ShuffleDependency[_, _, _] =>
if (!shuffleToMapStage.contains(shufDep.shuffleId)) {
parents.push(shufDep)
}

waitingForVisit.push(shufDep.rdd)
case _ =>
waitingForVisit.push(dep.rdd)
}
}
}
}

waitingForVisit.push(rdd)
while (!waitingForVisit.isEmpty) {
visit(waitingForVisit.pop())
}
parents
}

private def getMissingParentStages(stage: Stage): List[Stage] = {
val missing = new HashSet[Stage]
val visited = new HashSet[RDD[_]]
// We are manually maintaining a stack here to prevent StackOverflowError
// caused by recursively visiting
val waitingForVisit = new Stack[RDD[_]]
def visit(rdd: RDD[_]) {
if (!visited(rdd)) {
visited += rdd
Expand All @@ -299,13 +357,16 @@ class DAGScheduler(
missing += mapStage
}
case narrowDep: NarrowDependency[_] =>
visit(narrowDep.rdd)
waitingForVisit.push(narrowDep.rdd)
}
}
}
}
}
visit(stage.rdd)
waitingForVisit.push(stage.rdd)
while (!waitingForVisit.isEmpty) {
visit(waitingForVisit.pop())
}
missing.toList
}

Expand Down Expand Up @@ -1099,6 +1160,9 @@ class DAGScheduler(
}
val visitedRdds = new HashSet[RDD[_]]
val visitedStages = new HashSet[Stage]
// We are manually maintaining a stack here to prevent StackOverflowError
// caused by recursively visiting
val waitingForVisit = new Stack[RDD[_]]
def visit(rdd: RDD[_]) {
if (!visitedRdds(rdd)) {
visitedRdds += rdd
Expand All @@ -1108,15 +1172,18 @@ class DAGScheduler(
val mapStage = getShuffleMapStage(shufDep, stage.jobId)
if (!mapStage.isAvailable) {
visitedStages += mapStage
visit(mapStage.rdd)
waitingForVisit.push(mapStage.rdd)
} // Otherwise there's no need to follow the dependency back
case narrowDep: NarrowDependency[_] =>
visit(narrowDep.rdd)
waitingForVisit.push(narrowDep.rdd)
}
}
}
}
visit(stage.rdd)
waitingForVisit.push(stage.rdd)
while (!waitingForVisit.isEmpty) {
visit(waitingForVisit.pop())
}
visitedRdds.contains(target.rdd)
}

Expand Down