Skip to content
Prev Previous commit
Next Next commit
Rename
  • Loading branch information
bersprockets committed Oct 11, 2022
commit 165f558590dc46ffe50ebcbc9fcc9cd81f801c96
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,13 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
case a: Aggregate if mayNeedtoRewrite(a) => rewrite(a)
}

def rewrite(aRaw: Aggregate): Aggregate = {
def rewrite(aOrig: Aggregate): Aggregate = {
// make children of distinct aggregations the same if they are different
// only because of superficial differences.
val a = getSanitizedAggregate(aRaw)
// only because of superficial reasons, e.g.:
// "1 + col1" vs "col1 + 1", both become "1 + col1"
// or
// "col1" vs "Col1", both become "col1"
val a = reduceDistinctAggregateGroups(aOrig)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we just canonicalize the function inputs when group by them? e.g. change e.aggregateFunction.children.filter(!_.foldable).toSet to ExpressionSet(e.aggregateFunction.children.filter(!_.foldable))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I am working on it, just working through some small complications.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made the change to use ExpressionSet and also commented on some of the issues.

I still prefer 'sanitizing' each original function child to use the first semantically equivalent child, in essence creating a new set of "original" children, as it bypasses some complexities (in particular the one where we may lose some of the original children as keys when we group by ExpressionSet).


val aggExpressions = collectAggregateExprs(a)
val distinctAggs = aggExpressions.filter(_.isDistinct)
Expand Down Expand Up @@ -248,6 +251,13 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
}
val groupByAttrs = groupByMap.map(_._2)

def patchAggregateFunctionChildren(
af: AggregateFunction)(
attrs: Expression => Option[Expression]): AggregateFunction = {
val newChildren = af.children.map(c => attrs(c).getOrElse(c))
af.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
}

// Setup unique distinct aggregate children.
val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if this is necessary, but it's better to use ExpressionSet(distinctAggGroups.keySet.flatten).toSeq, instead of calling .distinct on Seq[Expression]

val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair)
Expand Down Expand Up @@ -409,14 +419,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
}}
}

private def patchAggregateFunctionChildren(
af: AggregateFunction)(
attrs: Expression => Option[Expression]): AggregateFunction = {
val newChildren = af.children.map(c => attrs(c).getOrElse(c))
af.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
}

private def getSanitizedAggregate(a: Aggregate): Aggregate = {
private def reduceDistinctAggregateGroups(a: Aggregate): Aggregate = {
val aggExpressions = collectAggregateExprs(a)
val distinctAggs = aggExpressions.filter(_.isDistinct)

Expand All @@ -436,6 +439,14 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
(e, funcChildren.find(fc => e.semanticEquals(fc)).getOrElse(e))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the code lead to find out itself.

}.toMap

val funcChildrenPatched = funcChildren.map { e =>
funcChildrenLookup.getOrElse(e, e)
}

if (funcChildren.distinct.size == funcChildrenPatched.distinct.size) {
return a;
}

val patchedAggExpressions = a.aggregateExpressions.map { e =>
e.transformDown {
case e: Expression =>
Expand Down