Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
ae1186f
[SPARK-34581][SQL] Don't optimize out grouping expressions from aggre…
peter-toth Mar 21, 2021
5ab9f75
comment fix
peter-toth Mar 21, 2021
2293fd4
move logic to the beginning of optimization, simplify test
peter-toth Mar 22, 2021
3de19ca
regenerate approved plans
peter-toth Mar 22, 2021
04e61c5
Merge branch 'master' into SPARK-34581-keep-grouping-expressions
peter-toth Mar 23, 2021
6e05f14
define GroupingExpression as TaggingExpression
peter-toth Mar 23, 2021
09f1a85
move test to SQLQueryTestSuite
peter-toth Mar 24, 2021
f46b89d
add more explanation
peter-toth Mar 24, 2021
56589a3
Merge commit 'c8233f1be5c2f853f42cda367475eb135a83afd5' into SPARK-34…
peter-toth Mar 26, 2021
ea95bff
Merge commit '3951e3371a83578a81474ed99fb50d59f27aac62' into SPARK-34…
peter-toth Mar 31, 2021
7ea2306
Merge commit '89ae83d19b9652348a685550c2c49920511160d5' into SPARK-34…
peter-toth Apr 1, 2021
468534f
Merge commit '65da9287bc5112564836a555cd2967fc6b05856f' into SPARK-34…
peter-toth Apr 2, 2021
977c0bf
new GroupingExprRef approach
peter-toth Mar 27, 2021
c2ba804
simplify
peter-toth Apr 11, 2021
0622444
minor fixes
peter-toth Apr 12, 2021
343f35e
Merge commit 'e40fce919ab77f5faeb0bbd34dc86c56c04adbaa' into SPARK-34…
peter-toth Apr 12, 2021
2e79eb9
review fixes
peter-toth Apr 13, 2021
cff9b9a
fix latest test failures, add new test case
peter-toth Apr 14, 2021
78296a8
better non-deterministic test case
peter-toth Apr 14, 2021
72c173b
make new rules non excludable
peter-toth Apr 15, 2021
34f0439
Merge branch 'fork/master' into SPARK-34581-keep-grouping-expressions
peter-toth Apr 15, 2021
fb3a19d
fix validConstraints, minor changes
peter-toth Apr 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
simplify
  • Loading branch information
peter-toth committed Apr 11, 2021
commit c2ba80457bd86d11ad26311bbc3c42607f33b19a
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ class Analyzer(override val catalogManager: CatalogManager)

def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) =>
Aggregate(groups, assignAliases(aggs), child, false)
Aggregate(groups, assignAliases(aggs), child)

case Pivot(groupByOpt, pivotColumn, pivotValues, aggregates, child)
if child.resolved && groupByOpt.isDefined && hasUnresolvedAlias(groupByOpt.get) =>
Expand Down Expand Up @@ -599,7 +599,7 @@ class Analyzer(override val catalogManager: CatalogManager)
val aggregations = constructAggregateExprs(
finalGroupByExpressions, aggregationExprs, groupByAliases, groupingAttrs, gid)

Aggregate(groupingAttrs, aggregations, expand, false)
Aggregate(groupingAttrs, aggregations, expand)
}

private def findGroupingExprs(plan: LogicalPlan): Seq[Expression] = {
Expand Down Expand Up @@ -746,15 +746,14 @@ class Analyzer(override val catalogManager: CatalogManager)
case _ => Alias(pivotColumn, "__pivot_col")()
}
val bigGroup = groupByExprs :+ namedPivotCol
val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child, false)
val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child)
val pivotAggs = namedAggExps.map { a =>
Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, evalPivotValues)
.toAggregateExpression()
, "__pivot_" + a.sql)()
}
val groupByExprsAttr = groupByExprs.map(_.toAttribute)
val secondAgg =
Aggregate(groupByExprsAttr, groupByExprsAttr ++ pivotAggs, firstAgg, false)
val secondAgg = Aggregate(groupByExprsAttr, groupByExprsAttr ++ pivotAggs, firstAgg)
val pivotAggAttribute = pivotAggs.map(_.toAttribute)
val pivotOutputs = pivotValues.zipWithIndex.flatMap { case (value, i) =>
aggregates.zip(pivotAggAttribute).map { case (aggregate, pivotAtt) =>
Expand Down Expand Up @@ -791,7 +790,7 @@ class Analyzer(override val catalogManager: CatalogManager)
Alias(filteredAggregate, outputName(value, aggregate))()
}
}
Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child, false)
Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child)
}
}

Expand Down Expand Up @@ -1407,8 +1406,7 @@ class Analyzer(override val catalogManager: CatalogManager)
if (a.groupingExpressions.exists(_.isInstanceOf[UnresolvedOrdinal])) {
throw QueryCompilationErrors.starNotAllowedWhenGroupByOrdinalPositionUsedError()
} else {
a.copy(aggrExprWithGroupingRefs =
buildExpandedProjectList(a.aggregateExpressions, a.child))
a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child))
}
// If the script transformation input contains Stars, expand it.
case t: ScriptTransformation if containsStar(t.input) =>
Expand Down Expand Up @@ -1821,7 +1819,7 @@ class Analyzer(override val catalogManager: CatalogManager)
throw QueryCompilationErrors.groupByPositionRangeError(index, aggs.size, ordinal)
case o => o
}
Aggregate(newGroups, aggs, child, false)
Aggregate(newGroups, aggs, child)
}
}

Expand Down Expand Up @@ -1919,8 +1917,7 @@ class Analyzer(override val catalogManager: CatalogManager)
val missingAttrs = (AttributeSet(newExprs) -- a.outputSet).intersect(newChild.outputSet)
if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) {
// All the missing attributes are grouping expressions, valid case.
(newExprs,
a.copy(aggrExprWithGroupingRefs = aggExprs ++ missingAttrs, child = newChild))
(newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild))
} else {
// Need to add non-grouping attributes, invalid case.
(exprs, a)
Expand Down Expand Up @@ -2241,7 +2238,7 @@ class Analyzer(override val catalogManager: CatalogManager)
object GlobalAggregates extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case Project(projectList, child) if containsAggregates(projectList) =>
Aggregate(Nil, projectList, child, false)
Aggregate(Nil, projectList, child)
}

def containsAggregates(exprs: Seq[Expression]): Boolean = {
Expand Down Expand Up @@ -2290,7 +2287,7 @@ class Analyzer(override val catalogManager: CatalogManager)
val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")())

val aggregateWithExtraOrdering = aggregate.copy(
aggrExprWithGroupingRefs = aggregate.aggregateExpressions ++ aliasedOrdering)
aggregateExpressions = aggregate.aggregateExpressions ++ aliasedOrdering)

val resolvedAggregate: Aggregate =
executeSameContext(aggregateWithExtraOrdering).asInstanceOf[Aggregate]
Expand Down Expand Up @@ -2344,7 +2341,7 @@ class Analyzer(override val catalogManager: CatalogManager)
} else {
Project(aggregate.output,
Sort(finalSortOrders, global,
aggregate.copy(aggrExprWithGroupingRefs = originalAggExprs ++ needsPushDown)))
aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown)))
}
} catch {
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
Expand All @@ -2371,8 +2368,7 @@ class Analyzer(override val catalogManager: CatalogManager)
Aggregate(
agg.groupingExpressions,
Alias(filterCond, "havingCondition")() :: Nil,
agg.child,
false)
agg.child)
val resolvedOperator = executeSameContext(aggregatedCondition)
def resolvedAggregateFilter =
resolvedOperator
Expand Down Expand Up @@ -2427,7 +2423,7 @@ class Analyzer(override val catalogManager: CatalogManager)
val (aggregateExpressions, resolvedHavingCond) = resolvedInfo.get
Project(agg.output,
Filter(resolvedHavingCond,
agg.copy(aggrExprWithGroupingRefs = agg.aggregateExpressions ++ aggregateExpressions)))
agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions)))
} else {
filter
}
Expand Down Expand Up @@ -2557,7 +2553,7 @@ class Analyzer(override val catalogManager: CatalogManager)
other :: Nil
}

val newAgg = Aggregate(groupList, newAggList, child, false)
val newAgg = Aggregate(groupList, newAggList, child)
Project(projectExprs.toList, newAgg)

case p @ Project(projectList, _) if hasAggFunctionInGenerator(projectList) =>
Expand Down Expand Up @@ -2867,7 +2863,7 @@ class Analyzer(override val catalogManager: CatalogManager)
a.expressions.forall(_.resolved) =>
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
// Create an Aggregate operator to evaluate aggregation functions.
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child, false)
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
// Add a Filter operator for conditions in the Having clause.
val withFilter = Filter(condition, withAggregate)
val withWindow = addWindow(windowExpressions, withFilter)
Expand All @@ -2884,7 +2880,7 @@ class Analyzer(override val catalogManager: CatalogManager)
a.expressions.forall(_.resolved) =>
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
// Create an Aggregate operator to evaluate aggregation functions.
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child, false)
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
// Add Window operators.
val withWindow = addWindow(windowExpressions, withAggregate)

Expand Down Expand Up @@ -3542,7 +3538,7 @@ object CleanupAliases extends Rule[LogicalPlan] with AliasHelper {

case Aggregate(grouping, aggs, child) =>
val cleanedAggs = aggs.map(trimNonTopLevelAliases)
Aggregate(grouping.map(trimAliases), cleanedAggs, child, false)
Aggregate(grouping.map(trimAliases), cleanedAggs, child)

case Window(windowExprs, partitionSpec, orderSpec, child) =>
val cleanedWindowExprs = windowExprs.map(trimNonTopLevelAliases)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
case oldVersion @ Aggregate(_, aggregateExpressions, _)
if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
Seq((oldVersion, oldVersion.copy(
aggrExprWithGroupingRefs = newAliases(aggregateExpressions))))
aggregateExpressions = newAliases(aggregateExpressions))))

// We don't search the child plan recursively for the same reason as the above Project.
case _ @ Aggregate(_, aggregateExpressions, _)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ object UnsupportedOperationChecker extends Logging {
// Since the Distinct node will be replaced to Aggregate in the optimizer rule
// [[ReplaceDistinctWithAggregate]], here we also need to check all Distinct node by
// assuming it as Aggregate.
case d @ Distinct(c: LogicalPlan) if d.isStreaming =>
Aggregate(c.output, c.output, c, false)
case d @ Distinct(c: LogicalPlan) if d.isStreaming => Aggregate(c.output, c.output, c)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ object UpdateGroupingExprRefNullability extends Rule[LogicalPlan] {
case a: Aggregate =>
val nullabilities = a.groupingExpressions.map(_.nullable).toArray

val newAggrExprWithGroupingRefs =
a.aggrExprWithGroupingRefs.map(_.transform {
val newAggregateExpressions =
a.aggregateExpressions.map(_.transform {
case g: GroupingExprRef if g.nullable != nullabilities(g.ordinal) =>
g.copy(nullable = nullabilities(g.ordinal))
}.asInstanceOf[NamedExpression])

a.copy(aggrExprWithGroupingRefs = newAggrExprWithGroupingRefs)
a.copy(aggregateExpressions = newAggregateExpressions)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.optimizer.EnforceGroupingReferencesInAggregates
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -408,7 +407,7 @@ package object dsl {
case ne: NamedExpression => ne
case e => Alias(e, e.toString)()
}
Aggregate(groupingExprs, aliasedExprs, logicalPlan, false)
Aggregate(groupingExprs, aliasedExprs, logicalPlan)
}

def having(
Expand Down Expand Up @@ -467,7 +466,7 @@ package object dsl {
def analyze: LogicalPlan = {
val analyzed = analysis.SimpleAnalyzer.execute(logicalPlan)
analysis.SimpleAnalyzer.checkAnalysis(analyzed)
EnforceGroupingReferencesInAggregates(EliminateSubqueryAliases(analyzed))
EliminateSubqueryAliases(analyzed)
}

def hint(name: String, parameters: Any*): LogicalPlan =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ trait AliasHelper {
protected def getAliasMap(plan: Aggregate): AttributeMap[Alias] = {
// Find all the aliased expressions in the aggregate list that don't include any actual
// AggregateExpression or PythonUDF, and create a map from the alias to the expression
val aliasMap = plan.aggregateExpressions.collect {
val aliasMap = plan.aggregateExpressionsWithoutGroupingRefs.collect {
case a: Alias if a.child.find(e => e.isInstanceOf[AggregateExpression] ||
PythonUDF.isGroupedAggPandasUDF(e)).isEmpty =>
(a.toAttribute, a)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.rules.Rule
object EnforceGroupingReferencesInAggregates extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
plan transform {
case a: Aggregate if !a.enforceGroupingReferences =>
Aggregate(a.groupingExpressions, a.aggrExprWithGroupingRefs, a.child)
case a: Aggregate =>
Aggregate.withGroupingRefs(a.groupingExpressions, a.aggregateExpressions, a.child)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
case upper @ Aggregate(_, _, lower: Aggregate) if lowerIsRedundant(upper, lower) =>
val aliasMap = getAliasMap(lower)

val newAggregate = Aggregate(
val newAggregate = Aggregate.withGroupingRefs(
child = lower.child,
groupingExpressions = upper.groupingExpressions.map(replaceAlias(_, aliasMap)),
aggregateExpressions = upper.aggregateExpressions.map(
Expand Down Expand Up @@ -752,8 +752,8 @@ object ColumnPruning extends Rule[LogicalPlan] {
case p @ Project(_, p2: Project) if !p2.outputSet.subsetOf(p.references) =>
p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains)))
case p @ Project(_, a: Aggregate) if !a.outputSet.subsetOf(p.references) =>
p.copy(child =
a.copy(aggrExprWithGroupingRefs = a.aggrExprWithGroupingRefs.filter(p.references.contains)))
p.copy(
child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains)))
case a @ Project(_, e @ Expand(_, _, grandChild)) if !e.outputSet.subsetOf(a.references) =>
val newOutput = e.output.filter(a.references.contains(_))
val newProjects = e.projections.map { proj =>
Expand Down Expand Up @@ -879,8 +879,8 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper {
if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions)) {
p
} else {
Aggregate(agg.groupingExpressions,
buildCleanedProjectList(p.projectList, agg.aggregateExpressions), agg.child)
agg.copy(aggregateExpressions = buildCleanedProjectList(
p.projectList, agg.aggregateExpressions))
}
case Project(l1, g @ GlobalLimit(_, limit @ LocalLimit(_, p2 @ Project(l2, _))))
if isRenaming(l1, l2) =>
Expand Down Expand Up @@ -1250,7 +1250,6 @@ object EliminateSorts extends Rule[LogicalPlan] {

def checkValidAggregateExpression(expr: Expression): Boolean = expr match {
case _: AttributeReference => true
case _: GroupingExprRef => true
case ae: AggregateExpression => isOrderIrrelevantAggFunction(ae.aggregateFunction)
case _: UserDefinedExpression => false
case e => e.children.forall(checkValidAggregateExpression)
Expand Down Expand Up @@ -1985,15 +1984,15 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] {
val droppedGroupsBefore =
grouping.scanLeft(0)((n, e) => n + (if (e.foldable) 1 else 0)).toArray

val newAggrExprWithGroupingReferences =
a.aggrExprWithGroupingRefs.map(_.transform {
val newAggregateExpressions =
a.aggregateExpressions.map(_.transform {
case g: GroupingExprRef if droppedGroupsBefore(g.ordinal) > 0 =>
g.copy(ordinal = g.ordinal - droppedGroupsBefore(g.ordinal))
}.asInstanceOf[NamedExpression])

a.copy(
groupingExpressions = newGrouping,
aggrExprWithGroupingRefs = newAggrExprWithGroupingReferences)
a.copy(
groupingExpressions = newGrouping,
aggregateExpressions = newAggregateExpressions)
} else {
// All grouping expressions are literals. We should not drop them all, because this can
// change the return semantics when the input of the Aggregate is empty (SPARK-17114). We
Expand Down Expand Up @@ -2024,15 +2023,15 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] {
})
).toArray

val newAggrExprWithGroupingReferences =
a.aggrExprWithGroupingRefs.map(_.transform {
val newAggregateExpressions =
a.aggregateExpressions.map(_.transform {
case g: GroupingExprRef if droppedGroupsBefore(g.ordinal) > 0 =>
g.copy(ordinal = g.ordinal - droppedGroupsBefore(g.ordinal))
}.asInstanceOf[NamedExpression])

a.copy(
groupingExpressions = newGrouping,
aggrExprWithGroupingRefs = newAggrExprWithGroupingReferences)
aggregateExpressions = newAggregateExpressions)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -612,9 +612,10 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe
* subqueries.
*/
def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput {
case a @ Aggregate(grouping, expressions, child) =>
case a @ Aggregate(grouping, _, child) =>
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val rewriteExprs = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
val rewriteExprs = a.aggregateExpressionsWithoutGroupingRefs
.map(extractCorrelatedScalarSubqueries(_, subqueries))
if (subqueries.nonEmpty) {
// We currently only allow correlated subqueries in an aggregate if they are part of the
// grouping expressions. As a result we need to replace all the scalar subqueries in the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
Filter(predicate, createProject())
} else {
// According to SQL standard, HAVING without GROUP BY means global aggregate.
withHavingClause(havingClause, Aggregate(Nil, namedExpressions, withFilter, false))
withHavingClause(havingClause, Aggregate(Nil, namedExpressions, withFilter))
}
} else if (aggregationClause != null) {
val aggregate = withAggregationClause(aggregationClause, namedExpressions, withFilter)
Expand Down Expand Up @@ -924,7 +924,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
val groupingSets =
ctx.groupingSet.asScala.map(_.expression.asScala.map(e => expression(e)).toSeq)
Aggregate(Seq(GroupingSets(groupingSets.toSeq, groupByExpressions)),
selectExpressions, query, false)
selectExpressions, query)
} else {
// GROUP BY .... (WITH CUBE | WITH ROLLUP)?
val mappedGroupByExpressions = if (ctx.CUBE != null) {
Expand All @@ -934,7 +934,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
} else {
groupByExpressions
}
Aggregate(mappedGroupByExpressions, selectExpressions, query, false)
Aggregate(mappedGroupByExpressions, selectExpressions, query)
}
} else {
val groupByExpressions =
Expand Down Expand Up @@ -978,7 +978,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
"`GROUP BY CUBE(a, b), ROLLUP(a, c)` is not supported.",
ctx)
}
Aggregate(groupByExpressions.toSeq, selectExpressions, query, false)
Aggregate(groupByExpressions.toSeq, selectExpressions, query)
}
}

Expand Down
Loading