-
Notifications
You must be signed in to change notification settings - Fork 29k
SPARK-1310: Start adding k-fold cross validation to MLLib [adds kFold to MLUtils & fixes bug in BernoulliSampler] #18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
a751ec6
08f8e4d
c0b7fa4
dd0b737
264502a
91eae64
b78804e
e8741a7
5a33f1d
163c5b1
bb5fa56
7ebe4d5
c5b723f
e187e35
c702a96
2cb90b3
150889c
90896c7
7157ae9
6ddbf05
e84f2fc
208db9b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,15 @@ import breeze.linalg.{Vector => BV, DenseVector => BDV, SparseVector => BSV, | |
|
|
||
| import org.apache.spark.SparkContext | ||
| import org.apache.spark.rdd.RDD | ||
| import scala.reflect._ | ||
|
|
||
| import org.apache.spark.SparkContext | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.rdd.PartitionwiseSampledRDD | ||
| import org.apache.spark.SparkContext._ | ||
| import org.apache.spark.util.random.BernoulliSampler | ||
|
|
||
| import org.jblas.DoubleMatrix | ||
| import org.apache.spark.mllib.regression.LabeledPoint | ||
| import org.apache.spark.mllib.linalg.{Vector, Vectors} | ||
| import org.apache.spark.mllib.regression.RegressionModel | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove unused imports |
||
|
|
@@ -176,6 +185,21 @@ object MLUtils { | |
| (a-b)*(a-b) | ||
| } | ||
|
|
||
| /** | ||
| * Return a k element list of pairs of RDDs with the first element of each pair | ||
| * containing a unique 1/Kth of the data and the second element contain the composite of that. | ||
| */ | ||
| def kFoldRdds[T : ClassTag](rdd: RDD[T], folds: Int, seed: Int): List[Pair[RDD[T], RDD[T]]] = { | ||
| val foldsF = folds.toFloat | ||
| 1.to(folds).map(fold => (( | ||
| new PartitionwiseSampledRDD(rdd, new BernoulliSampler[T]((fold-1)/foldsF,fold/foldsF, false), | ||
| seed), | ||
| new PartitionwiseSampledRDD(rdd, new BernoulliSampler[T]((fold-1)/foldsF,fold/foldsF, true), | ||
| seed) | ||
| ))).toList | ||
| } | ||
|
|
||
|
|
||
| /** | ||
| * Function to perform cross validation on a single learner. | ||
| * | ||
|
|
@@ -192,14 +216,14 @@ object MLUtils { | |
| if (folds <= 1) { | ||
| throw new IllegalArgumentException("Cross validation requires more than one fold") | ||
| } | ||
| val rdds = data.kFoldRdds(folds, seed) | ||
| val rdds = kFoldRdds(data, folds, seed) | ||
| val errorRates = rdds.map{case (testData, trainingData) => | ||
| val model = learner(trainingData) | ||
| val predictions = model.predict(testData.map(_.features)) | ||
| val errors = predictions.zip(testData.map(_.label)).map{case (x,y) => errorFunction(x,y)} | ||
| val predictions = testData.map(data => (data.label, model.predict(data.features))) | ||
| val errors = predictions.map{case (x, y) => errorFunction(x, y)} | ||
| errors.sum() | ||
| } | ||
| val averageError = errorRates.sum / data.count | ||
| val averageError = errorRates.sum / data.count.toFloat | ||
| averageError | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -43,7 +43,6 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { | |
| System.clearProperty("spark.driver.port") | ||
| } | ||
|
|
||
|
|
||
| test("epsilon computation") { | ||
| assert(1.0 + EPSILON > 1.0, s"EPSILON is too small: $EPSILON.") | ||
| assert(1.0 + EPSILON / 2.0 === 1.0, s"EPSILON is too big: $EPSILON.") | ||
|
|
@@ -137,20 +136,11 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { | |
| new LinearRegressionModel(Array(1.0), 0) | ||
| } | ||
|
|
||
| test("Test cross validation with a terrible learner") { | ||
| val data = sc.parallelize(1.to(100).zip(1.to(100))).map( | ||
| x => LabeledPoint(x._1, Array(x._2))) | ||
| val expectedError = 1.to(100).map(x => x*x).sum / 100.0 | ||
| for (seed <- 1 to 5) { | ||
| for (folds <- 2 to 5) { | ||
| val avgError = MLUtils.crossValidate(data, folds, seed, terribleLearner) | ||
| avgError should equal (expectedError) | ||
| } | ||
| } | ||
| } | ||
| test("Test cross validation with a reasonable learner") { | ||
| val data = sc.parallelize(1.to(100).zip(1.to(100))).map( | ||
| x => LabeledPoint(x._1, Array(x._2))) | ||
| val features = data.map(_.features) | ||
| val labels = data.map(_.label) | ||
| for (seed <- 1 to 5) { | ||
| for (folds <- 2 to 5) { | ||
| val avgError = MLUtils.crossValidate(data, folds, seed, exactLearner) | ||
|
|
@@ -163,8 +153,33 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { | |
| val data = sc.parallelize(1.to(100).zip(1.to(100))).map( | ||
| x => LabeledPoint(x._1, Array(x._2))) | ||
| val thrown = intercept[java.lang.IllegalArgumentException] { | ||
| val avgError = MLUtils.crossValidate(data, 1, 1, terribleLearner) | ||
| val avgError = MLUtils.crossValidate(data, 1, 1, exactLearner) | ||
| } | ||
| assert(thrown.getClass === classOf[IllegalArgumentException]) | ||
| } | ||
|
|
||
| test("kfoldRdd") { | ||
| val data = sc.parallelize(1 to 100, 2) | ||
| val collectedData = data.collect().sorted | ||
| val twoFoldedRdd = MLUtils.kFoldRdds(data, 2, 1) | ||
| assert(twoFoldedRdd(0)._1.collect().sorted === twoFoldedRdd(1)._2.collect().sorted) | ||
| assert(twoFoldedRdd(0)._2.collect().sorted === twoFoldedRdd(1)._1.collect().sorted) | ||
| for (folds <- 2 to 10) { | ||
| for (seed <- 1 to 5) { | ||
| val foldedRdds = MLUtils.kFoldRdds(data, folds, seed) | ||
| assert(foldedRdds.size === folds) | ||
| foldedRdds.map{case (test, train) => | ||
|
||
| val result = test.union(train).collect().sorted | ||
| assert(test.collect().size > 0, "Non empty test data") | ||
| assert(train.collect().size > 0, "Non empty training data") | ||
|
||
| assert(result === collectedData, | ||
| "Each training+test set combined contains all of the data") | ||
|
||
| } | ||
| // K fold cross validation should only have each element in the test set exactly once | ||
| assert(foldedRdds.map(_._1).reduce((x,y) => x.union(y)).collect().sorted === | ||
| data.collect().sorted) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please move this line to the block of breeze imports. And merge
sparkandspark.mllibimports in a single block.