Skip to content
Closed
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f3fa658
work
thunterdb Mar 3, 2017
7539835
work on the test suite
thunterdb Mar 6, 2017
673943f
last work
thunterdb Mar 7, 2017
202b672
work on using imperative aggregators
thunterdb Mar 13, 2017
be01981
Merge remote-tracking branch 'upstream/master' into 19634
thunterdb Mar 17, 2017
a983284
more work on summarizer
thunterdb Mar 18, 2017
647a4fe
work
thunterdb Mar 21, 2017
3c4bef7
changes
thunterdb Mar 21, 2017
56390cc
Merge remote-tracking branch 'upstream/master' into 19634
thunterdb Mar 21, 2017
c3f236c
cleanup
thunterdb Mar 21, 2017
ef955c0
debugging
thunterdb Mar 21, 2017
a04f923
work
thunterdb Mar 21, 2017
946d490
Merge remote-tracking branch 'upstream/master' into 19634
thunterdb Mar 22, 2017
201eb77
debug
thunterdb Mar 22, 2017
f4dec88
trying to debug serialization issue
thunterdb Mar 23, 2017
4af0f47
better tests
thunterdb Mar 23, 2017
9f29030
changes
thunterdb Mar 24, 2017
e9877dc
debugging
thunterdb Mar 24, 2017
3a11d02
more tests and debugging
thunterdb Mar 24, 2017
6d26c17
fixed tests
thunterdb Mar 24, 2017
35eaeb0
doc
thunterdb Mar 24, 2017
58b17dc
cleanups
thunterdb Mar 24, 2017
18078c1
cleanups
thunterdb Mar 24, 2017
ffe5cfe
Cleanups
thunterdb Mar 24, 2017
41f4be6
Cleanups
thunterdb Mar 24, 2017
ba200bb
Cleanups
thunterdb Mar 24, 2017
2f809ef
Merge remote-tracking branch 'upstream/master' into 19634
thunterdb Mar 27, 2017
662f62c
small test to find perf issues
thunterdb Mar 28, 2017
96be071
Current speed:
thunterdb Mar 30, 2017
a569dac
BLAS calls for dense interface
thunterdb Mar 30, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
last work
  • Loading branch information
thunterdb committed Mar 7, 2017
commit 673943f334b94e5d1ecd8874cb82bbc875d739c6
51 changes: 33 additions & 18 deletions mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -331,35 +331,38 @@ object SummaryBuilderImpl extends Logging {
// instead of failing silently?
// Using exceptions is not great for performance in this loop.
val thisNnz = if (weightSum == null) { 0.0 } else { weightSum(i) }
val otherNnz = if (otherWeightSum == null) { 0.0 } else { otherWeightSum(i) }
val otherNnz = if (otherWeightSum == null) { 0.0 } else { otherWeightSum.getDouble(i) }
val totalNnz = thisNnz + otherNnz
val totalCnnz = if (numNonzeros == null || numNonzeros == null) { 0.0 } else {
numNonzeros(i) + otherNumNonzeros(i)
numNonzeros(i) + otherNumNonzeros.getDouble(i)
}
if (totalNnz != 0.0) {
val deltaMean = if (computeMean >= 0) { otherMean(i) - currMean(i) } else { 0.0 }
val deltaMean = if (computeMean >= 0) {
otherMean.getDouble(i) - currMean(i)
} else { 0.0 }
// merge mean together
if (computeMean >= 0) {
currMean(i) += deltaMean * otherNnz / totalNnz
}
// merge m2n together
if (computeM2n >= 0) {
currM2n(i) += otherM2n(i) + deltaMean * deltaMean * thisNnz * otherNnz / totalNnz
val incr = deltaMean * deltaMean * thisNnz * otherNnz / totalNnz
currM2n(i) += otherM2n.getDouble(i) + incr
}
// merge m2 together
if (computeM2 >= 0) {
currM2(i) += otherM2(i)
currM2(i) += otherM2.getDouble(i)
}
// merge l1 together
if (computeL1 >= 0) {
currL1(i) += otherL1(i)
currL1(i) += otherL1.getDouble(i)
}
// merge max and min
if (computeCurrMax >= 0) {
currMax(i) = math.max(currMax(i), otherMax(i))
currMax(i) = math.max(currMax(i), otherMax.getDouble(i))
}
if (computeCurrMin >= 0) {
currMin(i) = math.min(currMin(i), otherMin(i))
currMin(i) = math.min(currMin(i), otherMin.getDouble(i))
}
}
if (computeWeightSum >= 0) {
Expand Down Expand Up @@ -409,7 +412,7 @@ object SummaryBuilderImpl extends Logging {
require(totalWeightSum > 0, "Data has zero weight")
var i = 0
while (i < n) {
realMean(i) = currMean(i) * (weightSum(i) / totalWeightSum)
realMean(i) = currMean.getDouble(i) * (weightSum.getDouble(i) / totalWeightSum)
i += 1
}
Vectors.dense(realMean)
Expand Down Expand Up @@ -442,8 +445,10 @@ object SummaryBuilderImpl extends Logging {
var i = 0
val len = currM2n.length
while (i < len) {
realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
(totalWeightSum - weightSum(i)) / totalWeightSum) / denominator
val m = deltaMean.getDouble(i)
val num = currM2n.getDouble(i) + m * m * weightSum.getDouble(i)
realVariance(i) = (num *
(totalWeightSum - weightSum.getDouble(i)) / totalWeightSum) / denominator
i += 1
}
}
Expand Down Expand Up @@ -499,7 +504,7 @@ object SummaryBuilderImpl extends Logging {
}

// All these may be null pointers.
val localCurrMean = exposeDArray(buffer, computeMean)
val localCurrMean = exposeDArray3(buffer, computeMean)
val localCurrM2n = exposeDArray(buffer, computeM2n)
val localCurrM2 = exposeDArray(buffer, computeM2)
val localCurrL1 = exposeDArray(buffer, computeL1)
Expand All @@ -520,12 +525,13 @@ object SummaryBuilderImpl extends Logging {
}

if (localCurrMean != null) {
val prevMean = localCurrMean(index)
val prevMean = localCurrMean.getDouble(index)
val diff = value - prevMean
localCurrMean(index) = prevMean + weight * diff / (localWeightSum(index) + weight)
localCurrMean.update(index, prevMean + weight * diff / (localWeightSum(index) + weight))
// localCurrMean(index) = prevMean + weight * diff / (localWeightSum(index) + weight)
if (localCurrM2n != null) {
// require: localCurrMean != null.
localCurrM2n(index) += weight * (value - localCurrMean(index)) * diff
localCurrM2n(index) += weight * (value - localCurrMean.getDouble(index)) * diff
}
}

Expand All @@ -550,23 +556,32 @@ object SummaryBuilderImpl extends Logging {
buffer.getInt(computeN)
}


private[this] def exposeDArray2(buffer: Row, index: Int): Array[Double] = {
private[this] def exposeDArray2(buffer: Row, index: Int): Row = {
if (index == -1) {
null
} else {
buffer.getAs[Array[Double]](index)
buffer.getStruct(index)
}
}

private[this] def exposeDArray(buffer: MutableAggregationBuffer, index: Int): Array[Double] = {
if (index == -1) {
null
} else {
buffer.getAs[MutableAggregationBuffer](index)
buffer.getAs[mutable.WrappedArray[Double]](index).array
}
}

private[this] def exposeDArray3(
buffer: MutableAggregationBuffer, index: Int): MutableAggregationBuffer = {
if (index == -1) {
null
} else {
buffer.getAs[MutableAggregationBuffer](index)
}
}

// Performs the initialization of the requested fields, if required.
private def initialize(buffer: MutableAggregationBuffer, n: Int): Unit = {
require(n > 0, s"Cannot set n to $n: must be positive")
Expand Down