diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 8dec0dd219567..d5a7d58480bba 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -465,6 +465,8 @@ class ALSModel private[ml] ( import srcFactors.sparkSession.implicits._ import scala.collection.JavaConverters._ + val ratingColumn = "rating" + val recommendColumn = "recommendations" val srcFactorsBlocked = blockify(srcFactors.as[(Int, Array[Float])], blockSize) val dstFactorsBlocked = blockify(dstFactors.as[(Int, Array[Float])], blockSize) val ratings = srcFactorsBlocked.crossJoin(dstFactorsBlocked) @@ -496,18 +498,20 @@ class ALSModel private[ml] ( .iterator.map { j => (srcId, dstIds(j), scores(j)) } } } - } - // We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output. - val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2)) - val recs = ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn) - .toDF("id", "recommendations") + }.toDF(srcOutputColumn, dstOutputColumn, ratingColumn) val arrayType = ArrayType( new StructType() .add(dstOutputColumn, IntegerType) - .add("rating", FloatType) + .add(ratingColumn, FloatType) ) - recs.select($"id".as(srcOutputColumn), $"recommendations".cast(arrayType)) + + ratings.groupBy(srcOutputColumn) + .agg(collect_top_k(struct(ratingColumn, dstOutputColumn), num, false)) + .as[(Int, Seq[(Float, Int)])] + .map(t => (t._1, t._2.map(p => (p._2, p._1)))) + .toDF(srcOutputColumn, recommendColumn) + .withColumn(recommendColumn, col(recommendColumn).cast(arrayType)) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala deleted file mode 100644 index ed41169070c59..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.recommendation - -import scala.reflect.runtime.universe.TypeTag - -import org.apache.spark.sql.{Encoder, Encoders} -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.expressions.Aggregator -import org.apache.spark.util.BoundedPriorityQueue - - -/** - * Works on rows of the form (K1, K2, V) where K1 & K2 are IDs and V is the score value. Finds - * the top `num` K2 items based on the given Ordering. - */ -private[recommendation] class TopByKeyAggregator[K1: TypeTag, K2: TypeTag, V: TypeTag] - (num: Int, ord: Ordering[(K2, V)]) - extends Aggregator[(K1, K2, V), BoundedPriorityQueue[(K2, V)], Array[(K2, V)]] { - - override def zero: BoundedPriorityQueue[(K2, V)] = new BoundedPriorityQueue[(K2, V)](num)(ord) - - override def reduce( - q: BoundedPriorityQueue[(K2, V)], - a: (K1, K2, V)): BoundedPriorityQueue[(K2, V)] = { - q += {(a._2, a._3)} - } - - override def merge( - q1: BoundedPriorityQueue[(K2, V)], - q2: BoundedPriorityQueue[(K2, V)]): BoundedPriorityQueue[(K2, V)] = { - q1 ++= q2 - } - - override def finish(r: BoundedPriorityQueue[(K2, V)]): Array[(K2, V)] = { - r.toArray.sorted(ord.reverse) - } - - override def bufferEncoder: Encoder[BoundedPriorityQueue[(K2, V)]] = { - Encoders.kryo[BoundedPriorityQueue[(K2, V)]] - } - - override def outputEncoder: Encoder[Array[(K2, V)]] = ExpressionEncoder[Array[(K2, V)]]() -} diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/CollectTopKSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/CollectTopKSuite.scala new file mode 100644 index 0000000000000..366cce1715900 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/CollectTopKSuite.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.recommendation + +import org.apache.spark.ml.util.MLTest +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.{col, collect_top_k, struct} + +class CollectTopKSuite extends MLTest { + + import testImplicits._ + + @transient var dataFrame: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + val sqlContext = spark.sqlContext + import sqlContext.implicits._ + dataFrame = Seq( + (0, 3, 54f), + (0, 4, 44f), + (0, 5, 42f), + (0, 6, 28f), + (1, 3, 39f), + (2, 3, 51f), + (2, 5, 45f), + (2, 6, 18f) + ).toDF("user", "item", "score") + } + + test("k smallest with k < #items") { + val k = 2 + val topK = dataFrame.groupBy("user") + .agg(collect_top_k(col("score"), k, true)) + .as[(Int, Seq[Float])] + .collect() + + val expected = Map( + 0 -> Array(28f, 42f), + 1 -> Array(39f), + 2 -> Array(18f, 45f) + ) + assert(topK.size === expected.size) + topK.foreach { case (k, v) => assert(v === expected(k)) } + } + + test("k smallest with k > #items") { + val k = 5 + val topK = dataFrame.groupBy("user") + .agg(collect_top_k(col("score"), k, true)) + .as[(Int, Seq[Float])] + .collect() + + val expected = Map( + 0 -> Array(28f, 42f, 44f, 54f), + 1 -> Array(39f), + 2 -> Array(18f, 45f, 51f) + ) + assert(topK.size === expected.size) + topK.foreach { case (k, v) => assert(v === expected(k)) } + } + + test("k largest with k < #items") { + val k = 2 + val topK = dataFrame.groupBy("user") + .agg(collect_top_k(struct("score", "item"), k, false)) + .as[(Int, Seq[(Float, Int)])] + .map(t => (t._1, t._2.map(p => (p._2, p._1)))) + .collect() + + val expected = Map( + 0 -> Array((3, 54f), (4, 44f)), + 1 -> Array((3, 39f)), + 2 -> Array((3, 51f), (5, 45f)) + ) + assert(topK.size === expected.size) + topK.foreach { case (k, v) => assert(v === expected(k)) } + } + + test("k largest with k > #items") { + val k = 5 + val topK = dataFrame.groupBy("user") + .agg(collect_top_k(struct("score", "item"), k, false)) + .as[(Int, Seq[(Float, Int)])] + .map(t => (t._1, t._2.map(p => (p._2, p._1)))) + .collect() + + val expected = Map( + 0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)), + 1 -> Array((3, 39f)), + 2 -> Array((3, 51f), (5, 45f), (6, 18f)) + ) + assert(topK.size === expected.size) + topK.foreach { case (k, v) => assert(v === expected(k)) } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala deleted file mode 100644 index 5e763a8e908b8..0000000000000 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.recommendation - -import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Dataset - - -class TopByKeyAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { - - private def getTopK(k: Int): Dataset[(Int, Array[(Int, Float)])] = { - val sqlContext = spark.sqlContext - import sqlContext.implicits._ - - val topKAggregator = new TopByKeyAggregator[Int, Int, Float](k, Ordering.by(_._2)) - Seq( - (0, 3, 54f), - (0, 4, 44f), - (0, 5, 42f), - (0, 6, 28f), - (1, 3, 39f), - (2, 3, 51f), - (2, 5, 45f), - (2, 6, 18f) - ).toDS().groupByKey(_._1).agg(topKAggregator.toColumn) - } - - test("topByKey with k < #items") { - val topK = getTopK(2) - assert(topK.count() === 3) - - val expected = Map( - 0 -> Array((3, 54f), (4, 44f)), - 1 -> Array((3, 39f)), - 2 -> Array((3, 51f), (5, 45f)) - ) - checkTopK(topK, expected) - } - - test("topByKey with k > #items") { - val topK = getTopK(5) - assert(topK.count() === 3) - - val expected = Map( - 0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)), - 1 -> Array((3, 39f)), - 2 -> Array((3, 51f), (5, 45f), (6, 18f)) - ) - checkTopK(topK, expected) - } - - private def checkTopK( - topK: Dataset[(Int, Array[(Int, Float)])], - expected: Map[Int, Array[(Int, Float)]]): Unit = { - topK.collect().foreach { case (id, recs) => assert(recs === expected(id)) } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 2514461d4c057..89255fac13167 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -24,9 +24,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.UnaryLike -import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ +import org.apache.spark.util.BoundedPriorityQueue /** * A base class for collect_list and collect_set aggregate functions. @@ -194,3 +194,45 @@ case class CollectSet( override protected def withNewChildInternal(newChild: Expression): CollectSet = copy(child = newChild) } + +/** + * Collect the top-k elements. This expression is dedicated only for Spark-ML. + * @param reverse when true, returns the smallest k elements. + */ +case class CollectTopK( + child: Expression, + num: Int, + reverse: Boolean = false, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends Collect[BoundedPriorityQueue[Any]] { + assert(num > 0) + + def this(child: Expression, num: Int) = this(child, num, false, 0, 0) + def this(child: Expression, num: Int, reverse: Boolean) = this(child, num, reverse, 0, 0) + + override protected lazy val bufferElementType: DataType = child.dataType + override protected def convertToBufferElement(value: Any): Any = InternalRow.copyValue(value) + + private def ordering: Ordering[Any] = if (reverse) { + TypeUtils.getInterpretedOrdering(child.dataType).reverse + } else { + TypeUtils.getInterpretedOrdering(child.dataType) + } + + override def createAggregationBuffer(): BoundedPriorityQueue[Any] = + new BoundedPriorityQueue[Any](num)(ordering) + + override def eval(buffer: BoundedPriorityQueue[Any]): Any = + new GenericArrayData(buffer.toArray.sorted(ordering.reverse)) + + override def prettyName: String = "collect_top_k" + + override protected def withNewChildInternal(newChild: Expression): CollectTopK = + copy(child = newChild) + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): CollectTopK = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): CollectTopK = + copy(inputAggBufferOffset = newInputAggBufferOffset) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 69da277d5e604..0fdc0038a2947 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -367,6 +367,9 @@ object functions { */ def collect_set(columnName: String): Column = collect_set(Column(columnName)) + private[spark] def collect_top_k(e: Column, num: Int, reverse: Boolean): Column = + withAggregateFunction { CollectTopK(e.expr, num, reverse) } + /** * Aggregate function: returns the Pearson Correlation Coefficient for two columns. *