Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,9 @@ public Cast(Expression expression, DataType dataType) {

@Override
public Expression[] children() { return new Expression[]{ expression() }; }

@Override
public String toString() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the only missing one? maybe we should add a default implementation in the base class Expression using ToStringSQLBuilder?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, some V2 expressions missing toString too. e.g. Extract.
Add a default implementation in the base class Expression using ToStringSQLBuilder is a good idea.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will fix it.

return "CAST(" + expression.describe() + " AS " + dataType.typeName() + ")";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribute, AttributeReference, AttributeSet, Cast, Expression, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribute, AttributeMap, AttributeReference, AttributeSet, Cast, Expression, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.optimizer.CollapseProject
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
Expand Down Expand Up @@ -189,12 +189,14 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
// +- ScanBuilderHolder[group_col_0#10, agg_func_0#21, agg_func_1#22]
// Later, we build the `Scan` instance and convert ScanBuilderHolder to DataSourceV2ScanRelation.
// scalastyle:on
val groupOutput = normalizedGroupingExpr.zipWithIndex.map { case (e, i) =>
AttributeReference(s"group_col_$i", e.dataType)()
val groupOutputMap = normalizedGroupingExpr.zipWithIndex.map { case (e, i) =>
AttributeReference(s"group_col_$i", e.dataType)() -> e
}
val aggOutput = finalAggExprs.zipWithIndex.map { case (e, i) =>
AttributeReference(s"agg_func_$i", e.dataType)()
val groupOutput = groupOutputMap.unzip._1
val aggOutputMap = finalAggExprs.zipWithIndex.map { case (e, i) =>
AttributeReference(s"agg_func_$i", e.dataType)() -> e
}
val aggOutput = aggOutputMap.unzip._1
val newOutput = groupOutput ++ aggOutput
val groupByExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int]
normalizedGroupingExpr.zipWithIndex.foreach { case (expr, ordinal) =>
Expand All @@ -204,6 +206,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
}

holder.pushedAggregate = Some(translatedAgg)
holder.pushedAggOutputMap = AttributeMap(groupOutputMap ++ aggOutputMap)
holder.output = newOutput
logInfo(
s"""
Expand Down Expand Up @@ -408,14 +411,20 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
}
(operation, isPushed && !isPartiallyPushed)
case s @ Sort(order, _, operation @ PhysicalOperation(project, Nil, sHolder: ScanBuilderHolder))
// Without building the Scan, we do not know the resulting column names after aggregate
// push-down, and thus can't push down Top-N which needs to know the ordering column names.
// TODO: we can support simple cases like GROUP BY columns directly and ORDER BY the same
// columns, which we know the resulting column names: the original table columns.
if sHolder.pushedAggregate.isEmpty &&
CollapseProject.canCollapseExpressions(order, project, alwaysInline = true) =>
if CollapseProject.canCollapseExpressions(order, project, alwaysInline = true) =>
val aliasMap = getAliasMap(project)
val newOrder = order.map(replaceAlias(_, aliasMap)).asInstanceOf[Seq[SortOrder]]
val aliasReplacedOrder = order.map(replaceAlias(_, aliasMap))
val newOrder = if (sHolder.pushedAggregate.isDefined) {
// `ScanBuilderHolder` has different output columns after aggregate push-down. Here we
// replace the attributes in ordering expressions with the original table output columns.
aliasReplacedOrder.map {
_.transform {
case a: Attribute => sHolder.pushedAggOutputMap.getOrElse(a, a)
}.asInstanceOf[SortOrder]
}
} else {
aliasReplacedOrder.asInstanceOf[Seq[SortOrder]]
}
val normalizedOrders = DataSourceStrategy.normalizeExprs(
newOrder, sHolder.relation.output).asInstanceOf[Seq[SortOrder]]
val orders = DataSourceStrategy.translateSortOrders(normalizedOrders)
Expand Down Expand Up @@ -545,6 +554,8 @@ case class ScanBuilderHolder(
var pushedPredicates: Seq[Predicate] = Seq.empty[Predicate]

var pushedAggregate: Option[Aggregation] = None

var pushedAggOutputMap: AttributeMap[Expression] = AttributeMap.empty[Expression]
}

// A wrapper for v1 scan to carry the translated filters and the handled ones, along with
Expand Down
253 changes: 215 additions & 38 deletions sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -775,59 +775,46 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkAnswer(df5,
Seq(Row(1, "cathy", 9000.00, 1200.0, false), Row(1, "amy", 10000.00, 1000.0, true)))

val name = udf { (x: String) => x.matches("cat|dav|amy") }
val sub = udf { (x: String) => x.substring(0, 3) }
val df6 = spark.read
.table("h2.test.employee")
.groupBy("DEPT").sum("SALARY")
.orderBy("DEPT")
.select($"SALARY", $"BONUS", sub($"NAME").as("shortName"))
.filter(name($"shortName"))
.sort($"SALARY".desc)
.limit(1)
// LIMIT is pushed down only if all the filters are pushed down
checkSortRemoved(df6, false)
checkLimitRemoved(df6, false)
checkPushedInfo(df6,
"PushedAggregates: [SUM(SALARY)]",
"PushedFilters: []",
"PushedGroupByExpressions: [DEPT]")
checkAnswer(df6, Seq(Row(1, 19000.00)))
checkPushedInfo(df6, "PushedFilters: []")
checkAnswer(df6, Seq(Row(10000.00, 1000.0, "amy")))

val name = udf { (x: String) => x.matches("cat|dav|amy") }
val sub = udf { (x: String) => x.substring(0, 3) }
val df7 = spark.read
.table("h2.test.employee")
.select($"SALARY", $"BONUS", sub($"NAME").as("shortName"))
.filter(name($"shortName"))
.sort($"SALARY".desc)
.sort(sub($"NAME"))
.limit(1)
// LIMIT is pushed down only if all the filters are pushed down
checkSortRemoved(df7, false)
checkLimitRemoved(df7, false)
checkPushedInfo(df7, "PushedFilters: []")
checkAnswer(df7, Seq(Row(10000.00, 1000.0, "amy")))
checkAnswer(df7, Seq(Row(2, "alex", 12000.00, 1200.0, false)))

val df8 = spark.read
.table("h2.test.employee")
.sort(sub($"NAME"))
.limit(1)
checkSortRemoved(df8, false)
checkLimitRemoved(df8, false)
checkPushedInfo(df8, "PushedFilters: []")
checkAnswer(df8, Seq(Row(2, "alex", 12000.00, 1200.0, false)))

val df9 = spark.read
.table("h2.test.employee")
.select($"DEPT", $"name", $"SALARY",
when(($"SALARY" > 8000).and($"SALARY" < 10000), $"salary").otherwise(0).as("key"))
.sort("key", "dept", "SALARY")
.limit(3)
checkSortRemoved(df9)
checkLimitRemoved(df9)
checkPushedInfo(df9,
checkSortRemoved(df8)
checkLimitRemoved(df8)
checkPushedInfo(df8,
"PushedFilters: []",
"PushedTopN: " +
"ORDER BY [CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END " +
"ASC NULLS FIRST, DEPT ASC NULLS FIRST, SALARY ASC NULLS FIRST] LIMIT 3,")
checkAnswer(df9,
"PushedTopN: ORDER BY " +
"[CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END" +
" ASC NULLS FIRST, DEPT ASC NULLS FIRST, SALARY ASC NULLS FIRST] LIMIT 3")
checkAnswer(df8,
Seq(Row(1, "amy", 10000, 0), Row(2, "david", 10000, 0), Row(2, "alex", 12000, 0)))

val df10 = spark.read
val df9 = spark.read
.option("partitionColumn", "dept")
.option("lowerBound", "0")
.option("upperBound", "2")
Expand All @@ -837,14 +824,14 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
when(($"SALARY" > 8000).and($"SALARY" < 10000), $"salary").otherwise(0).as("key"))
.orderBy($"key", $"dept", $"SALARY")
.limit(3)
checkSortRemoved(df10, false)
checkLimitRemoved(df10, false)
checkPushedInfo(df10,
checkSortRemoved(df9, false)
checkLimitRemoved(df9, false)
checkPushedInfo(df9,
"PushedFilters: []",
"PushedTopN: " +
"ORDER BY [CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END " +
"ASC NULLS FIRST, DEPT ASC NULLS FIRST, SALARY ASC NULLS FIRST] LIMIT 3,")
checkAnswer(df10,
"PushedTopN: ORDER BY " +
"[CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END " +
"ASC NULLS FIRST, DEPT ASC NULLS FIRST, SALARY ASC NULLS FIRST] LIMIT 3")
checkAnswer(df9,
Seq(Row(1, "amy", 10000, 0), Row(2, "david", 10000, 0), Row(2, "alex", 12000, 0)))
}

Expand Down Expand Up @@ -873,6 +860,196 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkAnswer(df2, Seq(Row(2, "david", 10000.00)))
}

test("scan with aggregate push-down, top N push-down and offset push-down") {
val df1 = spark.read
.table("h2.test.employee")
.groupBy("DEPT").sum("SALARY")
.orderBy("DEPT")

val paging1 = df1.offset(1).limit(1)
checkSortRemoved(paging1)
checkLimitRemoved(paging1)
checkPushedInfo(paging1,
"PushedAggregates: [SUM(SALARY)]",
"PushedGroupByExpressions: [DEPT]",
"PushedFilters: []",
"PushedOffset: OFFSET 1",
"PushedTopN: ORDER BY [DEPT ASC NULLS FIRST] LIMIT 2")
checkAnswer(paging1, Seq(Row(2, 22000.00)))

val topN1 = df1.limit(1)
checkSortRemoved(topN1)
checkLimitRemoved(topN1)
checkPushedInfo(topN1,
"PushedAggregates: [SUM(SALARY)]",
"PushedGroupByExpressions: [DEPT]",
"PushedFilters: []",
"PushedTopN: ORDER BY [DEPT ASC NULLS FIRST] LIMIT 1")
checkAnswer(topN1, Seq(Row(1, 19000.00)))

val df2 = spark.read
.table("h2.test.employee")
.select($"DEPT".cast("string").as("my_dept"), $"SALARY")
.groupBy("my_dept").sum("SALARY")
.orderBy("my_dept")

val paging2 = df2.offset(1).limit(1)
checkSortRemoved(paging2)
checkLimitRemoved(paging2)
checkPushedInfo(paging2,
"PushedAggregates: [SUM(SALARY)]",
"PushedGroupByExpressions: [CAST(DEPT AS string)]",
"PushedFilters: []",
"PushedOffset: OFFSET 1",
"PushedTopN: ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST] LIMIT 2")
checkAnswer(paging2, Seq(Row("2", 22000.00)))

val topN2 = df2.limit(1)
checkSortRemoved(topN2)
checkLimitRemoved(topN2)
checkPushedInfo(topN2,
"PushedAggregates: [SUM(SALARY)]",
"PushedGroupByExpressions: [CAST(DEPT AS string)]",
"PushedFilters: []",
"PushedTopN: ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST] LIMIT 1")
checkAnswer(topN2, Seq(Row("1", 19000.00)))

val df3 = spark.read
.table("h2.test.employee")
.groupBy("dept").sum("SALARY")
.orderBy($"dept".cast("string"))

val paging3 = df3.offset(1).limit(1)
checkSortRemoved(paging3)
checkLimitRemoved(paging3)
checkPushedInfo(paging3,
"PushedAggregates: [SUM(SALARY)]",
"PushedGroupByExpressions: [DEPT]",
"PushedFilters: []",
"PushedOffset: OFFSET 1",
"PushedTopN: ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST] LIMIT 2")
checkAnswer(paging3, Seq(Row(2, 22000.00)))

val topN3 = df3.limit(1)
checkSortRemoved(topN3)
checkLimitRemoved(topN3)
checkPushedInfo(topN3,
"PushedAggregates: [SUM(SALARY)]",
"PushedGroupByExpressions: [DEPT]",
"PushedFilters: []",
"PushedTopN: ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST] LIMIT 1")
checkAnswer(topN3, Seq(Row(1, 19000.00)))

val df4 = spark.read
.table("h2.test.employee")
.groupBy("DEPT", "IS_MANAGER").sum("SALARY")
.orderBy("DEPT", "IS_MANAGER")

val paging4 = df4.offset(1).limit(1)
checkSortRemoved(paging4)
checkLimitRemoved(paging4)
checkPushedInfo(paging4,
"PushedAggregates: [SUM(SALARY)]",
"PushedGroupByExpressions: [DEPT, IS_MANAGER]",
"PushedFilters: []",
"PushedOffset: OFFSET 1",
"PushedTopN: ORDER BY [DEPT ASC NULLS FIRST, IS_MANAGER ASC NULLS FIRST] LIMIT 2")
checkAnswer(paging4, Seq(Row(1, true, 10000.00)))

val topN4 = df4.limit(1)
checkSortRemoved(topN4)
checkLimitRemoved(topN4)
checkPushedInfo(topN4,
"PushedAggregates: [SUM(SALARY)]",
"PushedGroupByExpressions: [DEPT, IS_MANAGER]",
"PushedFilters: []",
"PushedTopN: ORDER BY [DEPT ASC NULLS FIRST, IS_MANAGER ASC NULLS FIRST] LIMIT 1")
checkAnswer(topN4, Seq(Row(1, false, 9000.00)))

val df5 = spark.read
.table("h2.test.employee")
.select($"SALARY", $"IS_MANAGER", $"DEPT".cast("string").as("my_dept"))
.groupBy("my_dept", "IS_MANAGER").sum("SALARY")
.orderBy("my_dept", "IS_MANAGER")

val paging5 = df5.offset(1).limit(1)
checkSortRemoved(paging5)
checkLimitRemoved(paging5)
checkPushedInfo(paging5,
"PushedAggregates: [SUM(SALARY)]",
"PushedGroupByExpressions: [CAST(DEPT AS string), IS_MANAGER]",
"PushedFilters: []",
"PushedOffset: OFFSET 1",
"PushedTopN: " +
"ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST, IS_MANAGER ASC NULLS FIRST] LIMIT 2")
checkAnswer(paging5, Seq(Row("1", true, 10000.00)))

val topN5 = df5.limit(1)
checkSortRemoved(topN5)
checkLimitRemoved(topN5)
checkPushedInfo(topN5,
"PushedAggregates: [SUM(SALARY)]",
"PushedGroupByExpressions: [CAST(DEPT AS string), IS_MANAGER]",
"PushedFilters: []",
"PushedTopN: " +
"ORDER BY [CAST(DEPT AS string) ASC NULLS FIRST, IS_MANAGER ASC NULLS FIRST] LIMIT 1")
checkAnswer(topN5, Seq(Row("1", false, 9000.00)))

val df6 = spark.read
.table("h2.test.employee")
.select($"DEPT", $"SALARY")
.groupBy("dept").agg(sum("SALARY"))
.orderBy(sum("SALARY"))

val paging6 = df6.offset(1).limit(1)
checkSortRemoved(paging6)
checkLimitRemoved(paging6)
checkPushedInfo(paging6,
"PushedAggregates: [SUM(SALARY)]",
"PushedGroupByExpressions: [DEPT]",
"PushedFilters: []",
"PushedOffset: OFFSET 1",
"PushedTopN: ORDER BY [SUM(SALARY) ASC NULLS FIRST] LIMIT 2")
checkAnswer(paging6, Seq(Row(1, 19000.00)))

val topN6 = df6.limit(1)
checkSortRemoved(topN6)
checkLimitRemoved(topN6)
checkPushedInfo(topN6,
"PushedAggregates: [SUM(SALARY)]",
"PushedGroupByExpressions: [DEPT]",
"PushedFilters: []",
"PushedTopN: ORDER BY [SUM(SALARY) ASC NULLS FIRST] LIMIT 1")
checkAnswer(topN6, Seq(Row(6, 12000.00)))

val df7 = spark.read
.table("h2.test.employee")
.select($"DEPT", $"SALARY")
.groupBy("dept").agg(sum("SALARY").as("total"))
.orderBy("total")

val paging7 = df7.offset(1).limit(1)
checkSortRemoved(paging7)
checkLimitRemoved(paging7)
checkPushedInfo(paging7,
"PushedAggregates: [SUM(SALARY)]",
"PushedGroupByExpressions: [DEPT]",
"PushedFilters: []",
"PushedOffset: OFFSET 1",
"PushedTopN: ORDER BY [SUM(SALARY) ASC NULLS FIRST] LIMIT 2")
checkAnswer(paging7, Seq(Row(1, 19000.00)))

val topN7 = df7.limit(1)
checkSortRemoved(topN7)
checkLimitRemoved(topN7)
checkPushedInfo(topN7,
"PushedAggregates: [SUM(SALARY)]",
"PushedGroupByExpressions: [DEPT]",
"PushedFilters: []",
"PushedTopN: ORDER BY [SUM(SALARY) ASC NULLS FIRST] LIMIT 1")
checkAnswer(topN7, Seq(Row(6, 12000.00)))
}

test("scan with filter push-down") {
val df = spark.table("h2.test.people").filter($"id" > 1)
checkFiltersRemoved(df)
Expand Down