Skip to content
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.mllib.stat
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.{Matrix, Vector}
import org.apache.spark.mllib.stat.correlation.Correlations
import org.apache.spark.mllib.stat.test.{ChiSquaredTest, ChiSquaredTestResult}
import org.apache.spark.rdd.RDD

/**
Expand Down Expand Up @@ -89,4 +90,76 @@ object Statistics {
*/
@Experimental
def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method)

/**
* :: Experimental ::
* Conduct the Chi-squared goodness of fit test of the observed data against the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chi -> chi

* expected distribution.
*
* Note: the two input RDDs need to have the same number of partitions and the same number of
* elements in each partition.
*
* @param observed RDD[Double] containing the observed counts.
* @param expected RDD[Double] containing the expected counts. If the observed total differs from
* the expected total, this RDD is rescaled to sum up to the observed total.
* @param method String specifying the method to use for the Chi-squared test.
* Supported: `pearson` (default)
* @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value,
* the method used, and the null hypothesis.
*/
@Experimental
def chiSquared(observed: RDD[Double],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we call it chiSqTest (following R's)? We need test in the method name because X_2 is also a distribution. I feel chiSqTest may be better than chiSquaredTest because it is also called chi-square test without d.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chiSqTest sounds good.

expected: RDD[Double],
method: String): ChiSquaredTestResult = {
ChiSquaredTest.chiSquared(observed, expected, method)
}

/**
* :: Experimental ::
* Conduct the Chi-squared goodness of fit test of the observed data against the
* expected distribution.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mention pearson here?

minor: I think it should be fine to remove the rest of the doc and point users to the method with the full set of parameters, so we only maintain one copy.

*
* Note: the two input RDDs need to have the same number of partitions and the same number of
* elements in each partition.
*
* @param observed RDD[Double] containing the observed counts.
* @param expected RDD[Double] containing the expected counts. If the observed total differs from
* the expected total, this RDD is rescaled to sum up to the observed total.
* @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value,
* the method used, and the null hypothesis.
*/
@Experimental
def chiSquared(observed: RDD[Double], expected: RDD[Double]): ChiSquaredTestResult = {
ChiSquaredTest.chiSquared(observed, expected)
}

/**
* :: Experimental ::
* Conduct the Chi-squared independence test between the columns in the input matrix.
*
* @param counts RDD[Vector] containing observations with rows representing categories and columns
* representing separate trials for which independence between trials is assessed.
* @param method String specifying the method to use for the Chi-squared test.
* Supported: `pearson` (default)
* @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value,
* the method used, and the null hypothesis.
*/
@Experimental
def chiSquared(counts: RDD[Vector], method: String): ChiSquaredTestResult = {
ChiSquaredTest.chiSquaredMatrix(counts, method)
}

/**
* :: Experimental ::
* Conduct the Chi-squared independence test between the columns in the input matrix.
*
* @param counts RDD[Vector] containing observations with rows representing categories and columns
* representing separate trials for which independence between trials is assessed.
* @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value,
* the method used, and the null hypothesis.
*/
@Experimental
def chiSquared(counts: RDD[Vector]): ChiSquaredTestResult = {
ChiSquaredTest.chiSquaredMatrix(counts)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.stat.test

import cern.jet.stat.Probability.chiSquareComplemented

import org.apache.spark.Logging
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD

/**
* Conduct the Chi-squared test for the input RDDs using the specified method.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chi -> chi

* Goodness-of-fit test is conducted on two RDD[Double]s, whereas test of independence is conducted
* on an input of type RDD[Vector] in which independence between columns is assessed.
*
* Supported methods for goodness of fit: `pearson` (default)
* Supported methods for independence: `pearson` (default)
*
* More information on Chi-squared test: http://en.wikipedia.org/wiki/Chi-squared_test
* More information on Pearson's chi-squared test:
* http://en.wikipedia.org/wiki/Pearson%27s_chi-squared_test
*
*/
private[stat] object ChiSquaredTest extends Logging {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: ChiSquaredTest -> ChiSqTest (to match the public method names)


val PEARSON = "pearson"

object NullHypothesis extends Enumeration {
type NullHypothesis = Value
val goodnessOfFit = Value("observed follows the same distribution as expected.")
val independence = Value("observations in each column are statistically independent.")
}

val zeroExpectedError = new IllegalArgumentException("Chi square statistic cannot be computed"
+ " for input RDD due to nonpositive entries in the expected contingency table.")

// delegator method for goodness of fit test
def chiSquared(observed: RDD[Double],
expected: RDD[Double],
method: String = PEARSON): ChiSquaredTestResult = {
method match {
case PEARSON => chiSquaredPearson(observed, expected)
case _ => throw new IllegalArgumentException("Unrecognized method for Chi squared test.")
}
}

// delegator method for independence test
def chiSquaredMatrix(counts: RDD[Vector], method: String = PEARSON): ChiSquaredTestResult = {
method match {
// Yates' correction doesn't really apply here
case PEARSON => chiSquaredPearson(counts)
case _ => throw new IllegalArgumentException("Unrecognized method for Chi squared test.")
}
}

// Equation for computing Pearson's chi-squared statistic
private def pearson = (observed: Double, expected: Double) => {
val dev = observed - expected
dev * dev / expected
}

/*
* Pearon's goodness of fit test. This can be easily made abstract to support other methods.
* Makes two passes over both input RDDs.
*/
private def chiSquaredPearson(observed: RDD[Double],
expected: RDD[Double]): ChiSquaredTestResult = {

// compute the scaling factor and count for the input RDDs and check positivity in one pass
val observedStats = observed.stats()
if (observedStats.min < 0.0) {
throw new IllegalArgumentException("Values in observed must be nonnegative.")
}
val expectedStats = expected.stats()
if (expectedStats.min <= 0.0) {
throw new IllegalArgumentException("Values in expected must be positive.")
}
if (observedStats.count != expectedStats.count) {
throw new IllegalArgumentException("observed and expected must be of the same size.")
}

val expScaled = if (math.abs(observedStats.sum - expectedStats.sum) < 1e-7) {
// No scaling needed since both RDDs have the same total
expected
} else {
expected.map(_ * observedStats.sum / expectedStats.sum)
}

// Second pass to compute chi-squared statistic
val statistic = observed.zip(expScaled).aggregate(0.0)({ case (sum, (obs, exp)) => {
sum + pearson(obs, exp)
}}, _ + _)
val df = observedStats.count - 1
val pValue = chiSquareComplemented(df, statistic)
new ChiSquaredTestResult(pValue, Array(df), statistic, PEARSON,
NullHypothesis.goodnessOfFit.toString)
}

/*
* Pearon's independence test. This can be easily made abstract to support other methods.
* Makes two passes over the input RDD.
*/
private def chiSquaredPearson(counts: RDD[Vector]): ChiSquaredTestResult = {

val numCols = counts.first.size

// first pass for collecting column sums
case class SumNCount(colSums: Array[Double], numRows: Long)

val result = counts.aggregate(new SumNCount(new Array[Double](numCols), 0L))(
(sumNCount, vector) => {
val arr = vector.toArray
// check that the counts are all non-negative and finite in this pass
if (!arr.forall(i => !i.isNaN && !i.isInfinite && i >= 0.0)) {
throw new IllegalArgumentException("Values in the input RDD must be nonnegative.")
}
new SumNCount((sumNCount.colSums, arr).zipped.map(_ + _), sumNCount.numRows + 1)
}, (sums1, sums2) => {
new SumNCount((sums1.colSums, sums2.colSums).zipped.map(_ + _),
sums1.numRows + sums2.numRows)
})

val colSums = result.colSums
if (!colSums.forall(_ > 0.0)) {
throw zeroExpectedError
}
val total = colSums.sum

// Second pass to compute chi-squared statistic
val statistic = counts.aggregate(0.0)(rowStatistic(colSums, total, pearson), _ + _)
val df = (numCols - 1) * (result.numRows - 1)
val pValue = chiSquareComplemented(df, statistic)
new ChiSquaredTestResult(pValue, Array(df), statistic, PEARSON,
NullHypothesis.independence.toString)
}

// returns function to be used as seqOp in the aggregate operation to collect statistic
private def rowStatistic(colSums: Array[Double],
total: Double,
chiSquared: (Double, Double) => Double) = {
(statistic: Double, vector: Vector) => {
val arr = vector.toArray
val rowSum = arr.sum
if (rowSum == 0.0) { // rowSum >= 0.0 as ensured by the nonnegative check
throw zeroExpectedError
}
(arr, colSums).zipped.foldLeft(statistic) { case (stat, (observed, colSum)) =>
val expected = rowSum * colSum / total
val r = stat + chiSquared(observed, expected)
r
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.stat.test

import org.apache.spark.annotation.Experimental

/**
* :: Experimental ::
* Trait for hypothesis test results.
*/
@Experimental
trait TestResult {

def pValue: Double
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

documentation


def degreesOfFreedom: Array[Long]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: doc

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

df should be an array of double or we can make it a generic type. In t-test and f-test, df are not integers.


def statistic: Double
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: doc


/**
* String explaining the hypothesis test result.
* Specific classes implementing this trait should override this method to output test-specific
* information.
*/
override def toString: String = {

val pValueExplain = if (pValue <= 0.01) {
"Very strong presumption against null hypothesis."
} else if (0.01 < pValue && pValue <= 0.05) {
"Strong presumption against null hypothesis."
} else if (0.05 < pValue && pValue <= 0.01) {
"Low presumption against null hypothesis."
} else {
"No presumption against null hypothesis."
}

s"degrees of freedom = ${degreesOfFreedom.mkString} \n" +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mkString("[", ",", "]")

s"statistic = $statistic \n" +
s"pValue = $pValue \n" + pValueExplain
}
}

/**
* :: Experimental ::
* Object containing the test results for the chi squared hypothesis test.
*/
@Experimental
case class ChiSquaredTestResult(override val pValue: Double,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it need to be a case class? Scala compiler will add many methods to a case class and make it very hard to extend.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Case class is a logical choice here since it's essentially an immutable object holding a bunch of invariant fields and doesn't do any stateful computations inside of the class. Is there development plan for extending this classes in the future?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No case class features are used, especially pattern matching. This case class will extend Product5 and make it impossible to add a field, for example, whether correction is used or not. Also, with a case class, it is very hard to add a static method. We might want to write the test result to JSON and later parse it back. A natural choice would be ChiSquaredTestResult.fromJSON(json: String) but it is very complicated to match the type signature generated by Scala's compiler. We had this problem with LabeledPoint in MLlib, which is a public case class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, shall we rename it to ChiSqTestResult? So chiSqTest() returns ChiSqTestResult.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whether correction is used or not can actually be reflected in the method name (pearson v yates). I doubt there's a lot of use cases for parsing the result back from JSON so let's not worry about it for now (also, don't static methods usually come in the companion object anyway?). The way I see case classes is that they're like data structs that encapsulates immutable fields (the list of fields can be modified in later releases given that this is all experimental) and we get free getter methods for constructor arguments with a case class, but if there are compiler optimization complications, I can change it to a regular class.

override val degreesOfFreedom: Array[Long],
override val statistic: Double,
val method: String,
val nullHypothesis: String) extends TestResult {

override def toString: String = {
"Chi squared test summary: \n" +
s"method: $method \n" +
s"null hypothesis: $nullHypothesis \n" +
super.toString
}
}
Loading