diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 7129c6984cf3f..cef0e06c357f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -34,7 +34,10 @@ abstract class LogicalPlan with QueryPlanConstraints with Logging { - /** Metadata fields that can be projected from this node */ + /** + * Metadata fields that can be projected from this node. + * Should be overridden if the plan does not propagate its children's output. + */ def metadataOutput: Seq[Attribute] = children.flatMap(_.metadataOutput) /** Returns true if this subtree has data from a streaming data source. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 9461dbf9f3d1b..49b09cc74c7cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -61,6 +61,7 @@ object Subquery { case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends OrderPreservingUnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) + override def metadataOutput: Seq[Attribute] = Nil override def maxRows: Option[Long] = child.maxRows override lazy val resolved: Boolean = { @@ -187,6 +188,8 @@ case class Intersect( leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable) } + override def metadataOutput: Seq[Attribute] = Nil + override protected lazy val validConstraints: ExpressionSet = leftConstraints.union(rightConstraints) @@ -207,6 +210,8 @@ case class Except( /** We don't use right.output because those rows get excluded from the set. */ override def output: Seq[Attribute] = left.output + override def metadataOutput: Seq[Attribute] = Nil + override protected lazy val validConstraints: ExpressionSet = leftConstraints } @@ -270,6 +275,8 @@ case class Union( } } + override def metadataOutput: Seq[Attribute] = Nil + override lazy val resolved: Boolean = { // allChildrenCompatible needs to be evaluated after childrenResolved def allChildrenCompatible: Boolean = @@ -364,6 +371,17 @@ case class Join( } } + override def metadataOutput: Seq[Attribute] = { + joinType match { + case ExistenceJoin(_) => + left.metadataOutput + case LeftExistence(_) => + left.metadataOutput + case _ => + children.flatMap(_.metadataOutput) + } + } + override protected lazy val validConstraints: ExpressionSet = { joinType match { case _: InnerLike if condition.isDefined => @@ -440,6 +458,7 @@ case class InsertIntoDir( extends UnaryNode { override def output: Seq[Attribute] = Seq.empty + override def metadataOutput: Seq[Attribute] = Nil override lazy val resolved: Boolean = false } @@ -466,6 +485,8 @@ case class View( override def output: Seq[Attribute] = child.output + override def metadataOutput: Seq[Attribute] = Nil + override def simpleString(maxFields: Int): String = { s"View (${desc.identifier}, ${output.mkString("[", ",", "]")})" } @@ -647,6 +668,7 @@ case class Aggregate( } override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) + override def metadataOutput: Seq[Attribute] = Nil override def maxRows: Option[Long] = { if (groupingExpressions.isEmpty) { Some(1L) @@ -782,6 +804,8 @@ case class Expand( override lazy val references: AttributeSet = AttributeSet(projections.flatten.flatMap(_.references)) + override def metadataOutput: Seq[Attribute] = Nil + override def producedAttributes: AttributeSet = AttributeSet(output diff child.output) // This operator can reuse attributes (for example making them null when doing a roll up) so @@ -818,6 +842,7 @@ case class Pivot( } groupByExprsOpt.getOrElse(Seq.empty).map(_.toAttribute) ++ pivotAgg } + override def metadataOutput: Seq[Attribute] = Nil } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index c4abed32bf624..e0dcdc794b7c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -2690,6 +2690,109 @@ class DataSourceV2SQLSuite } } + test("SPARK-34923: do not propagate metadata columns through Project") { + val t1 = s"${catalogAndNamespace}table" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + + "PARTITIONED BY (bucket(4, id), id)") + sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") + + assertThrows[AnalysisException] { + sql(s"SELECT index, _partition from (SELECT id, data FROM $t1)") + } + assertThrows[AnalysisException] { + spark.table(t1).select("id", "data").select("index", "_partition") + } + } + } + + test("SPARK-34923: do not propagate metadata columns through View") { + val t1 = s"${catalogAndNamespace}table" + val view = "view" + + withTable(t1) { + withTempView(view) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + + "PARTITIONED BY (bucket(4, id), id)") + sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") + sql(s"CACHE TABLE $view AS SELECT * FROM $t1") + assertThrows[AnalysisException] { + sql(s"SELECT index, _partition FROM $view") + } + } + } + } + + test("SPARK-34923: propagate metadata columns through Filter") { + val t1 = s"${catalogAndNamespace}table" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + + "PARTITIONED BY (bucket(4, id), id)") + sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") + + val sqlQuery = spark.sql(s"SELECT id, data, index, _partition FROM $t1 WHERE id > 1") + val dfQuery = spark.table(t1).where("id > 1").select("id", "data", "index", "_partition") + + Seq(sqlQuery, dfQuery).foreach { query => + checkAnswer(query, Seq(Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) + } + } + } + + test("SPARK-34923: propagate metadata columns through Sort") { + val t1 = s"${catalogAndNamespace}table" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + + "PARTITIONED BY (bucket(4, id), id)") + sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") + + val sqlQuery = spark.sql(s"SELECT id, data, index, _partition FROM $t1 ORDER BY id") + val dfQuery = spark.table(t1).orderBy("id").select("id", "data", "index", "_partition") + + Seq(sqlQuery, dfQuery).foreach { query => + checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) + } + } + } + + test("SPARK-34923: propagate metadata columns through RepartitionBy") { + val t1 = s"${catalogAndNamespace}table" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + + "PARTITIONED BY (bucket(4, id), id)") + sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") + + val sqlQuery = spark.sql( + s"SELECT /*+ REPARTITION_BY_RANGE(3, id) */ id, data, index, _partition FROM $t1") + val tbl = spark.table(t1) + val dfQuery = tbl.repartitionByRange(3, tbl.col("id")) + .select("id", "data", "index", "_partition") + + Seq(sqlQuery, dfQuery).foreach { query => + checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) + } + } + } + + test("SPARK-34923: propagate metadata columns through SubqueryAlias") { + val t1 = s"${catalogAndNamespace}table" + val sbq = "sbq" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + + "PARTITIONED BY (bucket(4, id), id)") + sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')") + + val sqlQuery = spark.sql( + s"SELECT $sbq.id, $sbq.data, $sbq.index, $sbq._partition FROM $t1 as $sbq") + val dfQuery = spark.table(t1).as(sbq).select( + s"$sbq.id", s"$sbq.data", s"$sbq.index", s"$sbq._partition") + + Seq(sqlQuery, dfQuery).foreach { query => + checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3"))) + } + } + } + private def testNotSupportedV2Command(sqlCommand: String, sqlParams: String): Unit = { val e = intercept[AnalysisException] { sql(s"$sqlCommand $sqlParams")