Skip to content

Commit 1907ae1

Browse files
author
DB Tsai
committed
address feedback
1 parent 98448bb commit 1907ae1

File tree

3 files changed

+60
-130
lines changed

3 files changed

+60
-130
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 34 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -78,20 +78,15 @@ sealed trait Vector extends Serializable {
7878
}
7979

8080
/**
81-
* It will return the iterator for the active elements of dense and sparse vector as
82-
* (index, value) pair. Note that foreach method can be overridden for better performance
83-
* in different vector implementation.
81+
* Applies a function `f` to all the active elements of dense and sparse vector.
8482
*
85-
* @param skippingZeros Skipping zero elements explicitly if true. It will be useful when we
86-
* iterator through dense vector having lots of zero elements which
87-
* we want to skip. Default is false.
88-
* @return Iterator[(Int, Double)] where the first element in the tuple is the index,
89-
* and the second element is the corresponding value.
83+
* @param f the function takes (Int, Double) as input where the first element
84+
* in the tuple is the index, and the second element is the corresponding value.
85+
* @param skippingZeros if true, skipping zero elements explicitly. It will be useful when
86+
* iterating through dense vector which has lots of zero elements to be
87+
* skipped. Default is false.
9088
*/
91-
private[spark] def activeIterator(skippingZeros: Boolean): Iterator[(Int, Double)]
92-
93-
private[spark] def activeIterator: Iterator[(Int, Double)] = activeIterator(false)
94-
89+
private[spark] def foreach(skippingZeros: Boolean = false)(f: ((Int, Double)) => Unit)
9590
}
9691

9792
/**
@@ -290,48 +285,25 @@ class DenseVector(val values: Array[Double]) extends Vector {
290285
new DenseVector(values.clone())
291286
}
292287

293-
private[spark] override def activeIterator(skippingZeros: Boolean) = new Iterator[(Int, Double)] {
294-
private var i = 0
295-
296-
// If zeros are asked to be explicitly skipped, the parent `size` method is called to count
297-
// the number of nonzero elements using `hasNext` and `next` methods.
298-
final override lazy val size: Int = if (skippingZeros) super.size else values.size
299-
300-
final override def hasNext = {
301-
if (skippingZeros) {
302-
var found = false
303-
while (!found && i < values.size) if (values(i) != 0.0) found = true else i += 1
304-
}
305-
i < values.size
306-
}
307-
308-
final override def next = {
309-
val result = (i, values(i))
310-
i += 1
311-
result
312-
}
288+
private[spark] override def foreach(skippingZeros: Boolean = false)(f: ((Int, Double)) => Unit) {
289+
var i = 0
290+
val localValuesSize = values.size
291+
val localValues = values
313292

314-
final override def foreach[@specialized(Unit) U](f: ((Int, Double)) => U) {
315-
var i = 0
316-
val localValuesSize = values.size
317-
val localValues = values
318-
319-
if (skippingZeros) {
320-
while (i < localValuesSize) {
321-
if (localValues(i) != 0.0) {
322-
f(i, localValues(i))
323-
}
324-
i += 1
325-
}
326-
} else {
327-
while (i < localValuesSize) {
293+
if (skippingZeros) {
294+
while (i < localValuesSize) {
295+
if (localValues(i) != 0.0) {
328296
f(i, localValues(i))
329-
i += 1
330297
}
298+
i += 1
299+
}
300+
} else {
301+
while (i < localValuesSize) {
302+
f(i, localValues(i))
303+
i += 1
331304
}
332305
}
333306
}
334-
335307
}
336308

337309
/**
@@ -369,47 +341,24 @@ class SparseVector(
369341

370342
private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size)
371343

372-
private[spark] override def activeIterator(skippingZeros: Boolean) = new Iterator[(Int, Double)] {
373-
private var i = 0
374-
375-
// If zeros are asked to be explicitly skipped, the parent `size` method is called to count
376-
// the number of nonzero elements using `hasNext` and `next` methods.
377-
final override lazy val size: Int = if (skippingZeros) super.size else values.size
378-
379-
final override def hasNext = {
380-
if (skippingZeros) {
381-
var found = false
382-
while (!found && i < values.size) if (values(i) != 0.0) found = true else i += 1
383-
}
384-
i < values.size
385-
}
386-
387-
final override def next = {
388-
val result = (indices(i), values(i))
389-
i += 1
390-
result
391-
}
344+
private[spark] override def foreach(skippingZeros: Boolean = false)(f: ((Int, Double)) => Unit) {
345+
var i = 0
346+
val localValuesSize = values.size
347+
val localIndices = indices
348+
val localValues = values
392349

393-
final override def foreach[@specialized(Unit) U](f: ((Int, Double)) => U) {
394-
var i = 0
395-
val localValuesSize = values.size
396-
val localIndices = indices
397-
val localValues = values
398-
399-
if (skippingZeros) {
400-
while (i < localValuesSize) {
401-
if (localValues(i) != 0.0) {
402-
f(localIndices(i), localValues(i))
403-
}
404-
i += 1
405-
}
406-
} else {
407-
while (i < localValuesSize) {
350+
if (skippingZeros) {
351+
while (i < localValuesSize) {
352+
if (localValues(i) != 0.0) {
408353
f(localIndices(i), localValues(i))
409-
i += 1
410354
}
355+
i += 1
356+
}
357+
} else {
358+
while (i < localValuesSize) {
359+
f(localIndices(i), localValues(i))
360+
i += 1
411361
}
412362
}
413363
}
414-
415364
}

mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
9393
require(n == sample.size, s"Dimensions mismatch when adding new sample." +
9494
s" Expecting $n but got ${sample.size}.")
9595

96-
sample.activeIterator(true).foreach {
97-
case (index, value) => add(index, value)
98-
}
96+
sample.foreach(true)(x => add(x._1, x._2))
9997

10098
totalCnt += 1
10199
this

mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala

Lines changed: 25 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -174,47 +174,22 @@ class VectorsSuite extends FunSuite {
174174
assert(v.size === x.rows)
175175
}
176176

177-
test("activeIterator") {
177+
test("foreach") {
178178
val dv = Vectors.dense(0.0, 1.2, 3.1, 0.0)
179179
val sv = Vectors.sparse(4, Seq((1, 1.2), (2, 3.1), (3, 0.0)))
180180

181-
// Testing if the size of iterator is correct when the zeros are explicitly skipped.
182-
// The default setting will not skip any zero explicitly.
183-
assert(dv.activeIterator.size === 4)
184-
assert(dv.activeIterator(false).size === 4)
185-
assert(dv.activeIterator(true).size === 2)
186-
187-
assert(sv.activeIterator.size === 3)
188-
assert(sv.activeIterator(false).size === 3)
189-
assert(sv.activeIterator(true).size === 2)
190-
191-
// Testing `hasNext` and `next`
192-
val dvIter1 = dv.activeIterator(false)
193-
assert(dvIter1.hasNext === true && dvIter1.next === (0, 0.0))
194-
assert(dvIter1.hasNext === true && dvIter1.next === (1, 1.2))
195-
assert(dvIter1.hasNext === true && dvIter1.next === (2, 3.1))
196-
assert(dvIter1.hasNext === true && dvIter1.next === (3, 0.0))
197-
assert(dvIter1.hasNext === false)
198-
199-
val dvIter2 = dv.activeIterator(true)
200-
assert(dvIter2.hasNext === true && dvIter2.next === (1, 1.2))
201-
assert(dvIter2.hasNext === true && dvIter2.next === (2, 3.1))
202-
assert(dvIter2.hasNext === false)
203-
204-
val svIter1 = sv.activeIterator(false)
205-
assert(svIter1.hasNext === true && svIter1.next === (1, 1.2))
206-
assert(svIter1.hasNext === true && svIter1.next === (2, 3.1))
207-
assert(svIter1.hasNext === true && svIter1.next === (3, 0.0))
208-
assert(svIter1.hasNext === false)
209-
210-
val svIter2 = sv.activeIterator(true)
211-
assert(svIter2.hasNext === true && svIter2.next === (1, 1.2))
212-
assert(svIter2.hasNext === true && svIter2.next === (2, 3.1))
213-
assert(svIter2.hasNext === false)
214-
215-
// Testing `foreach`
181+
val dvMap0 = scala.collection.mutable.Map[Int, Double]()
182+
dv.foreach() {
183+
case (index: Int, value: Double) => dvMap0.put(index, value)
184+
}
185+
assert(dvMap0.size === 4)
186+
assert(dvMap0.get(0) === Some(0.0))
187+
assert(dvMap0.get(1) === Some(1.2))
188+
assert(dvMap0.get(2) === Some(3.1))
189+
assert(dvMap0.get(3) === Some(0.0))
190+
216191
val dvMap1 = scala.collection.mutable.Map[Int, Double]()
217-
dvIter1.foreach{
192+
dv.foreach(false) {
218193
case (index, value) => dvMap1.put(index, value)
219194
}
220195
assert(dvMap1.size === 4)
@@ -223,16 +198,25 @@ class VectorsSuite extends FunSuite {
223198
assert(dvMap1.get(2) === Some(3.1))
224199
assert(dvMap1.get(3) === Some(0.0))
225200

226-
val dvMap2 = scala.collection.mutable.Map[Int, Double]()
227-
dvIter2.foreach{
201+
val dvMap2 = scala.collection .mutable.Map[Int, Double]()
202+
dv.foreach(true) {
228203
case (index, value) => dvMap2.put(index, value)
229204
}
230205
assert(dvMap2.size === 2)
231206
assert(dvMap2.get(1) === Some(1.2))
232207
assert(dvMap2.get(2) === Some(3.1))
233208

209+
val svMap0 = scala.collection.mutable.Map[Int, Double]()
210+
sv.foreach() {
211+
case (index, value) => svMap0.put(index, value)
212+
}
213+
assert(svMap0.size === 3)
214+
assert(svMap0.get(1) === Some(1.2))
215+
assert(svMap0.get(2) === Some(3.1))
216+
assert(svMap0.get(3) === Some(0.0))
217+
234218
val svMap1 = scala.collection.mutable.Map[Int, Double]()
235-
svIter1.foreach{
219+
sv.foreach(false) {
236220
case (index, value) => svMap1.put(index, value)
237221
}
238222
assert(svMap1.size === 3)
@@ -241,12 +225,11 @@ class VectorsSuite extends FunSuite {
241225
assert(svMap1.get(3) === Some(0.0))
242226

243227
val svMap2 = scala.collection.mutable.Map[Int, Double]()
244-
svIter2.foreach{
228+
sv.foreach(true) {
245229
case (index, value) => svMap2.put(index, value)
246230
}
247231
assert(svMap2.size === 2)
248232
assert(svMap2.get(1) === Some(1.2))
249233
assert(svMap2.get(2) === Some(3.1))
250-
251234
}
252235
}

0 commit comments

Comments
 (0)