-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-18137][SQL]Fix RewriteDistinctAggregates UnresolvedException when a UDAF has a foldable TypeCheck #15668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
…hild;if has unfoldable children,it will only expand the unfoldable children
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPl | |
| import org.apache.spark.sql.catalyst.rules.Rule | ||
| import org.apache.spark.sql.types.IntegerType | ||
|
|
||
| /** | ||
| /* | ||
| * This rule rewrites an aggregate query with distinct aggregations into an expanded double | ||
| * aggregation in which the regular aggregation expressions and every distinct clause is aggregated | ||
| * in a separate group. The results are then combined in a second aggregate. | ||
|
|
@@ -115,9 +115,19 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |
| } | ||
|
|
||
| // Extract distinct aggregate expressions. | ||
| val distinctAggGroups = aggExpressions | ||
| .filter(_.isDistinct) | ||
| .groupBy(_.aggregateFunction.children.toSet) | ||
| val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy{ | ||
|
||
| e => | ||
| if (e.aggregateFunction.children.exists(!_.foldable)) { | ||
|
||
| // Only expand the unfoldable children | ||
| e.aggregateFunction.children.filter(!_.foldable).toSet | ||
| } else { | ||
| // If aggregateFunction's children are all foldable | ||
| // we must expand at least one of the children (here we take the first child), | ||
| // or If we don't, we will get the wrong result, for example: | ||
| // count(distinct 1) will be explained to count(1) after the rewrite function. | ||
| e.aggregateFunction.children.take(1).toSet | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a good catch. It would be great if we could git rid of this by constant folding (not needed in this PR). Another way of getting rid of this, would be by creating a separate processing group for these distincts. |
||
| } | ||
| } | ||
|
|
||
| // Check if the aggregates contains functions that do not support partial aggregation. | ||
| val existsNonPartial = aggExpressions.exists(!_.aggregateFunction.supportsPartial) | ||
|
|
@@ -134,27 +144,19 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |
|
|
||
| // Functions used to modify aggregate functions and their inputs. | ||
| def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e)) | ||
| def patchAggregateFunctionChildren( | ||
| af: AggregateFunction)( | ||
| attrs: Expression => Expression): AggregateFunction = { | ||
| af.withNewChildren(af.children.map(attrs)).asInstanceOf[AggregateFunction] | ||
| def patchAggregateFunctionChildren(af: AggregateFunction)( | ||
|
||
| attrs: Expression => Option[Expression]): AggregateFunction = { | ||
| val newChildren = af.children.map(c => attrs(c).getOrElse(c)) | ||
| af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] | ||
| } | ||
|
|
||
| // Setup unique distinct aggregate children. | ||
| val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct | ||
| val distinctAggChildFoldable = distinctAggChildren.filter(_.foldable) | ||
| // 1.only unfoldable child should be expand | ||
| // 2.if foldable child mapped to AttributeRefference using expressionAttributePair, | ||
| // the udaf function(such as ApproximatePercentile) | ||
| // which has a foldable TypeCheck will failed,because AttributeRefference is unfoldable | ||
| val distinctAggChildUnFoldableAttrMap = distinctAggChildren | ||
| .filter(!_.foldable).map(expressionAttributePair) | ||
|
|
||
| val distinctAggChildrenUnFoldableAttrs = distinctAggChildUnFoldableAttrMap.map(_._2) | ||
| val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) | ||
| val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) | ||
|
|
||
| // Setup expand & aggregate operators for distinct aggregate expressions. | ||
| val distinctAggChildAttrLookup = (distinctAggChildUnFoldableAttrMap | ||
| ++ distinctAggChildFoldable.map(c => c -> c)).toMap | ||
| val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap | ||
| val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map { | ||
| case ((group, expressions), i) => | ||
| val id = Literal(i + 1) | ||
|
|
@@ -169,7 +171,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |
| val operators = expressions.map { e => | ||
| val af = e.aggregateFunction | ||
| val naf = patchAggregateFunctionChildren(af) { x => | ||
| evalWithinGroup(id, distinctAggChildAttrLookup(x)) | ||
| distinctAggChildAttrLookup.get(x).map(evalWithinGroup(id, _)) | ||
| } | ||
| (e, e.copy(aggregateFunction = naf, isDistinct = false)) | ||
| } | ||
|
|
@@ -178,20 +180,20 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |
| } | ||
|
|
||
| // Setup expand for the 'regular' aggregate expressions. | ||
| val regularAggExprs = aggExpressions.filter(!_.isDistinct) | ||
| val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct | ||
| val regularAggChildFoldable = regularAggChildren.filter(_.foldable) | ||
| val regularAggChildUnFoldable = regularAggChildren.filter(!_.foldable) | ||
| val regularAggChildUnFoldableAttrMap = regularAggChildUnFoldable | ||
| .map(expressionAttributePair) | ||
| val regularAggChildUnFoldableAttrs = regularAggChildUnFoldableAttrMap.map(_._2) | ||
| // only expand unfoldable children | ||
| val regularAggExprs = aggExpressions | ||
| .filter(e => !e.isDistinct && e.children.exists(!_.foldable)) | ||
| val regularAggChildren = regularAggExprs | ||
| .flatMap(_.aggregateFunction.children.filter(!_.foldable)) | ||
| .distinct | ||
| val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair) | ||
|
|
||
| // Setup aggregates for 'regular' aggregate expressions. | ||
| val regularGroupId = Literal(0) | ||
| val regularAggChildAttrLookup = (regularAggChildUnFoldableAttrMap | ||
| ++ regularAggChildFoldable.map(c => c -> c)).toMap | ||
| val regularAggChildAttrLookup = regularAggChildAttrMap.toMap | ||
| val regularAggOperatorMap = regularAggExprs.map { e => | ||
| // Perform the actual aggregation in the initial aggregate. | ||
| val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup) | ||
| val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup.get) | ||
| val operator = Alias(e.copy(aggregateFunction = af), e.sql)() | ||
|
|
||
| // Select the result of the first aggregate in the last aggregate. | ||
|
|
@@ -219,13 +221,13 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |
| Seq(a.groupingExpressions ++ | ||
| distinctAggChildren.map(nullify) ++ | ||
| Seq(regularGroupId) ++ | ||
| regularAggChildUnFoldable) | ||
| regularAggChildren) | ||
| } else { | ||
| Seq.empty[Seq[Expression]] | ||
| } | ||
|
|
||
| // Construct the distinct aggregate input projections. | ||
| val regularAggNulls = regularAggChildUnFoldable.map(nullify) | ||
| val regularAggNulls = regularAggChildren.map(nullify) | ||
| val distinctAggProjections = distinctAggOperatorMap.map { | ||
| case (projection, _) => | ||
| a.groupingExpressions ++ | ||
|
|
@@ -236,22 +238,21 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |
| // Construct the expand operator. | ||
| val expand = Expand( | ||
| regularAggProjection ++ distinctAggProjections, | ||
| groupByAttrs ++ distinctAggChildrenUnFoldableAttrs ++ Seq(gid) | ||
| ++ regularAggChildUnFoldableAttrs, | ||
| groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.map(_._2), | ||
| a.child) | ||
|
|
||
| // Construct the first aggregate operator. This de-duplicates the all the children of | ||
| // distinct operators, and applies the regular aggregate operators. | ||
| val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildrenUnFoldableAttrs :+ gid | ||
| val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid | ||
| val firstAggregate = Aggregate( | ||
| firstAggregateGroupBy, | ||
| firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2), | ||
| expand) | ||
|
|
||
| // Construct the second aggregate | ||
| val transformations: Map[Expression, Expression] = | ||
| (distinctAggOperatorMap.flatMap(_._2) ++ | ||
| regularAggOperatorMap.map(e => (e._1, e._3))).toMap | ||
| (distinctAggOperatorMap.flatMap(_._2) ++ | ||
|
||
| regularAggOperatorMap.map(e => (e._1, e._3))).toMap | ||
|
|
||
| val patchedAggExpressions = a.aggregateExpressions.map { e => | ||
| e.transformDown { | ||
|
|
@@ -274,9 +275,9 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |
| private def nullify(e: Expression) = Literal.create(null, e.dataType) | ||
|
|
||
| private def expressionAttributePair(e: Expression) = | ||
| // We are creating a new reference here instead of reusing the attribute in case of a | ||
| // NamedExpression. This is done to prevent collisions between distinct and regular aggregate | ||
| // children, in this case attribute reuse causes the input of the regular aggregate to bound to | ||
| // the (nulled out) input of the distinct aggregate. | ||
| // We are creating a new reference here instead of reusing the attribute in case of a | ||
|
||
| // NamedExpression. This is done to prevent collisions between distinct and regular aggregate | ||
| // children, in this case attribute reuse causes the input of the regular aggregate to bound to | ||
| // the (nulled out) input of the distinct aggregate. | ||
| e -> AttributeReference(e.sql, e.dataType, nullable = true)() | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -150,14 +150,22 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { | |
| } | ||
|
|
||
| test("Generic UDAF aggregates") { | ||
| checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999))" + | ||
| ", count(distinct key),sum(distinct key) FROM src LIMIT 1"), | ||
| sql("SELECT max(key), count(distinct key),sum(distinct key) FROM src LIMIT 1") | ||
| checkAnswer(sql("SELECT percentile_approx(2, 0.99999), " + | ||
|
||
| "sum(distinct 1), count(distinct 1,2,3,4) FROM src LIMIT 1"), | ||
| sql("SELECT 2, 1, 1 FROM src LIMIT 1") | ||
| .collect().toSeq) | ||
|
|
||
| checkAnswer(sql("SELECT ceiling(percentile_approx(distinct key, 0.99999))" + | ||
| ", count(distinct key), sum(distinct key), " + | ||
| "count(distinct 1), sum(distinct 1), sum(1) FROM src LIMIT 1"), | ||
| sql("SELECT max(key), count(distinct key), sum(distinct key)," + | ||
| " 1, 1, sum(1) FROM src LIMIT 1") | ||
| .collect().toSeq) | ||
|
|
||
| checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.09999 + 0.9))" + | ||
| ", count(distinct key),sum(distinct key),1 FROM src LIMIT 1"), | ||
| sql("SELECT max(key), count(distinct key),sum(distinct key), 1 FROM src LIMIT 1") | ||
| checkAnswer(sql("SELECT ceiling(percentile_approx(distinct key, 0.9 + 0.09999))" + | ||
| ", count(distinct key), sum(distinct key), " + | ||
| "count(distinct 1), sum(distinct 1), sum(1) FROM src LIMIT 1"), | ||
| sql("SELECT max(key), count(distinct key), sum(distinct key), 1, 1, sum(1) FROM src LIMIT 1") | ||
| .collect().toSeq) | ||
|
|
||
| checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999D)) FROM src LIMIT 1"), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you revert this? this breaks scaladoc.