Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add read/write support to ChiSqSelector, PCA, VectorIndexer
  • Loading branch information
yanboliang committed Nov 19, 2015
commit 0778a49d67162682b1caf376fcd4c9b47985a073
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@

package org.apache.spark.ml.feature

import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.Experimental
import org.apache.spark.ml._
import org.apache.spark.ml.attribute.{AttributeGroup, _}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.mllib.regression.LabeledPoint
Expand Down Expand Up @@ -60,7 +61,7 @@ private[feature] trait ChiSqSelectorParams extends Params
*/
@Experimental
final class ChiSqSelector(override val uid: String)
extends Estimator[ChiSqSelectorModel] with ChiSqSelectorParams {
extends Estimator[ChiSqSelectorModel] with ChiSqSelectorParams with DefaultParamsWritable {

def this() = this(Identifiable.randomUID("chiSqSelector"))

Expand Down Expand Up @@ -95,6 +96,11 @@ final class ChiSqSelector(override val uid: String)
override def copy(extra: ParamMap): ChiSqSelector = defaultCopy(extra)
}

object ChiSqSelector extends DefaultParamsReadable[ChiSqSelector] {

override def load(path: String): ChiSqSelector = super.load(path)
}

/**
* :: Experimental ::
* Model fitted by [[ChiSqSelector]].
Expand All @@ -103,7 +109,12 @@ final class ChiSqSelector(override val uid: String)
final class ChiSqSelectorModel private[ml] (
override val uid: String,
private val chiSqSelector: feature.ChiSqSelectorModel)
extends Model[ChiSqSelectorModel] with ChiSqSelectorParams {
extends Model[ChiSqSelectorModel] with ChiSqSelectorParams with MLWritable {

import ChiSqSelectorModel._

/** list of indices to select (filter). Must be ordered asc */
val selectedFeatures: Array[Int] = chiSqSelector.selectedFeatures

/** @group setParam */
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
Expand Down Expand Up @@ -147,4 +158,42 @@ final class ChiSqSelectorModel private[ml] (
val copied = new ChiSqSelectorModel(uid, chiSqSelector)
copyValues(copied, extra).setParent(parent)
}

override def write: MLWriter = new ChiSqSelectorModelWriter(this)
}

object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] {

private[ChiSqSelectorModel]
class ChiSqSelectorModelWriter(instance: ChiSqSelectorModel) extends MLWriter {

private case class Data(selectedFeatures: Seq[Int])

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.selectedFeatures.toSeq)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

private class ChiSqSelectorModelReader extends MLReader[ChiSqSelectorModel] {

private val className = classOf[ChiSqSelectorModel].getName

override def load(path: String): ChiSqSelectorModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath).select("selectedFeatures").head()
val selectedFeatures = data.getAs[Seq[Int]](0).toArray
val oldModel = new feature.ChiSqSelectorModel(selectedFeatures)
val model = new ChiSqSelectorModel(metadata.uid, oldModel)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}

override def read: MLReader[ChiSqSelectorModel] = new ChiSqSelectorModelReader

override def load(path: String): ChiSqSelectorModel = super.load(path)
}
59 changes: 55 additions & 4 deletions mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@

package org.apache.spark.ml.feature

import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.Experimental
import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.mllib.linalg._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StructField, StructType}
Expand All @@ -49,7 +51,8 @@ private[feature] trait PCAParams extends Params with HasInputCol with HasOutputC
* PCA trains a model to project vectors to a low-dimensional space using PCA.
*/
@Experimental
class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams {
class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
with DefaultParamsWritable {

def this() = this(Identifiable.randomUID("pca"))

Expand Down Expand Up @@ -86,6 +89,11 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
override def copy(extra: ParamMap): PCA = defaultCopy(extra)
}

object PCA extends DefaultParamsReadable[PCA] {

override def load(path: String): PCA = super.load(path)
}

/**
* :: Experimental ::
* Model fitted by [[PCA]].
Expand All @@ -94,7 +102,12 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
class PCAModel private[ml] (
override val uid: String,
pcaModel: feature.PCAModel)
extends Model[PCAModel] with PCAParams {
extends Model[PCAModel] with PCAParams with MLWritable {

import PCAModel._

/** a principal components Matrix. Each column is one principal component. */
val pc: DenseMatrix = pcaModel.pc

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
Expand Down Expand Up @@ -127,4 +140,42 @@ class PCAModel private[ml] (
val copied = new PCAModel(uid, pcaModel)
copyValues(copied, extra).setParent(parent)
}

override def write: MLWriter = new PCAModelWriter(this)
}

object PCAModel extends MLReadable[PCAModel] {

private[PCAModel] class PCAModelWriter(instance: PCAModel) extends MLWriter {

private case class Data(k: Int, pc: DenseMatrix)

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.getK, instance.pc)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

private class PCAModelReader extends MLReader[PCAModel] {

private val className = classOf[PCAModel].getName

override def load(path: String): PCAModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val Row(k: Int, pc: DenseMatrix) = sqlContext.read.parquet(dataPath)
.select("k", "pc")
.head()
val oldModel = new feature.PCAModel(k, pc)
val model = new PCAModel(metadata.uid, oldModel)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}

override def read: MLReader[PCAModel] = new PCAModelReader

override def load(path: String): PCAModel = super.load(path)
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@ package org.apache.spark.ml.feature
import java.lang.{Double => JDouble, Integer => JInt}
import java.util.{Map => JMap}

import org.apache.hadoop.fs.Path
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

organize imports


import scala.collection.JavaConverters._

import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators, Params}
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.udf
Expand Down Expand Up @@ -93,7 +95,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu
*/
@Experimental
class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerModel]
with VectorIndexerParams {
with VectorIndexerParams with DefaultParamsWritable {

def this() = this(Identifiable.randomUID("vecIdx"))

Expand Down Expand Up @@ -136,7 +138,9 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod
override def copy(extra: ParamMap): VectorIndexer = defaultCopy(extra)
}

private object VectorIndexer {
object VectorIndexer extends DefaultParamsReadable[VectorIndexer] {

override def load(path: String): VectorIndexer = super.load(path)

/**
* Helper class for tracking unique values for each feature.
Expand All @@ -146,7 +150,7 @@ private object VectorIndexer {
* @param numFeatures This class fails if it encounters a Vector whose length is not numFeatures.
* @param maxCategories This class caps the number of unique values collected at maxCategories.
*/
class CategoryStats(private val numFeatures: Int, private val maxCategories: Int)
private class CategoryStats(private val numFeatures: Int, private val maxCategories: Int)
extends Serializable {

/** featureValueSets[feature index] = set of unique values */
Expand Down Expand Up @@ -252,7 +256,9 @@ class VectorIndexerModel private[ml] (
override val uid: String,
val numFeatures: Int,
val categoryMaps: Map[Int, Map[Double, Int]])
extends Model[VectorIndexerModel] with VectorIndexerParams {
extends Model[VectorIndexerModel] with VectorIndexerParams with MLWritable {

import VectorIndexerModel._

/** Java-friendly version of [[categoryMaps]] */
def javaCategoryMaps: JMap[JInt, JMap[JDouble, JInt]] = {
Expand Down Expand Up @@ -408,4 +414,43 @@ class VectorIndexerModel private[ml] (
val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps)
copyValues(copied, extra).setParent(parent)
}

override def write: MLWriter = new VectorIndexerModelWriter(this)
}

object VectorIndexerModel extends MLReadable[VectorIndexerModel] {

private[VectorIndexerModel]
class VectorIndexerModelWriter(instance: VectorIndexerModel) extends MLWriter {

private case class Data(numFeatures: Int, categoryMaps: Map[Int, Map[Double, Int]])

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.numFeatures, instance.categoryMaps)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

private class VectorIndexerModelReader extends MLReader[VectorIndexerModel] {

private val className = classOf[VectorIndexerModel].getName

override def load(path: String): VectorIndexerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val Row(numFeatures: Int, categoryMaps: Map[Int, Map[Double, Int]]) =
sqlContext.read.parquet(dataPath)
.select("numFeatures", "categoryMaps")
.head()
val model = new VectorIndexerModel(metadata.uid, numFeatures, categoryMaps)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}

override def read: MLReader[VectorIndexerModel] = new VectorIndexerModelReader

override def load(path: String): VectorIndexerModel = super.load(path)
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@
package org.apache.spark.ml.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{Row, SQLContext}

class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
with DefaultReadWriteTest {

test("Test Chi-Square selector") {
val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._
Expand Down Expand Up @@ -58,4 +62,20 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(vec1 ~== vec2 absTol 1e-1)
}
}

test("ChiSqSelector read/write") {
val t = new ChiSqSelector()
.setFeaturesCol("myFeaturesCol")
.setLabelCol("myLabelCol")
.setOutputCol("myOutputCol")
.setNumTopFeatures(2)
testDefaultReadWrite(t)
}

test("ChiSqSelectorModel read/write") {
val oldModel = new feature.ChiSqSelectorModel(Array(1, 3))
val instance = new ChiSqSelectorModel("myChiSqSelectorModel", oldModel)
val newInstance = testDefaultReadWrite(instance)
assert(newInstance.selectedFeatures === instance.selectedFeatures)
}
}
34 changes: 31 additions & 3 deletions mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ package org.apache.spark.ml.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.distributed.RowMatrix
import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, Matrices}
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.mllib.feature.{PCAModel => OldPCAModel}
import org.apache.spark.sql.Row

class PCASuite extends SparkFunSuite with MLlibTestSparkContext {
class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

test("params") {
ParamsSuite.checkParams(new PCA)
Expand Down Expand Up @@ -65,4 +65,32 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext {
assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.")
}
}

test("PCA read/write") {
val t = new PCA()
.setInputCol("myInputCol")
.setOutputCol("myOutputCol")
.setK(3)
testDefaultReadWrite(t)
}

test("PCAModel read/write") {

def checkModelData(model1: PCAModel, model2: PCAModel): Unit = {
assert(model1.k === model2.k)
assert(model1.pc === model2.pc)
}

val data = Seq(
(0.0, Vectors.sparse(5, Seq((1, 1.0), (3, 7.0)))),
(1.0, Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)),
(2.0, Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0))
)
val df = sqlContext.createDataFrame(data).toDF("id", "features")
val pca = new PCA().setK(3)
val testParams: Map[String, Any] = Map("k" -> 3, "inputCol" -> "features",
"outputCol" -> "pca_features")

testEstimatorAndModelReadWrite(pca, df, testParams, checkModelData)
}
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same as StandardScalerModel, we can not construct PCAModel directly by specifying the variable k, so test estimator and model in one case with testEstimatorAndModelReadWrite.