Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
unit passed; still need docs and some refactoring
  • Loading branch information
dorx committed Aug 2, 2014
commit bc7eb2eeba4e2ccf10b891e4ce59db55823cea3b
18 changes: 11 additions & 7 deletions mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -91,22 +91,26 @@ object Statistics {
@Experimental
def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method)

// Technically these should be RDD[Long] since the data should be counts
def chiSquared(expected: RDD[Double],
observed: RDD[Double],
method: String): ChiSquaredTestResult = {
ChiSquaredTest.chiSquared(expected, observed, method)
// Technically input should be RDD[Long] since the data should be counts
@Experimental
def chiSquared(x: RDD[Double], y: RDD[Double], method: String): ChiSquaredTestResult = {
ChiSquaredTest.chiSquared(x, y, method)
}

@Experimental
def chiSquared(expected: RDD[Double], observed: RDD[Double]): ChiSquaredTestResult = {
ChiSquaredTest.chiSquared(expected, observed)
}

// Same here. It should be something like RDD[Array[Long]] for counts instead, but I don't know
// if we should be consistent about how a "matrix" is presented
@Experimental
def chiSquared(counts: RDD[Vector], method: String): ChiSquaredTestResult = {
ChiSquaredTest.chiSquared(counts, method)
ChiSquaredTest.chiSquaredMatrix(counts, method)
}

def chiSquared(counts: RDD[Vector]): ChiSquaredTestResult = ChiSquaredTest.chiSquared(counts)
@Experimental
def chiSquared(counts: RDD[Vector]): ChiSquaredTestResult = {
ChiSquaredTest.chiSquaredMatrix(counts)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.mllib.stat.test

import cern.jet.stat.Probability.chiSquareComplemented
import cern.jet.stat.Probability.chiSquare

import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.{DenseVector, Vector}
Expand All @@ -35,7 +35,7 @@ private[stat] object ChiSquaredTest {
}
}

def chiSquared(counts: RDD[Vector], method: String = PEARSON): ChiSquaredTestResult = {
def chiSquaredMatrix(counts: RDD[Vector], method: String = PEARSON): ChiSquaredTestResult = {
method match {
case PEARSON => chiSquaredPearson(counts)
case _ => throw new IllegalArgumentException("Unrecognized method for Chi squared test.")
Expand All @@ -48,23 +48,29 @@ private[stat] object ChiSquaredTest {
chiSquaredPearson(mat)
}

// Makes two passes over the RDD total
private def chiSquaredPearson(counts: RDD[Vector]): ChiSquaredTestResult = {
val numCols = counts.first.size
val colSums = new Array[Double](numCols)
var result = (colSums, 0) // second value is for count of vectors in the RDD

// Make two passes over the RDD with the first pass for collecting column sums
// TODO check that the counts are all non-negative in this pass
counts.aggregate(result)(
(sums, vector) => ((sums._1, vector.toArray).zipped.map(_ + _), sums._2 + 1), // seqOp
// first pass for collecting column sums
val result = counts.aggregate((new Array[Double](numCols), 0))(
(sums, vector) => {
val arr = vector.toArray
// check that the counts are all non-negative and finite in this pass
if (!arr.forall( i => !i.isNaN && !i.isInfinite && i >= 0.0)) {
throw new IllegalArgumentException("All input entries must be nonnegative and finite.")
}
((sums._1, arr).zipped.map(_ + _), sums._2 + 1)
}, //seqOp
(sums1, sums2) => ((sums1._1, sums2._1).zipped.map(_ + _), sums1._2 + sums2._2)) // combOp

val colSums = result._1
val total = colSums.sum

// Second pass to compute chi-squared statistic
val statistic = counts.aggregate(0.0)(rowStatistic(colSums, total), _ + _)
val df = (numCols - 1) * (result._2 - 1)
val pValue = chiSquareComplemented(statistic, df)
val pValue = chiSquare(statistic, df)

new ChiSquaredTestResult(pValue, Array(df), statistic, PEARSON)
}
Expand All @@ -76,7 +82,8 @@ private[stat] object ChiSquaredTest {
val rowSum = arr.sum
(arr, colSums).zipped.foldLeft(statistic) { case (stat, (observed, colSum)) =>
val expected = rowSum * colSum / total
stat + (observed - expected) * (observed - expected) / expected
val r = stat + (observed - expected) * (observed - expected) / expected
r
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ trait TestResult {
*/
override def toString: String = {
s"pValue = $pValue \n" + // TODO explain what pValue is
s"degrees of freedom = $degreesOfFreedom \n" +
s"degrees of freedom = ${degreesOfFreedom.mkString} \n" +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mkString("[", ",", "]")

s"statistic = $statistic"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,16 @@ package org.apache.spark.mllib.stat
import org.scalatest.FunSuite

import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.util.TestingUtils._

class HypothesisTestSuite extends FunSuite with LocalSparkContext {
test("chi squared") {

val x = sc.parallelize(Array(2.0, 23.0, 53.0))
val y = sc.parallelize(Array(53.0, 76.0, 1.0))
val c = Statistics.chiSquared(x, y)
assert(c.statistic ~= 120.2546 absTol 1e-3)

val bad = sc.parallelize(Array(2.0, -23.0, 53.0))
intercept[Exception](Statistics.chiSquared(bad, y))
}
}