Skip to content

Commit e9aa0ea

Browse files
committed
moved findMax to DenseVector and renamed to argmax. fixed bug for vector of length 0
1 parent 15b9957 commit e9aa0ea

File tree

3 files changed

+24
-22
lines changed

3 files changed

+24
-22
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,5 +150,5 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
150150
* This may be overridden to support thresholds which favor particular labels.
151151
* @return predicted label
152152
*/
153-
protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.findMax
153+
protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.toDense.argmax
154154
}

mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,5 +165,5 @@ private[spark] abstract class ProbabilisticClassificationModel[
165165
* This may be overridden to support thresholds which favor particular labels.
166166
* @return predicted label
167167
*/
168-
protected def probability2prediction(probability: Vector): Double = probability.findMax
168+
protected def probability2prediction(probability: Vector): Double = probability.toDense.argmax
169169
}

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -150,26 +150,6 @@ sealed trait Vector extends Serializable {
150150
toDense
151151
}
152152
}
153-
154-
/**
155-
* Find the index of a maximal element. Returns the first maximal element in case of a tie.
156-
* Returns -1 if vector has length 0.
157-
*/
158-
private[spark] def findMax: Int = {
159-
if (size == 0) {
160-
0
161-
} else {
162-
var maxIdx = 0
163-
var maxValue = apply(0)
164-
foreachActive { (idx, value) =>
165-
if (value > maxValue) {
166-
maxIdx = idx
167-
maxValue = value
168-
}
169-
}
170-
maxIdx
171-
}
172-
}
173153
}
174154

175155
/**
@@ -607,6 +587,28 @@ class DenseVector(val values: Array[Double]) extends Vector {
607587
}
608588
new SparseVector(size, ii, vv)
609589
}
590+
591+
/**
592+
* Find the index of a maximal element. Returns the first maximal element in case of a tie.
593+
* Returns -1 if vector has length 0.
594+
*/
595+
private[spark] def argmax: Int = {
596+
if (size == 0) {
597+
-1
598+
} else {
599+
var maxIdx = 0
600+
var maxValue = values(0)
601+
var i = 1
602+
while (i < size) {
603+
if (values(i) > maxValue) {
604+
maxIdx = i
605+
maxValue = values(i)
606+
}
607+
i += 1
608+
}
609+
maxIdx
610+
}
611+
}
610612
}
611613

612614
object DenseVector {

0 commit comments

Comments
 (0)