Skip to content
Prev Previous commit
Next Next commit
Simplify
  • Loading branch information
bersprockets committed Oct 11, 2022
commit 4a40f910f4fd44a526d5959fc255e5702aa29151
Original file line number Diff line number Diff line change
Expand Up @@ -213,32 +213,17 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
case a: Aggregate if mayNeedtoRewrite(a) => rewrite(a)
}

def rewrite(a: Aggregate): Aggregate = {
def rewrite(aRaw: Aggregate): Aggregate = {
// make children of distinct aggregations the same if they are different
// only because of superficial differences.
val a = getSanitizedAggregate(aRaw)

val aggExpressions = collectAggregateExprs(a)
val distinctAggs = aggExpressions.filter(_.isDistinct)

val funcChildren = distinctAggs.flatMap { e =>
e.aggregateFunction.children.filter(!_.foldable)
}

// For each function child, find the first instance that is semantically equivalent.
// E.g., assume funcChildren is the following three expressions:
// [('a + 1), (1 + 'a), 'b]
// then we want the map to be:
// Map(('a + 1) -> ('a + 1), (1 + 'a) -> ('a + 1), 'b -> 'b)
// That is, both ('a + 1) and (1 + 'a) map to ('a + 1).
// This is an n^2 operation, where n is the number of distinct aggregate children, but it
// happens only once every time this rule is called.
val funcChildrenLookup = funcChildren.map { e =>
(e, funcChildren.find(fc => e.semanticEquals(fc)).getOrElse(e))
}.toMap

// Extract distinct aggregate expressions.
val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e =>
val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).map { fc =>
funcChildrenLookup.getOrElse(fc, fc)
}.toSet
val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet
if (unfoldableChildren.nonEmpty) {
// Only expand the unfoldable children
unfoldableChildren
Expand All @@ -253,42 +238,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
}
}

// get the count of aggregation groups that takes into account
// even superficial differences in the function children
val distictAggGroupsCount = aggExpressions.filter(_.isDistinct).map { e =>
val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet
if (unfoldableChildren.nonEmpty) {
unfoldableChildren
} else {
e.aggregateFunction.children.take(1).toSet
}
}.toSet.size

def patchAggregateFunctionChildren(
af: AggregateFunction)(
attrs: Expression => Option[Expression]): AggregateFunction = {
val newChildren = af.children.map(c => attrs(c).getOrElse(c))
af.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
}

// Aggregation strategy can handle queries with a single distinct group without filter clause.
if (distinctAggGroups.size == 1 && distictAggGroupsCount > 1
&& !distinctAggs.exists(_.filter.isDefined)) {
// we have multiple groups only because of
// superficial differences. Make them the same so that SparkStrategies
// doesn't complain during sanity check. That is, if we have an aggList of:
// [count(distinct b + 1), sum(distinct 1 + b), sum(c)]
// Change it to:
// [count(distinct b + 1), sum(distinct b + 1), sum(c)]
// therefore we have distinct aggregations over only one expression
val patchedAggExpressions = a.aggregateExpressions.map { e =>
e.transformDown {
case e: Expression =>
funcChildrenLookup.getOrElse(e, e)
}.asInstanceOf[NamedExpression]
}
a.copy(aggregateExpressions = patchedAggExpressions)
} else if (distinctAggGroups.size > 1 || distinctAggs.exists(_.filter.isDefined)) {
if (distinctAggGroups.size > 1 || distinctAggs.exists(_.filter.isDefined)) {
// Create the attributes for the grouping id and the group by clause.
val gid = AttributeReference("gid", IntegerType, nullable = false)()
val groupByMap = a.groupingExpressions.collect {
Expand Down Expand Up @@ -337,12 +288,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
af
} else {
patchAggregateFunctionChildren(af) { x =>
// x might not exactly match any key of distinctAggChildAttrLookup
// only because `distinctAggChildAttrLookup`'s keys have been de-duped
// based on semantic equivalence. So we need to translate x to the
// semantic equivalent that we are actually using.
val x2 = funcChildrenLookup.getOrElse(x, x)
distinctAggChildAttrLookup.get(x2)
distinctAggChildAttrLookup.get(x)
}
}
val newCondition = if (condition.isDefined) {
Expand Down Expand Up @@ -463,6 +409,42 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
}}
}

private def patchAggregateFunctionChildren(
af: AggregateFunction)(
attrs: Expression => Option[Expression]): AggregateFunction = {
val newChildren = af.children.map(c => attrs(c).getOrElse(c))
af.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
}

private def getSanitizedAggregate(a: Aggregate): Aggregate = {
val aggExpressions = collectAggregateExprs(a)
val distinctAggs = aggExpressions.filter(_.isDistinct)

val funcChildren = distinctAggs.flatMap { e =>
e.aggregateFunction.children.filter(!_.foldable)
}

// For each function child, find the first instance that is semantically equivalent.
// E.g., assume funcChildren is the following three expressions:
// [('a + 1), (1 + 'a), 'b]
// then we want the map to be:
// Map(('a + 1) -> ('a + 1), (1 + 'a) -> ('a + 1), 'b -> 'b)
// That is, both ('a + 1) and (1 + 'a) map to ('a + 1).
// This is an n^2 operation, where n is the number of distinct aggregate children, but it
// happens only once every time this rule is called.
val funcChildrenLookup = funcChildren.map { e =>
(e, funcChildren.find(fc => e.semanticEquals(fc)).getOrElse(e))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the code lead to find out itself.

}.toMap

val patchedAggExpressions = a.aggregateExpressions.map { e =>
e.transformDown {
case e: Expression =>
funcChildrenLookup.getOrElse(e, e)
}.asInstanceOf[NamedExpression]
}
a.copy(aggregateExpressions = patchedAggExpressions)
}

private def nullify(e: Expression) = Literal.create(null, e.dataType)

private def expressionAttributePair(e: Expression) =
Expand Down