Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
SPARK-14661: PCA trimmed by variance
  • Loading branch information
Piotr Suszyński committed Apr 15, 2016
commit ac1b9fe9c920171b6baf1a3bb12a4c7ea7d79ec1
19 changes: 19 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.mllib.feature

import java.util.Arrays

import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg._
Expand Down Expand Up @@ -109,4 +111,21 @@ class PCAModel private[spark] (
s"SparseVector or DenseVector. Instead got: ${vector.getClass}")
}
}

def minimalByVarianceExplained(requiredVarianceRetained: Double): PCAModel = {
val minFeaturesNum = explainedVariance
Copy link
Member

@srowen srowen Apr 15, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about explainedVariance.values.scanLeft(0.0)(_ + _).indexWhere(_ >= requiredVarianceRetained) + 1. Eh, OK to make that robust you'd have to handle the case where that returns 0 (means you need to keep all the PCs, so, could just return this), and also arg-check the required variance to be in [0,1].

.values.zipWithIndex
.foldLeft((0.0, 0)) { case ((varianceSum, bestIndex), (variance, index)) =>
if (varianceSum >= requiredVarianceRetained) {
(varianceSum, bestIndex)
} else {
(varianceSum + variance, index)
}
}._2 + 1
val trimmedPc = Arrays.copyOfRange(pc.values, 0, pc.numRows * minFeaturesNum)
val trimmedExplainedVariance = Arrays.copyOfRange(explainedVariance.values, 0, minFeaturesNum)
new PCAModel(minFeaturesNum,
Matrices.dense(pc.numRows, minFeaturesNum, trimmedPc).asInstanceOf[DenseMatrix],
Vectors.dense(trimmedExplainedVariance).asInstanceOf[DenseVector])
}
}
14 changes: 14 additions & 0 deletions mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,18 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext {
assert(pca_transform.toSet === mat_multiply.toSet)
assert(pca.explainedVariance === explainedVariance)
}

test("should return model with minimal number of features that retain given level of variance") {
// given
val pca = new PCA(4).fit(dataRDD)

// when
val trimmed = pca.minimalByVarianceExplained(0.90)

// then
val pcaWithExpectedK = new PCA(2).fit(dataRDD)
assert(trimmed.k === pcaWithExpectedK.k)
assert(trimmed.explainedVariance === pcaWithExpectedK.explainedVariance)
assert(trimmed.pc === pcaWithExpectedK.pc)
}
}