Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Infer label names automatically
  • Loading branch information
sryza committed May 5, 2015
commit f383250d6342c13b28617d05d0d597f5a6bee814
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.spark.ml.feature

import org.apache.spark.SparkException
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute}
import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
Expand Down Expand Up @@ -47,45 +48,42 @@ class OneHotEncoder extends UnaryTransformer[Double, Vector, OneHotEncoder]
new BooleanParam(this, "includeFirst", "include first category")
setDefault(includeFirst -> true)

/**
* The names of the categories. Used to identify them in the attributes of the output column.
* This is a required parameter.
* @group param
*/
final val labelNames: Param[Array[String]] =
new Param[Array[String]](this, "labelNames", "categorical label names")
private var categories: Array[String] = _

/** @group setParam */
def setIncludeFirst(value: Boolean): this.type = set(includeFirst, value)

/** @group setParam */
def setLabelNames(attr: NominalAttribute): this.type = set(labelNames, attr.values.get)

/** @group setParam */
override def setInputCol(value: String): this.type = set(inputCol, value)

/** @group setParam */
override def setOutputCol(value: String): this.type = set(outputCol, value)

override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = extractParamMap(paramMap)
SchemaUtils.checkColumnType(schema, map(inputCol), DoubleType)
override def transformSchema(schema: StructType): StructType = {
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
val inputFields = schema.fields
val outputColName = map(outputCol)
require(inputFields.forall(_.name != outputColName),
s"Output column $outputColName already exists.")
require(map.contains(labelNames), "OneHotEncoder missing category names")
val categories = map(labelNames)
val attrValues = (if (map(includeFirst)) categories else categories.drop(1)).toArray
val outputColName = $(outputCol)
require(inputFields.forall(_.name != $(outputCol)),
s"Output column ${$(outputCol)} already exists.")

val inputColAttr = Attribute.fromStructField(schema($(inputCol)))
categories = inputColAttr match {
case nominal: NominalAttribute =>
nominal.values.getOrElse((0 until nominal.numValues.get).map(_.toString).toArray)
case binary: BinaryAttribute => binary.values.getOrElse(Array("0", "1"))
case _ =>
throw new SparkException(s"OneHotEncoder input column ${$(inputCol)} is not nominal")
}

val attrValues = (if ($(includeFirst)) categories else categories.drop(1)).toArray
val attr = NominalAttribute.defaultAttr.withName(outputColName).withValues(attrValues)
val outputFields = inputFields :+ attr.toStructField()
StructType(outputFields)
}

protected def createTransformFunc(paramMap: ParamMap): (Double) => Vector = {
val map = extractParamMap(paramMap)
val first = map(includeFirst)
val vecLen = if (first) map(labelNames).length else map(labelNames).length - 1
protected override def createTransformFunc(): (Double) => Vector = {
val first = $(includeFirst)
val vecLen = if (first) categories.length else categories.length - 1
val oneValue = Array(1.0)
val emptyValues = Array[Double]()
val emptyIndices = Array[Int]()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@

package org.apache.spark.ml.feature

import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.scalatest.FunSuite

import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLlibTestSparkContext

import org.apache.spark.sql.{DataFrame, SQLContext}

import org.scalatest.FunSuite

class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
private var sqlContext: SQLContext = _
Expand All @@ -33,23 +32,19 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
sqlContext = new SQLContext(sc)
}

def stringIndexed(): (DataFrame, NominalAttribute) = {
def stringIndexed(): DataFrame = {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label")
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
.fit(df)
val transformed = indexer.transform(df)
val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
.asInstanceOf[NominalAttribute]
(transformed, attr)
indexer.transform(df)
}

test("OneHotEncoder includeFirst = true") {
val (transformed, attr) = stringIndexed()
val transformed = stringIndexed()
val encoder = new OneHotEncoder()
.setLabelNames(attr)
.setInputCol("labelIndex")
.setOutputCol("labelVec")
val encoded = encoder.transform(transformed)
Expand All @@ -65,10 +60,9 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
}

test("OneHotEncoder includeFirst = false") {
val (transformed, attr) = stringIndexed()
val transformed = stringIndexed()
val encoder = new OneHotEncoder()
.setIncludeFirst(false)
.setLabelNames(attr)
.setInputCol("labelIndex")
.setOutputCol("labelVec")
val encoded = encoder.transform(transformed)
Expand Down