Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Hide details of getNarrowAncestors from outsiders
  • Loading branch information
andrewor14 committed Apr 23, 2014
commit 9d0e2b8da6ceb25a726c1b7bd1a24c848e4d7945
24 changes: 14 additions & 10 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -241,17 +241,21 @@ abstract class RDD[T: ClassTag](
* narrow dependencies. This traverses the given RDD's dependency tree using DFS, but maintains
* no ordering on the RDDs returned.
*/
private[spark] def getNarrowAncestors(
ancestors: mutable.Set[RDD[_]] = mutable.Set.empty): mutable.Set[RDD[_]] = {
val narrowDependencies = dependencies.filter(_.isInstanceOf[NarrowDependency[_]])
val narrowParents = narrowDependencies.map(_.rdd)
val narrowParentsNotVisited = narrowParents.filterNot(ancestors.contains)
narrowParentsNotVisited.foreach { parent =>
ancestors.add(parent)
parent.getNarrowAncestors(ancestors)
private[spark] def getNarrowAncestors: Seq[RDD[_]] = {
val ancestors = new mutable.HashSet[RDD[_]]

def visit(rdd: RDD[_]) {
val narrowDependencies = rdd.dependencies.filter(_.isInstanceOf[NarrowDependency[_]])
val narrowParents = narrowDependencies.map(_.rdd)
val narrowParentsNotVisited = narrowParents.filterNot(ancestors.contains)
narrowParentsNotVisited.foreach { parent =>
ancestors.add(parent)
visit(parent)
}
}
// In case there is a cycle, do not include the root itself
ancestors.filterNot(_ == this)

visit(this)
ancestors.filterNot(_ == this).toSeq
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ private[spark] object StageInfo {
* sequence of narrow dependencies should also be associated with this Stage.
*/
def fromStage(stage: Stage): StageInfo = {
val ancestorRddInfos = stage.rdd.getNarrowAncestors().map(RDDInfo.fromRdd)
val ancestorRddInfos = stage.rdd.getNarrowAncestors.map(RDDInfo.fromRdd)
val rddInfos = Seq(RDDInfo.fromRdd(stage.rdd)) ++ ancestorRddInfos
new StageInfo(stage.id, stage.name, stage.numTasks, rddInfos)
}
Expand Down
30 changes: 15 additions & 15 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -573,11 +573,11 @@ class RDDSuite extends FunSuite with SharedSparkContext {
val rdd3 = rdd2.map(_ - 1).filter(_ < 50).map(i => (i, i))
val rdd4 = rdd3.reduceByKey(_ + _)
val rdd5 = rdd4.mapValues(_ + 1).mapValues(_ + 2).mapValues(_ + 3)
val ancestors1 = rdd1.getNarrowAncestors()
val ancestors2 = rdd2.getNarrowAncestors()
val ancestors3 = rdd3.getNarrowAncestors()
val ancestors4 = rdd4.getNarrowAncestors()
val ancestors5 = rdd5.getNarrowAncestors()
val ancestors1 = rdd1.getNarrowAncestors
val ancestors2 = rdd2.getNarrowAncestors
val ancestors3 = rdd3.getNarrowAncestors
val ancestors4 = rdd4.getNarrowAncestors
val ancestors5 = rdd5.getNarrowAncestors

// Simple dependency tree with a single branch
assert(ancestors1.size === 0)
Expand Down Expand Up @@ -608,10 +608,10 @@ class RDDSuite extends FunSuite with SharedSparkContext {
val rdd7 = sc.union(rdd1, rdd2, rdd3)
val rdd8 = sc.union(rdd6, rdd7)
val rdd9 = rdd4.join(rdd5)
val ancestors6 = rdd6.getNarrowAncestors()
val ancestors7 = rdd7.getNarrowAncestors()
val ancestors8 = rdd8.getNarrowAncestors()
val ancestors9 = rdd9.getNarrowAncestors()
val ancestors6 = rdd6.getNarrowAncestors
val ancestors7 = rdd7.getNarrowAncestors
val ancestors8 = rdd8.getNarrowAncestors
val ancestors9 = rdd9.getNarrowAncestors

// Simple dependency tree with multiple branches
assert(ancestors6.size === 3)
Expand Down Expand Up @@ -649,8 +649,8 @@ class RDDSuite extends FunSuite with SharedSparkContext {
// Simple cyclical dependency
rdd1.addDependency(new OneToOneDependency[Int](rdd2))
rdd2.addDependency(new OneToOneDependency[Int](rdd1))
val ancestors1 = rdd1.getNarrowAncestors()
val ancestors2 = rdd2.getNarrowAncestors()
val ancestors1 = rdd1.getNarrowAncestors
val ancestors2 = rdd2.getNarrowAncestors
assert(ancestors1.size === 1)
assert(ancestors1.count(_ == rdd2) === 1)
assert(ancestors1.count(_ == rdd1) === 0)
Expand All @@ -660,8 +660,8 @@ class RDDSuite extends FunSuite with SharedSparkContext {

// Cycle involving a longer chain
rdd3.addDependency(new OneToOneDependency[Int](rdd4))
val ancestors3 = rdd3.getNarrowAncestors()
val ancestors4 = rdd4.getNarrowAncestors()
val ancestors3 = rdd3.getNarrowAncestors
val ancestors4 = rdd4.getNarrowAncestors
assert(ancestors3.size === 4)
assert(ancestors3.count(_.isInstanceOf[MappedRDD[_, _]]) === 2)
assert(ancestors3.count(_.isInstanceOf[FilteredRDD[_]]) === 2)
Expand All @@ -674,15 +674,15 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(ancestors4.count(_ == rdd4) === 0)

// Cycles that do not involve the root
val ancestors5 = rdd5.getNarrowAncestors()
val ancestors5 = rdd5.getNarrowAncestors
assert(ancestors5.size === 6)
assert(ancestors5.count(_.isInstanceOf[MappedRDD[_, _]]) === 3)
assert(ancestors5.count(_.isInstanceOf[FilteredRDD[_]]) === 2)
assert(ancestors5.count(_.isInstanceOf[CyclicalDependencyRDD[_]]) === 1)
assert(ancestors4.count(_ == rdd3) === 1)

// Complex cyclical dependency graph (combination of all of the above)
val ancestors6 = rdd6.getNarrowAncestors()
val ancestors6 = rdd6.getNarrowAncestors
assert(ancestors6.size === 12)
assert(ancestors6.count(_.isInstanceOf[UnionRDD[_]]) === 2)
assert(ancestors6.count(_.isInstanceOf[MappedRDD[_, _]]) === 4)
Expand Down