Skip to content
Prev Previous commit
Next Next commit
Handle case of one distinct grouping with superficially different fun…
…ction children to Spark strategies
  • Loading branch information
bersprockets committed Oct 11, 2022
commit 9938252d65861651601cef2db24ea12fa5a1ce16
Original file line number Diff line number Diff line change
Expand Up @@ -405,28 +405,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
}
Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate)
} else {
// We may have one distinct group only because we grouped using ExpressionSet.
// To prevent SparkStrategies from complaining during sanity check, we need to check whether
// the original list of aggregate expressions had multiple distinct groups and, if so,
// patch that list so we have only one distinct group.
val funcChildren = distinctAggs.flatMap { e =>
e.aggregateFunction.children.filter(!_.foldable)
}
val funcChildrenLookup = funcChildren.map { e =>
(e, funcChildren.find(fc => e.semanticEquals(fc)).getOrElse(e))
}.toMap

if (funcChildrenLookup.keySet.size > funcChildrenLookup.values.toSet.size) {
val patchedAggExpressions = a.aggregateExpressions.map { e =>
e.transformDown {
case e: Expression =>
funcChildrenLookup.getOrElse(e, e)
}.asInstanceOf[NamedExpression]
}
a.copy(aggregateExpressions = patchedAggExpressions)
} else {
a
}
a
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -527,8 +527,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {

val (functionsWithDistinct, functionsWithoutDistinct) =
aggregateExpressions.partition(_.isDistinct)
if (functionsWithDistinct.map(
_.aggregateFunction.children.filterNot(_.foldable).toSet).distinct.length > 1) {
val distinctAggChildSets = functionsWithDistinct.map { ae =>
ExpressionSet(ae.aggregateFunction.children.filterNot(_.foldable))
}.distinct
if (distinctAggChildSets.length > 1) {
// This is a sanity check. We should not reach here when we have multiple distinct
// column sets. Our `RewriteDistinctAggregates` should take care this case.
throw new IllegalStateException(
Expand Down Expand Up @@ -560,7 +562,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is disallowed because those two distinct
// aggregates have different column expressions.
val distinctExpressions =
functionsWithDistinct.head.aggregateFunction.children.filterNot(_.foldable)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I'm a little confused. Why do we change here? Since all children are semantically equivalent, we can just pick the first distinct function. If we need to look up the child later, we should make sure it uses ExpressionSet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I'm a little confused.

Ah yes, possibly I was too. I had not read all of planAggregateWithOneDistinct yet, and I see the creation of rewrittenDistinctFunctions, where I can possibly take advantage of semantic equivalence.

functionsWithDistinct.flatMap(
_.aggregateFunction.children.filterNot(_.foldable)).distinct
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My implementation here has an odd effect in the case where all child sets are semantically equivalent but cosmetically different, e.g.:

explain select k, sum(distinct c1 + 1), avg(distinct 1 + c1), count(distinct 1 + C1)
from v1
group by k;

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- HashAggregate(keys=[k#87], functions=[sum(distinct (c1#88 + 1)#99), avg(distinct (1 + c1#88)#100), count(distinct (1 + C1#88)#101)])
   +- Exchange hashpartitioning(k#87, 200), ENSURE_REQUIREMENTS, [plan_id=136]
      +- HashAggregate(keys=[k#87], functions=[partial_sum(distinct (c1#88 + 1)#99), partial_avg(distinct (1 + c1#88)#100), partial_count(distinct (1 + C1#88)#101)])
         +- HashAggregate(keys=[k#87, (c1#88 + 1)#99, (1 + c1#88)#100, (1 + C1#88)#101], functions=[])
            +- Exchange hashpartitioning(k#87, (c1#88 + 1)#99, (1 + c1#88)#100, (1 + C1#88)#101, 200), ENSURE_REQUIREMENTS, [plan_id=132]
               +- HashAggregate(keys=[k#87, (c1#88 + 1) AS (c1#88 + 1)#99, (1 + c1#88) AS (1 + c1#88)#100, (1 + C1#88) AS (1 + C1#88)#101], functions=[])
                  +- LocalTableScan [k#87, c1#88]

The grouping keys in the first aggregate should include the children of the distinct aggregations, and they do. But because I kept the children as cosmetically different (I no longer patch them in RewriteDistinctAggregates when handling the fall-through case), the grouping keys now include each cosmetic variation (c1 + 1, 1 + c1, and 1 + C1). If I remove one cosmetic variation, the final aggregate gets an error (because one of the aggregation expressions will refer to attributes that were not output in previous plan nodes).

My earlier implementation (where I patch the aggregate expressions in the fall-through case so there are no more cosmetic variations) doesn't have this oddity:

explain select k, sum(distinct c1 + 1), avg(distinct 1 + c1), count(distinct 1 + C1)
from v1
group by k;

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- HashAggregate(keys=[k#8], functions=[sum(distinct (c1#9 + 1)#20), avg(distinct (c1#9 + 1)#20), count(distinct (c1#9 + 1)#20)])
   +- Exchange hashpartitioning(k#8, 200), ENSURE_REQUIREMENTS, [plan_id=30]
      +- HashAggregate(keys=[k#8], functions=[partial_sum(distinct (c1#9 + 1)#20), partial_avg(distinct (c1#9 + 1)#20), partial_count(distinct (c1#9 + 1)#20)])
         +- HashAggregate(keys=[k#8, (c1#9 + 1)#20], functions=[])
            +- Exchange hashpartitioning(k#8, (c1#9 + 1)#20, 200), ENSURE_REQUIREMENTS, [plan_id=26]
               +- HashAggregate(keys=[k#8, (c1#9 + 1) AS (c1#9 + 1)#20], functions=[])
                  +- LocalTableScan [k#8, c1#9]

Also my earlier implementation seems about 22% faster for the case where all child sets are semantically equivalent but cosmetically different. I assume because the rows output from the first physical aggregation are narrower (but I have not dug down too deep on this).

val normalizedNamedDistinctExpressions = distinctExpressions.map { e =>
// Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here
// because `distinctExpressions` is not extracted during logical phase.
Expand Down