Skip to content

Commit 5dc6b10

Browse files
karenfengcloud-fan
authored andcommitted
[SPARK-34923][SQL] Metadata output should be empty for more plans
Changes the metadata propagation framework. Previously, most `LogicalPlan`'s propagated their `children`'s `metadataOutput`. This did not make sense in cases where the `LogicalPlan` did not even propagate their `children`'s `output`. I set the metadata output for plans that do not propagate their `children`'s `output` to be `Nil`. Notably, `Project` and `View` no longer have metadata output. Previously, `SELECT m from (SELECT a from tb)` would output `m` if it were metadata. This did not make sense. Yes. Now, `SELECT m from (SELECT a from tb)` will encounter an `AnalysisException`. Added unit tests. I did not cover all cases, as they are fairly extensive. However, the new tests cover major cases (and an existing test already covers Join). Closes #32017 from karenfeng/spark-34923. Authored-by: Karen Feng <[email protected]> Signed-off-by: Wenchen Fan <[email protected]> (cherry picked from commit 3b634f6) Signed-off-by: Wenchen Fan <[email protected]>
1 parent 96f981b commit 5dc6b10

File tree

3 files changed

+132
-1
lines changed

3 files changed

+132
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@ abstract class LogicalPlan
3333
with QueryPlanConstraints
3434
with Logging {
3535

36-
/** Metadata fields that can be projected from this node */
36+
/**
37+
* Metadata fields that can be projected from this node.
38+
* Should be overridden if the plan does not propagate its children's output.
39+
*/
3740
def metadataOutput: Seq[Attribute] = children.flatMap(_.metadataOutput)
3841

3942
/** Returns true if this subtree has data from a streaming data source. */

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ object Subquery {
5959
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan)
6060
extends OrderPreservingUnaryNode {
6161
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
62+
override def metadataOutput: Seq[Attribute] = Nil
6263
override def maxRows: Option[Long] = child.maxRows
6364

6465
override lazy val resolved: Boolean = {
@@ -185,6 +186,8 @@ case class Intersect(
185186
leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable)
186187
}
187188

189+
override def metadataOutput: Seq[Attribute] = Nil
190+
188191
override protected lazy val validConstraints: ExpressionSet =
189192
leftConstraints.union(rightConstraints)
190193

@@ -205,6 +208,8 @@ case class Except(
205208
/** We don't use right.output because those rows get excluded from the set. */
206209
override def output: Seq[Attribute] = left.output
207210

211+
override def metadataOutput: Seq[Attribute] = Nil
212+
208213
override protected lazy val validConstraints: ExpressionSet = leftConstraints
209214
}
210215

@@ -268,6 +273,8 @@ case class Union(
268273
}
269274
}
270275

276+
override def metadataOutput: Seq[Attribute] = Nil
277+
271278
override lazy val resolved: Boolean = {
272279
// allChildrenCompatible needs to be evaluated after childrenResolved
273280
def allChildrenCompatible: Boolean =
@@ -343,6 +350,17 @@ case class Join(
343350
}
344351
}
345352

353+
override def metadataOutput: Seq[Attribute] = {
354+
joinType match {
355+
case ExistenceJoin(_) =>
356+
left.metadataOutput
357+
case LeftExistence(_) =>
358+
left.metadataOutput
359+
case _ =>
360+
children.flatMap(_.metadataOutput)
361+
}
362+
}
363+
346364
override protected lazy val validConstraints: ExpressionSet = {
347365
joinType match {
348366
case _: InnerLike if condition.isDefined =>
@@ -419,6 +437,7 @@ case class InsertIntoDir(
419437
extends UnaryNode {
420438

421439
override def output: Seq[Attribute] = Seq.empty
440+
override def metadataOutput: Seq[Attribute] = Nil
422441
override lazy val resolved: Boolean = false
423442
}
424443

@@ -449,6 +468,8 @@ case class View(
449468

450469
override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance()))
451470

471+
override def metadataOutput: Seq[Attribute] = Nil
472+
452473
override def simpleString(maxFields: Int): String = {
453474
s"View (${desc.identifier}, ${output.mkString("[", ",", "]")})"
454475
}
@@ -616,6 +637,7 @@ case class Aggregate(
616637
}
617638

618639
override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
640+
override def metadataOutput: Seq[Attribute] = Nil
619641
override def maxRows: Option[Long] = {
620642
if (groupingExpressions.isEmpty) {
621643
Some(1L)
@@ -751,6 +773,8 @@ case class Expand(
751773
override lazy val references: AttributeSet =
752774
AttributeSet(projections.flatten.flatMap(_.references))
753775

776+
override def metadataOutput: Seq[Attribute] = Nil
777+
754778
override def producedAttributes: AttributeSet = AttributeSet(output diff child.output)
755779

756780
// This operator can reuse attributes (for example making them null when doing a roll up) so
@@ -813,6 +837,7 @@ case class Pivot(
813837
}
814838
groupByExprsOpt.getOrElse(Seq.empty).map(_.toAttribute) ++ pivotAgg
815839
}
840+
override def metadataOutput: Seq[Attribute] = Nil
816841
}
817842

818843
/**

sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2794,6 +2794,109 @@ class DataSourceV2SQLSuite
27942794
}.getMessage
27952795
assert(errMsg.contains(expectedError))
27962796
}
2797+
2798+
test("SPARK-34923: do not propagate metadata columns through Project") {
2799+
val t1 = s"${catalogAndNamespace}table"
2800+
withTable(t1) {
2801+
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
2802+
"PARTITIONED BY (bucket(4, id), id)")
2803+
sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")
2804+
2805+
assertThrows[AnalysisException] {
2806+
sql(s"SELECT index, _partition from (SELECT id, data FROM $t1)")
2807+
}
2808+
assertThrows[AnalysisException] {
2809+
spark.table(t1).select("id", "data").select("index", "_partition")
2810+
}
2811+
}
2812+
}
2813+
2814+
test("SPARK-34923: do not propagate metadata columns through View") {
2815+
val t1 = s"${catalogAndNamespace}table"
2816+
val view = "view"
2817+
2818+
withTable(t1) {
2819+
withTempView(view) {
2820+
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
2821+
"PARTITIONED BY (bucket(4, id), id)")
2822+
sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")
2823+
sql(s"CACHE TABLE $view AS SELECT * FROM $t1")
2824+
assertThrows[AnalysisException] {
2825+
sql(s"SELECT index, _partition FROM $view")
2826+
}
2827+
}
2828+
}
2829+
}
2830+
2831+
test("SPARK-34923: propagate metadata columns through Filter") {
2832+
val t1 = s"${catalogAndNamespace}table"
2833+
withTable(t1) {
2834+
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
2835+
"PARTITIONED BY (bucket(4, id), id)")
2836+
sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")
2837+
2838+
val sqlQuery = spark.sql(s"SELECT id, data, index, _partition FROM $t1 WHERE id > 1")
2839+
val dfQuery = spark.table(t1).where("id > 1").select("id", "data", "index", "_partition")
2840+
2841+
Seq(sqlQuery, dfQuery).foreach { query =>
2842+
checkAnswer(query, Seq(Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
2843+
}
2844+
}
2845+
}
2846+
2847+
test("SPARK-34923: propagate metadata columns through Sort") {
2848+
val t1 = s"${catalogAndNamespace}table"
2849+
withTable(t1) {
2850+
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
2851+
"PARTITIONED BY (bucket(4, id), id)")
2852+
sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")
2853+
2854+
val sqlQuery = spark.sql(s"SELECT id, data, index, _partition FROM $t1 ORDER BY id")
2855+
val dfQuery = spark.table(t1).orderBy("id").select("id", "data", "index", "_partition")
2856+
2857+
Seq(sqlQuery, dfQuery).foreach { query =>
2858+
checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
2859+
}
2860+
}
2861+
}
2862+
2863+
test("SPARK-34923: propagate metadata columns through RepartitionBy") {
2864+
val t1 = s"${catalogAndNamespace}table"
2865+
withTable(t1) {
2866+
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
2867+
"PARTITIONED BY (bucket(4, id), id)")
2868+
sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")
2869+
2870+
val sqlQuery = spark.sql(
2871+
s"SELECT /*+ REPARTITION_BY_RANGE(3, id) */ id, data, index, _partition FROM $t1")
2872+
val tbl = spark.table(t1)
2873+
val dfQuery = tbl.repartitionByRange(3, tbl.col("id"))
2874+
.select("id", "data", "index", "_partition")
2875+
2876+
Seq(sqlQuery, dfQuery).foreach { query =>
2877+
checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
2878+
}
2879+
}
2880+
}
2881+
2882+
test("SPARK-34923: propagate metadata columns through SubqueryAlias") {
2883+
val t1 = s"${catalogAndNamespace}table"
2884+
val sbq = "sbq"
2885+
withTable(t1) {
2886+
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
2887+
"PARTITIONED BY (bucket(4, id), id)")
2888+
sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")
2889+
2890+
val sqlQuery = spark.sql(
2891+
s"SELECT $sbq.id, $sbq.data, $sbq.index, $sbq._partition FROM $t1 as $sbq")
2892+
val dfQuery = spark.table(t1).as(sbq).select(
2893+
s"$sbq.id", s"$sbq.data", s"$sbq.index", s"$sbq._partition")
2894+
2895+
Seq(sqlQuery, dfQuery).foreach { query =>
2896+
checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
2897+
}
2898+
}
2899+
}
27972900
}
27982901

27992902

0 commit comments

Comments
 (0)