1717
1818package org .apache .spark .mllib .stat
1919
20+ import java .util .Random
21+
2022import org .scalatest .FunSuite
2123
24+ import org .apache .spark .SparkException
2225import org .apache .spark .mllib .linalg .{DenseVector , Matrices , Vectors }
2326import org .apache .spark .mllib .regression .LabeledPoint
2427import org .apache .spark .mllib .stat .test .ChiSqTest
@@ -107,12 +110,13 @@ class HypothesisTestSuite extends FunSuite with LocalSparkContext {
107110 // labels: 1.0 (2 / 6), 0.0 (4 / 6)
108111 // feature1: 0.5 (1 / 6), 1.5 (2 / 6), 3.5 (3 / 6)
109112 // feature2: 10.0 (1 / 6), 20.0 (1 / 6), 30.0 (2 / 6), 40.0 (2 / 6)
110- val data = Array (new LabeledPoint (0.0 , Vectors .dense(0.5 , 10.0 )),
111- new LabeledPoint (0.0 , Vectors .dense(1.5 , 20.0 )),
112- new LabeledPoint (1.0 , Vectors .dense(1.5 , 30.0 )),
113- new LabeledPoint (0.0 , Vectors .dense(3.5 , 30.0 )),
114- new LabeledPoint (0.0 , Vectors .dense(3.5 , 40.0 )),
115- new LabeledPoint (1.0 , Vectors .dense(3.5 , 40.0 )))
113+ val data = Seq (
114+ LabeledPoint (0.0 , Vectors .dense(0.5 , 10.0 )),
115+ LabeledPoint (0.0 , Vectors .dense(1.5 , 20.0 )),
116+ LabeledPoint (1.0 , Vectors .dense(1.5 , 30.0 )),
117+ LabeledPoint (0.0 , Vectors .dense(3.5 , 30.0 )),
118+ LabeledPoint (0.0 , Vectors .dense(3.5 , 40.0 )),
119+ LabeledPoint (1.0 , Vectors .dense(3.5 , 40.0 )))
116120 for (numParts <- List (2 , 4 , 6 , 8 )) {
117121 val chi = Statistics .chiSqTest(sc.parallelize(data, numParts))
118122 val feature1 = chi(0 )
@@ -130,10 +134,25 @@ class HypothesisTestSuite extends FunSuite with LocalSparkContext {
130134 }
131135
132136 // Test that the right number of results is returned
133- val numCols = 321
134- val sparseData = Array (new LabeledPoint (0.0 , Vectors .sparse(numCols, Seq ((100 , 2.0 )))),
135- new LabeledPoint (0.0 , Vectors .sparse(numCols, Seq ((200 , 1.0 )))))
137+ val numCols = 1001
138+ val sparseData = Array (
139+ new LabeledPoint (0.0 , Vectors .sparse(numCols, Seq ((100 , 2.0 )))),
140+ new LabeledPoint (0.1 , Vectors .sparse(numCols, Seq ((200 , 1.0 )))))
136141 val chi = Statistics .chiSqTest(sc.parallelize(sparseData))
137142 assert(chi.size === numCols)
143+ assert(chi(1000 ) != null ) // SPARK-3087
144+
145+ // Detect continous features or labels
146+ val random = new Random (11L )
147+ val continuousLabel =
148+ Seq .fill(100000 )(LabeledPoint (random.nextDouble(), Vectors .dense(random.nextInt(2 ))))
149+ intercept[SparkException ] {
150+ Statistics .chiSqTest(sc.parallelize(continuousLabel, 2 ))
151+ }
152+ val continuousFeature =
153+ Seq .fill(100000 )(LabeledPoint (random.nextInt(2 ), Vectors .dense(random.nextDouble())))
154+ intercept[SparkException ] {
155+ Statistics .chiSqTest(sc.parallelize(continuousFeature, 2 ))
156+ }
138157 }
139158}
0 commit comments