Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f3fa658
work
thunterdb Mar 3, 2017
7539835
work on the test suite
thunterdb Mar 6, 2017
673943f
last work
thunterdb Mar 7, 2017
202b672
work on using imperative aggregators
thunterdb Mar 13, 2017
be01981
Merge remote-tracking branch 'upstream/master' into 19634
thunterdb Mar 17, 2017
a983284
more work on summarizer
thunterdb Mar 18, 2017
647a4fe
work
thunterdb Mar 21, 2017
3c4bef7
changes
thunterdb Mar 21, 2017
56390cc
Merge remote-tracking branch 'upstream/master' into 19634
thunterdb Mar 21, 2017
c3f236c
cleanup
thunterdb Mar 21, 2017
ef955c0
debugging
thunterdb Mar 21, 2017
a04f923
work
thunterdb Mar 21, 2017
946d490
Merge remote-tracking branch 'upstream/master' into 19634
thunterdb Mar 22, 2017
201eb77
debug
thunterdb Mar 22, 2017
f4dec88
trying to debug serialization issue
thunterdb Mar 23, 2017
4af0f47
better tests
thunterdb Mar 23, 2017
9f29030
changes
thunterdb Mar 24, 2017
e9877dc
debugging
thunterdb Mar 24, 2017
3a11d02
more tests and debugging
thunterdb Mar 24, 2017
6d26c17
fixed tests
thunterdb Mar 24, 2017
35eaeb0
doc
thunterdb Mar 24, 2017
58b17dc
cleanups
thunterdb Mar 24, 2017
18078c1
cleanups
thunterdb Mar 24, 2017
ffe5cfe
Cleanups
thunterdb Mar 24, 2017
41f4be6
Cleanups
thunterdb Mar 24, 2017
ba200bb
Cleanups
thunterdb Mar 24, 2017
2f809ef
Merge remote-tracking branch 'upstream/master' into 19634
thunterdb Mar 27, 2017
662f62c
small test to find perf issues
thunterdb Mar 28, 2017
96be071
Current speed:
thunterdb Mar 30, 2017
a569dac
BLAS calls for dense interface
thunterdb Mar 30, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixed tests
  • Loading branch information
thunterdb committed Mar 24, 2017
commit 6d26c17d0bd4ab18d564ee7f37916780211702d5
23 changes: 13 additions & 10 deletions mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ object SummaryBuilderImpl extends Logging {
*/
@throws[SparkException]("When the buffers are not compatible")
def mergeBuffers(buffer: Buffer, other: Buffer): Buffer = {
println(s"mergeBuffers: buffer=$buffer other=$buffer")
println(s"mergeBuffers: buffer=$buffer other=$other")
if (buffer.n == -1) {
// buffer is not initialized.
if (other.n == -1) {
Expand Down Expand Up @@ -624,7 +624,7 @@ object SummaryBuilderImpl extends Logging {
// Mandatory scalar values
buffer.totalWeightSquareSum += other.totalWeightSquareSum
buffer.totalWeightSum += other.totalWeightSum
buffer.totalCount += buffer.totalCount
buffer.totalCount += other.totalCount
// Keep the original weight sums.
val weightSum1 = if (buffer.weightSum == null) null else { buffer.weightSum.clone() }
val weightSum2 = if (other.weightSum == null) null else { other.weightSum.clone() }
Expand All @@ -634,29 +634,32 @@ object SummaryBuilderImpl extends Logging {
// This is not going to change the value of the resul since the numerator will also be zero.
val weightSum: BV[Double] = if (weightSum1 == null) null else {
require(weightSum2 != null, s"buffer=$buffer other=$other")
val b1 = b(weightSum1) :+ b(weightSum2)
la.max(b1, Double.MinPositiveValue)
val x = b(weightSum1) :+ b(weightSum2)
la.max(x, Double.MinPositiveValue)
}


// Since the operations are dense, we can directly use BLAS calls here.
val deltaMean: Array[Double] = if (buffer.mean != null) {
val deltaMean: BV[Double] = if (buffer.mean != null) {
require(other.mean != null)
val arr: Array[Double] = Array.ofDim(buffer.n)
b(arr) :+= b(other.mean) :- b(buffer.mean)
arr
b(other.mean) :- b(buffer.mean)
} else { null }

if (buffer.mean != null) {
require(other.mean != null)
require(weightSum != null)
b(buffer.mean) :+= b(deltaMean) :* (b(weightSum2) :/ weightSum)
b(buffer.mean) :+= deltaMean :* (b(weightSum2) :/ weightSum)
}

if (buffer.m2n != null) {
require(other.m2n != null)
val w = (b(weightSum1) :* b(weightSum2)) :/ weightSum
b(buffer.m2n) :+= b(other.m2n) :+ (b(deltaMean) :* b(deltaMean)) :* w
val z = (deltaMean :* deltaMean) :* w
println(s"weightSum1=$weightSum1 weightSum2=$weightSum2 weightSum=$weightSum z=$z")
println(s"mergeInitializedBuffers: buffer.m2n=${b(buffer.m2n)}")
println(s"mergeInitializedBuffers: other.m2n=${b(other.m2n)}")
b(buffer.m2n) :+= b(other.m2n) :+ z
println(s"mergeInitializedBuffers: buffer.m2n_2=${b(buffer.m2n)}")
}

if (buffer.m2 != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.mllib.stat
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.mllib.linalg.{Vector, Vectors}

// scalastyle:off println

/**
* :: DeveloperApi ::
* MultivariateOnlineSummarizer implements [[MultivariateStatisticalSummary]] to compute the mean,
Expand Down Expand Up @@ -57,6 +59,17 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
private var currMax: Array[Double] = _
private var currMin: Array[Double] = _

override def toString: String = {
def v(x: Array[Double]) = if (x==null) "null" else x.toSeq.mkString("[", " ", "]")
def vl(x: Array[Long]) = if (x==null) "null" else x.toSeq.mkString("[", " ", "]")

s"MultivariateOnlineSummarizer(n=$n mean=${v(currMean)} m2n=${v(currM2n)} m2=${v(currM2)} " +
s"l1=${v(currL1)}" +
s" totalCount=$totalCnt totalWeightSum=$totalWeightSum" +
s" totalWeightSquareSum=$weightSquareSum weightSum=${v(weightSum)} nnz=${vl(nnz)}" +
s" max=${v(currMax)} min=${v(currMin)})"
}

/**
* Add a new sample to this summarizer, and update the statistical summary.
*
Expand Down Expand Up @@ -131,6 +144,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*/
@Since("1.1.0")
def merge(other: MultivariateOnlineSummarizer): this.type = {
println(s"MultivariateOnlineSummarizer:merge: this=$this")
println(s"MultivariateOnlineSummarizer:merge: other=$other")
if (this.totalWeightSum != 0.0 && other.totalWeightSum != 0.0) {
require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " +
s"Expecting $n but got ${other.n}.")
Expand All @@ -148,7 +163,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
// merge mean together
currMean(i) += deltaMean * otherNnz / totalNnz
// merge m2n together
val z = deltaMean * deltaMean * thisNnz * otherNnz / totalNnz
println(s"i=$i thisNnz=$thisNnz otherNnz=$otherNnz totalNnz=$totalNnz z=$z")
println(s"i=$i currM2n=${currM2n(i)} other.currM2n=${other.currM2n(i)}")
currM2n(i) += other.currM2n(i) + deltaMean * deltaMean * thisNnz * otherNnz / totalNnz
println(s"i=$i >> currM2n=${currM2n(i)} other.currM2n=${other.currM2n(i)}")
// merge m2 together
currM2(i) += other.currM2(i)
// merge l1 together
Expand All @@ -175,6 +194,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
this.currMax = other.currMax.clone()
this.currMin = other.currMin.clone()
}
println(s"MultivariateOnlineSummarizer:merge(2): this=$this")
this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {

val s = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename s -> summarizer

val s2 = new MultivariateOnlineSummarizer
inputVec.foreach(v => s2OldVectors.fromML(v)))
inputVec.foreach(v => s2.add(OldVectors.fromML(v)))
s2
}

Expand Down Expand Up @@ -289,8 +289,8 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
numNonZeros = Seq(2, 1, 1),
max = Seq(3.0, 0.0, 6.0),
min = Seq(-1.0, -3, 0.0),
normL1 = Seq(4.0, 4.0, 6.0),
normL2 = Seq(0.0, 0.0, 0.0)
normL1 = Seq(4.0, 3.0, 6.0),
normL2 = Seq(math.sqrt(10), 3, 6.0)
))

test("mixing dense and sparse vector input") {
Expand Down Expand Up @@ -324,14 +324,25 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
val summarizer1 = makeBuffer(Seq(
Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
Vectors.dense(0.0, -1.0, -3.0)))
// val s1 = new MultivariateOnlineSummarizer
// Seq(
// Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
// Vectors.dense(0.0, -1.0, -3.0)).foreach(v => s1.add(OldVectors.fromML(v)))

val summarizer2 = makeBuffer(Seq(
Vectors.sparse(3, Seq((1, -5.1))),
Vectors.dense(3.8, 0.0, 1.9),
Vectors.dense(1.7, -0.6, 0.0),
Vectors.sparse(3, Seq((1, 1.9), (2, 0.0)))))
// val s2 = new MultivariateOnlineSummarizer
// Seq(
// Vectors.sparse(3, Seq((1, -5.1))),
// Vectors.dense(3.8, 0.0, 1.9),
// Vectors.dense(1.7, -0.6, 0.0),
// Vectors.sparse(3, Seq((1, 1.9), (2, 0.0)))).foreach(v => s2.add(OldVectors.fromML(v)))

val summarizer = Buffer.mergeBuffers(summarizer1, summarizer2)
// val s = s1.merge(s2)

assert(b(Buffer.mean(summarizer)) ~==
Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333) absTol 1E-5, "mean mismatch")
Expand All @@ -343,12 +354,12 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(l(Buffer.nnz(summarizer)) ~== Vectors.dense(3, 5, 2) absTol 1E-5, "numNonzeros mismatch")

assert(b(Buffer.variance(summarizer)) ~==

Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5,
"variance mismatch")

assert(Buffer.totalCount(summarizer) === 6)
}



}