Skip to content

Commit 8fc2ab2

Browse files
committed
fix col indexing bug and add a check for number of distinct values
1 parent 2fc8aca commit 8fc2ab2

File tree

3 files changed

+59
-17
lines changed

3 files changed

+59
-17
lines changed

mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ object Statistics {
155155
* :: Experimental ::
156156
* Conduct Pearson's independence test for every feature against the label across the input RDD.
157157
* For each feature, the (feature, label) pairs are converted into a contingency matrix for which
158-
* the chi-squared statistic is computed.
158+
* the chi-squared statistic is computed. All label and feature values must be categorical.
159159
*
160160
* @param data an `RDD[LabeledPoint]` containing the labeled dataset with categorical features.
161161
* Real-valued features will be treated as categorical for each distinct value.

mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@ package org.apache.spark.mllib.stat.test
2020
import breeze.linalg.{DenseMatrix => BDM}
2121
import cern.jet.stat.Probability.chiSquareComplemented
2222

23-
import org.apache.spark.Logging
23+
import org.apache.spark.{SparkException, Logging}
2424
import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors}
2525
import org.apache.spark.mllib.regression.LabeledPoint
2626
import org.apache.spark.rdd.RDD
2727

28+
import scala.collection.mutable
29+
2830
/**
2931
* Conduct the chi-squared test for the input RDDs using the specified method.
3032
* Goodness-of-fit test is conducted on two `Vectors`, whereas test of independence is conducted
@@ -75,21 +77,42 @@ private[stat] object ChiSqTest extends Logging {
7577
*/
7678
def chiSquaredFeatures(data: RDD[LabeledPoint],
7779
methodName: String = PEARSON.name): Array[ChiSqTestResult] = {
80+
val maxCategories = 10000
7881
val numCols = data.first().features.size
7982
val results = new Array[ChiSqTestResult](numCols)
8083
var labels: Map[Double, Int] = null
81-
// At most 100 columns at a time
82-
val batchSize = 100
84+
// at most 1000 columns at a time
85+
val batchSize = 1000
8386
var batch = 0
8487
while (batch * batchSize < numCols) {
8588
// The following block of code can be cleaned up and made public as
8689
// chiSquared(data: RDD[(V1, V2)])
8790
val startCol = batch * batchSize
8891
val endCol = startCol + math.min(batchSize, numCols - startCol)
89-
val pairCounts = data.flatMap { p =>
90-
// assume dense vectors
91-
p.features.toArray.slice(startCol, endCol).zipWithIndex.map { case (feature, col) =>
92-
(col, feature, p.label)
92+
val pairCounts = data.mapPartitions { iter =>
93+
val distinctLabels = mutable.HashSet.empty[Double]
94+
val allDistinctFeatures: Map[Int, mutable.HashSet[Double]] =
95+
Map((startCol until endCol).map(col => (col, mutable.HashSet.empty[Double])): _*)
96+
var i = 1
97+
iter.flatMap { case LabeledPoint(label, features) =>
98+
if (i % 1000 == 0) {
99+
if (distinctLabels.size > maxCategories) {
100+
throw new SparkException(s"Chi-square test expect factors (categorical values) but "
101+
+ s"found more than $maxCategories distinct label values.")
102+
}
103+
allDistinctFeatures.foreach { case (col, distinctFeatures) =>
104+
if (distinctFeatures.size > maxCategories) {
105+
throw new SparkException(s"Chi-square test expect factors (categorical values) but "
106+
+ s"found more than $maxCategories distinct values in column $col.")
107+
}
108+
}
109+
}
110+
i += 1
111+
distinctLabels += label
112+
features.toArray.view.zipWithIndex.slice(startCol, endCol).map { case (feature, col) =>
113+
allDistinctFeatures(col) += feature
114+
(col, feature, label)
115+
}
93116
}
94117
}.countByValue()
95118

mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@
1717

1818
package org.apache.spark.mllib.stat
1919

20+
import java.util.Random
21+
2022
import org.scalatest.FunSuite
2123

24+
import org.apache.spark.SparkException
2225
import org.apache.spark.mllib.linalg.{DenseVector, Matrices, Vectors}
2326
import org.apache.spark.mllib.regression.LabeledPoint
2427
import org.apache.spark.mllib.stat.test.ChiSqTest
@@ -107,12 +110,13 @@ class HypothesisTestSuite extends FunSuite with LocalSparkContext {
107110
// labels: 1.0 (2 / 6), 0.0 (4 / 6)
108111
// feature1: 0.5 (1 / 6), 1.5 (2 / 6), 3.5 (3 / 6)
109112
// feature2: 10.0 (1 / 6), 20.0 (1 / 6), 30.0 (2 / 6), 40.0 (2 / 6)
110-
val data = Array(new LabeledPoint(0.0, Vectors.dense(0.5, 10.0)),
111-
new LabeledPoint(0.0, Vectors.dense(1.5, 20.0)),
112-
new LabeledPoint(1.0, Vectors.dense(1.5, 30.0)),
113-
new LabeledPoint(0.0, Vectors.dense(3.5, 30.0)),
114-
new LabeledPoint(0.0, Vectors.dense(3.5, 40.0)),
115-
new LabeledPoint(1.0, Vectors.dense(3.5, 40.0)))
113+
val data = Seq(
114+
LabeledPoint(0.0, Vectors.dense(0.5, 10.0)),
115+
LabeledPoint(0.0, Vectors.dense(1.5, 20.0)),
116+
LabeledPoint(1.0, Vectors.dense(1.5, 30.0)),
117+
LabeledPoint(0.0, Vectors.dense(3.5, 30.0)),
118+
LabeledPoint(0.0, Vectors.dense(3.5, 40.0)),
119+
LabeledPoint(1.0, Vectors.dense(3.5, 40.0)))
116120
for (numParts <- List(2, 4, 6, 8)) {
117121
val chi = Statistics.chiSqTest(sc.parallelize(data, numParts))
118122
val feature1 = chi(0)
@@ -130,10 +134,25 @@ class HypothesisTestSuite extends FunSuite with LocalSparkContext {
130134
}
131135

132136
// Test that the right number of results is returned
133-
val numCols = 321
134-
val sparseData = Array(new LabeledPoint(0.0, Vectors.sparse(numCols, Seq((100, 2.0)))),
135-
new LabeledPoint(0.0, Vectors.sparse(numCols, Seq((200, 1.0)))))
137+
val numCols = 1001
138+
val sparseData = Array(
139+
new LabeledPoint(0.0, Vectors.sparse(numCols, Seq((100, 2.0)))),
140+
new LabeledPoint(0.1, Vectors.sparse(numCols, Seq((200, 1.0)))))
136141
val chi = Statistics.chiSqTest(sc.parallelize(sparseData))
137142
assert(chi.size === numCols)
143+
assert(chi(1000) != null) // SPARK-3087
144+
145+
// Detect continous features or labels
146+
val random = new Random(11L)
147+
val continuousLabel =
148+
Seq.fill(100000)(LabeledPoint(random.nextDouble(), Vectors.dense(random.nextInt(2))))
149+
intercept[SparkException] {
150+
Statistics.chiSqTest(sc.parallelize(continuousLabel, 2))
151+
}
152+
val continuousFeature =
153+
Seq.fill(100000)(LabeledPoint(random.nextInt(2), Vectors.dense(random.nextDouble())))
154+
intercept[SparkException] {
155+
Statistics.chiSqTest(sc.parallelize(continuousFeature, 2))
156+
}
138157
}
139158
}

0 commit comments

Comments
 (0)