Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Switch FoldedRDD to use BernoulliSampler and PartitionwiseSampledRDD
  • Loading branch information
holdenk committed Apr 9, 2014
commit c0b7fa4d06dec185cc1695121411c709b904967c
33 changes: 7 additions & 26 deletions core/src/main/scala/org/apache/spark/rdd/FoldedRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import cern.jet.random.Poisson
import cern.jet.random.engine.DRand

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.util.random.BernoulliSampler

private[spark]
class FoldedRDDPartition(val prev: Partition, val seed: Int) extends Partition with Serializable {
Expand All @@ -32,24 +33,10 @@ class FoldedRDDPartition(val prev: Partition, val seed: Int) extends Partition w

class FoldedRDD[T: ClassTag](
prev: RDD[T],
fold: Int,
folds: Int,
fold: Float,
folds: Float,
seed: Int)
extends RDD[T](prev) {

override def getPartitions: Array[Partition] = {
val rg = new Random(seed)
firstParent[T].partitions.map(x => new FoldedRDDPartition(x, rg.nextInt))
}

override def getPreferredLocations(split: Partition): Seq[String] =
firstParent[T].preferredLocations(split.asInstanceOf[FoldedRDDPartition].prev)

override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = {
val split = splitIn.asInstanceOf[FoldedRDDPartition]
val rand = new Random(split.seed)
firstParent[T].iterator(split.prev, context).filter(x => (rand.nextInt(folds) == fold-1))
}
extends PartitionwiseSampledRDD[T, T](prev, new BernoulliSampler((fold-1)/folds,fold/folds, false), seed) {
}

/**
Expand All @@ -58,14 +45,8 @@ class FoldedRDD[T: ClassTag](
*/
class CompositeFoldedRDD[T: ClassTag](
prev: RDD[T],
fold: Int,
folds: Int,
fold: Float,
folds: Float,
seed: Int)
extends FoldedRDD[T](prev, fold, folds, seed) {

override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = {
val split = splitIn.asInstanceOf[FoldedRDDPartition]
val rand = new Random(split.seed)
firstParent[T].iterator(split.prev, context).filter(x => (rand.nextInt(folds) != fold-1))
}
extends PartitionwiseSampledRDD[T, T](prev, new BernoulliSampler((fold-1)/folds, fold/folds, true), seed) {
}
18 changes: 16 additions & 2 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -513,14 +513,28 @@ class RDDSuite extends FunSuite with SharedSparkContext {
}
}

test("FoldedRDD") {
val data = sc.parallelize(1 to 100, 2)
val lowerFoldedRdd = new FoldedRDD(data, 1, 2, 1)
val upperFoldedRdd = new FoldedRDD(data, 2, 2, 1)
val lowerCompositeFoldedRdd = new CompositeFoldedRDD(data, 1, 2, 1)
assert(lowerFoldedRdd.collect().sorted.size == 50)
assert(lowerCompositeFoldedRdd.collect().sorted.size == 50)
assert(lowerFoldedRdd.subtract(lowerCompositeFoldedRdd).collect().sorted ===
lowerFoldedRdd.collect().sorted)
assert(upperFoldedRdd.collect().sorted.size == 50)
}

test("kfoldRdd") {
val data = sc.parallelize(1 to 100, 2)
for (folds <- 1 to 10) {
val collectedData = data.collect().sorted
for (folds <- 2 to 10) {
for (seed <- 1 to 5) {
val foldedRdds = data.kFoldRdds(folds, seed)
assert(foldedRdds.size === folds)
foldedRdds.map{case (test, train) =>
assert(test.union(train).collect().sorted === data.collect().sorted,
val result = test.union(train).collect().sorted
assert(result === collectedData,
"Each training+test set combined contains all of the data")
}
// K fold cross validation should only have each element in the test set exactly once
Expand Down