Skip to content

Commit f7a91a1

Browse files
Jon McLeansrowen
authored andcommitted
[SPARK-20615][ML][TEST] SparseVector.argmax throws IndexOutOfBoundsException
## What changes were proposed in this pull request? Added a check for for the number of defined values. Previously the argmax function assumed that at least one value was defined if the vector size was greater than zero. ## How was this patch tested? Tests were added to the existing VectorsSuite to cover this case. Author: Jon McLean <[email protected]> Closes #17877 from jonmclean/vectorArgmaxIndexBug. (cherry picked from commit be53a78) Signed-off-by: Sean Owen <[email protected]>
1 parent a1112c6 commit f7a91a1

File tree

4 files changed

+18
-0
lines changed

4 files changed

+18
-0
lines changed

mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,8 @@ class SparseVector @Since("2.0.0") (
657657
override def argmax: Int = {
658658
if (size == 0) {
659659
-1
660+
} else if (numActives == 0) {
661+
0
660662
} else {
661663
// Find the max active entry.
662664
var maxIdx = indices(0)

mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,13 @@ class VectorsSuite extends SparkMLFunSuite {
125125

126126
val vec8 = Vectors.sparse(5, Array(1, 2), Array(0.0, -1.0))
127127
assert(vec8.argmax === 0)
128+
129+
// Check for case when sparse vector is non-empty but the values are empty
130+
val vec9 = Vectors.sparse(100, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector]
131+
assert(vec9.argmax === 0)
132+
133+
val vec10 = Vectors.sparse(1, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector]
134+
assert(vec10.argmax === 0)
128135
}
129136

130137
test("vector equals") {

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,8 @@ class SparseVector @Since("1.0.0") (
846846
override def argmax: Int = {
847847
if (size == 0) {
848848
-1
849+
} else if (numActives == 0) {
850+
0
849851
} else {
850852
// Find the max active entry.
851853
var maxIdx = indices(0)

mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,13 @@ class VectorsSuite extends SparkFunSuite with Logging {
122122

123123
val vec8 = Vectors.sparse(5, Array(1, 2), Array(0.0, -1.0))
124124
assert(vec8.argmax === 0)
125+
126+
// Check for case when sparse vector is non-empty but the values are empty
127+
val vec9 = Vectors.sparse(100, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector]
128+
assert(vec9.argmax === 0)
129+
130+
val vec10 = Vectors.sparse(1, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector]
131+
assert(vec10.argmax === 0)
125132
}
126133

127134
test("vector equals") {

0 commit comments

Comments
 (0)