Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
swap training and validation order
  • Loading branch information
holdenk committed Apr 10, 2014
commit 6ddbf05de0b0749cfc90f40e02a1864081aaa676
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ object MLUtils {

/**
* Return a k element array of pairs of RDDs with the first element of each pair
* containing the validation data, a unique 1/Kth of the data and the second
* element, the training data, contain the complement of that.
* containing the training data, a complement of the validation data and the second
* element, the validation data, containing a unique 1/kth of the data. Where k=numFolds.
*/
def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is natural to have training before validation in the return array because you always need to use the training set first.

val numFoldsF = numFolds.toFloat
Expand All @@ -189,7 +189,7 @@ object MLUtils {
complement = false)
val validation = new PartitionwiseSampledRDD(rdd, sampler, seed)
val training = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), seed)
(validation, training)
(training, validation)
}.toArray
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
for (seed <- 1 to 5) {
val foldedRdds = MLUtils.kFold(data, folds, seed)
assert(foldedRdds.size === folds)
foldedRdds.map { case (validation, training) =>
foldedRdds.map { case (training, validation) =>
val result = validation.union(training).collect().sorted
val validationSize = validation.collect().size.toFloat
assert(validationSize > 0, "empty validation data")
Expand Down