-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-19634][ML] Multivariate summarizer - dataframes API #17419
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
f3fa658
7539835
673943f
202b672
be01981
a983284
647a4fe
3c4bef7
56390cc
c3f236c
ef955c0
a04f923
946d490
201eb77
f4dec88
4af0f47
9f29030
e9877dc
3a11d02
6d26c17
35eaeb0
58b17dc
18078c1
ffe5cfe
41f4be6
ba200bb
2f809ef
662f62c
96be071
a569dac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,9 +17,14 @@ | |
|
|
||
| package org.apache.spark.ml.stat | ||
|
|
||
| import org.scalatest.exceptions.TestFailedException | ||
|
|
||
| import org.apache.spark.{SparkException, SparkFunSuite} | ||
| import org.apache.spark.ml.linalg.{Vector, Vectors} | ||
| import org.apache.spark.ml.stat.SummaryBuilderImpl.Buffer | ||
| import org.apache.spark.ml.util.TestingUtils._ | ||
| import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} | ||
| import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer | ||
| import org.apache.spark.mllib.util.MLlibTestSparkContext | ||
| import org.apache.spark.sql.{DataFrame, Dataset, Row} | ||
| import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema | ||
|
|
@@ -60,15 +65,22 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { | |
| case x => throw new Exception(x.toString) | ||
| } | ||
|
|
||
| val s = { | ||
| val s2 = new MultivariateOnlineSummarizer | ||
| inputVec.foreach(v => s2OldVectors.fromML(v))) | ||
| s2 | ||
| } | ||
|
|
||
| // Because the Spark context is reset between tests, we cannot hold a reference onto it. | ||
| def wrapped() = { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rename wrapped -> wrappedInit |
||
| val df = sc.parallelize(inputVec).map(Tuple1.apply).toDF("features") | ||
| val c = df.col("features") | ||
| df -> c | ||
| (df, c) | ||
| } | ||
|
|
||
| registerTest(s"$name - mean only") { | ||
| val (df, c) = wrapped() | ||
| compare(df.select(metrics("mean").summary(c), mean(c)), Seq(Row(exp.mean), exp.mean)) | ||
| compare(df.select(metrics("mean").summary(c), mean(c)), Seq(Row(exp.mean), s.mean)) | ||
| } | ||
|
|
||
| registerTest(s"$name - mean only (direct)") { | ||
|
|
@@ -79,7 +91,12 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { | |
| registerTest(s"$name - variance only") { | ||
| val (df, c) = wrapped() | ||
| compare(df.select(metrics("variance").summary(c), variance(c)), | ||
| Seq(Row(exp.variance), exp.variance)) | ||
| Seq(Row(exp.variance), s.variance)) | ||
| } | ||
|
|
||
| registerTest(s"$name - variance only (direct)") { | ||
| val (df, c) = wrapped() | ||
| compare(df.select(variance(c)), Seq(s.variance)) | ||
| } | ||
|
|
||
| registerTest(s"$name - count only") { | ||
|
|
@@ -88,6 +105,12 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { | |
| Seq(Row(exp.count), exp.count)) | ||
| } | ||
|
|
||
| registerTest(s"$name - count only (direct)") { | ||
| val (df, c) = wrapped() | ||
| compare(df.select(count(c)), | ||
| Seq(exp.count)) | ||
| } | ||
|
|
||
| registerTest(s"$name - numNonZeros only") { | ||
| val (df, c) = wrapped() | ||
| compare(df.select(metrics("numNonZeros").summary(c), numNonZeros(c)), | ||
|
|
@@ -147,6 +170,7 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { | |
| val res = row.toSeq | ||
| println(s"compare: ress=${res.map(_.getClass)}") | ||
| println(s"compare: row=${row}") | ||
| println(s"compare: exp=${exp}") | ||
| val names = df.schema.fieldNames.zipWithIndex.map { case (n, idx) => s"$n ($idx)" } | ||
| assert(res.size === exp.size, (res.size, exp.size)) | ||
| for (((x1, x2), name) <- res.zip(exp).zip(names)) { | ||
|
|
@@ -156,8 +180,10 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { | |
|
|
||
| // Compares structured content. | ||
| private def compareStructures(x1: Any, x2: Any, name: String): Unit = (x1, x2) match { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are there bugs in direct comparison of Rows? I'm wondering if we can avoid this here. Or, does Spark SQL already have this implemented? |
||
| case (y1: Seq[Double], v1: OldVector) => | ||
| compareStructures(y1, v1.toArray.toSeq, name) | ||
| case (d1: Double, d2: Double) => | ||
| assert(Vectors.dense(d1) ~== Vectors.dense(d2) absTol 1e-4, name) | ||
| assert2(Vectors.dense(d1) ~== Vectors.dense(d2) absTol 1e-4, name) | ||
| case (r1: GenericRowWithSchema, r2: Row) => | ||
| assert(r1.size === r2.size, (r1, r2)) | ||
| for (((fname, x1), x2) <- r1.schema.fieldNames.zip(r1.toSeq).zip(r2.toSeq)) { | ||
|
|
@@ -167,7 +193,7 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { | |
| assert(r1.size === r2.size, (r1, r2)) | ||
| for ((x1, x2) <- r1.toSeq.zip(r2.toSeq)) { compareStructures(x1, x2, name) } | ||
| case (v1: Vector, v2: Vector) => | ||
| assert(v1 ~== v2 absTol 1e-4, name) | ||
| assert2(v1 ~== v2 absTol 1e-4, name) | ||
| case (l1: Long, l2: Long) => assert(l1 === l2) | ||
| case (s1: Seq[_], s2: Seq[_]) => | ||
| assert(s1.size === s2.size, s"$name ${(s1, s2)}") | ||
|
|
@@ -179,7 +205,24 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { | |
| case _ => throw new Exception(s"$name: ${x1.getClass} ${x2.getClass} $x1 $x2") | ||
| } | ||
|
|
||
| private def assert2(x: => Boolean, hint: String): Unit = { | ||
| try { | ||
| assert(x, hint) | ||
| } catch { | ||
| case tfe: TestFailedException => | ||
| throw new TestFailedException(Some(s"Failure with hint $hint"), Some(tfe), 1) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you need to catch and re-throw it as a TestFailedException? |
||
| } | ||
| } | ||
|
|
||
| private def makeBuffer(vecs: Seq[Vector]): Buffer = { | ||
| val b = Buffer.allMetrics() | ||
| for (v <- vecs) { Buffer.updateInPlace(b, v, 1.0) } | ||
| b | ||
| } | ||
|
|
||
| private def b(x: Array[Double]): Vector = Vectors.dense(x) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rename these to something clearer
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 |
||
|
|
||
| private def l(x: Array[Long]): Vector = b(x.map(_.toDouble)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And this. It is better to rename it. |
||
|
|
||
| test("debugging test") { | ||
| val df = denseData(Nil) | ||
|
|
@@ -227,7 +270,8 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { | |
|
|
||
| testExample("two elements", Seq(Seq(0.0, 1.0, 2.0), Seq(0.0, -1.0, -2.0)), ExpectedMetrics( | ||
| mean = Seq(0.0, 0.0, 0.0), | ||
| variance = Seq(0.0, 1.0, 2.0), | ||
| // TODO: I have a doubt about these values, they are not normalized. | ||
| variance = Seq(0.0, 2.0, 8.0), | ||
| count = 2, | ||
| numNonZeros = Seq(0, 2, 2), | ||
| max = Seq(0.0, 1.0, 2.0), | ||
|
|
@@ -236,6 +280,75 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext { | |
| normL2 = Seq(0.0, math.sqrt(2.0), math.sqrt(2.0) * 2.0) | ||
| )) | ||
|
|
||
| testExample("dense vector input", | ||
| Seq(Seq(-1.0, 0.0, 6.0), Seq(3.0, -3.0, 0.0)), | ||
| ExpectedMetrics( | ||
| mean = Seq(1.0, -1.5, 3.0), | ||
| variance = Seq(8.0, 4.5, 18.0), | ||
| count = 2, | ||
| numNonZeros = Seq(2, 1, 1), | ||
| max = Seq(3.0, 0.0, 6.0), | ||
| min = Seq(-1.0, -3, 0.0), | ||
| normL1 = Seq(4.0, 4.0, 6.0), | ||
| normL2 = Seq(0.0, 0.0, 0.0) | ||
| )) | ||
|
|
||
| test("mixing dense and sparse vector input") { | ||
| val summarizer = makeBuffer(Seq( | ||
| Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), | ||
| Vectors.dense(0.0, -1.0, -3.0), | ||
| Vectors.sparse(3, Seq((1, -5.1))), | ||
| Vectors.dense(3.8, 0.0, 1.9), | ||
| Vectors.dense(1.7, -0.6, 0.0), | ||
| Vectors.sparse(3, Seq((1, 1.9), (2, 0.0))))) | ||
|
|
||
| assert(b(Buffer.mean(summarizer)) ~== | ||
| Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333) absTol 1E-5, "mean mismatch") | ||
|
|
||
| assert(b(Buffer.min(summarizer)) ~== Vectors.dense(-2.0, -5.1, -3) absTol 1E-5, "min " + | ||
| "mismatch") | ||
|
|
||
| assert(b(Buffer.max(summarizer)) ~== Vectors.dense(3.8, 2.3, 1.9) absTol 1E-5, "max mismatch") | ||
|
|
||
| assert(l(Buffer.nnz(summarizer)) ~== Vectors.dense(3, 5, 2) absTol 1E-5, "numNonzeros mismatch") | ||
|
|
||
| assert(b(Buffer.variance(summarizer)) ~== | ||
| Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, | ||
| "variance mismatch") | ||
|
|
||
| assert(Buffer.totalCount(summarizer) === 6) | ||
| } | ||
|
|
||
|
|
||
| test("merging two summarizers") { | ||
| val summarizer1 = makeBuffer(Seq( | ||
| Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), | ||
| Vectors.dense(0.0, -1.0, -3.0))) | ||
|
|
||
| val summarizer2 = makeBuffer(Seq( | ||
| Vectors.sparse(3, Seq((1, -5.1))), | ||
| Vectors.dense(3.8, 0.0, 1.9), | ||
| Vectors.dense(1.7, -0.6, 0.0), | ||
| Vectors.sparse(3, Seq((1, 1.9), (2, 0.0))))) | ||
|
|
||
| val summarizer = Buffer.mergeBuffers(summarizer1, summarizer2) | ||
|
|
||
| assert(b(Buffer.mean(summarizer)) ~== | ||
| Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333) absTol 1E-5, "mean mismatch") | ||
|
|
||
| assert(b(Buffer.min(summarizer)) ~== Vectors.dense(-2.0, -5.1, -3) absTol 1E-5, "min mismatch") | ||
|
|
||
| assert(b(Buffer.max(summarizer)) ~== Vectors.dense(3.8, 2.3, 1.9) absTol 1E-5, "max mismatch") | ||
|
|
||
| assert(l(Buffer.nnz(summarizer)) ~== Vectors.dense(3, 5, 2) absTol 1E-5, "numNonzeros mismatch") | ||
|
|
||
| assert(b(Buffer.variance(summarizer)) ~== | ||
|
|
||
| Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, | ||
| "variance mismatch") | ||
|
|
||
| assert(Buffer.totalCount(summarizer) === 6) | ||
| } | ||
|
|
||
|
|
||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rename s -> summarizer