Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[SPARK-18137][SQL]Fix RewriteDistinctAggregates UnresolvedException w…
…hen the UDAF has a foldable TypeCheck
  • Loading branch information
root authored and root committed Oct 28, 2016
commit 7029e891ba25a026a8daf2664180166ee387bba5
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,19 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {

// Setup unique distinct aggregate children.
val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct
val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair)
val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2)
val distinctAggChildFoldable = distinctAggChildren.filter(_.foldable)
// 1.only unfoldable child should be expand
// 2.if foldable child mapped to AttributeRefference using expressionAttributePair,
// the udaf function(such as ApproximatePercentile)
// which has a foldable TypeCheck will failed,because AttributeRefference is unfoldable
val distinctAggChildUnFoldableAttrMap = distinctAggChildren
.filter(!_.foldable).map(expressionAttributePair)

val distinctAggChildrenUnFoldableAttrs = distinctAggChildUnFoldableAttrMap.map(_._2)

// Setup expand & aggregate operators for distinct aggregate expressions.
val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap
val distinctAggChildAttrLookup = (distinctAggChildUnFoldableAttrMap
++ distinctAggChildFoldable.map(c => c -> c)).toMap
val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map {
case ((group, expressions), i) =>
val id = Literal(i + 1)
Expand All @@ -172,11 +180,15 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
// Setup expand for the 'regular' aggregate expressions.
val regularAggExprs = aggExpressions.filter(!_.isDistinct)
val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct
val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair)

val regularAggChildFoldable = regularAggChildren.filter(_.foldable)
val regularAggChildUnFoldable = regularAggChildren.filter(!_.foldable)
val regularAggChildUnFoldableAttrMap = regularAggChildUnFoldable
.map(expressionAttributePair)
val regularAggChildUnFoldableAttrs = regularAggChildUnFoldableAttrMap.map(_._2)
// Setup aggregates for 'regular' aggregate expressions.
val regularGroupId = Literal(0)
val regularAggChildAttrLookup = regularAggChildAttrMap.toMap
val regularAggChildAttrLookup = (regularAggChildUnFoldableAttrMap
++ regularAggChildFoldable.map(c => c -> c)).toMap
val regularAggOperatorMap = regularAggExprs.map { e =>
// Perform the actual aggregation in the initial aggregate.
val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup)
Expand Down Expand Up @@ -207,13 +219,13 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
Seq(a.groupingExpressions ++
distinctAggChildren.map(nullify) ++
Seq(regularGroupId) ++
regularAggChildren)
regularAggChildUnFoldable)
} else {
Seq.empty[Seq[Expression]]
}

// Construct the distinct aggregate input projections.
val regularAggNulls = regularAggChildren.map(nullify)
val regularAggNulls = regularAggChildUnFoldable.map(nullify)
val distinctAggProjections = distinctAggOperatorMap.map {
case (projection, _) =>
a.groupingExpressions ++
Expand All @@ -224,12 +236,13 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
// Construct the expand operator.
val expand = Expand(
regularAggProjection ++ distinctAggProjections,
groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.map(_._2),
groupByAttrs ++ distinctAggChildrenUnFoldableAttrs ++ Seq(gid)
++ regularAggChildUnFoldableAttrs,
a.child)

// Construct the first aggregate operator. This de-duplicates the all the children of
// distinct operators, and applies the regular aggregate operators.
val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid
val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildrenUnFoldableAttrs :+ gid
val firstAggregate = Aggregate(
firstAggregateGroupBy,
firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,16 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
}

test("Generic UDAF aggregates") {
checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999))" +
", count(distinct key),sum(distinct key) FROM src LIMIT 1"),
sql("SELECT max(key), count(distinct key),sum(distinct key) FROM src LIMIT 1")
.collect().toSeq)

checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.09999 + 0.9))" +
", count(distinct key),sum(distinct key),1 FROM src LIMIT 1"),
sql("SELECT max(key), count(distinct key),sum(distinct key), 1 FROM src LIMIT 1")
.collect().toSeq)

checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999D)) FROM src LIMIT 1"),
sql("SELECT max(key) FROM src LIMIT 1").collect().toSeq)

Expand Down