From a556ee63fffeeb4df35784871151f0f5fba42b05 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 27 Sep 2017 20:50:15 +0800 Subject: [PATCH 1/3] improve documents and minor clean up --- .../java/org/apache/spark/sql/sources/v2/ReadSupport.java | 5 ++--- .../apache/spark/sql/sources/v2/ReadSupportWithSchema.java | 5 ++--- .../org/apache/spark/sql/sources/v2/reader/DataReader.java | 4 ++++ .../spark/sql/sources/v2/reader/DataSourceV2Reader.java | 2 +- .../org/apache/spark/sql/sources/v2/reader/ReadTask.java | 3 ++- .../main/scala/org/apache/spark/sql/DataFrameReader.scala | 5 ++--- .../sql/execution/datasources/v2/DataSourceV2ScanExec.scala | 2 +- 7 files changed, 14 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java index ab5254a688d5..ee489ad0f608 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java @@ -30,9 +30,8 @@ public interface ReadSupport { /** * Creates a {@link DataSourceV2Reader} to scan the data from this data source. * - * @param options the options for this data source reader, which is an immutable case-insensitive - * string-to-string map. - * @return a reader that implements the actual read logic. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. */ DataSourceV2Reader createReader(DataSourceV2Options options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java index c13aeca2ef36..74e81a2c84d6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java @@ -39,9 +39,8 @@ public interface ReadSupportWithSchema { * physical schema of the underlying storage of this data source reader, e.g. * CSV files, JSON files, etc, while this reader may not read data with full * schema, as column pruning or other optimizations may happen. - * @param options the options for this data source reader, which is an immutable case-insensitive - * string-to-string map. - * @return a reader that implements the actual read logic. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. */ DataSourceV2Reader createReader(StructType schema, DataSourceV2Options options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java index cfafc1a57679..95e091569b61 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java @@ -24,6 +24,10 @@ /** * A data reader returned by {@link ReadTask#createReader()} and is responsible for outputting data * for a RDD partition. + * + * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data + * source readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for data source + * readers that mix in {@link SupportsScanUnsafeRow}. */ @InterfaceStability.Evolving public interface DataReader extends Closeable { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java index fb4d5c0d7ae4..5989a4ac8440 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java @@ -30,7 +30,7 @@ * {@link org.apache.spark.sql.sources.v2.ReadSupportWithSchema#createReader( * StructType, org.apache.spark.sql.sources.v2.DataSourceV2Options)}. * It can mix in various query optimization interfaces to speed up the data scan. The actual scan - * logic should be delegated to {@link ReadTask}s that are returned by {@link #createReadTasks()}. + * logic is delegated to {@link ReadTask}s that are returned by {@link #createReadTasks()}. * * There are mainly 3 kinds of query optimizations: * 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java index 7885bfcdd49e..01362df0978c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java @@ -27,7 +27,8 @@ * is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}. * * Note that, the read task will be serialized and sent to executors, then the data reader will be - * created on executors and do the actual reading. + * created on executors and do the actual reading. So {@link ReadTask} must be serializable and + * {@link DataReader} doesn't need to be. */ @InterfaceStability.Evolving public interface ReadTask extends Serializable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 78b668c04fd5..17966eecfc05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -184,7 +184,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val dataSource = cls.newInstance() val options = new DataSourceV2Options(extraOptions.asJava) val reader = (cls.newInstance(), userSpecifiedSchema) match { @@ -194,8 +193,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { case (ds: ReadSupport, None) => ds.createReader(options) - case (_: ReadSupportWithSchema, None) => - throw new AnalysisException(s"A schema needs to be specified when using $dataSource.") + case (ds: ReadSupportWithSchema, None) => + throw new AnalysisException(s"A schema needs to be specified when using $ds.") case (ds: ReadSupport, Some(schema)) => val reader = ds.createReader(options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 7999c0ceb574..9352217ae844 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -74,7 +74,7 @@ class RowToUnsafeRowReadTask(rowReadTask: ReadTask[Row], schema: StructType) override def preferredLocations: Array[String] = rowReadTask.preferredLocations override def createReader: DataReader[UnsafeRow] = { - new RowToUnsafeDataReader(rowReadTask.createReader, RowEncoder.apply(schema)) + new RowToUnsafeDataReader(rowReadTask.createReader, RowEncoder.apply(schema).resolveAndBind()) } } From e8e8feeeb54ae3e4f79157edb8b4f69886036cd0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 4 Oct 2017 09:37:25 +0800 Subject: [PATCH 2/3] push down operators to data source before planning --- .../SupportsPushDownCatalystFilters.java | 8 + .../v2/reader/SupportsPushDownFilters.java | 8 + .../spark/sql/execution/SparkOptimizer.scala | 4 +- .../v2/DataSourceReaderHolder.scala | 64 ++++++++ .../datasources/v2/DataSourceV2Relation.scala | 6 +- .../datasources/v2/DataSourceV2ScanExec.scala | 18 +-- .../datasources/v2/DataSourceV2Strategy.scala | 60 +------- .../v2/PushDownOperatorsToDataSource.scala | 140 ++++++++++++++++++ .../sources/v2/JavaAdvancedDataSourceV2.java | 5 + .../sql/sources/v2/DataSourceV2Suite.scala | 2 + 10 files changed, 240 insertions(+), 75 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java index 19d706238ec8..d6091774d75a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java @@ -40,4 +40,12 @@ public interface SupportsPushDownCatalystFilters { * Pushes down filters, and returns unsupported filters. */ Expression[] pushCatalystFilters(Expression[] filters); + + /** + * Returns the catalyst filters that are pushed in {@link #pushCatalystFilters(Expression[])}. + * It's possible that there is no filters in the query and + * {@link #pushCatalystFilters(Expression[])} is never called, empty array should be returned for + * this case. + */ + Expression[] pushedCatalystFilters(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index d4b509e7080f..d6f297c01337 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources.v2.reader; import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.sources.Filter; /** @@ -35,4 +36,11 @@ public interface SupportsPushDownFilters { * Pushes down filters, and returns unsupported filters. */ Filter[] pushFilters(Filter[] filters); + + /** + * Returns the filters that are pushed in {@link #pushFilters(Filter[])}. + * It's possible that there is no filters in the query and {@link #pushFilters(Filter[])} + * is never called, empty array should be returned for this case. + */ + Filter[] pushedFilters(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 00ff4c8ac310..1c8e4050978d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.ExperimentalMethods import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions +import org.apache.spark.sql.execution.datasources.v2.PushDownOperatorsToDataSource import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate class SparkOptimizer( @@ -31,7 +32,8 @@ class SparkOptimizer( override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ - Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++ + Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ + Batch("Push down operators to data source scan", Once, PushDownOperatorsToDataSource)) ++ postHocOptimizationBatches :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala new file mode 100644 index 000000000000..086d9b4f1de4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import java.util.Objects + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.sources.v2.reader._ + +/** + * A base class for data source reader holder and defines equals/hashCode methods. + */ +trait DataSourceReaderHolder { + def fullOutput: Seq[AttributeReference] + def reader: DataSourceV2Reader + + override def equals(other: Any): Boolean = other match { + case other: DataSourceV2Relation => + val basicEquals = this.fullOutput == other.fullOutput && + this.reader.getClass == other.reader.getClass && + this.reader.readSchema() == other.reader.readSchema() + + val samePushedFilters = (this.reader, other.reader) match { + case (l: SupportsPushDownCatalystFilters, r: SupportsPushDownCatalystFilters) => + l.pushedCatalystFilters().toSeq == r.pushedCatalystFilters().toSeq + case (l: SupportsPushDownFilters, r: SupportsPushDownFilters) => + l.pushedFilters().toSeq == r.pushedFilters().toSeq + case _ => true + } + + basicEquals && samePushedFilters + + case _ => false + } + + override def hashCode(): Int = { + val state = Seq(fullOutput, reader.getClass, reader.readSchema()) + val filters: Any = reader match { + case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSeq + case s: SupportsPushDownFilters => s.pushedFilters().toSeq + case _ => Nil + } + (state :+ filters).map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) + } + + lazy val output: Seq[Attribute] = reader.readSchema().map(_.name).map { name => + fullOutput.find(_.name == name).get + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 3c9b598fd07c..900330c57df6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} -import org.apache.spark.sql.sources.v2.reader.{DataSourceV2Reader, SupportsReportStatistics} +import org.apache.spark.sql.sources.v2.reader._ case class DataSourceV2Relation( - output: Seq[AttributeReference], - reader: DataSourceV2Reader) extends LeafNode { + fullOutput: Seq[AttributeReference], + reader: DataSourceV2Reader) extends LeafNode with DataSourceReaderHolder { override def computeStats(): Statistics = reader match { case r: SupportsReportStatistics => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 9352217ae844..32e8d168c146 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -29,20 +29,12 @@ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.types.StructType +/** + * Physical plan node for scanning data from a data source. + */ case class DataSourceV2ScanExec( - fullOutput: Array[AttributeReference], - @transient reader: DataSourceV2Reader, - // TODO: these 3 parameters are only used to determine the equality of the scan node, however, - // the reader also have this information, and ideally we can just rely on the equality of the - // reader. The only concern is, the reader implementation is outside of Spark and we have no - // control. - readSchema: StructType, - @transient filters: ExpressionSet, - hashPartitionKeys: Seq[String]) extends LeafExecNode { - - def output: Seq[Attribute] = readSchema.map(_.name).map { name => - fullOutput.find(_.name == name).get - } + fullOutput: Seq[AttributeReference], + @transient reader: DataSourceV2Reader) extends LeafExecNode with DataSourceReaderHolder { override def references: AttributeSet = AttributeSet.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index b80f695b2a87..f2cda002245e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -29,64 +29,8 @@ import org.apache.spark.sql.sources.v2.reader._ object DataSourceV2Strategy extends Strategy { // TODO: write path override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalOperation(projects, filters, DataSourceV2Relation(output, reader)) => - val stayUpFilters: Seq[Expression] = reader match { - case r: SupportsPushDownCatalystFilters => - r.pushCatalystFilters(filters.toArray) - - case r: SupportsPushDownFilters => - // A map from original Catalyst expressions to corresponding translated data source - // filters. If a predicate is not in this map, it means it cannot be pushed down. - val translatedMap: Map[Expression, Filter] = filters.flatMap { p => - DataSourceStrategy.translateFilter(p).map(f => p -> f) - }.toMap - - // Catalyst predicate expressions that cannot be converted to data source filters. - val nonConvertiblePredicates = filters.filterNot(translatedMap.contains) - - // Data source filters that cannot be pushed down. An unhandled filter means - // the data source cannot guarantee the rows returned can pass the filter. - // As a result we must return it so Spark can plan an extra filter operator. - val unhandledFilters = r.pushFilters(translatedMap.values.toArray).toSet - val unhandledPredicates = translatedMap.filter { case (_, f) => - unhandledFilters.contains(f) - }.keys - - nonConvertiblePredicates ++ unhandledPredicates - - case _ => filters - } - - val attrMap = AttributeMap(output.zip(output)) - val projectSet = AttributeSet(projects.flatMap(_.references)) - val filterSet = AttributeSet(stayUpFilters.flatMap(_.references)) - - // Match original case of attributes. - // TODO: nested fields pruning - val requiredColumns = (projectSet ++ filterSet).toSeq.map(attrMap) - reader match { - case r: SupportsPushDownRequiredColumns => - r.pruneColumns(requiredColumns.toStructType) - case _ => - } - - val scan = DataSourceV2ScanExec( - output.toArray, - reader, - reader.readSchema(), - ExpressionSet(filters), - Nil) - - val filterCondition = stayUpFilters.reduceLeftOption(And) - val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) - - val withProject = if (projects == withFilter.output) { - withFilter - } else { - ProjectExec(projects, withFilter) - } - - withProject :: Nil + case DataSourceV2Relation(output, reader) => + DataSourceV2ScanExec(output, reader) :: Nil case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala new file mode 100644 index 000000000000..0c1708131ae4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, Expression, NamedExpression, PredicateHelper} +import org.apache.spark.sql.catalyst.optimizer.RemoveRedundantProject +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.sources +import org.apache.spark.sql.sources.v2.reader._ + +/** + * Pushes down various operators to the underlying data source for better performance. Operators are + * being pushed down with a specific order. As an example, given a LIMIT has a FILTER child, you + * can't push down LIMIT if FILTER is not completely pushed down. When both are pushed down, the + * data source should execute FILTER before LIMIT. And required columns are calculated at the end, + * because when more operators are pushed down, we may need less columns at Spark side. + */ +object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHelper { + override def apply(plan: LogicalPlan): LogicalPlan = { + // Note that, we need to collect the target operator along with PROJECT node, as PROJECT may + // appear in many places for column pruning. + // TODO: Ideally column pruning should be implemented via a plan property that is propagated + // top-down, then we can simplify the logic here and only collect target operators. + val filterPushed = plan transformUp { + case FilterAndProject(fields, condition, r @ DataSourceV2Relation(_, reader)) => + // Non-deterministic expressions are stateful and we must keep the input sequence unchanged + // to avoid changing the result. This means, we can't evaluate the filter conditions that + // are after the first non-deterministic condition ahead. Here we only try to push down + // deterministic conditions that are before the first non-deterministic condition. + val (candidates, containingNonDeterministic) = + splitConjunctivePredicates(condition).span(_.deterministic) + + val stayUpFilters: Seq[Expression] = reader match { + case r: SupportsPushDownCatalystFilters => + r.pushCatalystFilters(candidates.toArray) + + case r: SupportsPushDownFilters => + // A map from original Catalyst expressions to corresponding translated data source + // filters. If a predicate is not in this map, it means it cannot be pushed down. + val translatedMap: Map[Expression, sources.Filter] = candidates.flatMap { p => + DataSourceStrategy.translateFilter(p).map(f => p -> f) + }.toMap + + // Catalyst predicate expressions that cannot be converted to data source filters. + val nonConvertiblePredicates = candidates.filterNot(translatedMap.contains) + + // Data source filters that cannot be pushed down. An unhandled filter means + // the data source cannot guarantee the rows returned can pass the filter. + // As a result we must return it so Spark can plan an extra filter operator. + val unhandledFilters = r.pushFilters(translatedMap.values.toArray).toSet + val unhandledPredicates = translatedMap.filter { case (_, f) => + unhandledFilters.contains(f) + }.keys + + nonConvertiblePredicates ++ unhandledPredicates + + case _ => candidates + } + + val filterCondition = (stayUpFilters ++ containingNonDeterministic).reduceLeftOption(And) + val withFilter = filterCondition.map(Filter(_, r)).getOrElse(r) + if (withFilter.output == fields) { + withFilter + } else { + Project(fields, withFilter) + } + } + + // TODO: add more push down rules. + + // TODO: nested fields pruning + def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: Seq[Attribute]): Unit = { + plan match { + case Project(projectList, child) => + val required = projectList.filter(requiredByParent.contains).flatMap(_.references) + pushDownRequiredColumns(child, required) + + case Filter(condition, child) => + val required = requiredByParent ++ condition.references + pushDownRequiredColumns(child, required) + + case DataSourceV2Relation(fullOutput, reader) => reader match { + case r: SupportsPushDownRequiredColumns => + // Match original case of attributes. + val attrMap = AttributeMap(fullOutput.zip(fullOutput)) + val requiredColumns = requiredByParent.map(attrMap) + r.pruneColumns(requiredColumns.toStructType) + case _ => + } + + // TODO: there may be more operators can be used to calculate required columns, we can add + // more and more in the future. + case _ => plan.children.foreach(child => pushDownRequiredColumns(child, child.output)) + } + } + + pushDownRequiredColumns(filterPushed, filterPushed.output) + // After column pruning, we may have redundant PROJECT nodes in the query plan, remove them. + RemoveRedundantProject(filterPushed) + } + + /** + * Finds a Filter node(with an optional Project child) above data source relation. + */ + object FilterAndProject { + // returns the project list, the filter condition and the data source relation. + def unapply(plan: LogicalPlan) + : Option[(Seq[NamedExpression], Expression, DataSourceV2Relation)] = plan match { + + case Filter(condition, r: DataSourceV2Relation) => Some((r.output, condition, r)) + + case Filter(condition, Project(fields, r: DataSourceV2Relation)) + if fields.forall(_.deterministic) => + val attributeMap = AttributeMap(fields.map(e => e.toAttribute -> e)) + val substituted = condition.transform { + case a: Attribute => attributeMap.getOrElse(a, a) + } + Some((fields, substituted, r)) + + case _ => None + } + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 7aacf0346d2f..da2c13f70c52 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -54,6 +54,11 @@ public Filter[] pushFilters(Filter[] filters) { return new Filter[0]; } + @Override + public Filter[] pushedFilters() { + return filters; + } + @Override public List> createReadTasks() { List> res = new ArrayList<>(); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 9ce93d7ae926..f238e565dc2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -129,6 +129,8 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { Array.empty } + override def pushedFilters(): Array[Filter] = filters + override def readSchema(): StructType = { requiredSchema } From 200cd204aa25c1571216047ba2da523fb14a612b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 12 Oct 2017 15:15:48 +0800 Subject: [PATCH 3/3] address comments --- .../v2/DataSourceReaderHolder.scala | 48 ++++++++++--------- .../datasources/v2/DataSourceV2Relation.scala | 2 + .../datasources/v2/DataSourceV2ScanExec.scala | 2 + 3 files changed, 30 insertions(+), 22 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala index 086d9b4f1de4..6093df26630c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala @@ -23,39 +23,43 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.sources.v2.reader._ /** - * A base class for data source reader holder and defines equals/hashCode methods. + * A base class for data source reader holder with customized equals/hashCode methods. */ trait DataSourceReaderHolder { + + /** + * The full output of the data source reader, without column pruning. + */ def fullOutput: Seq[AttributeReference] - def reader: DataSourceV2Reader - override def equals(other: Any): Boolean = other match { - case other: DataSourceV2Relation => - val basicEquals = this.fullOutput == other.fullOutput && - this.reader.getClass == other.reader.getClass && - this.reader.readSchema() == other.reader.readSchema() + /** + * The held data source reader. + */ + def reader: DataSourceV2Reader - val samePushedFilters = (this.reader, other.reader) match { - case (l: SupportsPushDownCatalystFilters, r: SupportsPushDownCatalystFilters) => - l.pushedCatalystFilters().toSeq == r.pushedCatalystFilters().toSeq - case (l: SupportsPushDownFilters, r: SupportsPushDownFilters) => - l.pushedFilters().toSeq == r.pushedFilters().toSeq - case _ => true - } + /** + * The metadata of this data source reader that can be used for equality test. + */ + private def metadata: Seq[Any] = { + val filters: Any = reader match { + case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet + case s: SupportsPushDownFilters => s.pushedFilters().toSet + case _ => Nil + } + Seq(fullOutput, reader.getClass, reader.readSchema(), filters) + } - basicEquals && samePushedFilters + def canEqual(other: Any): Boolean + override def equals(other: Any): Boolean = other match { + case other: DataSourceReaderHolder => + canEqual(other) && metadata.length == other.metadata.length && + metadata.zip(other.metadata).forall { case (l, r) => l == r } case _ => false } override def hashCode(): Int = { - val state = Seq(fullOutput, reader.getClass, reader.readSchema()) - val filters: Any = reader match { - case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSeq - case s: SupportsPushDownFilters => s.pushedFilters().toSeq - case _ => Nil - } - (state :+ filters).map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) + metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) } lazy val output: Seq[Attribute] = reader.readSchema().map(_.name).map { name => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 900330c57df6..7eb99a645001 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -25,6 +25,8 @@ case class DataSourceV2Relation( fullOutput: Seq[AttributeReference], reader: DataSourceV2Reader) extends LeafNode with DataSourceReaderHolder { + override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation] + override def computeStats(): Statistics = reader match { case r: SupportsReportStatistics => Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 32e8d168c146..addc12a3f090 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -36,6 +36,8 @@ case class DataSourceV2ScanExec( fullOutput: Seq[AttributeReference], @transient reader: DataSourceV2Reader) extends LeafExecNode with DataSourceReaderHolder { + override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] + override def references: AttributeSet = AttributeSet.empty override lazy val metrics = Map(