Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add read/write to StringIndexer
  • Loading branch information
mengxr committed Nov 18, 2015
commit 7286bbb56f8c58ff9428ab07cbeb4f7441ad147a
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@

package org.apache.spark.ml.feature

import org.apache.hadoop.fs.Path

import org.apache.spark.SparkException
import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model, Transformer}
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.util._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -64,7 +65,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
*/
@Experimental
class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel]
with StringIndexerBase {
with StringIndexerBase with Writable {

def this() = this(Identifiable.randomUID("strIdx"))

Expand Down Expand Up @@ -92,6 +93,19 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
}

override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra)

@Since("1.6.0")
override def write: Writer = new DefaultParamsWriter(this)
}

@Since("1.6.0")
object StringIndexer extends Readable[StringIndexer] {

@Since("1.6.0")
override def read: Reader[StringIndexer] = new DefaultParamsReader

@Since("1.6.0")
override def load(path: String): StringIndexer = super.load(path)
}

/**
Expand All @@ -107,7 +121,10 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
@Experimental
class StringIndexerModel (
override val uid: String,
val labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase {
val labels: Array[String])
extends Model[StringIndexerModel] with StringIndexerBase with Writable {

import StringIndexerModel._

def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels)

Expand Down Expand Up @@ -176,6 +193,49 @@ class StringIndexerModel (
val copied = new StringIndexerModel(uid, labels)
copyValues(copied, extra).setParent(parent)
}

@Since("1.6.0")
override def write: StringIndexModelWriter = new StringIndexModelWriter(this)
}

@Since("1.6.0")
object StringIndexerModel extends Readable[StringIndexerModel] {

private[StringIndexerModel]
class StringIndexModelWriter(instance: StringIndexerModel) extends Writer {

private case class Data(labels: Array[String])

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.labels)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).write.parquet(dataPath)
}
}

private class StringIndexerModelReader extends Reader[StringIndexerModel] {

private val className = "org.apache.spark.ml.feature.StringIndexerModel"

override def load(path: String): StringIndexerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath)
.select("labels")
.head()
val labels = data.getAs[Seq[String]](0).toArray
val model = new StringIndexerModel(metadata.uid, labels)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}

@Since("1.6.0")
override def read: Reader[StringIndexerModel] = new StringIndexerModelReader

@Since("1.6.0")
override def load(path: String): StringIndexerModel = super.load(path)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,23 @@ class StringIndexerSuite
assert(indexerModel.transform(df).eq(df))
}

test("StringIndexer read/write") {
val t = new StringIndexer()
.setInputCol("myInputCol")
.setOutputCol("myOutputCol")
.setHandleInvalid("skip")
testDefaultReadWrite(t)
}

test("StringIndexerModel read/write") {
val instance = new StringIndexerModel("myStringIndexerModel", Array("a", "b", "c"))
.setInputCol("myInputCol")
.setOutputCol("myOutputCol")
.setHandleInvalid("skip")
val newInstance = testDefaultReadWrite(instance)
assert(newInstance.labels === instance.labels)
}

test("IndexToString params") {
val idxToStr = new IndexToString()
ParamsSuite.checkParams(idxToStr)
Expand Down Expand Up @@ -175,7 +192,7 @@ class StringIndexerSuite
assert(outSchema("output").dataType === StringType)
}

test("read/write") {
test("IndexToString read/write") {
val t = new IndexToString()
.setInputCol("myInputCol")
.setOutputCol("myOutputCol")
Expand Down