diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 9cf722e12169..82a344b7677b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -37,6 +37,7 @@ private[feature] trait PCAParams extends Params with HasInputCol with HasOutputC /** * The number of principal components. + * * @group param */ final val k: IntParam = new IntParam(this, "k", "the number of principal components") @@ -44,6 +45,16 @@ private[feature] trait PCAParams extends Params with HasInputCol with HasOutputC /** @group getParam */ def getK: Int = $(k) + /** + * Minimal variance retained by principal components. + * + * @group param + */ + final val requiredVariance: DoubleParam = new DoubleParam(this, "requiredVariance", + "minimal variance retained by principal components") + + /** @group getParam */ + def getRequiredVariance: Double = $(requiredVariance) } /** @@ -63,7 +74,16 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams def setOutputCol(value: String): this.type = set(outputCol, value) /** @group setParam */ - def setK(value: Int): this.type = set(k, value) + def setK(value: Int): this.type = { + if (isSet(requiredVariance)) clear(requiredVariance) + set(k, value) + } + + /** @group setParam */ + def setRequiredVariance(value: Double): this.type = { + if (isSet(k)) clear(k) + set(requiredVariance, value) + } /** * Computes a [[PCAModel]] that contains the principal components of the input vectors. @@ -72,7 +92,11 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams override def fit(dataset: Dataset[_]): PCAModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v} - val pca = new feature.PCA(k = $(k)) + val pca = if (isSet(k)) { + new feature.PCA(k = $(k)) + } else { + new feature.PCA(requiredVariance = $(requiredVariance)) + } val pcaModel = pca.fit(input) copyValues(new PCAModel(uid, pcaModel.pc, pcaModel.explainedVariance).setParent(this)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala index 30c403e547be..d4a6904a4fff 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.feature +import java.util.Arrays + import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg._ @@ -31,7 +33,14 @@ import org.apache.spark.rdd.RDD @Since("1.4.0") class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) { require(k > 0, - s"Number of principal components must be positive but got ${k}") + s"Number of principal components must be positive but got $k") + + var pcFilter: Either[Int, Double] = Left(k) + + def this(requiredVariance: Double) = { + this(k = 1) + pcFilter = Right(requiredVariance) + } /** * Computes a [[PCAModel]] that contains the principal components of the input vectors. @@ -44,7 +53,7 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) { s"source vector size is ${sources.first().size} must be greater than k=$k") val mat = new RowMatrix(sources) - val (pc, explainedVariance) = mat.computePrincipalComponentsAndExplainedVariance(k) + val (pc, explainedVariance) = mat.computePrincipalComponentsAndExplainedVariance(pcFilter) val densePC = pc match { case dm: DenseMatrix => dm @@ -66,7 +75,7 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) { case sv: SparseVector => sv.toDense } - new PCAModel(k, densePC, denseExplainedVariance) + new PCAModel(explainedVariance.size, densePC, denseExplainedVariance) } /** @@ -109,4 +118,4 @@ class PCAModel private[spark] ( s"SparseVector or DenseVector. Instead got: ${vector.getClass}") } } -} +} \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index f6183a5eaadc..d2fc3a861b58 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -170,8 +170,7 @@ class RowMatrix @Since("1.0.0") ( * * @note The conditions that decide which method to use internally and the default parameters are * subject to change. - * - * @param k number of leading singular values to keep (0 < k <= n). + * @param k number of leading singular values to keep (0 < k <= n). * It might return less than k if * there are numerically zero singular values or there are not enough Ritz values * converged before the maximum number of Arnoldi update iterations is reached (in case @@ -321,7 +320,8 @@ class RowMatrix @Since("1.0.0") ( /** * Computes the covariance matrix, treating each row as an observation. Note that this cannot * be computed on matrices with more than 65535 columns. - * @return a local dense matrix of size n x n + * + * @return a local dense matrix of size n x n */ @Since("1.0.0") def computeCovariance(): Matrix = { @@ -379,15 +379,21 @@ class RowMatrix @Since("1.0.0") ( * * Note that this cannot be computed on matrices with more than 65535 columns. * - * @param k number of top principal components. + * @param filter either the number of top principal components or variance + * retained by the minimal set of principal components. * @return a matrix of size n-by-k, whose columns are principal components, and * a vector of values which indicate how much variance each principal component * explains */ @Since("1.6.0") - def computePrincipalComponentsAndExplainedVariance(k: Int): (Matrix, Vector) = { + def computePrincipalComponentsAndExplainedVariance(filter: Either[Int, Double]) + : (Matrix, Vector) = { val n = numCols().toInt - require(k > 0 && k <= n, s"k = $k out of range (0, n = $n]") + filter match { + case Left(k) => require(k > 0 && k <= n, s"k = $k out of range (0, n = $n]") + case Right(requiredVariance) => require(requiredVariance > 0.0 && requiredVariance <= 1.0, + s"requiredVariance = $requiredVariance out of range (0, 1.0]") + } val Cov = computeCovariance().toBreeze.asInstanceOf[BDM[Double]] @@ -396,6 +402,17 @@ class RowMatrix @Since("1.0.0") ( val eigenSum = s.data.sum val explainedVariance = s.data.map(_ / eigenSum) + val k = filter match { + case Left(k) => k + case Right(requiredVariance) => + val minFeatures = explainedVariance + .scanLeft(0.0)(_ + _) + .indexWhere(_ >= requiredVariance) + require(minFeatures > 0 && minFeatures <= n, s"minFeatures computed using " + + s"requiredVariance was $minFeatures and was out of range (0, n = $n]") + minFeatures + } + if (k == n) { (Matrices.dense(n, k, u.data), Vectors.dense(explainedVariance)) } else { @@ -413,7 +430,7 @@ class RowMatrix @Since("1.0.0") ( */ @Since("1.0.0") def computePrincipalComponents(k: Int): Matrix = { - computePrincipalComponentsAndExplainedVariance(k)._1 + computePrincipalComponentsAndExplainedVariance(Left(k))._1 } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index f372ec58269e..75781fc82edd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -28,6 +28,13 @@ import org.apache.spark.sql.Row class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + val data = Array( + Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ) + lazy val dataRDD = sc.parallelize(data, 2) + test("params") { ParamsSuite.checkParams(new PCA) val mat = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix] @@ -37,14 +44,6 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead } test("pca") { - val data = Array( - Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), - Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), - Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) - ) - - val dataRDD = sc.parallelize(data, 2) - val mat = new RowMatrix(dataRDD) val pc = mat.computePrincipalComponents(3) val expected = mat.multiply(pc).rows @@ -81,4 +80,25 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val newInstance = testDefaultReadWrite(instance) assert(newInstance.pc === instance.pc) } + + test("should return model with minimal number of features that retain given level of variance") { + // given + val df = sqlContext.createDataFrame(dataRDD.zipWithIndex()).toDF("features", "index") + + // when + val trimmed = new PCA() + .setInputCol("features") + .setOutputCol("pca_features") + .setRequiredVariance(0.9) + .fit(df) + + // then + val pcaWithExpectedK = new PCA() + .setInputCol("features") + .setOutputCol("pca_features") + .setK(2) + .fit(df) + assert(trimmed.explainedVariance === pcaWithExpectedK.explainedVariance) + assert(trimmed.pc === pcaWithExpectedK.pc) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala index a8d82932d390..e559d075fa30 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala @@ -37,7 +37,7 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext { val pca = new PCA(k).fit(dataRDD) val mat = new RowMatrix(dataRDD) - val (pc, explainedVariance) = mat.computePrincipalComponentsAndExplainedVariance(k) + val (pc, explainedVariance) = mat.computePrincipalComponentsAndExplainedVariance(Left(k)) val pca_transform = pca.transform(dataRDD).collect() val mat_multiply = mat.multiply(pc).rows.collect() @@ -45,4 +45,15 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext { assert(pca_transform.toSet === mat_multiply.toSet) assert(pca.explainedVariance === explainedVariance) } + + test("should return model with minimal number of features that retain given level of variance") { + // when + val trimmed = new PCA(requiredVariance = 0.90).fit(dataRDD) + + // then + val pcaWithExpectedK = new PCA(k = 2).fit(dataRDD) + assert(trimmed.k === pcaWithExpectedK.k) + assert(trimmed.explainedVariance === pcaWithExpectedK.explainedVariance) + assert(trimmed.pc === pcaWithExpectedK.pc) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index 2dff52c601d8..5f22f0131c84 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -204,7 +204,7 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { test("pca") { for (mat <- Seq(denseMat, sparseMat); k <- 1 to n) { - val (pc, expVariance) = mat.computePrincipalComponentsAndExplainedVariance(k) + val (pc, expVariance) = mat.computePrincipalComponentsAndExplainedVariance(Left(k)) assert(pc.numRows === n) assert(pc.numCols === k) assertColumnEqualUpToSign(pc.toBreeze.asInstanceOf[BDM[Double]], principalComponents, k)