Skip to content
Next Next commit
First attempt
  • Loading branch information
bersprockets committed Oct 11, 2022
commit a5a6fc0582b90b619f5ec732ca87165c83b519ee
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,27 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
val aggExpressions = collectAggregateExprs(a)
val distinctAggs = aggExpressions.filter(_.isDistinct)

val funcChildren = distinctAggs.flatMap { e =>
e.aggregateFunction.children.filter(!_.foldable)
}

// For each function child, find the first instance that is semantically equivalent.
// E.g., assume funcChildren is the following three expressions:
// [('a + 1), (1 + 'a), 'b]
// then we want the map to be:
// Map(('a + 1) -> ('a + 1), (1 + 'a) -> ('a + 1), 'b -> 'b)
// That is, both ('a + 1) and (1 + 'a) map to ('a + 1).
// This is an n^2 operation, where n is the number of distinct aggregate children, but it
// happens only once every time this rule is called.
val funcChildrenLookup = funcChildren.map { e =>
(e, funcChildren.find(fc => e.semanticEquals(fc)).getOrElse(e))
}.toMap

// Extract distinct aggregate expressions.
val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e =>
val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet
val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).map { fc =>
funcChildrenLookup.getOrElse(fc, fc)
}.toSet
if (unfoldableChildren.nonEmpty) {
// Only expand the unfoldable children
unfoldableChildren
Expand Down Expand Up @@ -292,7 +310,12 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
af
} else {
patchAggregateFunctionChildren(af) { x =>
distinctAggChildAttrLookup.get(x)
// x might not exactly match any key of distinctAggChildAttrLookup
// only because `distinctAggChildAttrLookup`'s keys have been de-duped
// based on semantic equivalence. So we need to translate x to the
// semantic equivalent that we are actually using.
val x2 = funcChildrenLookup(x)
distinctAggChildAttrLookup.get(x2)
}
}
val newCondition = if (condition.isDefined) {
Expand Down