Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,34 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec
def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] =
predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]]

/**
* Returns the distances to all clusters for a given point.
*/
@Since("1.5.0")
def distanceToCenters(point: Vector): Iterable[(Int, Double)] = {
val pointWithNorm = new VectorWithNorm(point)
clusterCentersWithNorm.zipWithIndex.map {
case (c, i) =>
(i, KMeans.fastSquaredDistance(c, pointWithNorm))
}.toList
}

/**
* Maps given points to their distances to all clusters.
*/
@Since("1.5.0")
def distanceToCenters(points: RDD[Vector]): RDD[(Vector, Iterable[(Int, Double)])] = {
val centersWithNorm = clusterCentersWithNorm
val bcCentersWithNorm = points.context.broadcast(centersWithNorm)
points.map(p => {
val pointWithNorm = new VectorWithNorm(p)
(p, bcCentersWithNorm.value.zipWithIndex.map {
case (c, i) =>
(i, KMeans.fastSquaredDistance(c, pointWithNorm))
}.toList)
})
}

/**
* Return the K-means cost (sum of squared distances of points to their nearest center) for this
* model on the given data.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,8 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {

for (initMode <- Seq(RANDOM, K_MEANS_PARALLEL)) {
// Two iterations are sufficient no matter where the initial centers are.
val model = KMeans.train(rdd, k = 2, maxIterations = 2, runs = 1, initMode)
val k = 2
val model = KMeans.train(rdd, k = k, maxIterations = 2, runs = 1, initMode)

val predicts = model.predict(rdd).collect()

Expand All @@ -259,6 +260,8 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(predicts(3) === predicts(4))
assert(predicts(3) === predicts(5))
assert(predicts(0) != predicts(3))

assert(model.distanceToCenters(rdd).flatMap(_._2).count === points.size * k)
}
}

Expand Down Expand Up @@ -341,6 +344,7 @@ class KMeansClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
val model = KMeans.train(points, 2, 2, 1, initMode)
val predictions = model.predict(points).collect()
val cost = model.computeCost(points)
val dToCenters = model.distanceToCenters(points.first())
}
}
}