Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ abstract class LogicalPlan
with QueryPlanConstraints
with Logging {

/** Metadata fields that can be projected from this node */
/**
* Metadata fields that can be projected from this node.
* Should be overridden if the plan does not propagate its children's output.
*/
def metadataOutput: Seq[Attribute] = children.flatMap(_.metadataOutput)

/** Returns true if this subtree has data from a streaming data source. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ object Subquery {
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan)
extends OrderPreservingUnaryNode {
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
override def metadataOutput: Seq[Attribute] = Nil
override def maxRows: Option[Long] = child.maxRows

override lazy val resolved: Boolean = {
Expand Down Expand Up @@ -187,6 +188,8 @@ case class Intersect(
leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable)
}

override def metadataOutput: Seq[Attribute] = Nil

override protected lazy val validConstraints: ExpressionSet =
leftConstraints.union(rightConstraints)

Expand All @@ -207,6 +210,8 @@ case class Except(
/** We don't use right.output because those rows get excluded from the set. */
override def output: Seq[Attribute] = left.output

override def metadataOutput: Seq[Attribute] = Nil

override protected lazy val validConstraints: ExpressionSet = leftConstraints
}

Expand Down Expand Up @@ -270,6 +275,8 @@ case class Union(
}
}

override def metadataOutput: Seq[Attribute] = Nil

override lazy val resolved: Boolean = {
// allChildrenCompatible needs to be evaluated after childrenResolved
def allChildrenCompatible: Boolean =
Expand Down Expand Up @@ -364,6 +371,17 @@ case class Join(
}
}

override def metadataOutput: Seq[Attribute] = {
joinType match {
case ExistenceJoin(_) =>
left.metadataOutput
case LeftExistence(_) =>
left.metadataOutput
case _ =>
children.flatMap(_.metadataOutput)
}
}

override protected lazy val validConstraints: ExpressionSet = {
joinType match {
case _: InnerLike if condition.isDefined =>
Expand Down Expand Up @@ -440,6 +458,7 @@ case class InsertIntoDir(
extends UnaryNode {

override def output: Seq[Attribute] = Seq.empty
override def metadataOutput: Seq[Attribute] = Nil
override lazy val resolved: Boolean = false
}

Expand All @@ -466,6 +485,8 @@ case class View(

override def output: Seq[Attribute] = child.output

override def metadataOutput: Seq[Attribute] = Nil

override def simpleString(maxFields: Int): String = {
s"View (${desc.identifier}, ${output.mkString("[", ",", "]")})"
}
Expand Down Expand Up @@ -647,6 +668,7 @@ case class Aggregate(
}

override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
override def metadataOutput: Seq[Attribute] = Nil
override def maxRows: Option[Long] = {
if (groupingExpressions.isEmpty) {
Some(1L)
Expand Down Expand Up @@ -782,6 +804,8 @@ case class Expand(
override lazy val references: AttributeSet =
AttributeSet(projections.flatten.flatMap(_.references))

override def metadataOutput: Seq[Attribute] = Nil

override def producedAttributes: AttributeSet = AttributeSet(output diff child.output)

// This operator can reuse attributes (for example making them null when doing a roll up) so
Expand Down Expand Up @@ -818,6 +842,7 @@ case class Pivot(
}
groupByExprsOpt.getOrElse(Seq.empty).map(_.toAttribute) ++ pivotAgg
}
override def metadataOutput: Seq[Attribute] = Nil
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2690,6 +2690,109 @@ class DataSourceV2SQLSuite
}
}

test("SPARK-34923: do not propagate metadata columns through Project") {
val t1 = s"${catalogAndNamespace}table"
withTable(t1) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
"PARTITIONED BY (bucket(4, id), id)")
sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")

assertThrows[AnalysisException] {
sql(s"SELECT index, _partition from (SELECT id, data FROM $t1)")
}
assertThrows[AnalysisException] {
spark.table(t1).select("id", "data").select("index", "_partition")
}
}
}

test("SPARK-34923: do not propagate metadata columns through View") {
val t1 = s"${catalogAndNamespace}table"
val view = "view"

withTable(t1) {
withTempView(view) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
"PARTITIONED BY (bucket(4, id), id)")
sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")
sql(s"CACHE TABLE $view AS SELECT * FROM $t1")
assertThrows[AnalysisException] {
sql(s"SELECT index, _partition FROM $view")
}
}
}
}

test("SPARK-34923: propagate metadata columns through Filter") {
val t1 = s"${catalogAndNamespace}table"
withTable(t1) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
"PARTITIONED BY (bucket(4, id), id)")
sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")

val sqlQuery = spark.sql(s"SELECT id, data, index, _partition FROM $t1 WHERE id > 1")
val dfQuery = spark.table(t1).where("id > 1").select("id", "data", "index", "_partition")

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query, Seq(Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
}
}
}

test("SPARK-34923: propagate metadata columns through Sort") {
val t1 = s"${catalogAndNamespace}table"
withTable(t1) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
"PARTITIONED BY (bucket(4, id), id)")
sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")

val sqlQuery = spark.sql(s"SELECT id, data, index, _partition FROM $t1 ORDER BY id")
val dfQuery = spark.table(t1).orderBy("id").select("id", "data", "index", "_partition")

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
}
}
}

test("SPARK-34923: propagate metadata columns through RepartitionBy") {
val t1 = s"${catalogAndNamespace}table"
withTable(t1) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
"PARTITIONED BY (bucket(4, id), id)")
sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")

val sqlQuery = spark.sql(
s"SELECT /*+ REPARTITION_BY_RANGE(3, id) */ id, data, index, _partition FROM $t1")
val tbl = spark.table(t1)
val dfQuery = tbl.repartitionByRange(3, tbl.col("id"))
.select("id", "data", "index", "_partition")

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
}
}
}

test("SPARK-34923: propagate metadata columns through SubqueryAlias") {
val t1 = s"${catalogAndNamespace}table"
val sbq = "sbq"
withTable(t1) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
"PARTITIONED BY (bucket(4, id), id)")
sql(s"INSERT INTO $t1 VALUES (1, 'a'), (2, 'b'), (3, 'c')")

val sqlQuery = spark.sql(
s"SELECT $sbq.id, $sbq.data, $sbq.index, $sbq._partition FROM $t1 as $sbq")
val dfQuery = spark.table(t1).as(sbq).select(
s"$sbq.id", s"$sbq.data", s"$sbq.index", s"$sbq._partition")

Seq(sqlQuery, dfQuery).foreach { query =>
checkAnswer(query, Seq(Row(1, "a", 0, "3/1"), Row(2, "b", 0, "0/2"), Row(3, "c", 0, "1/3")))
}
}
}

private def testNotSupportedV2Command(sqlCommand: String, sqlParams: String): Unit = {
val e = intercept[AnalysisException] {
sql(s"$sqlCommand $sqlParams")
Expand Down