Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Check the approach to check partial agg based on logical plan instead
  • Loading branch information
c21 committed Dec 1, 2021
commit cff1424c07d8423bd9d05c8f001b136dcbb26a75
Original file line number Diff line number Diff line change
Expand Up @@ -93,47 +93,14 @@ object ReplaceHashWithSortAgg extends Rule[SparkPlan] {
* Check if `partialAgg` to be partial aggregate of `finalAgg`.
*/
private def isPartialAgg(partialAgg: HashAggregateExec, finalAgg: HashAggregateExec): Boolean = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like reverse enginering the AggUtils. Could we just link the partial and final agg when they are constructed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tanelk - yeah I agree this is mostly reverse engineering and we can do a better job here. I tried link partial and final agg in AggUtils and check linked physical plan to be same or not. This does not quite work due to we are doing top-down planning, and the linked partial agg not being same as planned partial agg (having PlanLater operator in linked partial agg).

I found a more elegant way to do it, by checking the linked logical plan of both aggs to be same. Updated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @cloud-fan for review, thanks.

val partialGroupExprs = partialAgg.groupingExpressions
val finalGroupExprs = finalAgg.groupingExpressions
val partialAggExprs = partialAgg.aggregateExpressions
val finalAggExprs = finalAgg.aggregateExpressions
val partialAggAttrs = partialAggExprs.flatMap(_.aggregateFunction.aggBufferAttributes)
val finalAggAttrs = finalAggExprs.map(_.resultAttribute)
val partialResultExprs = partialGroupExprs ++
partialAggExprs.flatMap(_.aggregateFunction.inputAggBufferAttributes)

val groupExprsEqual = partialGroupExprs.length == finalGroupExprs.length &&
partialGroupExprs.zip(finalGroupExprs).forall {
case (e1, e2) => e1.semanticEquals(e2)
}
val aggExprsEqual = partialAggExprs.length == finalAggExprs.length &&
partialAggExprs.forall(_.mode == Partial) && finalAggExprs.forall(_.mode == Final) &&
partialAggExprs.zip(finalAggExprs).forall {
case (e1, e2) => e1.aggregateFunction.semanticEquals(e2.aggregateFunction)
}
val isPartialAggAttrsValid = partialAggAttrs.length == partialAgg.aggregateAttributes.length &&
partialAggAttrs.zip(partialAgg.aggregateAttributes).forall {
case (a1, a2) => a1.semanticEquals(a2)
if (partialAgg.aggregateExpressions.forall(_.mode == Partial) &&
finalAgg.aggregateExpressions.forall(_.mode == Final)) {
(finalAgg.logicalLink, partialAgg.logicalLink) match {
case (Some(agg1), Some(agg2)) => agg1.sameResult(agg2)
case _ => false
}
val isFinalAggAttrsValid = finalAggAttrs.length == finalAgg.aggregateAttributes.length &&
finalAggAttrs.zip(finalAgg.aggregateAttributes).forall {
case (a1, a2) => a1.semanticEquals(a2)
}
val isPartialResultExprsValid =
partialResultExprs.length == partialAgg.resultExpressions.length &&
partialResultExprs.zip(partialAgg.resultExpressions).forall {
case (a1, a2) => a1.semanticEquals(a2)
}
val isRequiredDistributionValid =
partialAgg.requiredChildDistributionExpressions.isEmpty &&
finalAgg.requiredChildDistributionExpressions.exists { exprs =>
exprs.length == finalGroupExprs.length &&
exprs.zip(finalGroupExprs).forall {
case (e1, e2) => e1.semanticEquals(e2)
}
}

groupExprsEqual && aggExprsEqual && isPartialAggAttrsValid && isFinalAggAttrsValid &&
isPartialResultExprsValid && isRequiredDistributionValid
} else {
false
}
}
}