-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-2515][mllib] Chi Squared test #1733
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
ff17423
6598379
706d436
3d61582
e6b83f3
4e4e361
50703a5
bc7eb2e
5686082
d64c2fb
7eea80b
e90d90a
c39eeb5
80d03e2
7dde711
e95e485
d286783
cafb3a7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.stat | |
| import org.apache.spark.annotation.Experimental | ||
| import org.apache.spark.mllib.linalg.{Matrix, Vector} | ||
| import org.apache.spark.mllib.stat.correlation.Correlations | ||
| import org.apache.spark.mllib.stat.test.{ChiSquaredTest, ChiSquaredTestResult} | ||
| import org.apache.spark.rdd.RDD | ||
|
|
||
| /** | ||
|
|
@@ -89,4 +90,76 @@ object Statistics { | |
| */ | ||
| @Experimental | ||
| def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method) | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * Conduct the Chi-squared goodness of fit test of the observed data against the | ||
| * expected distribution. | ||
| * | ||
| * Note: the two input RDDs need to have the same number of partitions and the same number of | ||
| * elements in each partition. | ||
| * | ||
| * @param observed RDD[Double] containing the observed counts. | ||
| * @param expected RDD[Double] containing the expected counts. If the observed total differs from | ||
| * the expected total, this RDD is rescaled to sum up to the observed total. | ||
| * @param method String specifying the method to use for the Chi-squared test. | ||
| * Supported: `pearson` (default) | ||
| * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, | ||
| * the method used, and the null hypothesis. | ||
| */ | ||
| @Experimental | ||
| def chiSquared(observed: RDD[Double], | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall we call it
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| expected: RDD[Double], | ||
| method: String): ChiSquaredTestResult = { | ||
| ChiSquaredTest.chiSquared(observed, expected, method) | ||
| } | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * Conduct the Chi-squared goodness of fit test of the observed data against the | ||
| * expected distribution. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mention minor: I think it should be fine to remove the rest of the doc and point users to the method with the full set of parameters, so we only maintain one copy. |
||
| * | ||
| * Note: the two input RDDs need to have the same number of partitions and the same number of | ||
| * elements in each partition. | ||
| * | ||
| * @param observed RDD[Double] containing the observed counts. | ||
| * @param expected RDD[Double] containing the expected counts. If the observed total differs from | ||
| * the expected total, this RDD is rescaled to sum up to the observed total. | ||
| * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, | ||
| * the method used, and the null hypothesis. | ||
| */ | ||
| @Experimental | ||
| def chiSquared(observed: RDD[Double], expected: RDD[Double]): ChiSquaredTestResult = { | ||
| ChiSquaredTest.chiSquared(observed, expected) | ||
| } | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * Conduct the Chi-squared independence test between the columns in the input matrix. | ||
| * | ||
| * @param counts RDD[Vector] containing observations with rows representing categories and columns | ||
| * representing separate trials for which independence between trials is assessed. | ||
| * @param method String specifying the method to use for the Chi-squared test. | ||
| * Supported: `pearson` (default) | ||
| * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, | ||
| * the method used, and the null hypothesis. | ||
| */ | ||
| @Experimental | ||
| def chiSquared(counts: RDD[Vector], method: String): ChiSquaredTestResult = { | ||
| ChiSquaredTest.chiSquaredMatrix(counts, method) | ||
| } | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * Conduct the Chi-squared independence test between the columns in the input matrix. | ||
| * | ||
| * @param counts RDD[Vector] containing observations with rows representing categories and columns | ||
| * representing separate trials for which independence between trials is assessed. | ||
| * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, | ||
| * the method used, and the null hypothesis. | ||
| */ | ||
| @Experimental | ||
| def chiSquared(counts: RDD[Vector]): ChiSquaredTestResult = { | ||
| ChiSquaredTest.chiSquaredMatrix(counts) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,170 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.mllib.stat.test | ||
|
|
||
| import cern.jet.stat.Probability.chiSquareComplemented | ||
|
|
||
| import org.apache.spark.Logging | ||
| import org.apache.spark.SparkContext._ | ||
| import org.apache.spark.mllib.linalg.Vector | ||
| import org.apache.spark.rdd.RDD | ||
|
|
||
| /** | ||
| * Conduct the Chi-squared test for the input RDDs using the specified method. | ||
|
||
| * Goodness-of-fit test is conducted on two RDD[Double]s, whereas test of independence is conducted | ||
| * on an input of type RDD[Vector] in which independence between columns is assessed. | ||
| * | ||
| * Supported methods for goodness of fit: `pearson` (default) | ||
| * Supported methods for independence: `pearson` (default) | ||
| * | ||
| * More information on Chi-squared test: http://en.wikipedia.org/wiki/Chi-squared_test | ||
| * More information on Pearson's chi-squared test: | ||
| * http://en.wikipedia.org/wiki/Pearson%27s_chi-squared_test | ||
| * | ||
| */ | ||
| private[stat] object ChiSquaredTest extends Logging { | ||
|
||
|
|
||
| val PEARSON = "pearson" | ||
|
|
||
| object NullHypothesis extends Enumeration { | ||
| type NullHypothesis = Value | ||
| val goodnessOfFit = Value("observed follows the same distribution as expected.") | ||
| val independence = Value("observations in each column are statistically independent.") | ||
| } | ||
|
|
||
| val zeroExpectedError = new IllegalArgumentException("Chi square statistic cannot be computed" | ||
| + " for input RDD due to nonpositive entries in the expected contingency table.") | ||
|
|
||
| // delegator method for goodness of fit test | ||
| def chiSquared(observed: RDD[Double], | ||
| expected: RDD[Double], | ||
| method: String = PEARSON): ChiSquaredTestResult = { | ||
| method match { | ||
| case PEARSON => chiSquaredPearson(observed, expected) | ||
| case _ => throw new IllegalArgumentException("Unrecognized method for Chi squared test.") | ||
| } | ||
| } | ||
|
|
||
| // delegator method for independence test | ||
| def chiSquaredMatrix(counts: RDD[Vector], method: String = PEARSON): ChiSquaredTestResult = { | ||
| method match { | ||
| // Yates' correction doesn't really apply here | ||
| case PEARSON => chiSquaredPearson(counts) | ||
| case _ => throw new IllegalArgumentException("Unrecognized method for Chi squared test.") | ||
| } | ||
| } | ||
|
|
||
| // Equation for computing Pearson's chi-squared statistic | ||
| private def pearson = (observed: Double, expected: Double) => { | ||
| val dev = observed - expected | ||
| dev * dev / expected | ||
| } | ||
|
|
||
| /* | ||
| * Pearon's goodness of fit test. This can be easily made abstract to support other methods. | ||
| * Makes two passes over both input RDDs. | ||
| */ | ||
| private def chiSquaredPearson(observed: RDD[Double], | ||
| expected: RDD[Double]): ChiSquaredTestResult = { | ||
|
|
||
| // compute the scaling factor and count for the input RDDs and check positivity in one pass | ||
| val observedStats = observed.stats() | ||
| if (observedStats.min < 0.0) { | ||
| throw new IllegalArgumentException("Values in observed must be nonnegative.") | ||
| } | ||
| val expectedStats = expected.stats() | ||
| if (expectedStats.min <= 0.0) { | ||
| throw new IllegalArgumentException("Values in expected must be positive.") | ||
| } | ||
| if (observedStats.count != expectedStats.count) { | ||
| throw new IllegalArgumentException("observed and expected must be of the same size.") | ||
| } | ||
|
|
||
| val expScaled = if (math.abs(observedStats.sum - expectedStats.sum) < 1e-7) { | ||
| // No scaling needed since both RDDs have the same total | ||
| expected | ||
| } else { | ||
| expected.map(_ * observedStats.sum / expectedStats.sum) | ||
| } | ||
|
|
||
| // Second pass to compute chi-squared statistic | ||
| val statistic = observed.zip(expScaled).aggregate(0.0)({ case (sum, (obs, exp)) => { | ||
| sum + pearson(obs, exp) | ||
| }}, _ + _) | ||
| val df = observedStats.count - 1 | ||
| val pValue = chiSquareComplemented(df, statistic) | ||
| new ChiSquaredTestResult(pValue, Array(df), statistic, PEARSON, | ||
| NullHypothesis.goodnessOfFit.toString) | ||
| } | ||
|
|
||
| /* | ||
| * Pearon's independence test. This can be easily made abstract to support other methods. | ||
| * Makes two passes over the input RDD. | ||
| */ | ||
| private def chiSquaredPearson(counts: RDD[Vector]): ChiSquaredTestResult = { | ||
|
|
||
| val numCols = counts.first.size | ||
|
|
||
| // first pass for collecting column sums | ||
| case class SumNCount(colSums: Array[Double], numRows: Long) | ||
|
|
||
| val result = counts.aggregate(new SumNCount(new Array[Double](numCols), 0L))( | ||
| (sumNCount, vector) => { | ||
| val arr = vector.toArray | ||
| // check that the counts are all non-negative and finite in this pass | ||
| if (!arr.forall(i => !i.isNaN && !i.isInfinite && i >= 0.0)) { | ||
| throw new IllegalArgumentException("Values in the input RDD must be nonnegative.") | ||
| } | ||
| new SumNCount((sumNCount.colSums, arr).zipped.map(_ + _), sumNCount.numRows + 1) | ||
| }, (sums1, sums2) => { | ||
| new SumNCount((sums1.colSums, sums2.colSums).zipped.map(_ + _), | ||
| sums1.numRows + sums2.numRows) | ||
| }) | ||
|
|
||
| val colSums = result.colSums | ||
| if (!colSums.forall(_ > 0.0)) { | ||
| throw zeroExpectedError | ||
| } | ||
| val total = colSums.sum | ||
|
|
||
| // Second pass to compute chi-squared statistic | ||
| val statistic = counts.aggregate(0.0)(rowStatistic(colSums, total, pearson), _ + _) | ||
| val df = (numCols - 1) * (result.numRows - 1) | ||
| val pValue = chiSquareComplemented(df, statistic) | ||
| new ChiSquaredTestResult(pValue, Array(df), statistic, PEARSON, | ||
| NullHypothesis.independence.toString) | ||
| } | ||
|
|
||
| // returns function to be used as seqOp in the aggregate operation to collect statistic | ||
| private def rowStatistic(colSums: Array[Double], | ||
| total: Double, | ||
| chiSquared: (Double, Double) => Double) = { | ||
| (statistic: Double, vector: Vector) => { | ||
| val arr = vector.toArray | ||
| val rowSum = arr.sum | ||
| if (rowSum == 0.0) { // rowSum >= 0.0 as ensured by the nonnegative check | ||
| throw zeroExpectedError | ||
| } | ||
| (arr, colSums).zipped.foldLeft(statistic) { case (stat, (observed, colSum)) => | ||
| val expected = rowSum * colSum / total | ||
| val r = stat + chiSquared(observed, expected) | ||
| r | ||
| } | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.mllib.stat.test | ||
|
|
||
| import org.apache.spark.annotation.Experimental | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * Trait for hypothesis test results. | ||
| */ | ||
| @Experimental | ||
| trait TestResult { | ||
|
|
||
| def pValue: Double | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. documentation |
||
|
|
||
| def degreesOfFreedom: Array[Long] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto: doc
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| def statistic: Double | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto: doc |
||
|
|
||
| /** | ||
| * String explaining the hypothesis test result. | ||
| * Specific classes implementing this trait should override this method to output test-specific | ||
| * information. | ||
| */ | ||
| override def toString: String = { | ||
|
|
||
| val pValueExplain = if (pValue <= 0.01) { | ||
| "Very strong presumption against null hypothesis." | ||
| } else if (0.01 < pValue && pValue <= 0.05) { | ||
| "Strong presumption against null hypothesis." | ||
| } else if (0.05 < pValue && pValue <= 0.01) { | ||
| "Low presumption against null hypothesis." | ||
| } else { | ||
| "No presumption against null hypothesis." | ||
| } | ||
|
|
||
| s"degrees of freedom = ${degreesOfFreedom.mkString} \n" + | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| s"statistic = $statistic \n" + | ||
| s"pValue = $pValue \n" + pValueExplain | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * Object containing the test results for the chi squared hypothesis test. | ||
| */ | ||
| @Experimental | ||
| case class ChiSquaredTestResult(override val pValue: Double, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it need to be a case class? Scala compiler will add many methods to a case class and make it very hard to extend.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Case class is a logical choice here since it's essentially an immutable object holding a bunch of invariant fields and doesn't do any stateful computations inside of the class. Is there development plan for extending this classes in the future?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No case class features are used, especially pattern matching. This case class will extend
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Btw, shall we rename it to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Whether correction is used or not can actually be reflected in the method name ( |
||
| override val degreesOfFreedom: Array[Long], | ||
| override val statistic: Double, | ||
| val method: String, | ||
| val nullHypothesis: String) extends TestResult { | ||
|
|
||
| override def toString: String = { | ||
| "Chi squared test summary: \n" + | ||
| s"method: $method \n" + | ||
| s"null hypothesis: $nullHypothesis \n" + | ||
| super.toString | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Chi->chi