Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
if distinct agg only has foldable children,it will expand the first c…
…hild;if has unfoldable children,it will only expand the unfoldable children
  • Loading branch information
root authored and root committed Nov 7, 2016
commit 8a6dd8daf11f7a0c29b3afc04706ccddc390a1bf
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPl
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.IntegerType

/**
/*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you revert this? this breaks scaladoc.

* This rule rewrites an aggregate query with distinct aggregations into an expanded double
* aggregation in which the regular aggregation expressions and every distinct clause is aggregated
* in a separate group. The results are then combined in a second aggregate.
Expand Down Expand Up @@ -115,9 +115,19 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
}

// Extract distinct aggregate expressions.
val distinctAggGroups = aggExpressions
.filter(_.isDistinct)
.groupBy(_.aggregateFunction.children.toSet)
val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Space between groupBy and bracket.

e =>
if (e.aggregateFunction.children.exists(!_.foldable)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just materialize the nonFoldables. Instead of filtering them twice.

// Only expand the unfoldable children
e.aggregateFunction.children.filter(!_.foldable).toSet
} else {
// If aggregateFunction's children are all foldable
// we must expand at least one of the children (here we take the first child),
// or If we don't, we will get the wrong result, for example:
// count(distinct 1) will be explained to count(1) after the rewrite function.
e.aggregateFunction.children.take(1).toSet
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good catch. It would be great if we could git rid of this by constant folding (not needed in this PR). Another way of getting rid of this, would be by creating a separate processing group for these distincts.

}
}

// Check if the aggregates contains functions that do not support partial aggregation.
val existsNonPartial = aggExpressions.exists(!_.aggregateFunction.supportsPartial)
Expand All @@ -134,27 +144,19 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {

// Functions used to modify aggregate functions and their inputs.
def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e))
def patchAggregateFunctionChildren(
af: AggregateFunction)(
attrs: Expression => Expression): AggregateFunction = {
af.withNewChildren(af.children.map(attrs)).asInstanceOf[AggregateFunction]
def patchAggregateFunctionChildren(af: AggregateFunction)(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Style, please keep this the way it was.

attrs: Expression => Option[Expression]): AggregateFunction = {
val newChildren = af.children.map(c => attrs(c).getOrElse(c))
af.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
}

// Setup unique distinct aggregate children.
val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct
val distinctAggChildFoldable = distinctAggChildren.filter(_.foldable)
// 1.only unfoldable child should be expand
// 2.if foldable child mapped to AttributeRefference using expressionAttributePair,
// the udaf function(such as ApproximatePercentile)
// which has a foldable TypeCheck will failed,because AttributeRefference is unfoldable
val distinctAggChildUnFoldableAttrMap = distinctAggChildren
.filter(!_.foldable).map(expressionAttributePair)

val distinctAggChildrenUnFoldableAttrs = distinctAggChildUnFoldableAttrMap.map(_._2)
val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair)
val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2)

// Setup expand & aggregate operators for distinct aggregate expressions.
val distinctAggChildAttrLookup = (distinctAggChildUnFoldableAttrMap
++ distinctAggChildFoldable.map(c => c -> c)).toMap
val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap
val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map {
case ((group, expressions), i) =>
val id = Literal(i + 1)
Expand All @@ -169,7 +171,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
val operators = expressions.map { e =>
val af = e.aggregateFunction
val naf = patchAggregateFunctionChildren(af) { x =>
evalWithinGroup(id, distinctAggChildAttrLookup(x))
distinctAggChildAttrLookup.get(x).map(evalWithinGroup(id, _))
}
(e, e.copy(aggregateFunction = naf, isDistinct = false))
}
Expand All @@ -178,20 +180,20 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
}

// Setup expand for the 'regular' aggregate expressions.
val regularAggExprs = aggExpressions.filter(!_.isDistinct)
val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct
val regularAggChildFoldable = regularAggChildren.filter(_.foldable)
val regularAggChildUnFoldable = regularAggChildren.filter(!_.foldable)
val regularAggChildUnFoldableAttrMap = regularAggChildUnFoldable
.map(expressionAttributePair)
val regularAggChildUnFoldableAttrs = regularAggChildUnFoldableAttrMap.map(_._2)
// only expand unfoldable children
val regularAggExprs = aggExpressions
.filter(e => !e.isDistinct && e.children.exists(!_.foldable))
val regularAggChildren = regularAggExprs
.flatMap(_.aggregateFunction.children.filter(!_.foldable))
.distinct
val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair)

// Setup aggregates for 'regular' aggregate expressions.
val regularGroupId = Literal(0)
val regularAggChildAttrLookup = (regularAggChildUnFoldableAttrMap
++ regularAggChildFoldable.map(c => c -> c)).toMap
val regularAggChildAttrLookup = regularAggChildAttrMap.toMap
val regularAggOperatorMap = regularAggExprs.map { e =>
// Perform the actual aggregation in the initial aggregate.
val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup)
val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup.get)
val operator = Alias(e.copy(aggregateFunction = af), e.sql)()

// Select the result of the first aggregate in the last aggregate.
Expand Down Expand Up @@ -219,13 +221,13 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
Seq(a.groupingExpressions ++
distinctAggChildren.map(nullify) ++
Seq(regularGroupId) ++
regularAggChildUnFoldable)
regularAggChildren)
} else {
Seq.empty[Seq[Expression]]
}

// Construct the distinct aggregate input projections.
val regularAggNulls = regularAggChildUnFoldable.map(nullify)
val regularAggNulls = regularAggChildren.map(nullify)
val distinctAggProjections = distinctAggOperatorMap.map {
case (projection, _) =>
a.groupingExpressions ++
Expand All @@ -236,22 +238,21 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
// Construct the expand operator.
val expand = Expand(
regularAggProjection ++ distinctAggProjections,
groupByAttrs ++ distinctAggChildrenUnFoldableAttrs ++ Seq(gid)
++ regularAggChildUnFoldableAttrs,
groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.map(_._2),
a.child)

// Construct the first aggregate operator. This de-duplicates the all the children of
// distinct operators, and applies the regular aggregate operators.
val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildrenUnFoldableAttrs :+ gid
val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid
val firstAggregate = Aggregate(
firstAggregateGroupBy,
firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2),
expand)

// Construct the second aggregate
val transformations: Map[Expression, Expression] =
(distinctAggOperatorMap.flatMap(_._2) ++
regularAggOperatorMap.map(e => (e._1, e._3))).toMap
(distinctAggOperatorMap.flatMap(_._2) ++
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert this change

regularAggOperatorMap.map(e => (e._1, e._3))).toMap

val patchedAggExpressions = a.aggregateExpressions.map { e =>
e.transformDown {
Expand All @@ -274,9 +275,9 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
private def nullify(e: Expression) = Literal.create(null, e.dataType)

private def expressionAttributePair(e: Expression) =
// We are creating a new reference here instead of reusing the attribute in case of a
// NamedExpression. This is done to prevent collisions between distinct and regular aggregate
// children, in this case attribute reuse causes the input of the regular aggregate to bound to
// the (nulled out) input of the distinct aggregate.
// We are creating a new reference here instead of reusing the attribute in case of a
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert this change

// NamedExpression. This is done to prevent collisions between distinct and regular aggregate
// children, in this case attribute reuse causes the input of the regular aggregate to bound to
// the (nulled out) input of the distinct aggregate.
e -> AttributeReference(e.sql, e.dataType, nullable = true)()
}
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,22 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
}

test("Generic UDAF aggregates") {
checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999))" +
", count(distinct key),sum(distinct key) FROM src LIMIT 1"),
sql("SELECT max(key), count(distinct key),sum(distinct key) FROM src LIMIT 1")
checkAnswer(sql("SELECT percentile_approx(2, 0.99999), " +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use multiline strings for these tests.

"sum(distinct 1), count(distinct 1,2,3,4) FROM src LIMIT 1"),
sql("SELECT 2, 1, 1 FROM src LIMIT 1")
.collect().toSeq)

checkAnswer(sql("SELECT ceiling(percentile_approx(distinct key, 0.99999))" +
", count(distinct key), sum(distinct key), " +
"count(distinct 1), sum(distinct 1), sum(1) FROM src LIMIT 1"),
sql("SELECT max(key), count(distinct key), sum(distinct key)," +
" 1, 1, sum(1) FROM src LIMIT 1")
.collect().toSeq)

checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.09999 + 0.9))" +
", count(distinct key),sum(distinct key),1 FROM src LIMIT 1"),
sql("SELECT max(key), count(distinct key),sum(distinct key), 1 FROM src LIMIT 1")
checkAnswer(sql("SELECT ceiling(percentile_approx(distinct key, 0.9 + 0.09999))" +
", count(distinct key), sum(distinct key), " +
"count(distinct 1), sum(distinct 1), sum(1) FROM src LIMIT 1"),
sql("SELECT max(key), count(distinct key), sum(distinct key), 1, 1, sum(1) FROM src LIMIT 1")
.collect().toSeq)

checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999D)) FROM src LIMIT 1"),
Expand Down