Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,8 @@ class ALSModel private[ml] (
import srcFactors.sparkSession.implicits._
import scala.collection.JavaConverters._

val ratingColumn = "rating"
val recommendColumn = "recommendations"
val srcFactorsBlocked = blockify(srcFactors.as[(Int, Array[Float])], blockSize)
val dstFactorsBlocked = blockify(dstFactors.as[(Int, Array[Float])], blockSize)
val ratings = srcFactorsBlocked.crossJoin(dstFactorsBlocked)
Expand Down Expand Up @@ -496,18 +498,20 @@ class ALSModel private[ml] (
.iterator.map { j => (srcId, dstIds(j), scores(j)) }
}
}
}
// We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output.
val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2))
val recs = ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn)
.toDF("id", "recommendations")
}.toDF(srcOutputColumn, dstOutputColumn, ratingColumn)

val arrayType = ArrayType(
new StructType()
.add(dstOutputColumn, IntegerType)
.add("rating", FloatType)
.add(ratingColumn, FloatType)
)
recs.select($"id".as(srcOutputColumn), $"recommendations".cast(arrayType))

ratings.groupBy(srcOutputColumn)
.agg(collect_top_k(struct(ratingColumn, dstOutputColumn), num, false))
.as[(Int, Seq[(Float, Int)])]
.map(t => (t._1, t._2.map(p => (p._2, p._1))))
.toDF(srcOutputColumn, recommendColumn)
.withColumn(recommendColumn, col(recommendColumn).cast(arrayType))
}

/**
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.recommendation

import org.apache.spark.ml.util.MLTest
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.{col, collect_top_k, struct}

class CollectTopKSuite extends MLTest {

import testImplicits._

@transient var dataFrame: DataFrame = _

override def beforeAll(): Unit = {
super.beforeAll()
val sqlContext = spark.sqlContext
import sqlContext.implicits._
dataFrame = Seq(
(0, 3, 54f),
(0, 4, 44f),
(0, 5, 42f),
(0, 6, 28f),
(1, 3, 39f),
(2, 3, 51f),
(2, 5, 45f),
(2, 6, 18f)
).toDF("user", "item", "score")
}

test("k smallest with k < #items") {
val k = 2
val topK = dataFrame.groupBy("user")
.agg(collect_top_k(col("score"), k, true))
.as[(Int, Seq[Float])]
.collect()

val expected = Map(
0 -> Array(28f, 42f),
1 -> Array(39f),
2 -> Array(18f, 45f)
)
assert(topK.size === expected.size)
topK.foreach { case (k, v) => assert(v === expected(k)) }
}

test("k smallest with k > #items") {
val k = 5
val topK = dataFrame.groupBy("user")
.agg(collect_top_k(col("score"), k, true))
.as[(Int, Seq[Float])]
.collect()

val expected = Map(
0 -> Array(28f, 42f, 44f, 54f),
1 -> Array(39f),
2 -> Array(18f, 45f, 51f)
)
assert(topK.size === expected.size)
topK.foreach { case (k, v) => assert(v === expected(k)) }
}

test("k largest with k < #items") {
val k = 2
val topK = dataFrame.groupBy("user")
.agg(collect_top_k(struct("score", "item"), k, false))
.as[(Int, Seq[(Float, Int)])]
.map(t => (t._1, t._2.map(p => (p._2, p._1))))
.collect()

val expected = Map(
0 -> Array((3, 54f), (4, 44f)),
1 -> Array((3, 39f)),
2 -> Array((3, 51f), (5, 45f))
)
assert(topK.size === expected.size)
topK.foreach { case (k, v) => assert(v === expected(k)) }
}

test("k largest with k > #items") {
val k = 5
val topK = dataFrame.groupBy("user")
.agg(collect_top_k(struct("score", "item"), k, false))
.as[(Int, Seq[(Float, Int)])]
.map(t => (t._1, t._2.map(p => (p._2, p._1))))
.collect()

val expected = Map(
0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)),
1 -> Array((3, 39f)),
2 -> Array((3, 51f), (5, 45f), (6, 18f))
)
assert(topK.size === expected.size)
topK.foreach { case (k, v) => assert(v === expected(k)) }
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.util.BoundedPriorityQueue

/**
* A base class for collect_list and collect_set aggregate functions.
Expand Down Expand Up @@ -194,3 +194,45 @@ case class CollectSet(
override protected def withNewChildInternal(newChild: Expression): CollectSet =
copy(child = newChild)
}

/**
* Collect the top-k elements. This expression is dedicated only for Spark-ML.
* @param reverse when true, returns the smallest k elements.
*/
case class CollectTopK(
child: Expression,
num: Int,
reverse: Boolean = false,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends Collect[BoundedPriorityQueue[Any]] {
assert(num > 0)

def this(child: Expression, num: Int) = this(child, num, false, 0, 0)
def this(child: Expression, num: Int, reverse: Boolean) = this(child, num, reverse, 0, 0)

override protected lazy val bufferElementType: DataType = child.dataType
override protected def convertToBufferElement(value: Any): Any = InternalRow.copyValue(value)

private def ordering: Ordering[Any] = if (reverse) {
TypeUtils.getInterpretedOrdering(child.dataType).reverse
} else {
TypeUtils.getInterpretedOrdering(child.dataType)
}

override def createAggregationBuffer(): BoundedPriorityQueue[Any] =
new BoundedPriorityQueue[Any](num)(ordering)

override def eval(buffer: BoundedPriorityQueue[Any]): Any =
new GenericArrayData(buffer.toArray.sorted(ordering.reverse))

override def prettyName: String = "collect_top_k"

override protected def withNewChildInternal(newChild: Expression): CollectTopK =
copy(child = newChild)

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): CollectTopK =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): CollectTopK =
copy(inputAggBufferOffset = newInputAggBufferOffset)
}
3 changes: 3 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,9 @@ object functions {
*/
def collect_set(columnName: String): Column = collect_set(Column(columnName))

private[spark] def collect_top_k(e: Column, num: Int, reverse: Boolean): Column =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: shall we make it public ? It might be a useful function.

We don't need to do it in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know, I also think it's useful and may further use it in Pandas-API-on-Spark.
But I don't know whether it is suitable to be public @cloud-fan @HyukjinKwon

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep it private for now.

withAggregateFunction { CollectTopK(e.expr, num, reverse) }

/**
* Aggregate function: returns the Pearson Correlation Coefficient for two columns.
*
Expand Down