Skip to content

Commit 7190eb0

Browse files
committed
fix review comments
1 parent 2873765 commit 7190eb0

File tree

3 files changed

+11
-18
lines changed

3 files changed

+11
-18
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,12 @@ abstract class DeclarativeAggregate
432432
* 4. After processing all input aggregation objects of current group (group by key), the framework
433433
* calls method `eval(buffer: T)` to generate the final output for this group.
434434
* 5. The framework moves on to next group, until all groups have been processed.
435+
*
436+
* NOTE: SQL with TypedImperativeAggregate functions is planned in sort based aggregation,
437+
* instead of hash based aggregation, as TypedImperativeAggregate use BinaryType as aggregation
438+
* buffer's storage format, which is not supported by hash based aggregation. Hash based
439+
* aggregation only support aggregation buffer of mutable types (like LongType, IntType that have
440+
* fixed length and can be mutated in place in UnsafeRow)
435441
*/
436442
abstract class TypedImperativeAggregate[T] extends ImperativeAggregate {
437443

@@ -507,8 +513,9 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate {
507513
}
508514
}
509515

516+
private[this] val anyObjectType = ObjectType(classOf[AnyRef])
510517
private def getField[U](input: InternalRow, fieldIndex: Int): U = {
511-
input.get(fieldIndex, null).asInstanceOf[U]
518+
input.get(fieldIndex, anyObjectType).asInstanceOf[U]
512519
}
513520

514521
final override lazy val aggBufferAttributes: Seq[AttributeReference] = {

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,9 @@ object AggUtils {
5454
initialInputBufferOffset: Int = 0,
5555
resultExpressions: Seq[NamedExpression] = Nil,
5656
child: SparkPlan): SparkPlan = {
57-
58-
val hasTypedImperativeAggregate: Boolean = aggregateExpressions.exists {
59-
case AggregateExpression(agg: TypedImperativeAggregate[_], _, _, _) => true
60-
case _ => false
61-
}
62-
63-
val aggBufferAttributesSupportedByHashAggregate = HashAggregateExec.supportsAggregate(
57+
val useHash = HashAggregateExec.supportsAggregate(
6458
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
65-
66-
if (aggBufferAttributesSupportedByHashAggregate && !hasTypedImperativeAggregate) {
59+
if (useHash) {
6760
HashAggregateExec(
6861
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
6962
groupingExpressions = groupingExpressions,

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,7 @@ class SortBasedAggregationIterator(
5555

5656
val genericMutableBuffer = new GenericMutableRow(bufferRowSize)
5757

58-
val allFieldsMutable = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable)
59-
60-
val hasTypedImperativeAggregate = aggregateFunctions.exists {
61-
case agg: TypedImperativeAggregate[_] => true
62-
case _ => false
63-
}
64-
65-
val useUnsafeBuffer = allFieldsMutable && !hasTypedImperativeAggregate
58+
val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable)
6659

6760
val buffer = if (useUnsafeBuffer) {
6861
val unsafeProjection =

0 commit comments

Comments
 (0)