Skip to content

Commit 03dd693

Browse files
author
DB Tsai
committed
futher performance tunning.
1 parent 1907ae1 commit 03dd693

File tree

3 files changed

+78
-143
lines changed

3 files changed

+78
-143
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 12 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,11 @@ sealed trait Vector extends Serializable {
8080
/**
8181
* Applies a function `f` to all the active elements of dense and sparse vector.
8282
*
83-
* @param f the function takes (Int, Double) as input where the first element
84-
* in the tuple is the index, and the second element is the corresponding value.
85-
* @param skippingZeros if true, skipping zero elements explicitly. It will be useful when
86-
* iterating through dense vector which has lots of zero elements to be
87-
* skipped. Default is false.
83+
* @param f the function takes two parameters where the first parameter is the index of
84+
* the vector with type `Int`, and the second parameter is the corresponding value
85+
* with type `Double`.
8886
*/
89-
private[spark] def foreach(skippingZeros: Boolean = false)(f: ((Int, Double)) => Unit)
87+
private[spark] def foreachActive(f: (Int, Double) => Unit)
9088
}
9189

9290
/**
@@ -285,23 +283,14 @@ class DenseVector(val values: Array[Double]) extends Vector {
285283
new DenseVector(values.clone())
286284
}
287285

288-
private[spark] override def foreach(skippingZeros: Boolean = false)(f: ((Int, Double)) => Unit) {
286+
private[spark] override def foreachActive(f: (Int, Double) => Unit) = {
289287
var i = 0
290288
val localValuesSize = values.size
291289
val localValues = values
292290

293-
if (skippingZeros) {
294-
while (i < localValuesSize) {
295-
if (localValues(i) != 0.0) {
296-
f(i, localValues(i))
297-
}
298-
i += 1
299-
}
300-
} else {
301-
while (i < localValuesSize) {
302-
f(i, localValues(i))
303-
i += 1
304-
}
291+
while (i < localValuesSize) {
292+
f(i, localValues(i))
293+
i += 1
305294
}
306295
}
307296
}
@@ -341,24 +330,15 @@ class SparseVector(
341330

342331
private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size)
343332

344-
private[spark] override def foreach(skippingZeros: Boolean = false)(f: ((Int, Double)) => Unit) {
333+
private[spark] override def foreachActive(f: (Int, Double) => Unit) = {
345334
var i = 0
346335
val localValuesSize = values.size
347336
val localIndices = indices
348337
val localValues = values
349338

350-
if (skippingZeros) {
351-
while (i < localValuesSize) {
352-
if (localValues(i) != 0.0) {
353-
f(localIndices(i), localValues(i))
354-
}
355-
i += 1
356-
}
357-
} else {
358-
while (i < localValuesSize) {
359-
f(localIndices(i), localValues(i))
360-
i += 1
361-
}
339+
while (i < localValuesSize) {
340+
f(localIndices(i), localValues(i))
341+
i += 1
362342
}
363343
}
364344
}

mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala

Lines changed: 52 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,8 @@
1717

1818
package org.apache.spark.mllib.stat
1919

20-
import breeze.linalg.{DenseVector => BDV}
21-
2220
import org.apache.spark.annotation.DeveloperApi
23-
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector}
21+
import org.apache.spark.mllib.linalg.{Vectors, Vector}
2422

2523
/**
2624
* :: DeveloperApi ::
@@ -40,35 +38,14 @@ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector
4038
class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable {
4139

4240
private var n = 0
43-
private var currMean: BDV[Double] = _
44-
private var currM2n: BDV[Double] = _
45-
private var currM2: BDV[Double] = _
46-
private var currL1: BDV[Double] = _
41+
private var currMean: Array[Double] = _
42+
private var currM2n: Array[Double] = _
43+
private var currM2: Array[Double] = _
44+
private var currL1: Array[Double] = _
4745
private var totalCnt: Long = 0
48-
private var nnz: BDV[Double] = _
49-
private var currMax: BDV[Double] = _
50-
private var currMin: BDV[Double] = _
51-
52-
/**
53-
* Adds input value to position i.
54-
*/
55-
private[this] def add(i: Int, value: Double) = {
56-
if (currMax(i) < value) {
57-
currMax(i) = value
58-
}
59-
if (currMin(i) > value) {
60-
currMin(i) = value
61-
}
62-
63-
val prevMean = currMean(i)
64-
val diff = value - prevMean
65-
currMean(i) = prevMean + diff / (nnz(i) + 1.0)
66-
currM2n(i) += (value - currMean(i)) * diff
67-
currM2(i) += value * value
68-
currL1(i) += math.abs(value)
69-
70-
nnz(i) += 1.0
71-
}
46+
private var nnz: Array[Double] = _
47+
private var currMax: Array[Double] = _
48+
private var currMin: Array[Double] = _
7249

7350
/**
7451
* Add a new sample to this summarizer, and update the statistical summary.
@@ -81,19 +58,37 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
8158
require(sample.size > 0, s"Vector should have dimension larger than zero.")
8259
n = sample.size
8360

84-
currMean = BDV.zeros[Double](n)
85-
currM2n = BDV.zeros[Double](n)
86-
currM2 = BDV.zeros[Double](n)
87-
currL1 = BDV.zeros[Double](n)
88-
nnz = BDV.zeros[Double](n)
89-
currMax = BDV.fill(n)(Double.MinValue)
90-
currMin = BDV.fill(n)(Double.MaxValue)
61+
currMean = Array.ofDim[Double](n)
62+
currM2n = Array.ofDim[Double](n)
63+
currM2 = Array.ofDim[Double](n)
64+
currL1 = Array.ofDim[Double](n)
65+
nnz = Array.ofDim[Double](n)
66+
currMax = Array.fill[Double](n)(Double.MinValue)
67+
currMin = Array.fill[Double](n)(Double.MaxValue)
9168
}
9269

9370
require(n == sample.size, s"Dimensions mismatch when adding new sample." +
9471
s" Expecting $n but got ${sample.size}.")
9572

96-
sample.foreach(true)(x => add(x._1, x._2))
73+
sample.foreachActive((index, value) => {
74+
if(value != 0.0){
75+
if (currMax(index) < value) {
76+
currMax(index) = value
77+
}
78+
if (currMin(index) > value) {
79+
currMin(index) = value
80+
}
81+
82+
val prevMean = currMean(index)
83+
val diff = value - prevMean
84+
currMean(index) = prevMean + diff / (nnz(index) + 1.0)
85+
currM2n(index) += (value - currMean(index)) * diff
86+
currM2(index) += value * value
87+
currL1(index) += math.abs(value)
88+
89+
nnz(index) += 1.0
90+
}
91+
})
9792

9893
totalCnt += 1
9994
this
@@ -135,34 +130,34 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
135130
}
136131
} else if (totalCnt == 0 && other.totalCnt != 0) {
137132
this.n = other.n
138-
this.currMean = other.currMean.copy
139-
this.currM2n = other.currM2n.copy
140-
this.currM2 = other.currM2.copy
141-
this.currL1 = other.currL1.copy
133+
this.currMean = other.currMean.clone
134+
this.currM2n = other.currM2n.clone
135+
this.currM2 = other.currM2.clone
136+
this.currL1 = other.currL1.clone
142137
this.totalCnt = other.totalCnt
143-
this.nnz = other.nnz.copy
144-
this.currMax = other.currMax.copy
145-
this.currMin = other.currMin.copy
138+
this.nnz = other.nnz.clone
139+
this.currMax = other.currMax.clone
140+
this.currMin = other.currMin.clone
146141
}
147142
this
148143
}
149144

150145
override def mean: Vector = {
151146
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
152147

153-
val realMean = BDV.zeros[Double](n)
148+
val realMean = Array.ofDim[Double](n)
154149
var i = 0
155150
while (i < n) {
156151
realMean(i) = currMean(i) * (nnz(i) / totalCnt)
157152
i += 1
158153
}
159-
Vectors.fromBreeze(realMean)
154+
Vectors.dense(realMean)
160155
}
161156

162157
override def variance: Vector = {
163158
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
164159

165-
val realVariance = BDV.zeros[Double](n)
160+
val realVariance = Array.ofDim[Double](n)
166161

167162
val denominator = totalCnt - 1.0
168163

@@ -177,16 +172,15 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
177172
i += 1
178173
}
179174
}
180-
181-
Vectors.fromBreeze(realVariance)
175+
Vectors.dense(realVariance)
182176
}
183177

184178
override def count: Long = totalCnt
185179

186180
override def numNonzeros: Vector = {
187181
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
188182

189-
Vectors.fromBreeze(nnz)
183+
Vectors.dense(nnz)
190184
}
191185

192186
override def max: Vector = {
@@ -197,7 +191,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
197191
if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
198192
i += 1
199193
}
200-
Vectors.fromBreeze(currMax)
194+
Vectors.dense(currMax)
201195
}
202196

203197
override def min: Vector = {
@@ -208,25 +202,25 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
208202
if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
209203
i += 1
210204
}
211-
Vectors.fromBreeze(currMin)
205+
Vectors.dense(currMin)
212206
}
213207

214208
override def normL2: Vector = {
215209
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
216210

217-
val realMagnitude = BDV.zeros[Double](n)
211+
val realMagnitude = Array.ofDim[Double](n)
218212

219213
var i = 0
220214
while (i < currM2.size) {
221215
realMagnitude(i) = math.sqrt(currM2(i))
222216
i += 1
223217
}
224-
225-
Vectors.fromBreeze(realMagnitude)
218+
Vectors.dense(realMagnitude)
226219
}
227220

228221
override def normL1: Vector = {
229222
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
230-
Vectors.fromBreeze(currL1)
223+
224+
Vectors.dense(currL1)
231225
}
232226
}

mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala

Lines changed: 14 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -178,58 +178,19 @@ class VectorsSuite extends FunSuite {
178178
val dv = Vectors.dense(0.0, 1.2, 3.1, 0.0)
179179
val sv = Vectors.sparse(4, Seq((1, 1.2), (2, 3.1), (3, 0.0)))
180180

181-
val dvMap0 = scala.collection.mutable.Map[Int, Double]()
182-
dv.foreach() {
183-
case (index: Int, value: Double) => dvMap0.put(index, value)
184-
}
185-
assert(dvMap0.size === 4)
186-
assert(dvMap0.get(0) === Some(0.0))
187-
assert(dvMap0.get(1) === Some(1.2))
188-
assert(dvMap0.get(2) === Some(3.1))
189-
assert(dvMap0.get(3) === Some(0.0))
190-
191-
val dvMap1 = scala.collection.mutable.Map[Int, Double]()
192-
dv.foreach(false) {
193-
case (index, value) => dvMap1.put(index, value)
194-
}
195-
assert(dvMap1.size === 4)
196-
assert(dvMap1.get(0) === Some(0.0))
197-
assert(dvMap1.get(1) === Some(1.2))
198-
assert(dvMap1.get(2) === Some(3.1))
199-
assert(dvMap1.get(3) === Some(0.0))
200-
201-
val dvMap2 = scala.collection .mutable.Map[Int, Double]()
202-
dv.foreach(true) {
203-
case (index, value) => dvMap2.put(index, value)
204-
}
205-
assert(dvMap2.size === 2)
206-
assert(dvMap2.get(1) === Some(1.2))
207-
assert(dvMap2.get(2) === Some(3.1))
208-
209-
val svMap0 = scala.collection.mutable.Map[Int, Double]()
210-
sv.foreach() {
211-
case (index, value) => svMap0.put(index, value)
212-
}
213-
assert(svMap0.size === 3)
214-
assert(svMap0.get(1) === Some(1.2))
215-
assert(svMap0.get(2) === Some(3.1))
216-
assert(svMap0.get(3) === Some(0.0))
217-
218-
val svMap1 = scala.collection.mutable.Map[Int, Double]()
219-
sv.foreach(false) {
220-
case (index, value) => svMap1.put(index, value)
221-
}
222-
assert(svMap1.size === 3)
223-
assert(svMap1.get(1) === Some(1.2))
224-
assert(svMap1.get(2) === Some(3.1))
225-
assert(svMap1.get(3) === Some(0.0))
226-
227-
val svMap2 = scala.collection.mutable.Map[Int, Double]()
228-
sv.foreach(true) {
229-
case (index, value) => svMap2.put(index, value)
230-
}
231-
assert(svMap2.size === 2)
232-
assert(svMap2.get(1) === Some(1.2))
233-
assert(svMap2.get(2) === Some(3.1))
181+
val dvMap = scala.collection.mutable.Map[Int, Double]()
182+
dv.foreachActive((index, value) => dvMap.put(index, value))
183+
assert(dvMap.size === 4)
184+
assert(dvMap.get(0) === Some(0.0))
185+
assert(dvMap.get(1) === Some(1.2))
186+
assert(dvMap.get(2) === Some(3.1))
187+
assert(dvMap.get(3) === Some(0.0))
188+
189+
val svMap = scala.collection.mutable.Map[Int, Double]()
190+
sv.foreachActive((index, value) => svMap.put(index, value))
191+
assert(svMap.size === 3)
192+
assert(svMap.get(1) === Some(1.2))
193+
assert(svMap.get(2) === Some(3.1))
194+
assert(svMap.get(3) === Some(0.0))
234195
}
235196
}

0 commit comments

Comments
 (0)