Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ import org.apache.spark.Logging
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.encoders.{OuterScopes, encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._

Expand All @@ -46,14 +45,12 @@ object TypedAggregateExpression {
/**
* This class is a rough sketch of how to hook `Aggregator` into the Aggregation system. It has
* the following limitations:
* - It assumes the aggregator reduces and returns a single column of type `long`.
* - It might only work when there is a single aggregator in the first column.
* - It assumes the aggregator has a zero, `0`.
*/
case class TypedAggregateExpression(
aggregator: Aggregator[Any, Any, Any],
aEncoder: Option[ExpressionEncoder[Any]], // Should be bound.
bEncoder: ExpressionEncoder[Any], // Should be bound.
unresolvedBEncoder: ExpressionEncoder[Any],
cEncoder: ExpressionEncoder[Any],
children: Seq[Attribute],
mutableAggBufferOffset: Int,
Expand All @@ -80,10 +77,14 @@ case class TypedAggregateExpression(

override lazy val inputTypes: Seq[DataType] = Nil

override val aggBufferSchema: StructType = bEncoder.schema
override val aggBufferSchema: StructType = unresolvedBEncoder.schema

override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes

val bEncoder = unresolvedBEncoder
.resolve(aggBufferAttributes, OuterScopes.outerScopes)
.bind(aggBufferAttributes)

// Note: although this simply copies aggBufferAttributes, this common code can not be placed
// in the superclass because that will lead to initialization ordering issues.
override val inputAggBufferAttributes: Seq[AttributeReference] =
Expand All @@ -93,12 +94,18 @@ case class TypedAggregateExpression(
lazy val boundA = aEncoder.get

private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = {
// todo: need a more neat way to assign the value.
var i = 0
while (i < aggBufferAttributes.length) {
val offset = mutableAggBufferOffset + i
aggBufferSchema(i).dataType match {
case IntegerType => buffer.setInt(mutableAggBufferOffset + i, value.getInt(i))
case LongType => buffer.setLong(mutableAggBufferOffset + i, value.getLong(i))
case BooleanType => buffer.setBoolean(offset, value.getBoolean(i))
case ByteType => buffer.setByte(offset, value.getByte(i))
case ShortType => buffer.setShort(offset, value.getShort(i))
case IntegerType => buffer.setInt(offset, value.getInt(i))
case LongType => buffer.setLong(offset, value.getLong(i))
case FloatType => buffer.setFloat(offset, value.getFloat(i))
case DoubleType => buffer.setDouble(offset, value.getDouble(i))
case other => buffer.update(offset, value.get(i, other))
}
i += 1
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, L
}

case class AggData(a: Int, b: String)
object ClassInputAgg extends Aggregator[AggData, Int, Int] with Serializable {
object ClassInputAgg extends Aggregator[AggData, Int, Int] {
/** A zero value for this aggregation. Should satisfy the property that any b + zero = b */
override def zero: Int = 0

Expand All @@ -88,6 +88,28 @@ object ClassInputAgg extends Aggregator[AggData, Int, Int] with Serializable {
override def merge(b1: Int, b2: Int): Int = b1 + b2
}

object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] {
/** A zero value for this aggregation. Should satisfy the property that any b + zero = b */
override def zero: (Int, AggData) = 0 -> AggData(0, "0")

/**
* Combine two values to produce a new value. For performance, the function may modify `b` and
* return it instead of constructing new object for b.
*/
override def reduce(b: (Int, AggData), a: AggData): (Int, AggData) = (b._1 + 1, a)

/**
* Transform the output of the reduction.
*/
override def finish(reduction: (Int, AggData)): Int = reduction._1

/**
* Merge two intermediate values
*/
override def merge(b1: (Int, AggData), b2: (Int, AggData)): (Int, AggData) =
(b1._1 + b2._1, b1._2)
}

class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {

import testImplicits._
Expand Down Expand Up @@ -168,4 +190,21 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
ds.groupBy(_.b).agg(ClassInputAgg.toColumn),
("one", 1))
}

test("typed aggregation: complex input") {
val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS()

checkAnswer(
ds.select(ComplexBufferAgg.toColumn),
2
)

checkAnswer(
ds.select(expr("avg(a)").as[Double], ComplexBufferAgg.toColumn),
(1.5, 2))

checkAnswer(
ds.groupBy(_.b).agg(ComplexBufferAgg.toColumn),
("one", 1), ("two", 1))
}
}