diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctions.scala index 8f78bcc15347f..8016258f054a9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctions.scala @@ -20,8 +20,10 @@ package org.apache.spark.mllib.rdd import scala.language.implicitConversions import scala.reflect.ClassTag +import org.apache.spark.{Aggregator, InterruptibleIterator, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.util.BoundedPriorityQueue +import org.apache.spark.util.collection.Utils /** * Machine learning specific Pair RDD functions. @@ -37,14 +39,30 @@ class MLPairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) extends Se * @return an RDD that contains the top k values for each key */ def topByKey(num: Int)(implicit ord: Ordering[V]): RDD[(K, Array[V])] = { - self.aggregateByKey(new BoundedPriorityQueue[V](num)(ord))( - seqOp = (queue, item) => { - queue += item - }, - combOp = (queue1, queue2) => { - queue1 ++= queue2 - } - ).mapValues(_.toArray.sorted(ord.reverse)) // This is a min-heap, so we reverse the order. + val createCombiner = (v: V) => new BoundedPriorityQueue[V](num)(ord) += v + val mergeValue = (c: BoundedPriorityQueue[V], v: V) => c += v + val mergeCombiners = (c1: BoundedPriorityQueue[V], c2: BoundedPriorityQueue[V]) => c1 ++= c2 + + val aggregator = new Aggregator[K, V, BoundedPriorityQueue[V]]( + self.context.clean(createCombiner), + self.context.clean(mergeValue), + self.context.clean(mergeCombiners)) + + self.mapPartitions(iter => { + val context = TaskContext.get() + new InterruptibleIterator( + context, + aggregator + .combineValuesByKey(iter, context) + .map { case (k, v) => (k, v.toArray.sorted(ord.reverse)) } + ) + }, preservesPartitioning = true + ).reduceByKey { (array1, array2) => + val size = math.min(num, array1.length + array2.length) + val array = Array.ofDim[V](size) + Utils.mergeOrdered[V](Seq(array1, array2))(ord.reverse).copyToArray(array, 0, size) + array + } } }