diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java index 26b97b46fe2e..44111913f124 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java @@ -42,4 +42,9 @@ public Cast(Expression expression, DataType dataType) { @Override public Expression[] children() { return new Expression[]{ expression() }; } + + @Override + public String toString() { + return "CAST(" + expression.describe() + " AS " + dataType.typeName() + ")"; + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index c7b09904df41..27daa899583e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribute, AttributeReference, AttributeSet, Cast, Expression, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribute, AttributeMap, AttributeReference, AttributeSet, Cast, Expression, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.CollapseProject import org.apache.spark.sql.catalyst.planning.PhysicalOperation @@ -189,12 +189,14 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { // +- ScanBuilderHolder[group_col_0#10, agg_func_0#21, agg_func_1#22] // Later, we build the `Scan` instance and convert ScanBuilderHolder to DataSourceV2ScanRelation. // scalastyle:on - val groupOutput = normalizedGroupingExpr.zipWithIndex.map { case (e, i) => - AttributeReference(s"group_col_$i", e.dataType)() + val groupOutputMap = normalizedGroupingExpr.zipWithIndex.map { case (e, i) => + AttributeReference(s"group_col_$i", e.dataType)() -> e } - val aggOutput = finalAggExprs.zipWithIndex.map { case (e, i) => - AttributeReference(s"agg_func_$i", e.dataType)() + val groupOutput = groupOutputMap.unzip._1 + val aggOutputMap = finalAggExprs.zipWithIndex.map { case (e, i) => + AttributeReference(s"agg_func_$i", e.dataType)() -> e } + val aggOutput = aggOutputMap.unzip._1 val newOutput = groupOutput ++ aggOutput val groupByExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] normalizedGroupingExpr.zipWithIndex.foreach { case (expr, ordinal) => @@ -204,6 +206,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } holder.pushedAggregate = Some(translatedAgg) + holder.pushedAggOutputMap = AttributeMap(groupOutputMap ++ aggOutputMap) holder.output = newOutput logInfo( s""" @@ -408,14 +411,20 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } (operation, isPushed && !isPartiallyPushed) case s @ Sort(order, _, operation @ PhysicalOperation(project, Nil, sHolder: ScanBuilderHolder)) - // Without building the Scan, we do not know the resulting column names after aggregate - // push-down, and thus can't push down Top-N which needs to know the ordering column names. - // TODO: we can support simple cases like GROUP BY columns directly and ORDER BY the same - // columns, which we know the resulting column names: the original table columns. - if sHolder.pushedAggregate.isEmpty && - CollapseProject.canCollapseExpressions(order, project, alwaysInline = true) => + if CollapseProject.canCollapseExpressions(order, project, alwaysInline = true) => val aliasMap = getAliasMap(project) - val newOrder = order.map(replaceAlias(_, aliasMap)).asInstanceOf[Seq[SortOrder]] + val aliasReplacedOrder = order.map(replaceAlias(_, aliasMap)) + val newOrder = if (sHolder.pushedAggregate.isDefined) { + // `ScanBuilderHolder` has different output columns after aggregate push-down. Here we + // replace the attributes in ordering expressions with the original table output columns. + aliasReplacedOrder.map { + _.transform { + case a: Attribute => sHolder.pushedAggOutputMap.getOrElse(a, a) + }.asInstanceOf[SortOrder] + } + } else { + aliasReplacedOrder.asInstanceOf[Seq[SortOrder]] + } val normalizedOrders = DataSourceStrategy.normalizeExprs( newOrder, sHolder.relation.output).asInstanceOf[Seq[SortOrder]] val orders = DataSourceStrategy.translateSortOrders(normalizedOrders) @@ -545,6 +554,8 @@ case class ScanBuilderHolder( var pushedPredicates: Seq[Predicate] = Seq.empty[Predicate] var pushedAggregate: Option[Aggregation] = None + + var pushedAggOutputMap: AttributeMap[Expression] = AttributeMap.empty[Expression] } // A wrapper for v1 scan to carry the translated filters and the handled ones, along with diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 02dff0973fe1..a8c770f46cd6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -775,59 +775,46 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df5, Seq(Row(1, "cathy", 9000.00, 1200.0, false), Row(1, "amy", 10000.00, 1000.0, true))) + val name = udf { (x: String) => x.matches("cat|dav|amy") } + val sub = udf { (x: String) => x.substring(0, 3) } val df6 = spark.read .table("h2.test.employee") - .groupBy("DEPT").sum("SALARY") - .orderBy("DEPT") + .select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) + .filter(name($"shortName")) + .sort($"SALARY".desc) .limit(1) + // LIMIT is pushed down only if all the filters are pushed down checkSortRemoved(df6, false) checkLimitRemoved(df6, false) - checkPushedInfo(df6, - "PushedAggregates: [SUM(SALARY)]", - "PushedFilters: []", - "PushedGroupByExpressions: [DEPT]") - checkAnswer(df6, Seq(Row(1, 19000.00))) + checkPushedInfo(df6, "PushedFilters: []") + checkAnswer(df6, Seq(Row(10000.00, 1000.0, "amy"))) - val name = udf { (x: String) => x.matches("cat|dav|amy") } - val sub = udf { (x: String) => x.substring(0, 3) } val df7 = spark.read .table("h2.test.employee") - .select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) - .filter(name($"shortName")) - .sort($"SALARY".desc) + .sort(sub($"NAME")) .limit(1) - // LIMIT is pushed down only if all the filters are pushed down checkSortRemoved(df7, false) checkLimitRemoved(df7, false) checkPushedInfo(df7, "PushedFilters: []") - checkAnswer(df7, Seq(Row(10000.00, 1000.0, "amy"))) + checkAnswer(df7, Seq(Row(2, "alex", 12000.00, 1200.0, false))) val df8 = spark.read - .table("h2.test.employee") - .sort(sub($"NAME")) - .limit(1) - checkSortRemoved(df8, false) - checkLimitRemoved(df8, false) - checkPushedInfo(df8, "PushedFilters: []") - checkAnswer(df8, Seq(Row(2, "alex", 12000.00, 1200.0, false))) - - val df9 = spark.read .table("h2.test.employee") .select($"DEPT", $"name", $"SALARY", when(($"SALARY" > 8000).and($"SALARY" < 10000), $"salary").otherwise(0).as("key")) .sort("key", "dept", "SALARY") .limit(3) - checkSortRemoved(df9) - checkLimitRemoved(df9) - checkPushedInfo(df9, + checkSortRemoved(df8) + checkLimitRemoved(df8) + checkPushedInfo(df8, "PushedFilters: []", - "PushedTopN: " + - "ORDER BY [CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END " + - "ASC NULLS FIRST, DEPT ASC NULLS FIRST, SALARY ASC NULLS FIRST] LIMIT 3,") - checkAnswer(df9, + "PushedTopN: ORDER BY " + + "[CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END" + + " ASC NULLS FIRST, DEPT ASC NULLS FIRST, SALARY ASC NULLS FIRST] LIMIT 3") + checkAnswer(df8, Seq(Row(1, "amy", 10000, 0), Row(2, "david", 10000, 0), Row(2, "alex", 12000, 0))) - val df10 = spark.read + val df9 = spark.read .option("partitionColumn", "dept") .option("lowerBound", "0") .option("upperBound", "2") @@ -837,14 +824,14 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel when(($"SALARY" > 8000).and($"SALARY" < 10000), $"salary").otherwise(0).as("key")) .orderBy($"key", $"dept", $"SALARY") .limit(3) - checkSortRemoved(df10, false) - checkLimitRemoved(df10, false) - checkPushedInfo(df10, + checkSortRemoved(df9, false) + checkLimitRemoved(df9, false) + checkPushedInfo(df9, "PushedFilters: []", - "PushedTopN: " + - "ORDER BY [CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END " + - "ASC NULLS FIRST, DEPT ASC NULLS FIRST, SALARY ASC NULLS FIRST] LIMIT 3,") - checkAnswer(df10, + "PushedTopN: ORDER BY " + + "[CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END " + + "ASC NULLS FIRST, DEPT ASC NULLS FIRST, SALARY ASC NULLS FIRST] LIMIT 3") + checkAnswer(df9, Seq(Row(1, "amy", 10000, 0), Row(2, "david", 10000, 0), Row(2, "alex", 12000, 0))) } @@ -873,6 +860,196 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df2, Seq(Row(2, "david", 10000.00))) } + test("scan with aggregate push-down, top N push-down and offset push-down") { + val df1 = spark.read + .table("h2.test.employee") + .groupBy("DEPT").sum("SALARY") + .orderBy("DEPT") + + val paging1 = df1.offset(1).limit(1) + checkSortRemoved(paging1) + checkLimitRemoved(paging1) + checkPushedInfo(paging1, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedOffset: OFFSET 1", + "PushedTopN: ORDER BY [DEPT ASC NULLS FIRST] LIMIT 2") + checkAnswer(paging1, Seq(Row(2, 22000.00))) + + val topN1 = df1.limit(1) + checkSortRemoved(topN1) + checkLimitRemoved(topN1) + checkPushedInfo(topN1, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedTopN: ORDER BY [DEPT ASC NULLS FIRST] LIMIT 1") + checkAnswer(topN1, Seq(Row(1, 19000.00))) + + val df2 = spark.read + .table("h2.test.employee") + .select($"DEPT".cast("string").as("my_dept"), $"SALARY") + .groupBy("my_dept").sum("SALARY") + .orderBy("my_dept") + + val paging2 = df2.offset(1).limit(1) + checkSortRemoved(paging2) + checkLimitRemoved(paging2) + checkPushedInfo(paging2, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [CAST(DEPT AS string)]", + "PushedFilters: []", + "PushedOffset: OFFSET 1", + "PushedTopN: ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST] LIMIT 2") + checkAnswer(paging2, Seq(Row("2", 22000.00))) + + val topN2 = df2.limit(1) + checkSortRemoved(topN2) + checkLimitRemoved(topN2) + checkPushedInfo(topN2, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [CAST(DEPT AS string)]", + "PushedFilters: []", + "PushedTopN: ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST] LIMIT 1") + checkAnswer(topN2, Seq(Row("1", 19000.00))) + + val df3 = spark.read + .table("h2.test.employee") + .groupBy("dept").sum("SALARY") + .orderBy($"dept".cast("string")) + + val paging3 = df3.offset(1).limit(1) + checkSortRemoved(paging3) + checkLimitRemoved(paging3) + checkPushedInfo(paging3, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedOffset: OFFSET 1", + "PushedTopN: ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST] LIMIT 2") + checkAnswer(paging3, Seq(Row(2, 22000.00))) + + val topN3 = df3.limit(1) + checkSortRemoved(topN3) + checkLimitRemoved(topN3) + checkPushedInfo(topN3, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedTopN: ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST] LIMIT 1") + checkAnswer(topN3, Seq(Row(1, 19000.00))) + + val df4 = spark.read + .table("h2.test.employee") + .groupBy("DEPT", "IS_MANAGER").sum("SALARY") + .orderBy("DEPT", "IS_MANAGER") + + val paging4 = df4.offset(1).limit(1) + checkSortRemoved(paging4) + checkLimitRemoved(paging4) + checkPushedInfo(paging4, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT, IS_MANAGER]", + "PushedFilters: []", + "PushedOffset: OFFSET 1", + "PushedTopN: ORDER BY [DEPT ASC NULLS FIRST, IS_MANAGER ASC NULLS FIRST] LIMIT 2") + checkAnswer(paging4, Seq(Row(1, true, 10000.00))) + + val topN4 = df4.limit(1) + checkSortRemoved(topN4) + checkLimitRemoved(topN4) + checkPushedInfo(topN4, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT, IS_MANAGER]", + "PushedFilters: []", + "PushedTopN: ORDER BY [DEPT ASC NULLS FIRST, IS_MANAGER ASC NULLS FIRST] LIMIT 1") + checkAnswer(topN4, Seq(Row(1, false, 9000.00))) + + val df5 = spark.read + .table("h2.test.employee") + .select($"SALARY", $"IS_MANAGER", $"DEPT".cast("string").as("my_dept")) + .groupBy("my_dept", "IS_MANAGER").sum("SALARY") + .orderBy("my_dept", "IS_MANAGER") + + val paging5 = df5.offset(1).limit(1) + checkSortRemoved(paging5) + checkLimitRemoved(paging5) + checkPushedInfo(paging5, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [CAST(DEPT AS string), IS_MANAGER]", + "PushedFilters: []", + "PushedOffset: OFFSET 1", + "PushedTopN: " + + "ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST, IS_MANAGER ASC NULLS FIRST] LIMIT 2") + checkAnswer(paging5, Seq(Row("1", true, 10000.00))) + + val topN5 = df5.limit(1) + checkSortRemoved(topN5) + checkLimitRemoved(topN5) + checkPushedInfo(topN5, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [CAST(DEPT AS string), IS_MANAGER]", + "PushedFilters: []", + "PushedTopN: " + + "ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST, IS_MANAGER ASC NULLS FIRST] LIMIT 1") + checkAnswer(topN5, Seq(Row("1", false, 9000.00))) + + val df6 = spark.read + .table("h2.test.employee") + .select($"DEPT", $"SALARY") + .groupBy("dept").agg(sum("SALARY")) + .orderBy(sum("SALARY")) + + val paging6 = df6.offset(1).limit(1) + checkSortRemoved(paging6) + checkLimitRemoved(paging6) + checkPushedInfo(paging6, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedOffset: OFFSET 1", + "PushedTopN: ORDER BY [SUM(SALARY) ASC NULLS FIRST] LIMIT 2") + checkAnswer(paging6, Seq(Row(1, 19000.00))) + + val topN6 = df6.limit(1) + checkSortRemoved(topN6) + checkLimitRemoved(topN6) + checkPushedInfo(topN6, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedTopN: ORDER BY [SUM(SALARY) ASC NULLS FIRST] LIMIT 1") + checkAnswer(topN6, Seq(Row(6, 12000.00))) + + val df7 = spark.read + .table("h2.test.employee") + .select($"DEPT", $"SALARY") + .groupBy("dept").agg(sum("SALARY").as("total")) + .orderBy("total") + + val paging7 = df7.offset(1).limit(1) + checkSortRemoved(paging7) + checkLimitRemoved(paging7) + checkPushedInfo(paging7, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedOffset: OFFSET 1", + "PushedTopN: ORDER BY [SUM(SALARY) ASC NULLS FIRST] LIMIT 2") + checkAnswer(paging7, Seq(Row(1, 19000.00))) + + val topN7 = df7.limit(1) + checkSortRemoved(topN7) + checkLimitRemoved(topN7) + checkPushedInfo(topN7, + "PushedAggregates: [SUM(SALARY)]", + "PushedGroupByExpressions: [DEPT]", + "PushedFilters: []", + "PushedTopN: ORDER BY [SUM(SALARY) ASC NULLS FIRST] LIMIT 1") + checkAnswer(topN7, Seq(Row(6, 12000.00))) + } + test("scan with filter push-down") { val df = spark.table("h2.test.people").filter($"id" > 1) checkFiltersRemoved(df)