Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f3fa658
work
thunterdb Mar 3, 2017
7539835
work on the test suite
thunterdb Mar 6, 2017
673943f
last work
thunterdb Mar 7, 2017
202b672
work on using imperative aggregators
thunterdb Mar 13, 2017
be01981
Merge remote-tracking branch 'upstream/master' into 19634
thunterdb Mar 17, 2017
a983284
more work on summarizer
thunterdb Mar 18, 2017
647a4fe
work
thunterdb Mar 21, 2017
3c4bef7
changes
thunterdb Mar 21, 2017
56390cc
Merge remote-tracking branch 'upstream/master' into 19634
thunterdb Mar 21, 2017
c3f236c
cleanup
thunterdb Mar 21, 2017
ef955c0
debugging
thunterdb Mar 21, 2017
a04f923
work
thunterdb Mar 21, 2017
946d490
Merge remote-tracking branch 'upstream/master' into 19634
thunterdb Mar 22, 2017
201eb77
debug
thunterdb Mar 22, 2017
f4dec88
trying to debug serialization issue
thunterdb Mar 23, 2017
4af0f47
better tests
thunterdb Mar 23, 2017
9f29030
changes
thunterdb Mar 24, 2017
e9877dc
debugging
thunterdb Mar 24, 2017
3a11d02
more tests and debugging
thunterdb Mar 24, 2017
6d26c17
fixed tests
thunterdb Mar 24, 2017
35eaeb0
doc
thunterdb Mar 24, 2017
58b17dc
cleanups
thunterdb Mar 24, 2017
18078c1
cleanups
thunterdb Mar 24, 2017
ffe5cfe
Cleanups
thunterdb Mar 24, 2017
41f4be6
Cleanups
thunterdb Mar 24, 2017
ba200bb
Cleanups
thunterdb Mar 24, 2017
2f809ef
Merge remote-tracking branch 'upstream/master' into 19634
thunterdb Mar 27, 2017
662f62c
small test to find perf issues
thunterdb Mar 28, 2017
96be071
Current speed:
thunterdb Mar 30, 2017
a569dac
BLAS calls for dense interface
thunterdb Mar 30, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
more tests and debugging
  • Loading branch information
thunterdb committed Mar 24, 2017
commit 3a11d0265ef665a63cd070eeb1ae4ac25bc89908
41 changes: 16 additions & 25 deletions mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.ml.stat

import breeze.{linalg => la}
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
import breeze.linalg.{Vector => BV}
import breeze.numerics

import org.apache.spark.SparkException
Expand Down Expand Up @@ -217,8 +217,8 @@ object SummaryBuilderImpl extends Logging {
case object ComputeM2n extends ComputeMetrics
case object ComputeM2 extends ComputeMetrics
case object ComputeL1 extends ComputeMetrics
case object ComputeCount extends ComputeMetrics // Always computed
case object ComputeTotalWeightSum extends ComputeMetrics // Always computed
case object ComputeCount extends ComputeMetrics // Always computed -> TODO: remove
case object ComputeTotalWeightSum extends ComputeMetrics // Always computed -> TODO: remove
case object ComputeWeightSquareSum extends ComputeMetrics
case object ComputeWeightSum extends ComputeMetrics
case object ComputeNNZ extends ComputeMetrics
Expand Down Expand Up @@ -256,6 +256,7 @@ object SummaryBuilderImpl extends Logging {
s" max=${v(max)} min=${v(min)})"
}
}

object Buffer extends Logging {
// Recursive function, but the number of cases is really small.
def fromMetrics(requested: Seq[ComputeMetrics]): Buffer = {
Expand All @@ -277,6 +278,15 @@ object SummaryBuilderImpl extends Logging {
}
}

/**
* (testing only). Makes a buffer with all the metrics enabled.
*/
def allMetrics(): Buffer = {
fromMetrics(Seq(ComputeMean, ComputeM2n, ComputeM2, ComputeL1, ComputeCount,
ComputeTotalWeightSum, ComputeWeightSquareSum, ComputeWeightSum, ComputeNNZ, ComputeMax,
ComputeMin))
}

val bufferSchema: StructType = {
val fields = Seq(
"n" -> IntegerType,
Expand Down Expand Up @@ -443,6 +453,7 @@ object SummaryBuilderImpl extends Logging {
require(buffer.mean != null)
require(m2n != null)
require(weightSum != null)

val denom = totalWeightSum - (totalWeightSquareSum / totalWeightSum)
if (denom > 0.0) {
val normWs = b(weightSum) :/ totalWeightSum
Expand Down Expand Up @@ -500,8 +511,6 @@ object SummaryBuilderImpl extends Logging {
}
}

// private[this] lazy val projection = UnsafeProjection.create(bufferSchema)

// Returns the array at a given index, or null if the array is null.
private def nullableArrayD(row: UnsafeRow, ordinal: Int): Array[Double] = {
if (row.isNullAt(ordinal)) {
Expand All @@ -524,24 +533,6 @@ object SummaryBuilderImpl extends Logging {

private def bl(x: Array[Long]): BV[Long] = BV.apply(x)

private def maxInPlace(x: Array[Double], y: Array[Double]): Unit = {
var i = 0
while(i < x.length) {
// Note: do not use conditions, it is wrong when handling NaNs.
x(i) = Math.max(x(i), y(i))
i += 1
}
}

private def minInPlace(x: Array[Double], y: Array[Double]): Unit = {
var i = 0
while(i < x.length) {
// Note: do not use conditions, it is wrong when handling NaNs.
x(i) = Math.min(x(i), y(i))
i += 1
}
}

/**
* Sets the content of a buffer based on a single row (initialization).
*
Expand Down Expand Up @@ -586,14 +577,14 @@ object SummaryBuilderImpl extends Logging {
if (buffer.m2 != null) {
buffer.m2 = Array.ofDim(n)
v.foreachActive { (index, value) =>
buffer.weightSum(index) = w * value * value
buffer.m2(index) = w * value * value
}
}

if (buffer.l1 != null) {
buffer.l1 = Array.ofDim(n)
v.foreachActive { (index, value) =>
buffer.weightSum(index) = w * math.abs(value)
buffer.l1(index) = w * math.abs(value)
}
}
}
Expand Down
125 changes: 119 additions & 6 deletions mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@

package org.apache.spark.ml.stat

import org.scalatest.exceptions.TestFailedException

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.stat.SummaryBuilderImpl.Buffer
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
Expand Down Expand Up @@ -60,15 +65,22 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
case x => throw new Exception(x.toString)
}

val s = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename s -> summarizer

val s2 = new MultivariateOnlineSummarizer
inputVec.foreach(v => s2OldVectors.fromML(v)))
s2
}

// Because the Spark context is reset between tests, we cannot hold a reference onto it.
def wrapped() = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename wrapped -> wrappedInit

val df = sc.parallelize(inputVec).map(Tuple1.apply).toDF("features")
val c = df.col("features")
df -> c
(df, c)
}

registerTest(s"$name - mean only") {
val (df, c) = wrapped()
compare(df.select(metrics("mean").summary(c), mean(c)), Seq(Row(exp.mean), exp.mean))
compare(df.select(metrics("mean").summary(c), mean(c)), Seq(Row(exp.mean), s.mean))
}

registerTest(s"$name - mean only (direct)") {
Expand All @@ -79,7 +91,12 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
registerTest(s"$name - variance only") {
val (df, c) = wrapped()
compare(df.select(metrics("variance").summary(c), variance(c)),
Seq(Row(exp.variance), exp.variance))
Seq(Row(exp.variance), s.variance))
}

registerTest(s"$name - variance only (direct)") {
val (df, c) = wrapped()
compare(df.select(variance(c)), Seq(s.variance))
}

registerTest(s"$name - count only") {
Expand All @@ -88,6 +105,12 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
Seq(Row(exp.count), exp.count))
}

registerTest(s"$name - count only (direct)") {
val (df, c) = wrapped()
compare(df.select(count(c)),
Seq(exp.count))
}

registerTest(s"$name - numNonZeros only") {
val (df, c) = wrapped()
compare(df.select(metrics("numNonZeros").summary(c), numNonZeros(c)),
Expand Down Expand Up @@ -147,6 +170,7 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
val res = row.toSeq
println(s"compare: ress=${res.map(_.getClass)}")
println(s"compare: row=${row}")
println(s"compare: exp=${exp}")
val names = df.schema.fieldNames.zipWithIndex.map { case (n, idx) => s"$n ($idx)" }
assert(res.size === exp.size, (res.size, exp.size))
for (((x1, x2), name) <- res.zip(exp).zip(names)) {
Expand All @@ -156,8 +180,10 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {

// Compares structured content.
private def compareStructures(x1: Any, x2: Any, name: String): Unit = (x1, x2) match {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there bugs in direct comparison of Rows? I'm wondering if we can avoid this here. Or, does Spark SQL already have this implemented?

case (y1: Seq[Double], v1: OldVector) =>
compareStructures(y1, v1.toArray.toSeq, name)
case (d1: Double, d2: Double) =>
assert(Vectors.dense(d1) ~== Vectors.dense(d2) absTol 1e-4, name)
assert2(Vectors.dense(d1) ~== Vectors.dense(d2) absTol 1e-4, name)
case (r1: GenericRowWithSchema, r2: Row) =>
assert(r1.size === r2.size, (r1, r2))
for (((fname, x1), x2) <- r1.schema.fieldNames.zip(r1.toSeq).zip(r2.toSeq)) {
Expand All @@ -167,7 +193,7 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(r1.size === r2.size, (r1, r2))
for ((x1, x2) <- r1.toSeq.zip(r2.toSeq)) { compareStructures(x1, x2, name) }
case (v1: Vector, v2: Vector) =>
assert(v1 ~== v2 absTol 1e-4, name)
assert2(v1 ~== v2 absTol 1e-4, name)
case (l1: Long, l2: Long) => assert(l1 === l2)
case (s1: Seq[_], s2: Seq[_]) =>
assert(s1.size === s2.size, s"$name ${(s1, s2)}")
Expand All @@ -179,7 +205,24 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
case _ => throw new Exception(s"$name: ${x1.getClass} ${x2.getClass} $x1 $x2")
}

private def assert2(x: => Boolean, hint: String): Unit = {
try {
assert(x, hint)
} catch {
case tfe: TestFailedException =>
throw new TestFailedException(Some(s"Failure with hint $hint"), Some(tfe), 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to catch and re-throw it as a TestFailedException?

}
}

private def makeBuffer(vecs: Seq[Vector]): Buffer = {
val b = Buffer.allMetrics()
for (v <- vecs) { Buffer.updateInPlace(b, v, 1.0) }
b
}

private def b(x: Array[Double]): Vector = Vectors.dense(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename these to something clearer

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1


private def l(x: Array[Long]): Vector = b(x.map(_.toDouble))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And this. It is better to rename it.


test("debugging test") {
val df = denseData(Nil)
Expand Down Expand Up @@ -227,7 +270,8 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {

testExample("two elements", Seq(Seq(0.0, 1.0, 2.0), Seq(0.0, -1.0, -2.0)), ExpectedMetrics(
mean = Seq(0.0, 0.0, 0.0),
variance = Seq(0.0, 1.0, 2.0),
// TODO: I have a doubt about these values, they are not normalized.
variance = Seq(0.0, 2.0, 8.0),
count = 2,
numNonZeros = Seq(0, 2, 2),
max = Seq(0.0, 1.0, 2.0),
Expand All @@ -236,6 +280,75 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
normL2 = Seq(0.0, math.sqrt(2.0), math.sqrt(2.0) * 2.0)
))

testExample("dense vector input",
Seq(Seq(-1.0, 0.0, 6.0), Seq(3.0, -3.0, 0.0)),
ExpectedMetrics(
mean = Seq(1.0, -1.5, 3.0),
variance = Seq(8.0, 4.5, 18.0),
count = 2,
numNonZeros = Seq(2, 1, 1),
max = Seq(3.0, 0.0, 6.0),
min = Seq(-1.0, -3, 0.0),
normL1 = Seq(4.0, 4.0, 6.0),
normL2 = Seq(0.0, 0.0, 0.0)
))

test("mixing dense and sparse vector input") {
val summarizer = makeBuffer(Seq(
Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
Vectors.dense(0.0, -1.0, -3.0),
Vectors.sparse(3, Seq((1, -5.1))),
Vectors.dense(3.8, 0.0, 1.9),
Vectors.dense(1.7, -0.6, 0.0),
Vectors.sparse(3, Seq((1, 1.9), (2, 0.0)))))

assert(b(Buffer.mean(summarizer)) ~==
Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333) absTol 1E-5, "mean mismatch")

assert(b(Buffer.min(summarizer)) ~== Vectors.dense(-2.0, -5.1, -3) absTol 1E-5, "min " +
"mismatch")

assert(b(Buffer.max(summarizer)) ~== Vectors.dense(3.8, 2.3, 1.9) absTol 1E-5, "max mismatch")

assert(l(Buffer.nnz(summarizer)) ~== Vectors.dense(3, 5, 2) absTol 1E-5, "numNonzeros mismatch")

assert(b(Buffer.variance(summarizer)) ~==
Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5,
"variance mismatch")

assert(Buffer.totalCount(summarizer) === 6)
}


test("merging two summarizers") {
val summarizer1 = makeBuffer(Seq(
Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
Vectors.dense(0.0, -1.0, -3.0)))

val summarizer2 = makeBuffer(Seq(
Vectors.sparse(3, Seq((1, -5.1))),
Vectors.dense(3.8, 0.0, 1.9),
Vectors.dense(1.7, -0.6, 0.0),
Vectors.sparse(3, Seq((1, 1.9), (2, 0.0)))))

val summarizer = Buffer.mergeBuffers(summarizer1, summarizer2)

assert(b(Buffer.mean(summarizer)) ~==
Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333) absTol 1E-5, "mean mismatch")

assert(b(Buffer.min(summarizer)) ~== Vectors.dense(-2.0, -5.1, -3) absTol 1E-5, "min mismatch")

assert(b(Buffer.max(summarizer)) ~== Vectors.dense(3.8, 2.3, 1.9) absTol 1E-5, "max mismatch")

assert(l(Buffer.nnz(summarizer)) ~== Vectors.dense(3, 5, 2) absTol 1E-5, "numNonzeros mismatch")

assert(b(Buffer.variance(summarizer)) ~==

Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5,
"variance mismatch")

assert(Buffer.totalCount(summarizer) === 6)
}


}