Skip to content
Prev Previous commit
Update
  • Loading branch information
bersprockets committed Oct 11, 2022
commit f7d29df9ac7541c5fe727a6fa037fd9e6a3d9a07
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
// Setup unique distinct aggregate children.
val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if this is necessary, but it's better to use ExpressionSet(distinctAggGroups.keySet.flatten).toSeq, instead of calling .distinct on Seq[Expression]

val distinctAggChildAttrMap = distinctAggChildren.map { e =>
ExpressionSet(Seq(e)) -> AttributeReference(e.sql, e.dataType, nullable = true)()
e.canonicalized -> AttributeReference(e.sql, e.dataType, nullable = true)()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we can update expressionAttributePair.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expressionAttributePair is used in two other places, though, for regular aggregate children and filter expressions where the key does not need to be canonicalized.

}
val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2)
// Setup all the filters in distinct aggregate.
Expand Down Expand Up @@ -293,9 +293,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
val naf = if (af.children.forall(_.foldable)) {
af
} else {
patchAggregateFunctionChildren(af) { x1 =>
val es = ExpressionSet(Seq(x1))
distinctAggChildAttrLookup.get(es)
patchAggregateFunctionChildren(af) { x =>
distinctAggChildAttrLookup.get(x.canonicalized)
}
}
val newCondition = if (condition.isDefined) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -562,8 +562,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is disallowed because those two distinct
// aggregates have different column expressions.
val distinctExpressions =
functionsWithDistinct.flatMap(
_.aggregateFunction.children.filterNot(_.foldable)).distinct
functionsWithDistinct.head.aggregateFunction.children.filterNot(_.foldable)
val normalizedNamedDistinctExpressions = distinctExpressions.map { e =>
// Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here
// because `distinctExpressions` is not extracted during logical phase.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,17 @@ object AggUtils {
}

// 3. Create an Aggregate operator for partial aggregation (for distinct)
val distinctColumnAttributeLookup = CUtils.toMap(distinctExpressions, distinctAttributes)
val distinctColumnAttributeLookup = CUtils.toMap(distinctExpressions.map(_.canonicalized),
distinctAttributes)
val rewrittenDistinctFunctions = functionsWithDistinct.map {
// Children of an AggregateFunction with DISTINCT keyword has already
// been evaluated. At here, we need to replace original children
// to AttributeReferences.
case agg @ AggregateExpression(aggregateFunction, mode, true, _, _) =>
aggregateFunction.transformDown(distinctColumnAttributeLookup)
.asInstanceOf[AggregateFunction]
aggregateFunction.transformDown {
case e: Expression if distinctColumnAttributeLookup.contains(e.canonicalized) =>
distinctColumnAttributeLookup(e.canonicalized)
}.asInstanceOf[AggregateFunction]
case agg =>
throw new IllegalArgumentException(
"Non-distinct aggregate is found in functionsWithDistinct " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
// 2 distinct columns with different order
val query3 = sql("SELECT corr(DISTINCT j, k), count(DISTINCT k, j) FROM v GROUP BY i")
assertNoExpand(query3.queryExecution.executedPlan)

// SPARK-40382: 1 distinct expression with cosmetic differences
val query4 = sql("SELECT sum(DISTINCT j), max(DISTINCT J) FROM v GROUP BY i")
assertNoExpand(query4.queryExecution.executedPlan)
}
}

Expand Down