Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
address comments
  • Loading branch information
zhengruifeng committed Sep 20, 2022
commit 01ad8cb324ba8543dfcf37fc93e8b58dc45b3110
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ import org.apache.spark.mllib.linalg.CholeskyDecomposition
import org.apache.spark.mllib.optimization.NNLS
import org.apache.spark.rdd.{DeterministicLevel, RDD}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.aggregate.CollectOrdered
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -501,17 +500,14 @@ class ALSModel private[ml] (
}
}.toDF(srcOutputColumn, dstOutputColumn, ratingColumn)

val aggFunc = CollectOrdered(struct(ratingColumn, dstOutputColumn).expr, num, true)
.toAggregateExpression(false)

val arrayType = ArrayType(
new StructType()
.add(dstOutputColumn, IntegerType)
.add(ratingColumn, FloatType)
)

ratings.groupBy(srcOutputColumn)
.agg(new Column(aggFunc))
.agg(collect_top_k(struct(ratingColumn, dstOutputColumn), num, false))
.as[(Int, Seq[(Float, Int)])]
.map(t => (t._1, t._2.map(p => (p._2, p._1))))
.toDF(srcOutputColumn, recommendColumn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@
package org.apache.spark.ml.recommendation

import org.apache.spark.ml.util.MLTest
import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.catalyst.expressions.aggregate.CollectOrdered
import org.apache.spark.sql.functions.{col, struct}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.{col, collect_top_k, struct}

class CollectOrderedSuite extends MLTest {
class CollectTopKSuite extends MLTest {

import testImplicits._

Expand All @@ -44,16 +43,10 @@ class CollectOrderedSuite extends MLTest {
).toDF("user", "item", "score")
}

private def collect_ordered(e: Column, num: Int, reverse: Boolean): Column = {
new Column(CollectOrdered(e.expr, num, reverse)
.toAggregateExpression(false)
)
}

test("k smallest with k < #items") {
val k = 2
val topK = dataFrame.groupBy("user")
.agg(collect_ordered(col("score"), k, false))
.agg(collect_top_k(col("score"), k, true))
.as[(Int, Seq[Float])]
.collect()

Expand All @@ -69,7 +62,7 @@ class CollectOrderedSuite extends MLTest {
test("k smallest with k > #items") {
val k = 5
val topK = dataFrame.groupBy("user")
.agg(collect_ordered(col("score"), k, false))
.agg(collect_top_k(col("score"), k, true))
.as[(Int, Seq[Float])]
.collect()

Expand All @@ -85,7 +78,7 @@ class CollectOrderedSuite extends MLTest {
test("k largest with k < #items") {
val k = 2
val topK = dataFrame.groupBy("user")
.agg(collect_ordered(struct("score", "item"), k, true))
.agg(collect_top_k(struct("score", "item"), k, false))
.as[(Int, Seq[(Float, Int)])]
.map(t => (t._1, t._2.map(p => (p._2, p._1))))
.collect()
Expand All @@ -102,7 +95,7 @@ class CollectOrderedSuite extends MLTest {
test("k largest with k > #items") {
val k = 5
val topK = dataFrame.groupBy("user")
.agg(collect_ordered(struct("score", "item"), k, true))
.agg(collect_top_k(struct("score", "item"), k, false))
.as[(Int, Seq[(Float, Int)])]
.map(t => (t._1, t._2.map(p => (p._2, p._1))))
.collect()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,15 +196,16 @@ case class CollectSet(
}

/**
* Collect the top-k elements. This expression is dedicated only for MLLIB.
* Collect the top-k elements. This expression is dedicated only for Spark-ML.
* @param reverse when true, returns the smallest k elements.
*/
case class CollectOrdered(
case class CollectTopK(
child: Expression,
num: Int,
reverse: Boolean = false,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends Collect[BoundedPriorityQueue[Any]] {
require(num > 0)
assert(num > 0)

def this(child: Expression, num: Int) = this(child, num, false, 0, 0)
def this(child: Expression, num: Int, reverse: Boolean) = this(child, num, reverse, 0, 0)
Expand All @@ -219,19 +220,19 @@ case class CollectOrdered(
}

override def createAggregationBuffer(): BoundedPriorityQueue[Any] =
new BoundedPriorityQueue[Any](num)(ordering.reverse)
new BoundedPriorityQueue[Any](num)(ordering)

override def eval(buffer: BoundedPriorityQueue[Any]): Any =
new GenericArrayData(buffer.toArray.sorted(ordering))
new GenericArrayData(buffer.toArray.sorted(ordering.reverse))

override def prettyName: String = "collect_ordered"
override def prettyName: String = "collect_top_k"

override protected def withNewChildInternal(newChild: Expression): CollectOrdered =
override protected def withNewChildInternal(newChild: Expression): CollectTopK =
copy(child = newChild)

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): CollectOrdered =
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): CollectTopK =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): CollectOrdered =
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): CollectTopK =
copy(inputAggBufferOffset = newInputAggBufferOffset)
}
3 changes: 3 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,9 @@ object functions {
*/
def collect_set(columnName: String): Column = collect_set(Column(columnName))

private[spark] def collect_top_k(e: Column, num: Int, reverse: Boolean): Column =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: shall we make it public ? It might be a useful function.

We don't need to do it in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know, I also think it's useful and may further use it in Pandas-API-on-Spark.
But I don't know whether it is suitable to be public @cloud-fan @HyukjinKwon

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep it private for now.

withAggregateFunction { CollectTopK(e.expr, num, reverse) }

/**
* Aggregate function: returns the Pearson Correlation Coefficient for two columns.
*
Expand Down