Skip to content

Commit a171791

Browse files
cloud-fandongjoon-hyun
authored andcommitted
[SPARK-46378][SQL] Still remove Sort after converting Aggregate to Project
### What changes were proposed in this pull request? This is a follow-up of apache#33397 to avoid sub-optimal plans. After converting `Aggregate` to `Project`, there is information lost: `Aggregate` doesn't care about the data order of inputs, but `Project` cares. `EliminateSorts` can remove `Sort` below `Aggregate`, but it doesn't work anymore if we convert `Aggregate` to `Project`. This PR fixes this issue by tagging the `Project` to be order-irrelevant if it's converted from `Aggregate`. Then `EliminateSorts` optimizes the tagged `Project`. ### Why are the changes needed? avoid sub-optimal plans ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? new test ### Was this patch authored or co-authored using generative AI tooling? No Closes apache#44310 from cloud-fan/sort. Authored-by: Wenchen Fan <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent b7b58e3 commit a171791

File tree

4 files changed

+22
-1
lines changed

4 files changed

+22
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,8 @@ package object dsl {
404404

405405
def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan)
406406

407+
def localLimit(limitExpr: Expression): LogicalPlan = LocalLimit(limitExpr, logicalPlan)
408+
407409
def offset(offsetExpr: Expression): LogicalPlan = Offset(offsetExpr, logicalPlan)
408410

409411
def join(

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,9 @@ object LimitPushDown extends Rule[LogicalPlan] {
757757
LocalLimit(exp, project.copy(child = pushLocalLimitThroughJoin(exp, join)))
758758
// Push down limit 1 through Aggregate and turn Aggregate into Project if it is group only.
759759
case Limit(le @ IntegerLiteral(1), a: Aggregate) if a.groupOnly =>
760-
Limit(le, Project(a.aggregateExpressions, LocalLimit(le, a.child)))
760+
val project = Project(a.aggregateExpressions, LocalLimit(le, a.child))
761+
project.setTagValue(Project.dataOrderIrrelevantTag, ())
762+
Limit(le, project)
761763
case Limit(le @ IntegerLiteral(1), p @ Project(_, a: Aggregate)) if a.groupOnly =>
762764
Limit(le, p.copy(child = Project(a.aggregateExpressions, LocalLimit(le, a.child))))
763765
// Merge offset value and limit value into LocalLimit and pushes down LocalLimit through Offset.
@@ -1563,6 +1565,8 @@ object EliminateSorts extends Rule[LogicalPlan] {
15631565
right = recursiveRemoveSort(originRight, true))
15641566
case g @ Aggregate(_, aggs, originChild) if isOrderIrrelevantAggs(aggs) =>
15651567
g.copy(child = recursiveRemoveSort(originChild, true))
1568+
case p: Project if p.getTagValue(Project.dataOrderIrrelevantTag).isDefined =>
1569+
p.copy(child = recursiveRemoveSort(p.child, true))
15661570
}
15671571

15681572
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan)
9898

9999
object Project {
100100
val hiddenOutputTag: TreeNodeTag[Seq[Attribute]] = TreeNodeTag[Seq[Attribute]]("hidden_output")
101+
// Project with this tag means it doesn't care about the data order of its input. We only set
102+
// this tag when the Project was converted from grouping-only Aggregate.
103+
val dataOrderIrrelevantTag: TreeNodeTag[Unit] = TreeNodeTag[Unit]("data_order_irrelevant")
101104

102105
def matchSchema(plan: LogicalPlan, schema: StructType, conf: SQLConf): Project = {
103106
assert(plan.resolved)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,4 +478,16 @@ class EliminateSortsSuite extends AnalysisTest {
478478

479479
comparePlans(Optimize.execute(originalPlan.analyze), correctAnswer.analyze)
480480
}
481+
482+
test("SPARK-46378: Still remove Sort after converting Aggregate to Project") {
483+
val originalPlan = testRelation.orderBy($"a".asc)
484+
.groupBy($"a")($"a")
485+
.limit(1)
486+
487+
val correctAnswer = testRelation.localLimit(1)
488+
.select($"a")
489+
.limit(1)
490+
491+
comparePlans(Optimize.execute(originalPlan.analyze), correctAnswer.analyze)
492+
}
481493
}

0 commit comments

Comments
 (0)