-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-17187][SQL] Supports using arbitrary Java object as internal aggregation buffer object #14753
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-17187][SQL] Supports using arbitrary Java object as internal aggregation buffer object #14753
Changes from 3 commits
10861b2
0fdc1ea
d3108ab
2873765
7190eb0
5904bcd
8c8bd9a
7e7cb85
86166a1
e060d21
ac8e36a
ca574e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -389,3 +389,146 @@ abstract class DeclarativeAggregate | |
| def right: AttributeReference = inputAggBufferAttributes(aggBufferAttributes.indexOf(a)) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Aggregation function which allows **arbitrary** user-defined java object to be used as internal | ||
| * aggregation buffer object. | ||
| * | ||
| * {{{ | ||
| * aggregation buffer for normal aggregation function `avg` | ||
| * | | ||
| * v | ||
| * +--------------+---------------+-----------------------------------+ | ||
| * | sum1 (Long) | count1 (Long) | generic user-defined java objects | | ||
| * +--------------+---------------+-----------------------------------+ | ||
| * ^ | ||
| * | | ||
| * Aggregation buffer object for `TypedImperativeAggregate` aggregation function | ||
| * }}} | ||
| * | ||
| * Work flow (Partial mode aggregate at Mapper side, and Final mode aggregate at Reducer side): | ||
| * | ||
| * Stage 1: Partial aggregate at Mapper side: | ||
| * | ||
| * 1. The framework calls `createAggregationBuffer(): T` to create an empty internal aggregation | ||
| * buffer object. | ||
| * 2. Upon each input row, the framework calls | ||
| * `update(buffer: T, input: InternalRow): Unit` to update the aggregation buffer object T. | ||
| * 3. After processing all rows of current group (group by key), the framework will serialize | ||
| * aggregation buffer object T to storage format (Array[Byte]) and persist the Array[Byte] | ||
| * to disk if needed. | ||
| * 4. The framework moves on to next group, until all groups have been processed. | ||
| * | ||
| * Shuffling exchange data to Reducer tasks... | ||
| * | ||
| * Stage 2: Final mode aggregate at Reducer side: | ||
| * | ||
| * 1. The framework calls `createAggregationBuffer(): T` to create an empty internal aggregation | ||
| * buffer object (type T) for merging. | ||
| * 2. For each aggregation output of Stage 1, The framework de-serializes the storage | ||
| * format (Array[Byte]) and produces one input aggregation object (type T). | ||
| * 3. For each input aggregation object, the framework calls `merge(buffer: T, input: T): Unit` | ||
| * to merge the input aggregation object into aggregation buffer object. | ||
| * 4. After processing all input aggregation objects of current group (group by key), the framework | ||
| * calls method `eval(buffer: T)` to generate the final output for this group. | ||
| * 5. The framework moves on to next group, until all groups have been processed. | ||
| */ | ||
| abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't this the wrong way around? Isn't I know this has been done for engineering purposes, but I still wonder if we shouldn't reverse the hierarchy here.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| /** | ||
| * Creates an empty aggregation buffer object. This is called before processing each key group | ||
| * (group by key). | ||
| * | ||
| * @return an aggregation buffer object | ||
| */ | ||
| def createAggregationBuffer(): T | ||
|
|
||
| /** | ||
| * In-place updates the aggregation buffer object with an input row. buffer = buffer + input. | ||
| * This is typically called when doing Partial or Complete mode aggregation. | ||
| * | ||
| * @param buffer The aggregation buffer object. | ||
| * @param input an input row | ||
| */ | ||
| def update(buffer: T, input: InternalRow): Unit | ||
|
||
|
|
||
| /** | ||
| * Merges an input aggregation object into aggregation buffer object. buffer = buffer + input. | ||
| * This is typically called when doing PartialMerge or Final mode aggregation. | ||
| * | ||
| * @param buffer the aggregation buffer object used to store the aggregation result. | ||
| * @param input an input aggregation object. Input aggregation object can be produced by | ||
| * de-serializing the partial aggregate's output from Mapper side. | ||
| */ | ||
| def merge(buffer: T, input: T): Unit | ||
|
||
|
|
||
| /** | ||
| * Generates the final aggregation result value for current key group with the aggregation buffer | ||
| * object. | ||
| * | ||
| * @param buffer aggregation buffer object. | ||
| * @return The aggregation result of current key group | ||
| */ | ||
| def eval(buffer: T): Any | ||
|
|
||
| /** Returns the class of aggregation buffer object */ | ||
| def aggregationBufferClass: Class[T] | ||
|
||
|
|
||
| /** Serializes the aggregation buffer object T to Array[Byte] */ | ||
| def serialize(buffer: T): Array[Byte] | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here we limit the serializable format to The reason is that SpecialMutableRow will do type check for atomic types for each
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This detail deserves a comment in the code. |
||
|
|
||
| /** De-serializes the serialized format Array[Byte], and produces aggregation buffer object T */ | ||
| def deserialize(storageFormat: Array[Byte]): T | ||
|
|
||
| final override def initialize(buffer: MutableRow): Unit = { | ||
| val bufferObject = createAggregationBuffer() | ||
| buffer.update(mutableAggBufferOffset, bufferObject) | ||
| } | ||
|
|
||
| final override def update(buffer: MutableRow, input: InternalRow): Unit = { | ||
| val bufferObject = field(buffer, mutableAggBufferOffset).asInstanceOf[T] | ||
| update(bufferObject, input) | ||
| } | ||
|
|
||
| final override def merge(buffer: MutableRow, inputBuffer: InternalRow): Unit = { | ||
| val bufferObject = field[T](buffer, mutableAggBufferOffset) | ||
| val inputObject = deserialize(field[Array[Byte]](inputBuffer, inputAggBufferOffset)) | ||
| merge(bufferObject, inputObject) | ||
| } | ||
|
|
||
| final override def eval(buffer: InternalRow): Any = { | ||
| val bufferObject = field[AnyRef](buffer, mutableAggBufferOffset) | ||
| if (bufferObject.getClass == aggregationBufferClass) { | ||
| // When used in Window frame aggregation, eval(buffer: InternalRow) is called directly | ||
| // on the object aggregation buffer without intermediate serializing/de-serializing. | ||
| eval(bufferObject.asInstanceOf[T]) | ||
| } else { | ||
| eval(deserialize(bufferObject.asInstanceOf[Array[Byte]])) | ||
| } | ||
| } | ||
|
|
||
| private def field[U](input: InternalRow, fieldIndex: Int): U = { | ||
|
||
| input.get(fieldIndex, null).asInstanceOf[U] | ||
| } | ||
|
|
||
| final override lazy val aggBufferAttributes: Seq[AttributeReference] = { | ||
| // Underlying storage type for the aggregation buffer object | ||
| Seq(AttributeReference("buf", BinaryType)()) | ||
| } | ||
|
|
||
| final override lazy val inputAggBufferAttributes: Seq[AttributeReference] = | ||
| aggBufferAttributes.map(_.newInstance()) | ||
|
|
||
| final override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) | ||
|
|
||
| /** | ||
| * In-place replaces the aggregation buffer object stored at buffer's index | ||
| * `mutableAggBufferOffset`, with SparkSQL internally supported underlying storage format. | ||
|
||
| * | ||
| * The framework calls this method every time after updating/merging one group (group by key). | ||
|
||
| */ | ||
| final def serializeAggregateBufferInPlace(buffer: MutableRow): Unit = { | ||
| val bufferObject = field(buffer, mutableAggBufferOffset).asInstanceOf[T] | ||
| buffer(mutableAggBufferOffset) = serialize(bufferObject) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.aggregate | |
|
|
||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, TypedImperativeAggregate} | ||
| import org.apache.spark.sql.execution.metric.SQLMetric | ||
|
|
||
| /** | ||
|
|
@@ -54,7 +54,15 @@ class SortBasedAggregationIterator( | |
| val bufferRowSize: Int = bufferSchema.length | ||
|
|
||
| val genericMutableBuffer = new GenericMutableRow(bufferRowSize) | ||
| val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable) | ||
|
|
||
| val allFieldsMutable = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable) | ||
|
|
||
| val hasTypedImperativeAggregate = aggregateFunctions.exists { | ||
| case agg: TypedImperativeAggregate[_] => true | ||
| case _ => false | ||
| } | ||
|
|
||
| val useUnsafeBuffer = allFieldsMutable && !hasTypedImperativeAggregate | ||
|
|
||
| val buffer = if (useUnsafeBuffer) { | ||
| val unsafeProjection = | ||
|
|
@@ -90,6 +98,24 @@ class SortBasedAggregationIterator( | |
| // compared to MutableRow (aggregation buffer) directly. | ||
| private[this] val safeProj: Projection = FromUnsafeProjection(valueAttributes.map(_.dataType)) | ||
|
|
||
| // Aggregation function which uses generic aggregation buffer object. | ||
| // @see [[TypedImperativeAggregate]] for more information | ||
| private val typedImperativeAggregates: Array[TypedImperativeAggregate[_]] = { | ||
| aggregateFunctions.collect { | ||
| case (ag: TypedImperativeAggregate[_]) => ag | ||
| } | ||
| } | ||
|
|
||
| // For TypedImperativeAggregate with generic aggregation buffer object, we need to call | ||
| // serializeAggregateBufferInPlace(...) explicitly to convert the aggregation buffer object | ||
| // to Spark Sql internally supported serializable storage format. | ||
| private def serializeTypedAggregateBuffer(aggregationBuffer: MutableRow): Unit = { | ||
|
||
| typedImperativeAggregates.foreach { agg => | ||
| // In-place serialization | ||
| agg.serializeAggregateBufferInPlace(sortBasedAggregationBuffer) | ||
| } | ||
| } | ||
|
|
||
| protected def initialize(): Unit = { | ||
| if (inputIterator.hasNext) { | ||
| initializeBuffer(sortBasedAggregationBuffer) | ||
|
|
@@ -131,6 +157,11 @@ class SortBasedAggregationIterator( | |
| firstRowInNextGroup = currentRow.copy() | ||
| } | ||
| } | ||
|
|
||
| // Serializes the generic object stored in aggregation buffer for TypedImperativeAggregate | ||
| // aggregation functions. | ||
| serializeTypedAggregateBuffer(sortBasedAggregationBuffer) | ||
|
||
|
|
||
| // We have not seen a new group. It means that there is no new row in the input | ||
| // iter. The current group is the last group of the iter. | ||
| if (!findNextPartition) { | ||
|
|
@@ -162,6 +193,9 @@ class SortBasedAggregationIterator( | |
|
|
||
| def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { | ||
| initializeBuffer(sortBasedAggregationBuffer) | ||
| // Serializes the generic object stored in aggregation buffer for TypedImperativeAggregate | ||
| // aggregation functions. | ||
| serializeTypedAggregateBuffer(sortBasedAggregationBuffer) | ||
| generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's also add a normal agg buffer after the generic one. So, readers will not assume that generic ones will always be put at the end.