-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-2515][mllib] Chi Squared test #1733
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
ff17423
WIP
dorx 6598379
API and code structure.
dorx 706d436
Added API for RDD[Vector]
dorx 3d61582
input names
dorx e6b83f3
reviewer comments
dorx 4e4e361
WIP
dorx 50703a5
merge master
dorx bc7eb2e
unit passed; still need docs and some refactoring
dorx 5686082
facelift
dorx d64c2fb
Merge branch 'master' into chisquare
dorx 7eea80b
WIP
dorx e90d90a
Merge branch 'master' into chisquare
dorx c39eeb5
units passed with updated API
dorx 80d03e2
Reviewer comments.
dorx 7dde711
ChiSqTestResult renaming and changed to Class
dorx e95e485
reviewer comments.
dorx d286783
Merge branch 'master' into chisquare
dorx cafb3a7
fixed p-value for extreme case.
dorx File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
221 changes: 221 additions & 0 deletions
221
mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,221 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.mllib.stat.test | ||
|
|
||
| import breeze.linalg.{DenseMatrix => BDM} | ||
| import cern.jet.stat.Probability.chiSquareComplemented | ||
|
|
||
| import org.apache.spark.Logging | ||
| import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} | ||
| import org.apache.spark.mllib.regression.LabeledPoint | ||
| import org.apache.spark.rdd.RDD | ||
|
|
||
| /** | ||
| * Conduct the chi-squared test for the input RDDs using the specified method. | ||
| * Goodness-of-fit test is conducted on two `Vectors`, whereas test of independence is conducted | ||
| * on an input of type `Matrix` in which independence between columns is assessed. | ||
| * We also provide a method for computing the chi-squared statistic between each feature and the | ||
| * label for an input `RDD[LabeledPoint]`, return an `Array[ChiSquaredTestResult]` of size = | ||
| * number of features in the inpuy RDD. | ||
| * | ||
| * Supported methods for goodness of fit: `pearson` (default) | ||
| * Supported methods for independence: `pearson` (default) | ||
| * | ||
| * More information on Chi-squared test: http://en.wikipedia.org/wiki/Chi-squared_test | ||
| */ | ||
| private[stat] object ChiSqTest extends Logging { | ||
|
|
||
| /** | ||
| * @param name String name for the method. | ||
| * @param chiSqFunc Function for computing the statistic given the observed and expected counts. | ||
| */ | ||
| case class Method(name: String, chiSqFunc: (Double, Double) => Double) | ||
|
|
||
| // Pearson's chi-squared test: http://en.wikipedia.org/wiki/Pearson%27s_chi-squared_test | ||
| val PEARSON = new Method("pearson", (observed: Double, expected: Double) => { | ||
| val dev = observed - expected | ||
| dev * dev / expected | ||
| }) | ||
|
|
||
| // Null hypothesis for the two different types of chi-squared tests to be included in the result. | ||
| object NullHypothesis extends Enumeration { | ||
| type NullHypothesis = Value | ||
| val goodnessOfFit = Value("observed follows the same distribution as expected.") | ||
| val independence = Value("observations in each column are statistically independent.") | ||
| } | ||
|
|
||
| // Method identification based on input methodName string | ||
| private def methodFromString(methodName: String): Method = { | ||
| methodName match { | ||
| case PEARSON.name => PEARSON | ||
| case _ => throw new IllegalArgumentException("Unrecognized method for Chi squared test.") | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Conduct Pearson's independence test for each feature against the label across the input RDD. | ||
| * The contingency table is constructed from the raw (feature, label) pairs and used to conduct | ||
| * the independence test. | ||
| * Returns an array containing the ChiSquaredTestResult for every feature against the label. | ||
| */ | ||
| def chiSquaredFeatures(data: RDD[LabeledPoint], | ||
| methodName: String = PEARSON.name): Array[ChiSqTestResult] = { | ||
| val numCols = data.first().features.size | ||
| val results = new Array[ChiSqTestResult](numCols) | ||
| var labels: Map[Double, Int] = null | ||
| // At most 100 columns at a time | ||
| val batchSize = 100 | ||
| var batch = 0 | ||
| while (batch * batchSize < numCols) { | ||
| // The following block of code can be cleaned up and made public as | ||
| // chiSquared(data: RDD[(V1, V2)]) | ||
| val startCol = batch * batchSize | ||
| val endCol = startCol + math.min(batchSize, numCols - startCol) | ||
| val pairCounts = data.flatMap { p => | ||
| // assume dense vectors | ||
| p.features.toArray.slice(startCol, endCol).zipWithIndex.map { case (feature, col) => | ||
| (col, feature, p.label) | ||
| } | ||
| }.countByValue() | ||
|
|
||
| if (labels == null) { | ||
| // Do this only once for the first column since labels are invariant across features. | ||
| labels = | ||
| pairCounts.keys.filter(_._1 == startCol).map(_._3).toArray.distinct.zipWithIndex.toMap | ||
| } | ||
| val numLabels = labels.size | ||
| pairCounts.keys.groupBy(_._1).map { case (col, keys) => | ||
| val features = keys.map(_._2).toArray.distinct.zipWithIndex.toMap | ||
| val numRows = features.size | ||
| val contingency = new BDM(numRows, numLabels, new Array[Double](numRows * numLabels)) | ||
| keys.foreach { case (_, feature, label) => | ||
| val i = features(feature) | ||
| val j = labels(label) | ||
| contingency(i, j) += pairCounts((col, feature, label)) | ||
| } | ||
| results(col) = chiSquaredMatrix(Matrices.fromBreeze(contingency), methodName) | ||
| } | ||
| batch += 1 | ||
| } | ||
| results | ||
| } | ||
|
|
||
| /* | ||
| * Pearon's goodness of fit test on the input observed and expected counts/relative frequencies. | ||
| * Uniform distribution is assumed when `expected` is not passed in. | ||
| */ | ||
| def chiSquared(observed: Vector, | ||
| expected: Vector = Vectors.dense(Array[Double]()), | ||
| methodName: String = PEARSON.name): ChiSqTestResult = { | ||
|
|
||
| // Validate input arguments | ||
| val method = methodFromString(methodName) | ||
| if (expected.size != 0 && observed.size != expected.size) { | ||
| throw new IllegalArgumentException("observed and expected must be of the same size.") | ||
| } | ||
| val size = observed.size | ||
| if (size > 1000) { | ||
| logWarning("Chi-squared approximation may not be accurate due to low expected frequencies " | ||
| + s" as a result of a large number of categories: $size.") | ||
| } | ||
| val obsArr = observed.toArray | ||
| val expArr = if (expected.size == 0) Array.tabulate(size)(_ => 1.0 / size) else expected.toArray | ||
| if (!obsArr.forall(_ >= 0.0)) { | ||
| throw new IllegalArgumentException("Negative entries disallowed in the observed vector.") | ||
| } | ||
| if (expected.size != 0 && ! expArr.forall(_ >= 0.0)) { | ||
| throw new IllegalArgumentException("Negative entries disallowed in the expected vector.") | ||
| } | ||
|
|
||
| // Determine the scaling factor for expected | ||
| val obsSum = obsArr.sum | ||
| val expSum = if (expected.size == 0.0) 1.0 else expArr.sum | ||
| val scale = if (math.abs(obsSum - expSum) < 1e-7) 1.0 else obsSum / expSum | ||
|
|
||
| // compute chi-squared statistic | ||
| val statistic = obsArr.zip(expArr).foldLeft(0.0) { case (stat, (obs, exp)) => | ||
| if (exp == 0.0) { | ||
| if (obs == 0.0) { | ||
| throw new IllegalArgumentException("Chi-squared statistic undefined for input vectors due" | ||
| + " to 0.0 values in both observed and expected.") | ||
| } else { | ||
| return new ChiSqTestResult(0.0, size - 1, Double.PositiveInfinity, PEARSON.name, | ||
| NullHypothesis.goodnessOfFit.toString) | ||
| } | ||
| } | ||
| if (scale == 1.0) { | ||
| stat + method.chiSqFunc(obs, exp) | ||
| } else { | ||
| stat + method.chiSqFunc(obs, exp * scale) | ||
| } | ||
| } | ||
| val df = size - 1 | ||
| val pValue = chiSquareComplemented(df, statistic) | ||
| new ChiSqTestResult(pValue, df, statistic, PEARSON.name, NullHypothesis.goodnessOfFit.toString) | ||
| } | ||
|
|
||
| /* | ||
| * Pearon's independence test on the input contingency matrix. | ||
| * TODO: optimize for SparseMatrix when it becomes supported. | ||
| */ | ||
| def chiSquaredMatrix(counts: Matrix, methodName:String = PEARSON.name): ChiSqTestResult = { | ||
| val method = methodFromString(methodName) | ||
| val numRows = counts.numRows | ||
| val numCols = counts.numCols | ||
|
|
||
| // get row and column sums | ||
| val colSums = new Array[Double](numCols) | ||
| val rowSums = new Array[Double](numRows) | ||
| val colMajorArr = counts.toArray | ||
| var i = 0 | ||
| while (i < colMajorArr.size) { | ||
| val elem = colMajorArr(i) | ||
| if (elem < 0.0) { | ||
| throw new IllegalArgumentException("Contingency table cannot contain negative entries.") | ||
| } | ||
| colSums(i / numRows) += elem | ||
| rowSums(i % numRows) += elem | ||
| i += 1 | ||
| } | ||
| val total = colSums.sum | ||
|
|
||
| // second pass to collect statistic | ||
| var statistic = 0.0 | ||
| var j = 0 | ||
| while (j < colMajorArr.size) { | ||
| val col = j / numRows | ||
| val colSum = colSums(col) | ||
| if (colSum == 0.0) { | ||
| throw new IllegalArgumentException("Chi-squared statistic undefined for input matrix due to" | ||
| + s"0 sum in column [$col].") | ||
| } | ||
| val row = j % numRows | ||
| val rowSum = rowSums(row) | ||
| if (rowSum == 0.0) { | ||
| throw new IllegalArgumentException("Chi-squared statistic undefined for input matrix due to" | ||
| + s"0 sum in row [$row].") | ||
| } | ||
| val expected = colSum * rowSum / total | ||
| statistic += method.chiSqFunc(colMajorArr(j), expected) | ||
| j += 1 | ||
| } | ||
| val df = (numCols - 1) * (numRows - 1) | ||
| val pValue = chiSquareComplemented(df, statistic) | ||
| new ChiSqTestResult(pValue, df, statistic, methodName, NullHypothesis.independence.toString) | ||
| } | ||
| } |
88 changes: 88 additions & 0 deletions
88
mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.mllib.stat.test | ||
|
|
||
| import org.apache.spark.annotation.Experimental | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * Trait for hypothesis test results. | ||
| * @tparam DF Return type of `degreesOfFreedom`. | ||
| */ | ||
| @Experimental | ||
| trait TestResult[DF] { | ||
|
|
||
| /** | ||
| * The probability of obtaining a test statistic result at least as extreme as the one that was | ||
| * actually observed, assuming that the null hypothesis is true. | ||
| */ | ||
| def pValue: Double | ||
|
|
||
| /** | ||
| * Returns the degree(s) of freedom of the hypothesis test. | ||
| * Return type should be Number(e.g. Int, Double) or tuples of Numbers for toString compatibility. | ||
| */ | ||
| def degreesOfFreedom: DF | ||
|
|
||
| /** | ||
| * Test statistic. | ||
| */ | ||
| def statistic: Double | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto: doc |
||
|
|
||
| /** | ||
| * String explaining the hypothesis test result. | ||
| * Specific classes implementing this trait should override this method to output test-specific | ||
| * information. | ||
| */ | ||
| override def toString: String = { | ||
|
|
||
| // String explaining what the p-value indicates. | ||
| val pValueExplain = if (pValue <= 0.01) { | ||
| "Very strong presumption against null hypothesis." | ||
| } else if (0.01 < pValue && pValue <= 0.05) { | ||
| "Strong presumption against null hypothesis." | ||
| } else if (0.05 < pValue && pValue <= 0.01) { | ||
| "Low presumption against null hypothesis." | ||
| } else { | ||
| "No presumption against null hypothesis." | ||
| } | ||
|
|
||
| s"degrees of freedom = ${degreesOfFreedom.toString} \n" + | ||
| s"statistic = $statistic \n" + | ||
| s"pValue = $pValue \n" + pValueExplain | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * Object containing the test results for the chi squared hypothesis test. | ||
| */ | ||
| @Experimental | ||
| class ChiSqTestResult(override val pValue: Double, | ||
| override val degreesOfFreedom: Int, | ||
| override val statistic: Double, | ||
| val method: String, | ||
| val nullHypothesis: String) extends TestResult[Int] { | ||
|
|
||
| override def toString: String = { | ||
| "Chi squared test summary: \n" + | ||
| s"method: $method \n" + | ||
| s"null hypothesis: $nullHypothesis \n" + | ||
| super.toString | ||
| } | ||
| } | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
documentation