Skip to content
Prev Previous commit
Next Next commit
fix review comments
  • Loading branch information
clockfly committed Aug 24, 2016
commit 7190eb0c2a4dce2c5b84c29fb90bb2def23a3520
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,12 @@ abstract class DeclarativeAggregate
* 4. After processing all input aggregation objects of current group (group by key), the framework
* calls method `eval(buffer: T)` to generate the final output for this group.
* 5. The framework moves on to next group, until all groups have been processed.
*
* NOTE: SQL with TypedImperativeAggregate functions is planned in sort based aggregation,
* instead of hash based aggregation, as TypedImperativeAggregate use BinaryType as aggregation
* buffer's storage format, which is not supported by hash based aggregation. Hash based
* aggregation only support aggregation buffer of mutable types (like LongType, IntType that have
* fixed length and can be mutated in place in UnsafeRow)
*/
abstract class TypedImperativeAggregate[T] extends ImperativeAggregate {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this the wrong way around? Isn't ImperativeAggregate the untyped version of an TypedImperativeAggregate? Much like Dataset and DataFrame?

I know this has been done for engineering purposes, but I still wonder if we shouldn't reverse the hierarchy here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ImperativeAggregate only defines the interface. It does not specify what are accepted buffer types, right?


Expand Down Expand Up @@ -507,8 +513,9 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate {
}
}

private[this] val anyObjectType = ObjectType(classOf[AnyRef])
private def getField[U](input: InternalRow, fieldIndex: Int): U = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems we only need getField(input: InternalRow, fieldIndex: Int): T?

input.get(fieldIndex, null).asInstanceOf[U]
input.get(fieldIndex, anyObjectType).asInstanceOf[U]
}

final override lazy val aggBufferAttributes: Seq[AttributeReference] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,9 @@ object AggUtils {
initialInputBufferOffset: Int = 0,
resultExpressions: Seq[NamedExpression] = Nil,
child: SparkPlan): SparkPlan = {

val hasTypedImperativeAggregate: Boolean = aggregateExpressions.exists {
case AggregateExpression(agg: TypedImperativeAggregate[_], _, _, _) => true
case _ => false
}

val aggBufferAttributesSupportedByHashAggregate = HashAggregateExec.supportsAggregate(
val useHash = HashAggregateExec.supportsAggregate(
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))

if (aggBufferAttributesSupportedByHashAggregate && !hasTypedImperativeAggregate) {
if (useHash) {
HashAggregateExec(
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
groupingExpressions = groupingExpressions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,7 @@ class SortBasedAggregationIterator(

val genericMutableBuffer = new GenericMutableRow(bufferRowSize)

val allFieldsMutable = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable)

val hasTypedImperativeAggregate = aggregateFunctions.exists {
case agg: TypedImperativeAggregate[_] => true
case _ => false
}

val useUnsafeBuffer = allFieldsMutable && !hasTypedImperativeAggregate
val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable)

val buffer = if (useUnsafeBuffer) {
val unsafeProjection =
Expand Down