diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala index 91080b15727d..840fcae8c691 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala @@ -116,10 +116,28 @@ private[sql] object ParquetSchemaPruning extends Rule[LogicalPlan] { // For example, for a query `SELECT name.first FROM contacts WHERE name IS NOT NULL`, // we don't need to read nested fields of `name` struct other than `first` field. val (rootFields, optRootFields) = (projectionRootFields ++ filterRootFields) - .distinct.partition(_.contentAccessed) + .distinct.partition(!_.prunedIfAnyChildAccessed) optRootFields.filter { opt => - !rootFields.exists(_.field.name == opt.field.name) + !rootFields.exists { root => + root.field.name == opt.field.name && { + // Checking if current optional root field can be pruned. + // For each required root field, we merge it with the optional root field: + // 1. If this optional root field has nested fields and any nested field of it is used + // in the query, the merged field type must equal to the optional root field type. + // We can prune this optional root field. For example, for optional root field + // `struct>`, if its field + // `struct>` is used, we don't need to add this optional + // root field. + // 2. If this optional root field has no nested fields, the merged field type equals + // to the optional root field only if they are the same. If they are, we can prune + // this optional root field too. + val rootFieldType = StructType(Array(root.field)) + val optFieldType = StructType(Array(opt.field)) + val merged = optFieldType.merge(rootFieldType) + merged.sameType(optFieldType) + } + } } ++ rootFields } @@ -213,11 +231,11 @@ private[sql] object ParquetSchemaPruning extends Rule[LogicalPlan] { // don't actually use any nested fields. These root field accesses might be excluded later // if there are any nested fields accesses in the query plan. case IsNotNull(SelectedField(field)) => - RootField(field, derivedFromAtt = false, contentAccessed = false) :: Nil + RootField(field, derivedFromAtt = false, prunedIfAnyChildAccessed = true) :: Nil case IsNull(SelectedField(field)) => - RootField(field, derivedFromAtt = false, contentAccessed = false) :: Nil + RootField(field, derivedFromAtt = false, prunedIfAnyChildAccessed = true) :: Nil case IsNotNull(_: Attribute) | IsNull(_: Attribute) => - expr.children.flatMap(getRootFields).map(_.copy(contentAccessed = false)) + expr.children.flatMap(getRootFields).map(_.copy(prunedIfAnyChildAccessed = true)) case _ => expr.children.flatMap(getRootFields) } @@ -271,9 +289,9 @@ private[sql] object ParquetSchemaPruning extends Rule[LogicalPlan] { /** * This represents a "root" schema field (aka top-level, no-parent). `field` is the * `StructField` for field name and datatype. `derivedFromAtt` indicates whether it - * was derived from an attribute or had a proper child. `contentAccessed` means whether - * it was accessed with its content by the expressions refer it. + * was derived from an attribute or had a proper child. `prunedIfAnyChildAccessed` means + * whether this root field can be pruned if any of child field is used in the query. */ private case class RootField(field: StructField, derivedFromAtt: Boolean, - contentAccessed: Boolean = true) + prunedIfAnyChildAccessed: Boolean = false) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala index 434c4414edeb..966190e12c6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.SchemaPruningTest import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.execution.FileSourceScanExec +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType @@ -217,6 +218,41 @@ class ParquetSchemaPruningSuite Row("Y.") :: Nil) } + testSchemaPruning("select one complex field and having is null predicate on another " + + "complex field") { + val query = sql("select * from contacts") + .where("name.middle is not null") + .select( + "id", + "name.first", + "name.middle", + "name.last" + ) + .where("last = 'Jones'") + .select(count("id")).toDF() + checkScan(query, + "struct>") + checkAnswer(query, Row(0) :: Nil) + } + + testSchemaPruning("select one deep nested complex field and having is null predicate on " + + "another deep nested complex field") { + val query = sql("select * from contacts") + .where("employer.company.address is not null") + .selectExpr( + "id", + "name.first", + "name.middle", + "name.last", + "employer.id as employer_id" + ) + .where("employer_id = 0") + .select(count("id")).toDF() + checkScan(query, + "struct>>") + checkAnswer(query, Row(1) :: Nil) + } + private def testSchemaPruning(testName: String)(testThunk: => Unit) { test(s"Spark vectorized reader - without partition data column - $testName") { withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") {