diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index d6a39ecf53b8..f1e0017d3f87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -116,8 +116,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // Extract distinct aggregate expressions. val distinctAggGroups = aggExpressions - .filter(_.isDistinct) - .groupBy(_.aggregateFunction.children.toSet) + .filter(e => e.isDistinct && e.children.exists(!_.foldable)) + .groupBy(_.aggregateFunction.children.filter(!_.foldable).toSet) // Check if the aggregates contains functions that do not support partial aggregation. val existsNonPartial = aggExpressions.exists(!_.aggregateFunction.supportsPartial) @@ -136,8 +136,9 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e)) def patchAggregateFunctionChildren( af: AggregateFunction)( - attrs: Expression => Expression): AggregateFunction = { - af.withNewChildren(af.children.map(attrs)).asInstanceOf[AggregateFunction] + attrs: Expression => Option[Expression]): AggregateFunction = { + val newChildren = af.children.map(c => attrs(c).getOrElse(c)) + af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] } // Setup unique distinct aggregate children. @@ -161,7 +162,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val operators = expressions.map { e => val af = e.aggregateFunction val naf = patchAggregateFunctionChildren(af) { x => - evalWithinGroup(id, distinctAggChildAttrLookup(x)) + distinctAggChildAttrLookup.get(x).map(evalWithinGroup(id, _)) } (e, e.copy(aggregateFunction = naf, isDistinct = false)) } @@ -170,8 +171,11 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Setup expand for the 'regular' aggregate expressions. - val regularAggExprs = aggExpressions.filter(!_.isDistinct) - val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct + val regularAggExprs = aggExpressions + .filter(e => e.isDistinct && e.children.exists(!_.foldable)) + val regularAggChildren = regularAggExprs + .flatMap(_.aggregateFunction.children.filter(!_.foldable)) + .distinct val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair) // Setup aggregates for 'regular' aggregate expressions. @@ -179,7 +183,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val regularAggChildAttrLookup = regularAggChildAttrMap.toMap val regularAggOperatorMap = regularAggExprs.map { e => // Perform the actual aggregation in the initial aggregate. - val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup) + val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup.get) val operator = Alias(e.copy(aggregateFunction = af), e.sql)() // Select the result of the first aggregate in the last aggregate.