Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,29 @@ private[feature] trait ChiSqSelectorParams extends Params

/** @group getParam */
def getNumTopFeatures: Int = $(numTopFeatures)

final val percentile = new DoubleParam(this, "percentile",
"Percentile of features that selector will select, ordered by statistics value descending.",
ParamValidators.gtEq(0))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it still okay when percentile is 0?

setDefault(percentile -> 10)

/** @group getParam */
def getPercentile: Double = $(percentile)

final val alpha = new DoubleParam(this, "alpha",
"The highest p-value for features to be kept.",
ParamValidators.gtEq(0))
setDefault(alpha -> 0.05)

/** @group getParam */
def getAlpha: Double = $(alpha)

final val selectorType = new Param[String](this, "selectorType",
"ChiSqSelector Type: KBest, Percentile, Fpr")
setDefault(selectorType -> "KBest")

/** @group getParam */
def getChiSqSelectorType: String = $(selectorType)
}

/**
Expand All @@ -67,9 +90,27 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str
@Since("1.6.0")
def this() = this(Identifiable.randomUID("chiSqSelector"))

@Since("2.1.0")
var chiSqSelector: feature.ChiSqSelector = null

/** @group setParam */
@Since("1.6.0")
def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value)
@Since("2.1.0")
def setNumTopFeatures(value: Int): this.type = {
set(selectorType, "KBest")
set(numTopFeatures, value)
}

@Since("2.1.0")
def setPercentile(value: Double): this.type = {
set(selectorType, "Percentile")
set(percentile, value)
}

@Since("2.1.0")
def setAlpha(value: Double): this.type = {
set(selectorType, "Fpr")
set(alpha, value)
}

/** @group setParam */
@Since("1.6.0")
Expand All @@ -91,8 +132,38 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str
case Row(label: Double, features: Vector) =>
OldLabeledPoint(label, OldVectors.fromML(features))
}
val chiSqSelector = new feature.ChiSqSelector($(numTopFeatures)).fit(input)
copyValues(new ChiSqSelectorModel(uid, chiSqSelector).setParent(this))
$(selectorType) match {
case "KBest" =>
chiSqSelector = new feature.ChiSqSelector().setNumTopFeatures($(numTopFeatures))
case "Percentile" =>
chiSqSelector = new feature.ChiSqSelector().setPercentile($(percentile))
case "Fpr" =>
chiSqSelector = new feature.ChiSqSelector().setAlpha($(alpha))
case _ => throw new Exception("Unknown ChiSqSelector Type.")
}
val model = chiSqSelector.fit(input)
copyValues(new ChiSqSelectorModel(uid, model).setParent(this))
}

@Since("2.1.0")
def selectKBest(value: Int): ChiSqSelectorModel = {
require(chiSqSelector != null, "ChiSqSelector has not been created.")
val model = chiSqSelector.selectKBest(value)
copyValues(new ChiSqSelectorModel(uid, model).setParent(this))
}

@Since("2.1.0")
def selectPercentile(value: Double): ChiSqSelectorModel = {
require(chiSqSelector != null, "ChiSqSelector has not been created.")
val model = chiSqSelector.selectPercentile(value)
copyValues(new ChiSqSelectorModel(uid, model).setParent(this))
}

@Since("2.1.0")
def selectFpr(value: Double): ChiSqSelectorModel = {
require(chiSqSelector != null, "ChiSqSelector has not been created.")
val model = chiSqSelector.selectFpr(value)
copyValues(new ChiSqSelectorModel(uid, model).setParent(this))
}

@Since("1.6.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,27 @@ import org.apache.spark.annotation.Since
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.stat.Statistics
import org.apache.spark.mllib.stat.test.ChiSqTestResult
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext
import org.apache.spark.sql.{Row, SparkSession}

@Since("2.1.0")
object ChiSqSelectorType extends Enumeration {
type SelectorType = Value
val KBest, Percentile, Fpr = Value
}

/**
* Chi Squared selector model.
*
* @param selectedFeatures list of indices to select (filter). Must be ordered asc
* @param selectedFeatures list of indices to select (filter).
*/
@Since("1.3.0")
class ChiSqSelectorModel @Since("1.3.0") (
@Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable {

require(isSorted(selectedFeatures), "Array has to be sorted asc")

protected def isSorted(array: Array[Int]): Boolean = {
var i = 1
val len = array.length
Expand All @@ -69,21 +74,22 @@ class ChiSqSelectorModel @Since("1.3.0") (
* Preserves the order of filtered features the same as their indices are stored.
* Might be moved to Vector as .slice
* @param features vector
* @param filterIndices indices of features to filter, must be ordered asc
* @param filterIndices indices of features to filter
*/
private def compress(features: Vector, filterIndices: Array[Int]): Vector = {
val orderedIndices = filterIndices.sorted
features match {
case SparseVector(size, indices, values) =>
val newSize = filterIndices.length
val newSize = orderedIndices.length
val newValues = new ArrayBuilder.ofDouble
val newIndices = new ArrayBuilder.ofInt
var i = 0
var j = 0
var indicesIdx = 0
var filterIndicesIdx = 0
while (i < indices.length && j < filterIndices.length) {
while (i < indices.length && j < orderedIndices.length) {
indicesIdx = indices(i)
filterIndicesIdx = filterIndices(j)
filterIndicesIdx = orderedIndices(j)
if (indicesIdx == filterIndicesIdx) {
newIndices += j
newValues += values(i)
Expand All @@ -101,7 +107,7 @@ class ChiSqSelectorModel @Since("1.3.0") (
Vectors.sparse(newSize, newIndices.result(), newValues.result())
case DenseVector(values) =>
val values = features.toArray
Vectors.dense(filterIndices.map(i => values(i)))
Vectors.dense(orderedIndices.map(i => values(i)))
case other =>
throw new UnsupportedOperationException(
s"Only sparse and dense vectors are supported but got ${other.getClass}.")
Expand Down Expand Up @@ -171,14 +177,47 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {

/**
* Creates a ChiSquared feature selector.
* @param numTopFeatures number of features that selector will select
* (ordered by statistic value descending)
* Note that if the number of features is less than numTopFeatures,
* then this will select all features.
*/
@Since("1.3.0")
class ChiSqSelector @Since("1.3.0") (
@Since("1.3.0") val numTopFeatures: Int) extends Serializable {
@Since("2.1.0")
class ChiSqSelector @Since("2.1.0") () extends Serializable {
private var numTopFeatures: Int = 50
private var percentile: Double = 10.0
private var alpha: Double = 0.05
private var selectorType = ChiSqSelectorType.KBest
private var chiSqTestResult: Array[ChiSqTestResult] = _

@Since("1.3.0")
def this(numTopFeatures: Int) {
this()
this.numTopFeatures = numTopFeatures
}

@Since("2.1.0")
def setNumTopFeatures(value: Int): this.type = {
numTopFeatures = value
selectorType = ChiSqSelectorType.KBest
this
}

@Since("2.1.0")
def setPercentile(value: Double): this.type = {
percentile = value
selectorType = ChiSqSelectorType.Percentile
this
}

@Since("2.1.0")
def setAlpha(value: Double): this.type = {
alpha = value
selectorType = ChiSqSelectorType.Fpr
this
}

@Since("2.1.0")
def setChiSqSelectorType(value: ChiSqSelectorType.Value): this.type = {
selectorType = value
this
}

/**
* Returns a ChiSquared feature selector.
Expand All @@ -189,11 +228,35 @@ class ChiSqSelector @Since("1.3.0") (
*/
@Since("1.3.0")
def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = {
val indices = Statistics.chiSqTest(data)
.zipWithIndex.sortBy { case (res, _) => -res.statistic }
.take(numTopFeatures)
.map { case (_, indices) => indices }
.sorted
chiSqTestResult = Statistics.chiSqTest(data)
selectorType match {
case ChiSqSelectorType.KBest => selectKBest(numTopFeatures)
case ChiSqSelectorType.Percentile => selectPercentile(percentile)
case ChiSqSelectorType.Fpr => selectFpr(alpha)
case _ => throw new Exception("Unknown ChiSqSelector Type")
}
}

@Since("2.1.0")
def selectKBest(value: Int): ChiSqSelectorModel = {
val indices = chiSqTestResult.zipWithIndex.sortBy { case (res, _) => -res.statistic }
.take(numTopFeatures)
.map { case (_, indices) => indices }
new ChiSqSelectorModel(indices)
}

@Since("2.1.0")
def selectPercentile(value: Double): ChiSqSelectorModel = {
val indices = chiSqTestResult.zipWithIndex.sortBy { case (res, _) => -res.statistic }
.take((chiSqTestResult.length * percentile / 100).toInt)
.map { case (_, indices) => indices }
new ChiSqSelectorModel(indices)
}

@Since("2.1.0")
def selectFpr(value: Double): ChiSqSelectorModel = {
val indices = chiSqTestResult.zipWithIndex.filter{ case (res, _) => res.pValue < alpha }
.map { case (_, indices) => indices }
new ChiSqSelectorModel(indices)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,23 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
.map(x => (x._1.label, x._1.features, x._2))
.toDF("label", "data", "preFilteredData")

val model = new ChiSqSelector()
val selector = new ChiSqSelector()
.setNumTopFeatures(1)
.setFeaturesCol("data")
.setLabelCol("label")
.setOutputCol("filtered")

model.fit(df).transform(df).select("filtered", "preFilteredData").collect().foreach {
selector.fit(df).transform(df).select("filtered", "preFilteredData").collect().foreach {
case Row(vec1: Vector, vec2: Vector) =>
assert(vec1 ~== vec2 absTol 1e-1)
}

selector.selectPercentile(34).transform(df)
.select("filtered", "preFilteredData").collect().foreach {
case Row(vec1: Vector, vec2: Vector) =>
assert(vec1 ~== vec2 absTol 1e-1)
}

}

test("ChiSqSelector read/write") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,24 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(filteredData == preFilteredData)
}

test("ChiSqSelector by FPR transform test (sparse & dense vector)") {
val labeledDiscreteData = sc.parallelize(
Seq(LabeledPoint(0.0, Vectors.sparse(4, Array((0, 8.0), (1, 7.0)))),
LabeledPoint(1.0, Vectors.sparse(4, Array((1, 9.0), (2, 6.0), (3, 4.0)))),
LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0, 4.0))),
LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0, 9.0)))), 2)
val preFilteredData =
Set(LabeledPoint(0.0, Vectors.dense(Array(0.0))),
LabeledPoint(1.0, Vectors.dense(Array(4.0))),
LabeledPoint(1.0, Vectors.dense(Array(4.0))),
LabeledPoint(2.0, Vectors.dense(Array(9.0))))
val model = new ChiSqSelector().setAlpha(0.1).fit(labeledDiscreteData)
val filteredData = labeledDiscreteData.map { lp =>
LabeledPoint(lp.label, model.transform(lp.features))
}.collect().toSet
assert(filteredData == preFilteredData)
}

test("model load / save") {
val model = ChiSqSelectorSuite.createModel()
val tempDir = Utils.createTempDir()
Expand Down