Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add read/write to IDF
  • Loading branch information
mengxr committed Nov 18, 2015
commit 60b6a101c0ddb90257d507b86c51d4f24e2b9e5b
71 changes: 67 additions & 4 deletions mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

package org.apache.spark.ml.feature

import org.apache.spark.annotation.Experimental
import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql._
Expand Down Expand Up @@ -60,7 +62,7 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol
* Compute the Inverse Document Frequency (IDF) given a collection of documents.
*/
@Experimental
final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase {
final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase with Writable {

def this() = this(Identifiable.randomUID("idf"))

Expand All @@ -85,6 +87,19 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa
}

override def copy(extra: ParamMap): IDF = defaultCopy(extra)

@Since("1.6.0")
override def write: Writer = new DefaultParamsWriter(this)
}

@Since("1.6.0")
object IDF extends Readable[IDF] {

@Since("1.6.0")
override def read: Reader[IDF] = new DefaultParamsReader

@Since("1.6.0")
override def load(path: String): IDF = super.load(path)
}

/**
Expand All @@ -95,7 +110,9 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa
class IDFModel private[ml] (
override val uid: String,
idfModel: feature.IDFModel)
extends Model[IDFModel] with IDFBase {
extends Model[IDFModel] with IDFBase with Writable {

import IDFModel._

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
Expand All @@ -117,4 +134,50 @@ class IDFModel private[ml] (
val copied = new IDFModel(uid, idfModel)
copyValues(copied, extra).setParent(parent)
}

/** Returns the IDF vector. */
@Since("1.6.0")
def idf: Vector = idfModel.idf

@Since("1.6.0")
override def write: Writer = new IDFModelWriter(this)
}

@Since("1.6.0")
object IDFModel extends Readable[IDFModel] {

private[IDFModel] class IDFModelWriter(instance: IDFModel) extends Writer {

private case class Data(idf: Vector)

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.idf)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

private class IDFModelReader extends Reader[IDFModel] {

private val className = "org.apache.spark.ml.feature.StringIndexerModel"

override def load(path: String): IDFModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath)
.select("idf")
.head()
val idf = data.getAs[Vector](0)
val model = new IDFModel(metadata.uid, new feature.IDFModel(idf))
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}

@Since("1.6.0")
override def read: Reader[IDFModel] = new IDFModelReader

@Since("1.6.0")
override def load(path: String): IDFModel = super.load(path)
}
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ object StringIndexerModel extends Readable[StringIndexerModel] {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.labels)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).write.parquet(dataPath)
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}

Expand Down
19 changes: 18 additions & 1 deletion mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ package org.apache.spark.ml.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.Row

class IDFSuite extends SparkFunSuite with MLlibTestSparkContext {
class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {

def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = {
dataSet.map {
Expand Down Expand Up @@ -98,4 +99,20 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.")
}
}

test("IDF read/write") {
val t = new IDF()
.setInputCol("myInputCol")
.setOutputCol("myOutputCol")
.setMinDocFreq(5)
testDefaultReadWrite(t)
}

test("IDFModel read/write") {
val instance = new IDFModel("myIDFModel", new OldIDFModel(Vectors.dense(1.0, 2.0)))
.setInputCol("myInputCol")
.setOutputCol("myOutputCol")
val newInstance = testDefaultReadWrite(instance)
assert(newInstance.idf === instance.idf)
}
}