From b698defbb640f11e2a9b9368f1be4e7c97d104f8 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Fri, 21 Jun 2024 18:30:03 +0200 Subject: [PATCH 01/16] initial working version --- .../apache/spark/sql/sources/filters.scala | 12 ++ .../sql/execution/DataSourceScanExec.scala | 2 +- .../datasources/DataSourceStrategy.scala | 60 +++++--- .../datasources/DataSourceUtils.scala | 21 +-- .../execution/datasources/FileFormat.scala | 6 - .../datasources/FileSourceStrategy.scala | 6 +- .../PruneFileSourcePartitions.scala | 3 +- .../datasources/v2/FileScanBuilder.scala | 15 +- .../datasources/v2/PushDownUtils.scala | 5 +- .../spark/sql/FileBasedDataSourceSuite.scala | 50 +++--- ...CollatedFilterPushDownToReadersSuite.scala | 143 ++++++++++++++++++ .../datasources/DataSourceStrategySuite.scala | 2 +- .../parquet/ParquetFilterSuite.scala | 3 +- 13 files changed, 243 insertions(+), 85 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToReadersSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala index a52bca106605..2c6b2dadd81f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -352,6 +352,18 @@ case class StringContains(attribute: String, value: String) extends Filter { Array(toV2Column(attribute), LiteralValue(UTF8String.fromString(value), StringType))) } +/** + * A. + * @param filter a. + * @param fullyTranslated a. + */ +@Evolving +case class TranslatedFilter(filter: Filter, fullyTranslated: Boolean) { + def withFilter(newFilter: Filter): TranslatedFilter = { + copy(filter = newFilter) + } +} + /** * A filter that always evaluates to `true`. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 2ebbb9664f67..220742338b05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -413,7 +413,7 @@ trait FileSourceScanLike extends DataSourceScanExec { scalarSubqueryReplaced.filterNot(_.references.exists { case FileSourceConstantMetadataAttribute(_) => true case _ => false - }).flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) + }).flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown).map(_.filter)) } // This field may execute subquery expressions and should not be accessed during planning. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 7cda347ce581..6738ad68069b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -54,7 +54,7 @@ import org.apache.spark.sql.execution.streaming.StreamingRelation import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.sql.util.{PartitioningUtils => CatalystPartitioningUtils} -import org.apache.spark.sql.util.{CaseInsensitiveStringMap, SchemaUtils} +import org.apache.spark.sql.util.{CaseInsensitiveStringMap} import org.apache.spark.unsafe.types.UTF8String /** @@ -573,10 +573,11 @@ object DataSourceStrategy /** * Tries to translate a Catalyst [[Expression]] into data source [[Filter]]. * - * @return a `Some[Filter]` if the input [[Expression]] is convertible, otherwise a `None`. + * @return a `Some[TranslatedFilter]` if the input [[Expression]] is at least partially + * convertible, otherwise a `None`. */ protected[sql] def translateFilter( - predicate: Expression, supportNestedPredicatePushdown: Boolean): Option[Filter] = { + predicate: Expression, supportNestedPredicatePushdown: Boolean): Option[TranslatedFilter] = { translateFilterWithMapping(predicate, None, supportNestedPredicatePushdown) } @@ -588,19 +589,24 @@ object DataSourceStrategy * translated [[Filter]]. The map is used for rebuilding * [[Expression]] from [[Filter]]. * @param nestedPredicatePushdownEnabled Whether nested predicate pushdown is enabled. - * @return a `Some[Filter]` if the input [[Expression]] is convertible, otherwise a `None`. + * @return a `Some[TranslatedFilter]` if the input [[Expression]] is at least partially + * convertible, otherwise a `None`. */ protected[sql] def translateFilterWithMapping( predicate: Expression, - translatedFilterToExpr: Option[mutable.HashMap[sources.Filter, Expression]], + translatedFilterToExpr: Option[mutable.HashMap[Filter, Expression]], nestedPredicatePushdownEnabled: Boolean) - : Option[Filter] = { + : Option[TranslatedFilter] = { - def translateAndRecordLeafNodeFilter(filter: Expression): Option[Filter] = { + def translateAndRecordLeafNodeFilter(filter: Expression, canBeFullyTranslated: Boolean = true) + : Option[TranslatedFilter] = { val translatedFilter = - translateLeafNodeFilter(filter, PushableColumn(nestedPredicatePushdownEnabled)) + translateLeafNodeFilter(filter, PushableColumn(nestedPredicatePushdownEnabled)) match { + case Some(f) => Some(TranslatedFilter(f, canBeFullyTranslated)) + case None => None + } if (translatedFilter.isDefined && translatedFilterToExpr.isDefined) { - translatedFilterToExpr.get(translatedFilter.get) = predicate + translatedFilterToExpr.get(translatedFilter.get.filter) = predicate } translatedFilter } @@ -621,7 +627,8 @@ object DataSourceStrategy left, translatedFilterToExpr, nestedPredicatePushdownEnabled) rightFilter <- translateFilterWithMapping( right, translatedFilterToExpr, nestedPredicatePushdownEnabled) - } yield sources.And(leftFilter, rightFilter) + } yield TranslatedFilter(sources.And(leftFilter.filter, rightFilter.filter), + leftFilter.fullyTranslated && rightFilter.fullyTranslated) case expressions.Or(left, right) => for { @@ -629,24 +636,28 @@ object DataSourceStrategy left, translatedFilterToExpr, nestedPredicatePushdownEnabled) rightFilter <- translateFilterWithMapping( right, translatedFilterToExpr, nestedPredicatePushdownEnabled) - } yield sources.Or(leftFilter, rightFilter) + } yield TranslatedFilter(sources.Or(leftFilter.filter, rightFilter.filter), + leftFilter.fullyTranslated && rightFilter.fullyTranslated) - case notNull @ expressions.IsNotNull(_: AttributeReference) => + case notNull @ expressions.IsNotNull(_: AttributeReference | _: GetStructField) => // Not null filters on attribute references can always be pushed, also for collated columns. translateAndRecordLeafNodeFilter(notNull) - case isNull @ expressions.IsNull(_: AttributeReference) => + case isNull @ expressions.IsNull(_: AttributeReference | _: GetStructField) => // Is null filters on attribute references can always be pushed, also for collated columns. translateAndRecordLeafNodeFilter(isNull) - case p if p.references.exists(ref => SchemaUtils.hasNonUTF8BinaryCollation(ref.dataType)) => - // The filter cannot be pushed and we widen it to be AlwaysTrue(). This is only valid if - // the result of the filter is not negated by a Not expression it is wrapped in. - translateAndRecordLeafNodeFilter(Literal.TrueLiteral) + case p if DataSourceUtils.hasNonUTF8BinaryCollation(p) => + // The filter cannot be pushed and we widen it to be AlwaysTrue() and set it + // as partially translated so it has to get evaluated by the engine as well. + // This is only valid if the result of the filter is not negated by a + // Not expression it is wrapped in. + translateAndRecordLeafNodeFilter(Literal.TrueLiteral, canBeFullyTranslated = false) case expressions.Not(child) => translateFilterWithMapping(child, translatedFilterToExpr, nestedPredicatePushdownEnabled) - .map(sources.Not) + .map(translatedFilter => + translatedFilter.withFilter(sources.Not(translatedFilter.filter))) case other => translateAndRecordLeafNodeFilter(other) @@ -693,21 +704,22 @@ object DataSourceStrategy // If a predicate is not in this map, it means it cannot be pushed down. val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation) // SPARK-41636: we keep the order of the predicates to avoid CodeGenerator cache misses - val translatedMap: Map[Expression, Filter] = ListMap(predicates.flatMap { p => + val translatedMap: Map[Expression, TranslatedFilter] = ListMap(predicates.flatMap { p => translateFilter(p, supportNestedPredicatePushdown).map(f => p -> f) }: _*) - val pushedFilters: Seq[Filter] = translatedMap.values.toSeq + val pushedFilters: Seq[Filter] = translatedMap.values.map(_.filter).toSeq - // Catalyst predicate expressions that cannot be converted to data source filters. - val nonconvertiblePredicates = predicates.filterNot(translatedMap.contains) + // Catalyst predicate expressions that cannot be fully converted to data source filters. + val nonconvertiblePredicates = predicates.filter(predicate => + !translatedMap.contains(predicate) || !translatedMap(predicate).fullyTranslated) // Data source filters that cannot be handled by `relation`. An unhandled filter means // the data source cannot guarantee the rows returned can pass the filter. // As a result we must return it so Spark can plan an extra filter operator. - val unhandledFilters = relation.unhandledFilters(translatedMap.values.toArray).toSet + val unhandledFilters = relation.unhandledFilters(pushedFilters.toArray).toSet val unhandledPredicates = translatedMap.filter { case (p, f) => - unhandledFilters.contains(f) + unhandledFilters.contains(f.filter) }.keys val handledFilters = pushedFilters.toSet -- unhandledFilters diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index b6aee77577a4..9f79fa6b3ae0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -28,7 +28,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.{SparkException, SparkUpgradeException} import org.apache.spark.sql.{SPARK_LEGACY_DATETIME_METADATA_KEY, SPARK_LEGACY_INT96_METADATA_KEY, SPARK_TIMEZONE_METADATA_KEY, SPARK_VERSION_METADATA_KEY} import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogUtils} -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, ExpressionSet, GetStructField, PredicateHelper} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExpressionSet, PredicateHelper} import org.apache.spark.sql.catalyst.util.RebaseDateTime import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -281,22 +281,9 @@ object DataSourceUtils extends PredicateHelper { } /** - * Determines whether a filter should be pushed down to the data source or not. - * - * @param expression The filter expression to be evaluated. - * @param isCollationPushDownSupported Whether the data source supports collation push down. - * @return A boolean indicating whether the filter should be pushed down or not. + * Determines whether a filter references any columns with non-UTF8 binary collation. */ - def shouldPushFilter(expression: Expression, isCollationPushDownSupported: Boolean): Boolean = { - if (!expression.deterministic) return false - - isCollationPushDownSupported || !expression.exists { - case childExpression @ (_: Attribute | _: GetStructField) => - // don't push down filters for types with non-binary sortable collation - // as it could lead to incorrect results - SchemaUtils.hasNonUTF8BinaryCollation(childExpression.dataType) - - case _ => false - } + def hasNonUTF8BinaryCollation(expression: Expression): Boolean = { + expression.references.exists(ref => SchemaUtils.hasNonUTF8BinaryCollation(ref.dataType)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index 0785b0cbe9e2..36c59950fe20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -223,12 +223,6 @@ trait FileFormat { */ def fileConstantMetadataExtractors: Map[String, PartitionedFile => Any] = FileFormat.BASE_METADATA_EXTRACTORS - - /** - * Returns whether the file format supports filter push down - * for non utf8 binary collated columns. - */ - def supportsCollationPushDown: Boolean = false } object FileFormat { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index d31cb111924b..921872620a63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -160,11 +160,8 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging { // - filters that need to be evaluated again after the scan val filterSet = ExpressionSet(filters) - val filtersToPush = filters.filter(f => - DataSourceUtils.shouldPushFilter(f, fsRelation.fileFormat.supportsCollationPushDown)) - val normalizedFilters = DataSourceStrategy.normalizeExprs( - filtersToPush, l.output) + filters.filter(_.deterministic), l.output) val partitionColumns = l.resolve( @@ -206,6 +203,7 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging { DataSourceUtils.supportNestedPredicatePushdown(fsRelation) val pushedFilters = dataFilters .flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) + .map(_.filter) logInfo(log"Pushed Filters: ${MDC(PUSHED_FILTERS, pushedFilters.mkString(","))}") // Predicates with both partition keys and attributes need to be evaluated after the scan. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index b0431d1df398..1dffea4e1bc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -63,8 +63,7 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { _)) if filters.nonEmpty && fsRelation.partitionSchema.nonEmpty => val normalizedFilters = DataSourceStrategy.normalizeExprs( - filters.filter(f => !SubqueryExpression.hasSubquery(f) && - DataSourceUtils.shouldPushFilter(f, fsRelation.fileFormat.supportsCollationPushDown)), + filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), logicalRelation.output) val (partitionKeyFilters, _) = DataSourceUtils .getPartitionFiltersAndDataFilters(partitionSchema, normalizedFilters) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala index 7cd2779f86f9..342ec79c2077 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala @@ -70,21 +70,20 @@ abstract class FileScanBuilder( } override def pushFilters(filters: Seq[Expression]): Seq[Expression] = { - val (filtersToPush, filtersToRemain) = filters.partition( - f => DataSourceUtils.shouldPushFilter(f, supportsCollationPushDown)) + val (deterministicFilters, nonDeterminsticFilters) = filters.partition(_.deterministic) val (partitionFilters, dataFilters) = - DataSourceUtils.getPartitionFiltersAndDataFilters(partitionSchema, filtersToPush) + DataSourceUtils.getPartitionFiltersAndDataFilters(partitionSchema, deterministicFilters) this.partitionFilters = partitionFilters this.dataFilters = dataFilters val translatedFilters = mutable.ArrayBuffer.empty[sources.Filter] for (filterExpr <- dataFilters) { val translated = DataSourceStrategy.translateFilter(filterExpr, true) if (translated.nonEmpty) { - translatedFilters += translated.get + translatedFilters += translated.get.filter } } pushedDataFilters = pushDataFilters(translatedFilters.toArray) - dataFilters ++ filtersToRemain + dataFilters ++ nonDeterminsticFilters } override def pushedFilters: Array[Predicate] = pushedDataFilters.map(_.toV2) @@ -96,12 +95,6 @@ abstract class FileScanBuilder( */ protected def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = Array.empty[Filter] - /** - * Returns whether the file scan builder supports filter pushdown - * for non utf8 binary collated columns. - */ - protected def supportsCollationPushDown: Boolean = false - private def createRequiredNameSet(): Set[String] = requiredSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 34a1adcb6e09..dca35fe21c4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -58,7 +58,10 @@ object PushDownUtils { if (translated.isEmpty) { untranslatableExprs += filterExpr } else { - translatedFilters += translated.get + translatedFilters += translated.get.filter + if (!translated.get.fullyTranslated) { + untranslatableExprs += filterExpr + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 5fbe88a09e7c..a72de62f93f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -1265,7 +1265,7 @@ class FileBasedDataSourceSuite extends QueryTest df.write.format(format).save(path.getAbsolutePath) // filter and expected result - val filters = Seq( + val filterTypes = Seq( ("==", Seq(Row("aaa"), Row("AAA"))), ("!=", Seq(Row("bbb"))), ("<", Seq()), @@ -1273,28 +1273,44 @@ class FileBasedDataSourceSuite extends QueryTest (">", Seq(Row("bbb"))), (">=", Seq(Row("aaa"), Row("AAA"), Row("bbb")))) - filters.foreach { filter => - val readback = spark.read - .format(format) - .load(path.getAbsolutePath) - .where(s"c1 ${filter._1} collate('aaa', $collation)") - .where(s"str ${filter._1} struct(collate('aaa', $collation))") - .where(s"namedstr.f1.f2 ${filter._1} collate('aaa', $collation)") - .where(s"arr ${filter._1} array(collate('aaa', $collation))") - .where(s"map_keys(map1) ${filter._1} array(collate('aaa', $collation))") - .where(s"map_values(map2) ${filter._1} array(collate('aaa', $collation))") - .select("c1") - - val explain = readback.queryExecution.explainString( - ExplainMode.fromString("extended")) - assert(explain.contains("PushedFilters: []")) - checkAnswer(readback, filter._2) + filterTypes.foreach { filterType => + Seq( + s"c1 $filterType collate('aaa', $collation)", + s"str $filterType struct(collate('aaa', $collation))", + s"namedstr.f1.f2 $filterType collate('aaa', $collation)", + s"arr $filterType array(collate('aaa', $collation))", + s"map_keys(map1) $filterType array(collate('aaa', $collation))", + s"map_values(map2) $filterType array(collate('aaa', $collation))", + ).foreach { filterString => + + val readback = spark.read + .format(format) + .load(path.getAbsolutePath) + .where(filterString) + .select("c1") + + val pus = getPushedFilters(readback) + getPushedFilters(readback).foreach { filter => + assert(filter === "AlwaysTrue()" || filter.startsWith("IsNotNull")) + } + checkAnswer(readback, filterType._2) + } } } } } } } + + def getPushedFilters(df: DataFrame): Set[String] = { + val explain = df.queryExecution.explainString(ExplainMode.fromString("extended")) + assert(explain.contains("PushedFilters:")) + + // Regular expression to extract text inside the brackets + val pattern = "PushedFilters: \\[(.*?)]".r + + pattern.findFirstMatchIn(explain).get.group(1).split(", ").toSet + } } object TestingUDT { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToReadersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToReadersSuite.scala new file mode 100644 index 000000000000..2b77246a0679 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToReadersSuite.scala @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.collation + +import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.execution.ExplainMode +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +class CollatedFilterPushDownToReadersSuite extends QueryTest + with SharedSparkSession + with AdaptiveSparkPlanHelper { + + private val tblName = "tbl" + private val nonCollatedCol = "c0" + private val collatedCol = "c1" + private val collatedStructCol = "c2" + private val collatedStructNestedCol = "f1" + private val collatedStructFieldAccess = s"$collatedStructCol.$collatedStructNestedCol" + private val collatedArrayCol = "c3" + private val collatedMapCol = "c4" + + private val lcaseCollation = "'UTF8_LCASE'" + private val dataSources = Seq("parquet") + + def testV1AndV2PushDown( + filterString: String, + expectedPushedFilters: Seq[String], + expectedRowCount: Int): Unit = { + def testPushDown(dataSource: String, useV1: Boolean): Unit = { + val v1Source = if (useV1) dataSource else "" + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> v1Source) { + withTestTable(dataSource) { + val df = sql(s"SELECT * FROM $tblName WHERE $filterString") + val actualPushedFilters = getPushedFilters(df) + assert(actualPushedFilters.sorted === expectedPushedFilters.sorted) + assert(df.count() === expectedRowCount) + } + } + } + + dataSources.foreach { source => + testPushDown(source, useV1 = true) + testPushDown(source, useV1 = false) + } + } + + def withTestTable(dataSource: String)(fn: => Unit): Unit = { + withTable(tblName) { + sql(s""" + |CREATE TABLE $tblName USING $dataSource AS + |SELECT + | c as $nonCollatedCol, + | COLLATE(c, $lcaseCollation) as $collatedCol, + | named_struct('$collatedStructNestedCol', + | COLLATE(c, $lcaseCollation)) as $collatedStructCol, + | array(COLLATE(c, $lcaseCollation)) as $collatedArrayCol, + | map(COLLATE(c, $lcaseCollation), 1) as $collatedMapCol + |FROM VALUES ('aaa'), ('AAA'), ('bbb') + |as data(c) + |""".stripMargin) + + fn + } + } + + def getPushedFilters(df: DataFrame): Seq[String] = { + val explain = df.queryExecution.explainString(ExplainMode.fromString("extended")) + + // Regular expression to extract text inside the brackets + val pattern = "PushedFilters: \\[(.*?)\\]".r + + pattern.findFirstMatchIn(explain) match { + case Some(m) => m.group(1).split(", ").toSeq + case None => Seq.empty + } + } + + test("asdf") { + testV1AndV2PushDown( + filterString = s"'aaa' COLLATE UNICODE = 'bbb' COLLATE UNICODE", + expectedPushedFilters = Seq.empty, + expectedRowCount = 0) + + testV1AndV2PushDown( + filterString = s"$collatedCol = 'aaa'", + expectedPushedFilters = Seq("AlwaysTrue()", s"IsNotNull($collatedCol)"), + expectedRowCount = 2) + + testV1AndV2PushDown( + filterString = s"$collatedCol = 'aaa' OR $nonCollatedCol = 'aaa'", + expectedPushedFilters = Seq(s"Or(AlwaysTrue(),EqualTo($nonCollatedCol,aaa))"), + expectedRowCount = 2) + + testV1AndV2PushDown( + filterString = s"$collatedCol != 'aaa'", + expectedPushedFilters = Seq("AlwaysTrue()", s"IsNotNull($collatedCol)"), + expectedRowCount = 1) + + testV1AndV2PushDown( + filterString = s"NOT($collatedCol == 'aaa')", + expectedPushedFilters = Seq("AlwaysTrue()", s"IsNotNull($collatedCol)"), + expectedRowCount = 1) + + testV1AndV2PushDown( + filterString = s"$collatedStructFieldAccess = 'aaa'", + expectedPushedFilters = Seq( + "AlwaysTrue()", s"IsNotNull($collatedStructFieldAccess)"), + expectedRowCount = 2) + + testV1AndV2PushDown( + filterString = s"$collatedStructFieldAccess = 'aaa'", + expectedPushedFilters = Seq( + "AlwaysTrue()", s"IsNotNull($collatedStructFieldAccess)"), + expectedRowCount = 2) + + testV1AndV2PushDown( + filterString = s"$collatedArrayCol = array(collate('aaa', $lcaseCollation))", + expectedPushedFilters = Seq("AlwaysTrue()", s"IsNotNull($collatedArrayCol)"), + expectedRowCount = 2) + + testV1AndV2PushDown( + filterString = s"map_keys($collatedMapCol) != array(collate('aaa', $lcaseCollation))", + expectedPushedFilters = Seq("AlwaysTrue()", s"IsNotNull($collatedMapCol)"), + expectedRowCount = 1) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala index 834225baf070..2dbdd992be57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala @@ -321,7 +321,7 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { */ def testTranslateFilter(catalystFilter: Expression, result: Option[sources.Filter]): Unit = { assertResult(result) { - DataSourceStrategy.translateFilter(catalystFilter, true) + DataSourceStrategy.translateFilter(catalystFilter, true).map(_.filter) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 795e9f46a8d1..ffa4d480f103 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -2345,7 +2345,8 @@ class ParquetV2FilterSuite extends ParquetFilterSuite { case PhysicalOperation(_, filters, DataSourceV2ScanRelation(_, scan: ParquetScan, _, _, _)) => assert(filters.nonEmpty, "No filter is analyzed from the given query") - val sourceFilters = filters.flatMap(DataSourceStrategy.translateFilter(_, true)).toArray + val sourceFilters = filters.flatMap(DataSourceStrategy.translateFilter(_, true) + .map(_.filter)).toArray val pushedFilters = scan.pushedFilters assert(pushedFilters.nonEmpty, "No filter is pushed down") val schema = new SparkToParquetSchemaConverter(conf).convert(df.schema) From 45b87ef0ad2384121c0d12ceebd8917130f0c2ab Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Fri, 21 Jun 2024 18:33:58 +0200 Subject: [PATCH 02/16] method rename --- .../sql/execution/datasources/DataSourceStrategy.scala | 7 +++---- .../spark/sql/execution/datasources/DataSourceUtils.scala | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 6738ad68069b..227a0909113c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.{GeneratedColumn, ResolveDefaultColumns, V2ExpressionBuilder} import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ -import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue} +import org.apache.spark.sql.connector.expressions.{NullOrdering, SortDirection, SortValue, Expression => V2Expression, SortOrder => V2SortOrder} import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} @@ -53,8 +53,7 @@ import org.apache.spark.sql.execution.datasources.v2.PushedDownOperators import org.apache.spark.sql.execution.streaming.StreamingRelation import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.util.{PartitioningUtils => CatalystPartitioningUtils} -import org.apache.spark.sql.util.{CaseInsensitiveStringMap} +import org.apache.spark.sql.util.{CaseInsensitiveStringMap, SchemaUtils, PartitioningUtils => CatalystPartitioningUtils} import org.apache.spark.unsafe.types.UTF8String /** @@ -647,7 +646,7 @@ object DataSourceStrategy // Is null filters on attribute references can always be pushed, also for collated columns. translateAndRecordLeafNodeFilter(isNull) - case p if DataSourceUtils.hasNonUTF8BinaryCollation(p) => + case p if DataSourceUtils.referencesNonUTF8BinaryCollation(p) => // The filter cannot be pushed and we widen it to be AlwaysTrue() and set it // as partially translated so it has to get evaluated by the engine as well. // This is only valid if the result of the filter is not negated by a diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index 9f79fa6b3ae0..b9246b1bd2cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -283,7 +283,7 @@ object DataSourceUtils extends PredicateHelper { /** * Determines whether a filter references any columns with non-UTF8 binary collation. */ - def hasNonUTF8BinaryCollation(expression: Expression): Boolean = { + def referencesNonUTF8BinaryCollation(expression: Expression): Boolean = { expression.references.exists(ref => SchemaUtils.hasNonUTF8BinaryCollation(ref.dataType)) } } From 0a2dbc2cd312cb46af1abab1b68e84fbca444db1 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Sun, 23 Jun 2024 00:24:41 +0200 Subject: [PATCH 03/16] fix import error --- .../spark/sql/execution/datasources/DataSourceStrategy.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 227a0909113c..ab74b125b92a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.{GeneratedColumn, ResolveDefaultColumns, V2ExpressionBuilder} import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ -import org.apache.spark.sql.connector.expressions.{NullOrdering, SortDirection, SortValue, Expression => V2Expression, SortOrder => V2SortOrder} +import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue} import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} @@ -53,7 +53,7 @@ import org.apache.spark.sql.execution.datasources.v2.PushedDownOperators import org.apache.spark.sql.execution.streaming.StreamingRelation import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.util.{CaseInsensitiveStringMap, SchemaUtils, PartitioningUtils => CatalystPartitioningUtils} +import org.apache.spark.sql.util.{CaseInsensitiveStringMap, PartitioningUtils => CatalystPartitioningUtils} import org.apache.spark.unsafe.types.UTF8String /** From aabb1c180797f56a5824faf4e235e77939b25fa7 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Sun, 23 Jun 2024 00:31:41 +0200 Subject: [PATCH 04/16] add replace part to create table --- .../sql/collation/CollatedFilterPushDownToReadersSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToReadersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToReadersSuite.scala index 2b77246a0679..15fe6d96271b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToReadersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToReadersSuite.scala @@ -64,7 +64,7 @@ class CollatedFilterPushDownToReadersSuite extends QueryTest def withTestTable(dataSource: String)(fn: => Unit): Unit = { withTable(tblName) { sql(s""" - |CREATE TABLE $tblName USING $dataSource AS + |CREATE OR REPLACE TABLE $tblName USING $dataSource AS |SELECT | c as $nonCollatedCol, | COLLATE(c, $lcaseCollation) as $collatedCol, From 11b97c411a625091458ac98daf3c2a11797995b5 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Sun, 23 Jun 2024 03:57:51 +0200 Subject: [PATCH 05/16] fix syntax error in sql --- .../spark/sql/FileBasedDataSourceSuite.scala | 71 +------------------ ...CollatedFilterPushDownToReadersSuite.scala | 2 +- 2 files changed, 2 insertions(+), 71 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index a72de62f93f9..229677d20813 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GreaterTha import org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils.{negativeInt, positiveInt} import org.apache.spark.sql.catalyst.plans.logical.Filter import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.execution.{ExplainMode, FileSourceScanLike, SimpleMode} +import org.apache.spark.sql.execution.{FileSourceScanLike, SimpleMode} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.FilePartition import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, FileScan} @@ -1242,75 +1242,6 @@ class FileBasedDataSourceSuite extends QueryTest } } } - - test("disable filter pushdown for collated strings") { - Seq("parquet").foreach { format => - Seq(format, "").foreach { conf => - withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> conf) { - withTempPath { path => - val collation = "'UTF8_LCASE'" - val df = sql( - s"""SELECT - | COLLATE(c, $collation) as c1, - | struct(COLLATE(c, $collation)) as str, - | named_struct('f1', named_struct('f2', - | COLLATE(c, $collation), 'f3', 1)) as namedstr, - | array(COLLATE(c, $collation)) as arr, - | map(COLLATE(c, $collation), 1) as map1, - | map(1, COLLATE(c, $collation)) as map2 - |FROM VALUES ('aaa'), ('AAA'), ('bbb') - |as data(c) - |""".stripMargin) - - df.write.format(format).save(path.getAbsolutePath) - - // filter and expected result - val filterTypes = Seq( - ("==", Seq(Row("aaa"), Row("AAA"))), - ("!=", Seq(Row("bbb"))), - ("<", Seq()), - ("<=", Seq(Row("aaa"), Row("AAA"))), - (">", Seq(Row("bbb"))), - (">=", Seq(Row("aaa"), Row("AAA"), Row("bbb")))) - - filterTypes.foreach { filterType => - Seq( - s"c1 $filterType collate('aaa', $collation)", - s"str $filterType struct(collate('aaa', $collation))", - s"namedstr.f1.f2 $filterType collate('aaa', $collation)", - s"arr $filterType array(collate('aaa', $collation))", - s"map_keys(map1) $filterType array(collate('aaa', $collation))", - s"map_values(map2) $filterType array(collate('aaa', $collation))", - ).foreach { filterString => - - val readback = spark.read - .format(format) - .load(path.getAbsolutePath) - .where(filterString) - .select("c1") - - val pus = getPushedFilters(readback) - getPushedFilters(readback).foreach { filter => - assert(filter === "AlwaysTrue()" || filter.startsWith("IsNotNull")) - } - checkAnswer(readback, filterType._2) - } - } - } - } - } - } - } - - def getPushedFilters(df: DataFrame): Set[String] = { - val explain = df.queryExecution.explainString(ExplainMode.fromString("extended")) - assert(explain.contains("PushedFilters:")) - - // Regular expression to extract text inside the brackets - val pattern = "PushedFilters: \\[(.*?)]".r - - pattern.findFirstMatchIn(explain).get.group(1).split(", ").toSet - } } object TestingUDT { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToReadersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToReadersSuite.scala index 15fe6d96271b..2b77246a0679 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToReadersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToReadersSuite.scala @@ -64,7 +64,7 @@ class CollatedFilterPushDownToReadersSuite extends QueryTest def withTestTable(dataSource: String)(fn: => Unit): Unit = { withTable(tblName) { sql(s""" - |CREATE OR REPLACE TABLE $tblName USING $dataSource AS + |CREATE TABLE $tblName USING $dataSource AS |SELECT | c as $nonCollatedCol, | COLLATE(c, $lcaseCollation) as $collatedCol, From b8594088a3b1c437f084ad03d084046ba35317c5 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Mon, 24 Jun 2024 13:45:20 +0200 Subject: [PATCH 06/16] clean up tests a bit --- .../apache/spark/sql/sources/filters.scala | 23 ++-- .../datasources/DataSourceStrategy.scala | 3 +- ...CollatedFilterPushDownToReadersSuite.scala | 104 +++++++++--------- 3 files changed, 62 insertions(+), 68 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala index 2c6b2dadd81f..5db2cca4906f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -352,18 +352,6 @@ case class StringContains(attribute: String, value: String) extends Filter { Array(toV2Column(attribute), LiteralValue(UTF8String.fromString(value), StringType))) } -/** - * A. - * @param filter a. - * @param fullyTranslated a. - */ -@Evolving -case class TranslatedFilter(filter: Filter, fullyTranslated: Boolean) { - def withFilter(newFilter: Filter): TranslatedFilter = { - copy(filter = newFilter) - } -} - /** * A filter that always evaluates to `true`. * @@ -393,3 +381,14 @@ case class AlwaysFalse() extends Filter { @Evolving object AlwaysFalse extends AlwaysFalse { } + +/** + * Filter that can be translated partially. It can be pushed down but if it is not fully translated + * then the original expression needs to be evaluated as well. + */ +@Evolving +case class TranslatedFilter(filter: Filter, fullyTranslated: Boolean) { + def withFilter(newFilter: Filter): TranslatedFilter = { + copy(filter = newFilter) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index ab74b125b92a..7e6500b6a651 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -716,7 +716,8 @@ object DataSourceStrategy // Data source filters that cannot be handled by `relation`. An unhandled filter means // the data source cannot guarantee the rows returned can pass the filter. // As a result we must return it so Spark can plan an extra filter operator. - val unhandledFilters = relation.unhandledFilters(pushedFilters.toArray).toSet + val unhandledFilters = relation.unhandledFilters( + translatedMap.values.map(_.filter).toArray).toSet val unhandledPredicates = translatedMap.filter { case (p, f) => unhandledFilters.contains(f.filter) }.keys diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToReadersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToReadersSuite.scala index 2b77246a0679..e00b6c68031f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToReadersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToReadersSuite.scala @@ -44,13 +44,15 @@ class CollatedFilterPushDownToReadersSuite extends QueryTest expectedPushedFilters: Seq[String], expectedRowCount: Int): Unit = { def testPushDown(dataSource: String, useV1: Boolean): Unit = { - val v1Source = if (useV1) dataSource else "" - withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> v1Source) { - withTestTable(dataSource) { - val df = sql(s"SELECT * FROM $tblName WHERE $filterString") - val actualPushedFilters = getPushedFilters(df) - assert(actualPushedFilters.sorted === expectedPushedFilters.sorted) - assert(df.count() === expectedRowCount) + test(s"collation push down filter: $filterString, source: $dataSource, isV1: $useV1") { + val v1Source = if (useV1) dataSource else "" + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> v1Source) { + withTestTable(dataSource) { + val df = sql(s"SELECT * FROM $tblName WHERE $filterString") + val actualPushedFilters = getPushedFilters(df) + assert(actualPushedFilters.sorted === expectedPushedFilters.sorted) + assert(df.count() === expectedRowCount) + } } } } @@ -92,52 +94,44 @@ class CollatedFilterPushDownToReadersSuite extends QueryTest } } - test("asdf") { - testV1AndV2PushDown( - filterString = s"'aaa' COLLATE UNICODE = 'bbb' COLLATE UNICODE", - expectedPushedFilters = Seq.empty, - expectedRowCount = 0) - - testV1AndV2PushDown( - filterString = s"$collatedCol = 'aaa'", - expectedPushedFilters = Seq("AlwaysTrue()", s"IsNotNull($collatedCol)"), - expectedRowCount = 2) - - testV1AndV2PushDown( - filterString = s"$collatedCol = 'aaa' OR $nonCollatedCol = 'aaa'", - expectedPushedFilters = Seq(s"Or(AlwaysTrue(),EqualTo($nonCollatedCol,aaa))"), - expectedRowCount = 2) - - testV1AndV2PushDown( - filterString = s"$collatedCol != 'aaa'", - expectedPushedFilters = Seq("AlwaysTrue()", s"IsNotNull($collatedCol)"), - expectedRowCount = 1) - - testV1AndV2PushDown( - filterString = s"NOT($collatedCol == 'aaa')", - expectedPushedFilters = Seq("AlwaysTrue()", s"IsNotNull($collatedCol)"), - expectedRowCount = 1) - - testV1AndV2PushDown( - filterString = s"$collatedStructFieldAccess = 'aaa'", - expectedPushedFilters = Seq( - "AlwaysTrue()", s"IsNotNull($collatedStructFieldAccess)"), - expectedRowCount = 2) - - testV1AndV2PushDown( - filterString = s"$collatedStructFieldAccess = 'aaa'", - expectedPushedFilters = Seq( - "AlwaysTrue()", s"IsNotNull($collatedStructFieldAccess)"), - expectedRowCount = 2) - - testV1AndV2PushDown( - filterString = s"$collatedArrayCol = array(collate('aaa', $lcaseCollation))", - expectedPushedFilters = Seq("AlwaysTrue()", s"IsNotNull($collatedArrayCol)"), - expectedRowCount = 2) - - testV1AndV2PushDown( - filterString = s"map_keys($collatedMapCol) != array(collate('aaa', $lcaseCollation))", - expectedPushedFilters = Seq("AlwaysTrue()", s"IsNotNull($collatedMapCol)"), - expectedRowCount = 1) - } + testV1AndV2PushDown( + filterString = s"'aaa' COLLATE UNICODE = 'bbb' COLLATE UNICODE", + expectedPushedFilters = Seq.empty, + expectedRowCount = 0) + + testV1AndV2PushDown( + filterString = s"$collatedCol = 'aaa'", + expectedPushedFilters = Seq("AlwaysTrue()", s"IsNotNull($collatedCol)"), + expectedRowCount = 2) + + testV1AndV2PushDown( + filterString = s"$collatedCol = 'aaa' OR $nonCollatedCol = 'aaa'", + expectedPushedFilters = Seq(s"Or(AlwaysTrue(),EqualTo($nonCollatedCol,aaa))"), + expectedRowCount = 2) + + testV1AndV2PushDown( + filterString = s"$collatedCol != 'aaa'", + expectedPushedFilters = Seq("AlwaysTrue()", s"IsNotNull($collatedCol)"), + expectedRowCount = 1) + + testV1AndV2PushDown( + filterString = s"NOT($collatedCol == 'aaa')", + expectedPushedFilters = Seq("AlwaysTrue()", s"IsNotNull($collatedCol)"), + expectedRowCount = 1) + + testV1AndV2PushDown( + filterString = s"$collatedStructFieldAccess = 'aaa'", + expectedPushedFilters = Seq( + "AlwaysTrue()", s"IsNotNull($collatedStructFieldAccess)"), + expectedRowCount = 2) + + testV1AndV2PushDown( + filterString = s"$collatedArrayCol = array(collate('aaa', $lcaseCollation))", + expectedPushedFilters = Seq("AlwaysTrue()", s"IsNotNull($collatedArrayCol)"), + expectedRowCount = 2) + + testV1AndV2PushDown( + filterString = s"map_keys($collatedMapCol) != array(collate('aaa', $lcaseCollation))", + expectedPushedFilters = Seq("AlwaysTrue()", s"IsNotNull($collatedMapCol)"), + expectedRowCount = 1) } From 410a26cf9b15327890120f9a2e3c47ac569b9b32 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Tue, 25 Jun 2024 15:37:25 +0200 Subject: [PATCH 07/16] remove translated filter --- .../apache/spark/sql/sources/filters.scala | 11 ---- .../sql/execution/DataSourceScanExec.scala | 2 +- .../datasources/DataSourceStrategy.scala | 62 ++++++++----------- .../datasources/FileSourceStrategy.scala | 1 - .../datasources/v2/FileScanBuilder.scala | 2 +- .../datasources/v2/PushDownUtils.scala | 5 +- .../datasources/DataSourceStrategySuite.scala | 2 +- .../parquet/ParquetFilterSuite.scala | 3 +- 8 files changed, 30 insertions(+), 58 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala index 5db2cca4906f..a52bca106605 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -381,14 +381,3 @@ case class AlwaysFalse() extends Filter { @Evolving object AlwaysFalse extends AlwaysFalse { } - -/** - * Filter that can be translated partially. It can be pushed down but if it is not fully translated - * then the original expression needs to be evaluated as well. - */ -@Evolving -case class TranslatedFilter(filter: Filter, fullyTranslated: Boolean) { - def withFilter(newFilter: Filter): TranslatedFilter = { - copy(filter = newFilter) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 220742338b05..2ebbb9664f67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -413,7 +413,7 @@ trait FileSourceScanLike extends DataSourceScanExec { scalarSubqueryReplaced.filterNot(_.references.exists { case FileSourceConstantMetadataAttribute(_) => true case _ => false - }).flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown).map(_.filter)) + }).flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) } // This field may execute subquery expressions and should not be accessed during planning. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 7e6500b6a651..7cda347ce581 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -53,7 +53,8 @@ import org.apache.spark.sql.execution.datasources.v2.PushedDownOperators import org.apache.spark.sql.execution.streaming.StreamingRelation import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.util.{CaseInsensitiveStringMap, PartitioningUtils => CatalystPartitioningUtils} +import org.apache.spark.sql.util.{PartitioningUtils => CatalystPartitioningUtils} +import org.apache.spark.sql.util.{CaseInsensitiveStringMap, SchemaUtils} import org.apache.spark.unsafe.types.UTF8String /** @@ -572,11 +573,10 @@ object DataSourceStrategy /** * Tries to translate a Catalyst [[Expression]] into data source [[Filter]]. * - * @return a `Some[TranslatedFilter]` if the input [[Expression]] is at least partially - * convertible, otherwise a `None`. + * @return a `Some[Filter]` if the input [[Expression]] is convertible, otherwise a `None`. */ protected[sql] def translateFilter( - predicate: Expression, supportNestedPredicatePushdown: Boolean): Option[TranslatedFilter] = { + predicate: Expression, supportNestedPredicatePushdown: Boolean): Option[Filter] = { translateFilterWithMapping(predicate, None, supportNestedPredicatePushdown) } @@ -588,24 +588,19 @@ object DataSourceStrategy * translated [[Filter]]. The map is used for rebuilding * [[Expression]] from [[Filter]]. * @param nestedPredicatePushdownEnabled Whether nested predicate pushdown is enabled. - * @return a `Some[TranslatedFilter]` if the input [[Expression]] is at least partially - * convertible, otherwise a `None`. + * @return a `Some[Filter]` if the input [[Expression]] is convertible, otherwise a `None`. */ protected[sql] def translateFilterWithMapping( predicate: Expression, - translatedFilterToExpr: Option[mutable.HashMap[Filter, Expression]], + translatedFilterToExpr: Option[mutable.HashMap[sources.Filter, Expression]], nestedPredicatePushdownEnabled: Boolean) - : Option[TranslatedFilter] = { + : Option[Filter] = { - def translateAndRecordLeafNodeFilter(filter: Expression, canBeFullyTranslated: Boolean = true) - : Option[TranslatedFilter] = { + def translateAndRecordLeafNodeFilter(filter: Expression): Option[Filter] = { val translatedFilter = - translateLeafNodeFilter(filter, PushableColumn(nestedPredicatePushdownEnabled)) match { - case Some(f) => Some(TranslatedFilter(f, canBeFullyTranslated)) - case None => None - } + translateLeafNodeFilter(filter, PushableColumn(nestedPredicatePushdownEnabled)) if (translatedFilter.isDefined && translatedFilterToExpr.isDefined) { - translatedFilterToExpr.get(translatedFilter.get.filter) = predicate + translatedFilterToExpr.get(translatedFilter.get) = predicate } translatedFilter } @@ -626,8 +621,7 @@ object DataSourceStrategy left, translatedFilterToExpr, nestedPredicatePushdownEnabled) rightFilter <- translateFilterWithMapping( right, translatedFilterToExpr, nestedPredicatePushdownEnabled) - } yield TranslatedFilter(sources.And(leftFilter.filter, rightFilter.filter), - leftFilter.fullyTranslated && rightFilter.fullyTranslated) + } yield sources.And(leftFilter, rightFilter) case expressions.Or(left, right) => for { @@ -635,28 +629,24 @@ object DataSourceStrategy left, translatedFilterToExpr, nestedPredicatePushdownEnabled) rightFilter <- translateFilterWithMapping( right, translatedFilterToExpr, nestedPredicatePushdownEnabled) - } yield TranslatedFilter(sources.Or(leftFilter.filter, rightFilter.filter), - leftFilter.fullyTranslated && rightFilter.fullyTranslated) + } yield sources.Or(leftFilter, rightFilter) - case notNull @ expressions.IsNotNull(_: AttributeReference | _: GetStructField) => + case notNull @ expressions.IsNotNull(_: AttributeReference) => // Not null filters on attribute references can always be pushed, also for collated columns. translateAndRecordLeafNodeFilter(notNull) - case isNull @ expressions.IsNull(_: AttributeReference | _: GetStructField) => + case isNull @ expressions.IsNull(_: AttributeReference) => // Is null filters on attribute references can always be pushed, also for collated columns. translateAndRecordLeafNodeFilter(isNull) - case p if DataSourceUtils.referencesNonUTF8BinaryCollation(p) => - // The filter cannot be pushed and we widen it to be AlwaysTrue() and set it - // as partially translated so it has to get evaluated by the engine as well. - // This is only valid if the result of the filter is not negated by a - // Not expression it is wrapped in. - translateAndRecordLeafNodeFilter(Literal.TrueLiteral, canBeFullyTranslated = false) + case p if p.references.exists(ref => SchemaUtils.hasNonUTF8BinaryCollation(ref.dataType)) => + // The filter cannot be pushed and we widen it to be AlwaysTrue(). This is only valid if + // the result of the filter is not negated by a Not expression it is wrapped in. + translateAndRecordLeafNodeFilter(Literal.TrueLiteral) case expressions.Not(child) => translateFilterWithMapping(child, translatedFilterToExpr, nestedPredicatePushdownEnabled) - .map(translatedFilter => - translatedFilter.withFilter(sources.Not(translatedFilter.filter))) + .map(sources.Not) case other => translateAndRecordLeafNodeFilter(other) @@ -703,23 +693,21 @@ object DataSourceStrategy // If a predicate is not in this map, it means it cannot be pushed down. val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation) // SPARK-41636: we keep the order of the predicates to avoid CodeGenerator cache misses - val translatedMap: Map[Expression, TranslatedFilter] = ListMap(predicates.flatMap { p => + val translatedMap: Map[Expression, Filter] = ListMap(predicates.flatMap { p => translateFilter(p, supportNestedPredicatePushdown).map(f => p -> f) }: _*) - val pushedFilters: Seq[Filter] = translatedMap.values.map(_.filter).toSeq + val pushedFilters: Seq[Filter] = translatedMap.values.toSeq - // Catalyst predicate expressions that cannot be fully converted to data source filters. - val nonconvertiblePredicates = predicates.filter(predicate => - !translatedMap.contains(predicate) || !translatedMap(predicate).fullyTranslated) + // Catalyst predicate expressions that cannot be converted to data source filters. + val nonconvertiblePredicates = predicates.filterNot(translatedMap.contains) // Data source filters that cannot be handled by `relation`. An unhandled filter means // the data source cannot guarantee the rows returned can pass the filter. // As a result we must return it so Spark can plan an extra filter operator. - val unhandledFilters = relation.unhandledFilters( - translatedMap.values.map(_.filter).toArray).toSet + val unhandledFilters = relation.unhandledFilters(translatedMap.values.toArray).toSet val unhandledPredicates = translatedMap.filter { case (p, f) => - unhandledFilters.contains(f.filter) + unhandledFilters.contains(f) }.keys val handledFilters = pushedFilters.toSet -- unhandledFilters diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 921872620a63..27019ab047ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -203,7 +203,6 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging { DataSourceUtils.supportNestedPredicatePushdown(fsRelation) val pushedFilters = dataFilters .flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) - .map(_.filter) logInfo(log"Pushed Filters: ${MDC(PUSHED_FILTERS, pushedFilters.mkString(","))}") // Predicates with both partition keys and attributes need to be evaluated after the scan. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala index 342ec79c2077..447a36fe622c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala @@ -79,7 +79,7 @@ abstract class FileScanBuilder( for (filterExpr <- dataFilters) { val translated = DataSourceStrategy.translateFilter(filterExpr, true) if (translated.nonEmpty) { - translatedFilters += translated.get.filter + translatedFilters += translated.get } } pushedDataFilters = pushDataFilters(translatedFilters.toArray) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index dca35fe21c4e..34a1adcb6e09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -58,10 +58,7 @@ object PushDownUtils { if (translated.isEmpty) { untranslatableExprs += filterExpr } else { - translatedFilters += translated.get.filter - if (!translated.get.fullyTranslated) { - untranslatableExprs += filterExpr - } + translatedFilters += translated.get } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala index 2dbdd992be57..834225baf070 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala @@ -321,7 +321,7 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { */ def testTranslateFilter(catalystFilter: Expression, result: Option[sources.Filter]): Unit = { assertResult(result) { - DataSourceStrategy.translateFilter(catalystFilter, true).map(_.filter) + DataSourceStrategy.translateFilter(catalystFilter, true) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index ffa4d480f103..795e9f46a8d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -2345,8 +2345,7 @@ class ParquetV2FilterSuite extends ParquetFilterSuite { case PhysicalOperation(_, filters, DataSourceV2ScanRelation(_, scan: ParquetScan, _, _, _)) => assert(filters.nonEmpty, "No filter is analyzed from the given query") - val sourceFilters = filters.flatMap(DataSourceStrategy.translateFilter(_, true) - .map(_.filter)).toArray + val sourceFilters = filters.flatMap(DataSourceStrategy.translateFilter(_, true)).toArray val pushedFilters = scan.pushedFilters assert(pushedFilters.nonEmpty, "No filter is pushed down") val schema = new SparkToParquetSchemaConverter(conf).convert(df.schema) From b958fc0c78993886785c90690868efc6044e8810 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 26 Jun 2024 16:52:13 +0200 Subject: [PATCH 08/16] somewhat working with new ds1 filters --- .../spark/sql/catalyst/StructFilters.scala | 2 + .../apache/spark/sql/sources/filters.scala | 62 +++++++- .../sql/execution/DataSourceScanExec.scala | 4 +- .../datasources/DataSourceStrategy.scala | 137 ++++++++++-------- .../datasources/DataSourceUtils.scala | 57 +++++++- ...CollatedFilterPushDownToReadersSuite.scala | 26 ++-- 6 files changed, 209 insertions(+), 79 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/StructFilters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/StructFilters.scala index 4ac62b987b15..1b2013d87eed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/StructFilters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/StructFilters.scala @@ -156,6 +156,8 @@ object StructFilters { Some(Literal(true, BooleanType)) case sources.AlwaysFalse() => Some(Literal(false, BooleanType)) + case _: sources.CollatedFilter => + None } translate(filter) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala index a52bca106605..ae188002733a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.parseColumnPath import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue, NamedReference} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse => V2AlwaysFalse, AlwaysTrue => V2AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate} -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.{DataType, StringType} import org.apache.spark.unsafe.types.UTF8String //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -381,3 +381,63 @@ case class AlwaysFalse() extends Filter { @Evolving object AlwaysFalse extends AlwaysFalse { } + +// COLLATION AWARE FILTERS + +abstract class CollatedFilter() extends Filter { + def correspondingFilter: Filter + def dataType: DataType + + override def references: Array[String] = correspondingFilter.references + override def toV2: Predicate = correspondingFilter.toV2 +} + +case class CollatedEqualTo(attribute: String, value: Any, dataType: DataType) + extends CollatedFilter { + override def correspondingFilter: Filter = EqualTo(attribute, value) +} + +case class CollatedEqualNullSafe(attribute: String, value: Any, dataType: DataType) + extends CollatedFilter { + override def correspondingFilter: Filter = EqualNullSafe(attribute, value) +} + +case class CollatedGreaterThan(attribute: String, value: Any, dataType: DataType) + extends CollatedFilter { + override def correspondingFilter: Filter = GreaterThan(attribute, value) +} + +case class CollatedGreaterThanOrEqual(attribute: String, value: Any, dataType: DataType) + extends CollatedFilter { + override def correspondingFilter: Filter = GreaterThanOrEqual(attribute, value) +} + +case class CollatedLessThan(attribute: String, value: Any, dataType: DataType) + extends CollatedFilter { + override def correspondingFilter: Filter = LessThan(attribute, value) +} + +case class CollatedLessThanOrEqual(attribute: String, value: Any, dataType: DataType) + extends CollatedFilter { + override def correspondingFilter: Filter = LessThanOrEqual(attribute, value) +} + +case class CollatedIn(attribute: String, values: Array[Any], dataType: DataType) + extends CollatedFilter { + override def correspondingFilter: Filter = In(attribute, values) +} + +case class CollatedStringStartsWith(attribute: String, value: String, dataType: DataType) + extends CollatedFilter { + override def correspondingFilter: Filter = StringStartsWith(attribute, value) +} + +case class CollatedStringEndsWith(attribute: String, value: String, dataType: DataType) + extends CollatedFilter { + override def correspondingFilter: Filter = StringEndsWith(attribute, value) +} + +case class CollatedStringContains(attribute: String, value: String, dataType: DataType) + extends CollatedFilter { + override def correspondingFilter: Filter = StringContains(attribute, value) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 2ebbb9664f67..33a2c5601c7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.execution.datasources.v2.PushedDownOperators import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.vectorized.ConstantColumnVector import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.{BaseRelation, Filter} +import org.apache.spark.sql.sources.{AlwaysTrue, BaseRelation, Filter} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.ArrayImplicits._ @@ -414,6 +414,8 @@ trait FileSourceScanLike extends DataSourceScanExec { case FileSourceConstantMetadataAttribute(_) => true case _ => false }).flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) + .map(DataSourceUtils.removeColl) + .filterNot(filter => filter == AlwaysTrue()) } // This field may execute subquery expressions and should not be accessed during planning. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 7cda347ce581..12376146ea4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -500,42 +500,76 @@ object DataSourceStrategy } } + /** + * Creates a collation aware filter if the input data type is string with non-default collation + */ + private def collationAwareFilter(filter: sources.Filter, dataType: DataType): Option[Filter] = { + if (!SchemaUtils.hasNonUTF8BinaryCollation(dataType)) { + return Some(filter) + } + + filter match { + case sources.EqualTo(attribute, value) => + Some(CollatedEqualTo(attribute, value, dataType)) + case sources.EqualNullSafe(attribute, value) => + Some(CollatedEqualNullSafe(attribute, value, dataType)) + case sources.GreaterThan(attribute, value) => + Some(CollatedGreaterThan(attribute, value, dataType)) + case sources.GreaterThanOrEqual(attribute, value) => + Some(CollatedGreaterThanOrEqual(attribute, value, dataType)) + case sources.LessThan(attribute, value) => + Some(CollatedLessThan(attribute, value, dataType)) + case sources.LessThanOrEqual(attribute, value) => + Some(CollatedLessThanOrEqual(attribute, value, dataType)) + case sources.In(attribute, values) => + Some(CollatedIn(attribute, values, dataType)) + case sources.StringStartsWith(attribute, value) => + Some(CollatedStringStartsWith(attribute, value, dataType)) + case sources.StringEndsWith(attribute, value) => + Some(CollatedStringEndsWith(attribute, value, dataType)) + case sources.StringContains(attribute, value) => + Some(CollatedStringContains(attribute, value, dataType)) + case other => + Some(other) + } + } + private def translateLeafNodeFilter( predicate: Expression, pushableColumn: PushableColumnBase): Option[Filter] = predicate match { - case expressions.EqualTo(pushableColumn(name), Literal(v, t)) => - Some(sources.EqualTo(name, convertToScala(v, t))) - case expressions.EqualTo(Literal(v, t), pushableColumn(name)) => - Some(sources.EqualTo(name, convertToScala(v, t))) - - case expressions.EqualNullSafe(pushableColumn(name), Literal(v, t)) => - Some(sources.EqualNullSafe(name, convertToScala(v, t))) - case expressions.EqualNullSafe(Literal(v, t), pushableColumn(name)) => - Some(sources.EqualNullSafe(name, convertToScala(v, t))) - - case expressions.GreaterThan(pushableColumn(name), Literal(v, t)) => - Some(sources.GreaterThan(name, convertToScala(v, t))) - case expressions.GreaterThan(Literal(v, t), pushableColumn(name)) => - Some(sources.LessThan(name, convertToScala(v, t))) - - case expressions.LessThan(pushableColumn(name), Literal(v, t)) => - Some(sources.LessThan(name, convertToScala(v, t))) - case expressions.LessThan(Literal(v, t), pushableColumn(name)) => - Some(sources.GreaterThan(name, convertToScala(v, t))) - - case expressions.GreaterThanOrEqual(pushableColumn(name), Literal(v, t)) => - Some(sources.GreaterThanOrEqual(name, convertToScala(v, t))) - case expressions.GreaterThanOrEqual(Literal(v, t), pushableColumn(name)) => - Some(sources.LessThanOrEqual(name, convertToScala(v, t))) - - case expressions.LessThanOrEqual(pushableColumn(name), Literal(v, t)) => - Some(sources.LessThanOrEqual(name, convertToScala(v, t))) - case expressions.LessThanOrEqual(Literal(v, t), pushableColumn(name)) => - Some(sources.GreaterThanOrEqual(name, convertToScala(v, t))) + case expressions.EqualTo(e @ pushableColumn(name), Literal(v, t)) => + collationAwareFilter(sources.EqualTo(name, convertToScala(v, t)), e.dataType) + case expressions.EqualTo(Literal(v, t), e @ pushableColumn(name)) => + collationAwareFilter(sources.EqualTo(name, convertToScala(v, t)), e.dataType) + + case expressions.EqualNullSafe(e @ pushableColumn(name), Literal(v, t)) => + collationAwareFilter(sources.EqualNullSafe(name, convertToScala(v, t)), e.dataType) + case expressions.EqualNullSafe(Literal(v, t), e @ pushableColumn(name)) => + collationAwareFilter(sources.EqualNullSafe(name, convertToScala(v, t)), e.dataType) + + case expressions.GreaterThan(e @ pushableColumn(name), Literal(v, t)) => + collationAwareFilter(sources.GreaterThan(name, convertToScala(v, t)), e.dataType) + case expressions.GreaterThan(Literal(v, t), e @ pushableColumn(name)) => + collationAwareFilter(sources.LessThan(name, convertToScala(v, t)), e.dataType) + + case expressions.LessThan(e @ pushableColumn(name), Literal(v, t)) => + collationAwareFilter(sources.LessThan(name, convertToScala(v, t)), e.dataType) + case expressions.LessThan(Literal(v, t), e @ pushableColumn(name)) => + collationAwareFilter(sources.GreaterThan(name, convertToScala(v, t)), e.dataType) + + case expressions.GreaterThanOrEqual(e @ pushableColumn(name), Literal(v, t)) => + collationAwareFilter(sources.GreaterThanOrEqual(name, convertToScala(v, t)), e.dataType) + case expressions.GreaterThanOrEqual(Literal(v, t), e @ pushableColumn(name)) => + collationAwareFilter(sources.LessThanOrEqual(name, convertToScala(v, t)), e.dataType) + + case expressions.LessThanOrEqual(e @ pushableColumn(name), Literal(v, t)) => + collationAwareFilter(sources.LessThanOrEqual(name, convertToScala(v, t)), e.dataType) + case expressions.LessThanOrEqual(Literal(v, t), e @ pushableColumn(name)) => + collationAwareFilter(sources.GreaterThanOrEqual(name, convertToScala(v, t)), e.dataType) case expressions.InSet(e @ pushableColumn(name), set) => val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType) - Some(sources.In(name, set.toArray.map(toScala))) + collationAwareFilter(sources.In(name, set.toArray.map(toScala)), e.dataType) // Because we only convert In to InSet in Optimizer when there are more than certain // items. So it is possible we still get an In expression here that needs to be pushed @@ -543,20 +577,20 @@ object DataSourceStrategy case expressions.In(e @ pushableColumn(name), list) if list.forall(_.isInstanceOf[Literal]) => val hSet = list.map(_.eval(EmptyRow)) val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType) - Some(sources.In(name, hSet.toArray.map(toScala))) + collationAwareFilter(sources.In(name, hSet.toArray.map(toScala)), e.dataType) case expressions.IsNull(pushableColumn(name)) => Some(sources.IsNull(name)) case expressions.IsNotNull(pushableColumn(name)) => Some(sources.IsNotNull(name)) - case expressions.StartsWith(pushableColumn(name), Literal(v: UTF8String, StringType)) => - Some(sources.StringStartsWith(name, v.toString)) + case expressions.StartsWith(e @ pushableColumn(name), Literal(v: UTF8String, StringType)) => + collationAwareFilter(sources.StringStartsWith(name, v.toString), e.dataType) - case expressions.EndsWith(pushableColumn(name), Literal(v: UTF8String, StringType)) => - Some(sources.StringEndsWith(name, v.toString)) + case expressions.EndsWith(e @ pushableColumn(name), Literal(v: UTF8String, StringType)) => + collationAwareFilter(sources.StringEndsWith(name, v.toString), e.dataType) - case expressions.Contains(pushableColumn(name), Literal(v: UTF8String, StringType)) => - Some(sources.StringContains(name, v.toString)) + case expressions.Contains(e @ pushableColumn(name), Literal(v: UTF8String, StringType)) => + collationAwareFilter(sources.StringContains(name, v.toString), e.dataType) case expressions.Literal(true, BooleanType) => Some(sources.AlwaysTrue) @@ -595,16 +629,6 @@ object DataSourceStrategy translatedFilterToExpr: Option[mutable.HashMap[sources.Filter, Expression]], nestedPredicatePushdownEnabled: Boolean) : Option[Filter] = { - - def translateAndRecordLeafNodeFilter(filter: Expression): Option[Filter] = { - val translatedFilter = - translateLeafNodeFilter(filter, PushableColumn(nestedPredicatePushdownEnabled)) - if (translatedFilter.isDefined && translatedFilterToExpr.isDefined) { - translatedFilterToExpr.get(translatedFilter.get) = predicate - } - translatedFilter - } - predicate match { case expressions.And(left, right) => // See SPARK-12218 for detailed discussion @@ -631,25 +655,16 @@ object DataSourceStrategy right, translatedFilterToExpr, nestedPredicatePushdownEnabled) } yield sources.Or(leftFilter, rightFilter) - case notNull @ expressions.IsNotNull(_: AttributeReference) => - // Not null filters on attribute references can always be pushed, also for collated columns. - translateAndRecordLeafNodeFilter(notNull) - - case isNull @ expressions.IsNull(_: AttributeReference) => - // Is null filters on attribute references can always be pushed, also for collated columns. - translateAndRecordLeafNodeFilter(isNull) - - case p if p.references.exists(ref => SchemaUtils.hasNonUTF8BinaryCollation(ref.dataType)) => - // The filter cannot be pushed and we widen it to be AlwaysTrue(). This is only valid if - // the result of the filter is not negated by a Not expression it is wrapped in. - translateAndRecordLeafNodeFilter(Literal.TrueLiteral) - case expressions.Not(child) => translateFilterWithMapping(child, translatedFilterToExpr, nestedPredicatePushdownEnabled) .map(sources.Not) case other => - translateAndRecordLeafNodeFilter(other) + val filter = translateLeafNodeFilter(other, PushableColumn(nestedPredicatePushdownEnabled)) + if (filter.isDefined && translatedFilterToExpr.isDefined) { + translatedFilterToExpr.get(filter.get) = predicate + } + filter } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index b9246b1bd2cd..5c392d7a684b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -26,7 +26,7 @@ import org.json4s.{Formats, NoTypeHints} import org.json4s.jackson.Serialization import org.apache.spark.{SparkException, SparkUpgradeException} -import org.apache.spark.sql.{SPARK_LEGACY_DATETIME_METADATA_KEY, SPARK_LEGACY_INT96_METADATA_KEY, SPARK_TIMEZONE_METADATA_KEY, SPARK_VERSION_METADATA_KEY} +import org.apache.spark.sql.{sources, SPARK_LEGACY_DATETIME_METADATA_KEY, SPARK_LEGACY_INT96_METADATA_KEY, SPARK_TIMEZONE_METADATA_KEY, SPARK_VERSION_METADATA_KEY} import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogUtils} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExpressionSet, PredicateHelper} import org.apache.spark.sql.catalyst.util.RebaseDateTime @@ -36,7 +36,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.util.{CaseInsensitiveStringMap, SchemaUtils} +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.Utils @@ -280,10 +280,53 @@ object DataSourceUtils extends PredicateHelper { (ExpressionSet(partitionFilters ++ extraPartitionFilter).toSeq, dataFilters) } - /** - * Determines whether a filter references any columns with non-UTF8 binary collation. - */ - def referencesNonUTF8BinaryCollation(expression: Expression): Boolean = { - expression.references.exists(ref => SchemaUtils.hasNonUTF8BinaryCollation(ref.dataType)) + def containsFiltersWithCollation(filter: sources.Filter): Boolean = { + filter match { + case sources.And(left, right) => + containsFiltersWithCollation(left) || containsFiltersWithCollation(right) + case sources.Or(left, right) => + containsFiltersWithCollation(left) || containsFiltersWithCollation(right) + case sources.Not(child) => + containsFiltersWithCollation(child) + case _: sources.CollatedFilter => true + case _ => false + } + } + + def removeColl(filter: sources.Filter): sources.Filter = { + filter match { + case sources.And(left, right) => + val newLeft = removeColl(left) + val newRight = removeColl(right) + if (newLeft == sources.AlwaysTrue()) { + newRight + } else if (newRight == sources.AlwaysTrue()) { + newLeft + } else { + sources.And(newLeft, newRight) + } + + case sources.Or(left, right) => + val newLeft = removeColl(left) + if (newLeft == sources.AlwaysTrue()) { + return sources.AlwaysTrue() + } + val newRight = removeColl(right) + if (newRight == sources.AlwaysTrue()) { + sources.AlwaysTrue() + } else { + sources.Or(newLeft, newRight) + } + + case _: sources.IsNull | _: sources.IsNotNull => + filter + + case other => + if (containsFiltersWithCollation(other)) { + sources.AlwaysTrue() + } else { + other + } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToReadersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToReadersSuite.scala index e00b6c68031f..6388904eaae9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToReadersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToReadersSuite.scala @@ -89,7 +89,10 @@ class CollatedFilterPushDownToReadersSuite extends QueryTest val pattern = "PushedFilters: \\[(.*?)\\]".r pattern.findFirstMatchIn(explain) match { - case Some(m) => m.group(1).split(", ").toSeq + case Some(m) => + m.group(1) + .split(", ") + .toSeq.filterNot(_ == "") case None => Seq.empty } } @@ -101,37 +104,42 @@ class CollatedFilterPushDownToReadersSuite extends QueryTest testV1AndV2PushDown( filterString = s"$collatedCol = 'aaa'", - expectedPushedFilters = Seq("AlwaysTrue()", s"IsNotNull($collatedCol)"), + expectedPushedFilters = Seq(s"IsNotNull($collatedCol)"), expectedRowCount = 2) + testV1AndV2PushDown( + filterString = s"$collatedCol = 'aaa' AND $nonCollatedCol = 'aaa'", + expectedPushedFilters = Seq( + s"IsNotNull($collatedCol)", s"IsNotNull($nonCollatedCol)", s"EqualTo($nonCollatedCol,aaa)"), + expectedRowCount = 1) + testV1AndV2PushDown( filterString = s"$collatedCol = 'aaa' OR $nonCollatedCol = 'aaa'", - expectedPushedFilters = Seq(s"Or(AlwaysTrue(),EqualTo($nonCollatedCol,aaa))"), + expectedPushedFilters = Seq.empty, expectedRowCount = 2) testV1AndV2PushDown( filterString = s"$collatedCol != 'aaa'", - expectedPushedFilters = Seq("AlwaysTrue()", s"IsNotNull($collatedCol)"), + expectedPushedFilters = Seq(s"IsNotNull($collatedCol)"), expectedRowCount = 1) testV1AndV2PushDown( filterString = s"NOT($collatedCol == 'aaa')", - expectedPushedFilters = Seq("AlwaysTrue()", s"IsNotNull($collatedCol)"), + expectedPushedFilters = Seq(s"IsNotNull($collatedCol)"), expectedRowCount = 1) testV1AndV2PushDown( filterString = s"$collatedStructFieldAccess = 'aaa'", - expectedPushedFilters = Seq( - "AlwaysTrue()", s"IsNotNull($collatedStructFieldAccess)"), + expectedPushedFilters = Seq(s"IsNotNull($collatedStructFieldAccess)"), expectedRowCount = 2) testV1AndV2PushDown( filterString = s"$collatedArrayCol = array(collate('aaa', $lcaseCollation))", - expectedPushedFilters = Seq("AlwaysTrue()", s"IsNotNull($collatedArrayCol)"), + expectedPushedFilters = Seq(s"IsNotNull($collatedArrayCol)"), expectedRowCount = 2) testV1AndV2PushDown( filterString = s"map_keys($collatedMapCol) != array(collate('aaa', $lcaseCollation))", - expectedPushedFilters = Seq("AlwaysTrue()", s"IsNotNull($collatedMapCol)"), + expectedPushedFilters = Seq(s"IsNotNull($collatedMapCol)"), expectedRowCount = 1) } From f3aeb0b1681dcb4d3efb91376b84195185064676 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Thu, 27 Jun 2024 13:57:38 +0200 Subject: [PATCH 09/16] clean up tests --- ...CollatedFilterPushDownToParquetSuite.scala | 237 ++++++++++++++++++ ...CollatedFilterPushDownToReadersSuite.scala | 145 ----------- 2 files changed, 237 insertions(+), 145 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToParquetSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToReadersSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToParquetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToParquetSuite.scala new file mode 100644 index 000000000000..ff7a33726cec --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToParquetSuite.scala @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.collation + +import org.apache.parquet.schema.MessageType + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation +import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan +import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} +import org.apache.spark.sql.sources.{EqualTo, Filter, IsNotNull} +import org.apache.spark.sql.test.SharedSparkSession + +abstract class CollatedFilterPushDownToParquetSuite extends QueryTest + with SharedSparkSession + with AdaptiveSparkPlanHelper { + + val dataSource = "parquet" + val nonCollatedCol = "c0" + val collatedCol = "c1" + val collatedStructCol = "c2" + val collatedStructNestedCol = "f1" + val collatedStructFieldAccess = s"$collatedStructCol.$collatedStructNestedCol" + val collatedArrayCol = "c3" + val collatedMapCol = "c4" + + val lcaseCollation = "'UTF8_LCASE'" + + def getPushedDownFilters(query: DataFrame): Seq[Filter] + + protected def createParquetFilters(schema: MessageType): ParquetFilters = + new ParquetFilters(schema, conf.parquetFilterPushDownDate, conf.parquetFilterPushDownTimestamp, + conf.parquetFilterPushDownDecimal, conf.parquetFilterPushDownStringPredicate, + conf.parquetFilterPushDownInFilterThreshold, + conf.caseSensitiveAnalysis, + RebaseSpec(LegacyBehaviorPolicy.CORRECTED)) + + def testPushDown( + filterString: String, + expectedPushedFilters: Seq[Filter], + expectedRowCount: Int): Unit = { + withTempPath { path => + val df = sql( + s""" + |SELECT + | c as $nonCollatedCol, + | COLLATE(c, $lcaseCollation) as $collatedCol, + | named_struct('$collatedStructNestedCol', + | COLLATE(c, $lcaseCollation)) as $collatedStructCol, + | array(COLLATE(c, $lcaseCollation)) as $collatedArrayCol, + | map(COLLATE(c, $lcaseCollation), 1) as $collatedMapCol + |FROM VALUES ('aaa'), ('AAA'), ('bbb') + |as data(c) + |""".stripMargin) + + df.write.format(dataSource).save(path.getAbsolutePath) + + val query = spark.read.format(dataSource).load(path.getAbsolutePath) + .filter(filterString) + + val actualPushedFilters = getPushedDownFilters(query) + assert(actualPushedFilters.toSet === expectedPushedFilters.toSet) + assert(query.count() === expectedRowCount) + } + } + + test("do not push down anything for literal comparison") { + testPushDown( + filterString = s"'aaa' COLLATE UNICODE = 'bbb' COLLATE UNICODE", + expectedPushedFilters = Seq.empty, + expectedRowCount = 0) + } + + test("push down null check for collated column") { + testPushDown( + filterString = s"$collatedCol = 'aaa'", + expectedPushedFilters = Seq(IsNotNull(collatedCol)), + expectedRowCount = 2) + } + + test("push down null check for non-equality check") { + testPushDown( + filterString = s"$collatedCol != 'aaa'", + expectedPushedFilters = Seq(IsNotNull(collatedCol)), + expectedRowCount = 1) + } + + test("push down null check for greater than check") { + testPushDown( + filterString = s"$collatedCol > 'aaa'", + expectedPushedFilters = Seq(IsNotNull(collatedCol)), + expectedRowCount = 1) + } + + test("push down null check for gte check") { + testPushDown( + filterString = s"$collatedCol >= 'aaa'", + expectedPushedFilters = Seq(IsNotNull(collatedCol)), + expectedRowCount = 3) + } + + test("push down null check for less than check") { + testPushDown( + filterString = s"$collatedCol < 'aaa'", + expectedPushedFilters = Seq(IsNotNull(collatedCol)), + expectedRowCount = 0) + } + + test("push down null check for lte check") { + testPushDown( + filterString = s"$collatedCol <= 'aaa'", + expectedPushedFilters = Seq(IsNotNull(collatedCol)), + expectedRowCount = 2) + } + + test("push down null check for equality for non-collated column in AND") { + testPushDown( + filterString = s"$collatedCol = 'aaa' AND $nonCollatedCol = 'aaa'", + expectedPushedFilters = + Seq(IsNotNull(collatedCol), IsNotNull(nonCollatedCol), EqualTo(nonCollatedCol, "aaa")), + expectedRowCount = 1) + } + + test("for OR do not push down anything") { + testPushDown( + filterString = s"$collatedCol = 'aaa' OR $nonCollatedCol = 'aaa'", + expectedPushedFilters = Seq.empty, + expectedRowCount = 2) + } + + test("mix or and and") { + testPushDown( + filterString = s"$collatedCol = 'aaa' AND ($nonCollatedCol = 'aaa' OR $collatedCol = 'aaa')", + expectedPushedFilters = Seq(IsNotNull(collatedCol)), + expectedRowCount = 2) + } + + test("negate check on collated column") { + testPushDown( + filterString = s"NOT($collatedCol == 'aaa')", + expectedPushedFilters = Seq(IsNotNull(collatedCol)), + expectedRowCount = 1) + } + + test("compare entire struct - parquet does not support null check on complex types") { + testPushDown( + filterString = s"$collatedStructCol = " + + s"named_struct('$collatedStructNestedCol', collate('aaa', $lcaseCollation))", + expectedPushedFilters = Seq.empty, + expectedRowCount = 2) + } + + test("inner struct field access") { + testPushDown( + filterString = s"$collatedStructFieldAccess = 'aaa'", + expectedPushedFilters = Seq(IsNotNull(collatedStructFieldAccess)), + expectedRowCount = 2) + } + + test("array - parquet does not support null check on complex types") { + testPushDown( + filterString = s"$collatedArrayCol = array(collate('aaa', $lcaseCollation))", + expectedPushedFilters = Seq.empty, + expectedRowCount = 2) + } + + test("map - parquet does not support null check on complex types") { + testPushDown( + filterString = s"map_keys($collatedMapCol) != array(collate('aaa', $lcaseCollation))", + expectedPushedFilters = Seq.empty, + expectedRowCount = 1) + } +} + +class V1CollatedFilterPushDownToParquetSuite extends CollatedFilterPushDownToParquetSuite { + override protected def sparkConf: SparkConf = + super + .sparkConf + .set(SQLConf.USE_V1_SOURCE_LIST, dataSource) + + override def getPushedDownFilters(query: DataFrame): Seq[Filter] = { + var maybeRelation: Option[HadoopFsRelation] = None + val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { + case PhysicalOperation(_, filters, + LogicalRelation(relation: HadoopFsRelation, _, _, _)) => + maybeRelation = Some(relation) + filters + }.flatten + + if (maybeAnalyzedPredicate.isEmpty) { + return Seq.empty + } + + val (_, selectedFilters, _) = + DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate) + + val schema = new SparkToParquetSchemaConverter(conf).convert(query.schema) + val parquetFilters = createParquetFilters(schema) + parquetFilters.convertibleFilters(selectedFilters) + } +} + +class V2CollatedFilterPushDownToParquetSuite extends CollatedFilterPushDownToParquetSuite { + override protected def sparkConf: SparkConf = + super + .sparkConf + .set(SQLConf.USE_V1_SOURCE_LIST, "") + + override def getPushedDownFilters(query: DataFrame): Seq[Filter] = { + query.queryExecution.optimizedPlan.collectFirst { + case PhysicalOperation(_, _, + DataSourceV2ScanRelation(_, scan: ParquetScan, _, _, _)) => + scan.pushedFilters.toSeq + }.getOrElse(Seq.empty) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToReadersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToReadersSuite.scala deleted file mode 100644 index 6388904eaae9..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToReadersSuite.scala +++ /dev/null @@ -1,145 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.collation - -import org.apache.spark.sql.{DataFrame, QueryTest} -import org.apache.spark.sql.execution.ExplainMode -import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSparkSession - -class CollatedFilterPushDownToReadersSuite extends QueryTest - with SharedSparkSession - with AdaptiveSparkPlanHelper { - - private val tblName = "tbl" - private val nonCollatedCol = "c0" - private val collatedCol = "c1" - private val collatedStructCol = "c2" - private val collatedStructNestedCol = "f1" - private val collatedStructFieldAccess = s"$collatedStructCol.$collatedStructNestedCol" - private val collatedArrayCol = "c3" - private val collatedMapCol = "c4" - - private val lcaseCollation = "'UTF8_LCASE'" - private val dataSources = Seq("parquet") - - def testV1AndV2PushDown( - filterString: String, - expectedPushedFilters: Seq[String], - expectedRowCount: Int): Unit = { - def testPushDown(dataSource: String, useV1: Boolean): Unit = { - test(s"collation push down filter: $filterString, source: $dataSource, isV1: $useV1") { - val v1Source = if (useV1) dataSource else "" - withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> v1Source) { - withTestTable(dataSource) { - val df = sql(s"SELECT * FROM $tblName WHERE $filterString") - val actualPushedFilters = getPushedFilters(df) - assert(actualPushedFilters.sorted === expectedPushedFilters.sorted) - assert(df.count() === expectedRowCount) - } - } - } - } - - dataSources.foreach { source => - testPushDown(source, useV1 = true) - testPushDown(source, useV1 = false) - } - } - - def withTestTable(dataSource: String)(fn: => Unit): Unit = { - withTable(tblName) { - sql(s""" - |CREATE TABLE $tblName USING $dataSource AS - |SELECT - | c as $nonCollatedCol, - | COLLATE(c, $lcaseCollation) as $collatedCol, - | named_struct('$collatedStructNestedCol', - | COLLATE(c, $lcaseCollation)) as $collatedStructCol, - | array(COLLATE(c, $lcaseCollation)) as $collatedArrayCol, - | map(COLLATE(c, $lcaseCollation), 1) as $collatedMapCol - |FROM VALUES ('aaa'), ('AAA'), ('bbb') - |as data(c) - |""".stripMargin) - - fn - } - } - - def getPushedFilters(df: DataFrame): Seq[String] = { - val explain = df.queryExecution.explainString(ExplainMode.fromString("extended")) - - // Regular expression to extract text inside the brackets - val pattern = "PushedFilters: \\[(.*?)\\]".r - - pattern.findFirstMatchIn(explain) match { - case Some(m) => - m.group(1) - .split(", ") - .toSeq.filterNot(_ == "") - case None => Seq.empty - } - } - - testV1AndV2PushDown( - filterString = s"'aaa' COLLATE UNICODE = 'bbb' COLLATE UNICODE", - expectedPushedFilters = Seq.empty, - expectedRowCount = 0) - - testV1AndV2PushDown( - filterString = s"$collatedCol = 'aaa'", - expectedPushedFilters = Seq(s"IsNotNull($collatedCol)"), - expectedRowCount = 2) - - testV1AndV2PushDown( - filterString = s"$collatedCol = 'aaa' AND $nonCollatedCol = 'aaa'", - expectedPushedFilters = Seq( - s"IsNotNull($collatedCol)", s"IsNotNull($nonCollatedCol)", s"EqualTo($nonCollatedCol,aaa)"), - expectedRowCount = 1) - - testV1AndV2PushDown( - filterString = s"$collatedCol = 'aaa' OR $nonCollatedCol = 'aaa'", - expectedPushedFilters = Seq.empty, - expectedRowCount = 2) - - testV1AndV2PushDown( - filterString = s"$collatedCol != 'aaa'", - expectedPushedFilters = Seq(s"IsNotNull($collatedCol)"), - expectedRowCount = 1) - - testV1AndV2PushDown( - filterString = s"NOT($collatedCol == 'aaa')", - expectedPushedFilters = Seq(s"IsNotNull($collatedCol)"), - expectedRowCount = 1) - - testV1AndV2PushDown( - filterString = s"$collatedStructFieldAccess = 'aaa'", - expectedPushedFilters = Seq(s"IsNotNull($collatedStructFieldAccess)"), - expectedRowCount = 2) - - testV1AndV2PushDown( - filterString = s"$collatedArrayCol = array(collate('aaa', $lcaseCollation))", - expectedPushedFilters = Seq(s"IsNotNull($collatedArrayCol)"), - expectedRowCount = 2) - - testV1AndV2PushDown( - filterString = s"map_keys($collatedMapCol) != array(collate('aaa', $lcaseCollation))", - expectedPushedFilters = Seq(s"IsNotNull($collatedMapCol)"), - expectedRowCount = 1) -} From a4458b6935ac5626f9130167e8a78ab65db37209 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Thu, 27 Jun 2024 14:03:13 +0200 Subject: [PATCH 10/16] add docstring for new filters --- .../apache/spark/sql/sources/filters.scala | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala index ae188002733a..65162cca005f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.parseColumnPath import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue, NamedReference} -import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse => V2AlwaysFalse, AlwaysTrue => V2AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate} +import org.apache.spark.sql.connector.expressions.filter.{Predicate, AlwaysFalse => V2AlwaysFalse, AlwaysTrue => V2AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or} import org.apache.spark.sql.types.{DataType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -382,9 +382,12 @@ case class AlwaysFalse() extends Filter { object AlwaysFalse extends AlwaysFalse { } -// COLLATION AWARE FILTERS - +/** + * Base class for collation aware string filters. + */ abstract class CollatedFilter() extends Filter { + + /** The corresponding non-collation aware filter. */ def correspondingFilter: Filter def dataType: DataType @@ -392,51 +395,61 @@ abstract class CollatedFilter() extends Filter { override def toV2: Predicate = correspondingFilter.toV2 } +/** Collation aware equivalent of [[EqualTo]]. */ case class CollatedEqualTo(attribute: String, value: Any, dataType: DataType) extends CollatedFilter { override def correspondingFilter: Filter = EqualTo(attribute, value) } +/** Collation aware equivalent of [[EqualNullSafe]]. */ case class CollatedEqualNullSafe(attribute: String, value: Any, dataType: DataType) extends CollatedFilter { override def correspondingFilter: Filter = EqualNullSafe(attribute, value) } +/** Collation aware equivalent of [[GreaterThan]]. */ case class CollatedGreaterThan(attribute: String, value: Any, dataType: DataType) extends CollatedFilter { override def correspondingFilter: Filter = GreaterThan(attribute, value) } +/** Collation aware equivalent of [[GreaterThanOrEqual]]. */ case class CollatedGreaterThanOrEqual(attribute: String, value: Any, dataType: DataType) extends CollatedFilter { override def correspondingFilter: Filter = GreaterThanOrEqual(attribute, value) } +/** Collation aware equivalent of [[LessThan]]. */ case class CollatedLessThan(attribute: String, value: Any, dataType: DataType) extends CollatedFilter { override def correspondingFilter: Filter = LessThan(attribute, value) } +/** Collation aware equivalent of [[LessThanOrEqual]]. */ case class CollatedLessThanOrEqual(attribute: String, value: Any, dataType: DataType) extends CollatedFilter { override def correspondingFilter: Filter = LessThanOrEqual(attribute, value) } +/** Collation aware equivalent of [[In]]. */ case class CollatedIn(attribute: String, values: Array[Any], dataType: DataType) extends CollatedFilter { override def correspondingFilter: Filter = In(attribute, values) } +/** Collation aware equivalent of [[StringStartsWith]]. */ case class CollatedStringStartsWith(attribute: String, value: String, dataType: DataType) extends CollatedFilter { override def correspondingFilter: Filter = StringStartsWith(attribute, value) } +/** Collation aware equivalent of [[StringEndsWith]]. */ case class CollatedStringEndsWith(attribute: String, value: String, dataType: DataType) extends CollatedFilter { override def correspondingFilter: Filter = StringEndsWith(attribute, value) } +/** Collation aware equivalent of [[StringContains]]. */ case class CollatedStringContains(attribute: String, value: String, dataType: DataType) extends CollatedFilter { override def correspondingFilter: Filter = StringContains(attribute, value) From becd8dc358eff36fc6044035e1ff9dc87364ac25 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Thu, 27 Jun 2024 14:06:36 +0200 Subject: [PATCH 11/16] delete unused method --- .../sql/execution/DataSourceScanExec.scala | 4 +- .../datasources/DataSourceUtils.scala | 37 ------------------- 2 files changed, 1 insertion(+), 40 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 33a2c5601c7b..2ebbb9664f67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.execution.datasources.v2.PushedDownOperators import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.vectorized.ConstantColumnVector import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.{AlwaysTrue, BaseRelation, Filter} +import org.apache.spark.sql.sources.{BaseRelation, Filter} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.ArrayImplicits._ @@ -414,8 +414,6 @@ trait FileSourceScanLike extends DataSourceScanExec { case FileSourceConstantMetadataAttribute(_) => true case _ => false }).flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) - .map(DataSourceUtils.removeColl) - .filterNot(filter => filter == AlwaysTrue()) } // This field may execute subquery expressions and should not be accessed during planning. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index 5c392d7a684b..c80dc8307967 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -292,41 +292,4 @@ object DataSourceUtils extends PredicateHelper { case _ => false } } - - def removeColl(filter: sources.Filter): sources.Filter = { - filter match { - case sources.And(left, right) => - val newLeft = removeColl(left) - val newRight = removeColl(right) - if (newLeft == sources.AlwaysTrue()) { - newRight - } else if (newRight == sources.AlwaysTrue()) { - newLeft - } else { - sources.And(newLeft, newRight) - } - - case sources.Or(left, right) => - val newLeft = removeColl(left) - if (newLeft == sources.AlwaysTrue()) { - return sources.AlwaysTrue() - } - val newRight = removeColl(right) - if (newRight == sources.AlwaysTrue()) { - sources.AlwaysTrue() - } else { - sources.Or(newLeft, newRight) - } - - case _: sources.IsNull | _: sources.IsNotNull => - filter - - case other => - if (containsFiltersWithCollation(other)) { - sources.AlwaysTrue() - } else { - other - } - } - } } From 72c53f0a85ecdab04ee706a4fa1f9706920522d8 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Thu, 27 Jun 2024 14:20:54 +0200 Subject: [PATCH 12/16] add more tests --- ...CollatedFilterPushDownToParquetSuite.scala | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToParquetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToParquetSuite.scala index ff7a33726cec..05d1e2948ffc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToParquetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToParquetSuite.scala @@ -134,6 +134,34 @@ abstract class CollatedFilterPushDownToParquetSuite extends QueryTest expectedRowCount = 2) } + test("push down null check for STARTSWITH") { + testPushDown( + filterString = s"STARTSWITH($collatedCol, 'a')", + expectedPushedFilters = Seq(IsNotNull(collatedCol)), + expectedRowCount = 2) + } + + test("push down null check for ENDSWITH") { + testPushDown( + filterString = s"ENDSWITH($collatedCol, 'a')", + expectedPushedFilters = Seq(IsNotNull(collatedCol)), + expectedRowCount = 2) + } + + test("push down null check for CONTAINS") { + testPushDown( + filterString = s"CONTAINS($collatedCol, 'a')", + expectedPushedFilters = Seq(IsNotNull(collatedCol)), + expectedRowCount = 2) + } + + test("no push down for IN") { + testPushDown( + filterString = s"$collatedCol IN ('aaa', 'bbb')", + expectedPushedFilters = Seq.empty, + expectedRowCount = 3) + } + test("push down null check for equality for non-collated column in AND") { testPushDown( filterString = s"$collatedCol = 'aaa' AND $nonCollatedCol = 'aaa'", From e4a29b90ca9a78b86ac3b1ec9f4342d92cd2b45f Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Thu, 27 Jun 2024 14:23:28 +0200 Subject: [PATCH 13/16] small refactor --- .../datasources/DataSourceStrategy.scala | 60 +++++++++---------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 12376146ea4b..5d2310c13070 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -503,34 +503,34 @@ object DataSourceStrategy /** * Creates a collation aware filter if the input data type is string with non-default collation */ - private def collationAwareFilter(filter: sources.Filter, dataType: DataType): Option[Filter] = { + private def collationAwareFilter(filter: sources.Filter, dataType: DataType): Filter = { if (!SchemaUtils.hasNonUTF8BinaryCollation(dataType)) { - return Some(filter) + return filter } filter match { case sources.EqualTo(attribute, value) => - Some(CollatedEqualTo(attribute, value, dataType)) + CollatedEqualTo(attribute, value, dataType) case sources.EqualNullSafe(attribute, value) => - Some(CollatedEqualNullSafe(attribute, value, dataType)) + CollatedEqualNullSafe(attribute, value, dataType) case sources.GreaterThan(attribute, value) => - Some(CollatedGreaterThan(attribute, value, dataType)) + CollatedGreaterThan(attribute, value, dataType) case sources.GreaterThanOrEqual(attribute, value) => - Some(CollatedGreaterThanOrEqual(attribute, value, dataType)) + CollatedGreaterThanOrEqual(attribute, value, dataType) case sources.LessThan(attribute, value) => - Some(CollatedLessThan(attribute, value, dataType)) + CollatedLessThan(attribute, value, dataType) case sources.LessThanOrEqual(attribute, value) => - Some(CollatedLessThanOrEqual(attribute, value, dataType)) + CollatedLessThanOrEqual(attribute, value, dataType) case sources.In(attribute, values) => - Some(CollatedIn(attribute, values, dataType)) + CollatedIn(attribute, values, dataType) case sources.StringStartsWith(attribute, value) => - Some(CollatedStringStartsWith(attribute, value, dataType)) + CollatedStringStartsWith(attribute, value, dataType) case sources.StringEndsWith(attribute, value) => - Some(CollatedStringEndsWith(attribute, value, dataType)) + CollatedStringEndsWith(attribute, value, dataType) case sources.StringContains(attribute, value) => - Some(CollatedStringContains(attribute, value, dataType)) + CollatedStringContains(attribute, value, dataType) case other => - Some(other) + other } } @@ -538,38 +538,38 @@ object DataSourceStrategy predicate: Expression, pushableColumn: PushableColumnBase): Option[Filter] = predicate match { case expressions.EqualTo(e @ pushableColumn(name), Literal(v, t)) => - collationAwareFilter(sources.EqualTo(name, convertToScala(v, t)), e.dataType) + Some(collationAwareFilter(sources.EqualTo(name, convertToScala(v, t)), e.dataType)) case expressions.EqualTo(Literal(v, t), e @ pushableColumn(name)) => - collationAwareFilter(sources.EqualTo(name, convertToScala(v, t)), e.dataType) + Some(collationAwareFilter(sources.EqualTo(name, convertToScala(v, t)), e.dataType)) case expressions.EqualNullSafe(e @ pushableColumn(name), Literal(v, t)) => - collationAwareFilter(sources.EqualNullSafe(name, convertToScala(v, t)), e.dataType) + Some(collationAwareFilter(sources.EqualNullSafe(name, convertToScala(v, t)), e.dataType)) case expressions.EqualNullSafe(Literal(v, t), e @ pushableColumn(name)) => - collationAwareFilter(sources.EqualNullSafe(name, convertToScala(v, t)), e.dataType) + Some(collationAwareFilter(sources.EqualNullSafe(name, convertToScala(v, t)), e.dataType)) case expressions.GreaterThan(e @ pushableColumn(name), Literal(v, t)) => - collationAwareFilter(sources.GreaterThan(name, convertToScala(v, t)), e.dataType) + Some(collationAwareFilter(sources.GreaterThan(name, convertToScala(v, t)), e.dataType)) case expressions.GreaterThan(Literal(v, t), e @ pushableColumn(name)) => - collationAwareFilter(sources.LessThan(name, convertToScala(v, t)), e.dataType) + Some(collationAwareFilter(sources.LessThan(name, convertToScala(v, t)), e.dataType)) case expressions.LessThan(e @ pushableColumn(name), Literal(v, t)) => - collationAwareFilter(sources.LessThan(name, convertToScala(v, t)), e.dataType) + Some(collationAwareFilter(sources.LessThan(name, convertToScala(v, t)), e.dataType)) case expressions.LessThan(Literal(v, t), e @ pushableColumn(name)) => - collationAwareFilter(sources.GreaterThan(name, convertToScala(v, t)), e.dataType) + Some(collationAwareFilter(sources.GreaterThan(name, convertToScala(v, t)), e.dataType)) case expressions.GreaterThanOrEqual(e @ pushableColumn(name), Literal(v, t)) => - collationAwareFilter(sources.GreaterThanOrEqual(name, convertToScala(v, t)), e.dataType) + Some(collationAwareFilter(sources.GreaterThanOrEqual(name, convertToScala(v, t)), e.dataType)) case expressions.GreaterThanOrEqual(Literal(v, t), e @ pushableColumn(name)) => - collationAwareFilter(sources.LessThanOrEqual(name, convertToScala(v, t)), e.dataType) + Some(collationAwareFilter(sources.LessThanOrEqual(name, convertToScala(v, t)), e.dataType)) case expressions.LessThanOrEqual(e @ pushableColumn(name), Literal(v, t)) => - collationAwareFilter(sources.LessThanOrEqual(name, convertToScala(v, t)), e.dataType) + Some(collationAwareFilter(sources.LessThanOrEqual(name, convertToScala(v, t)), e.dataType)) case expressions.LessThanOrEqual(Literal(v, t), e @ pushableColumn(name)) => - collationAwareFilter(sources.GreaterThanOrEqual(name, convertToScala(v, t)), e.dataType) + Some(collationAwareFilter(sources.GreaterThanOrEqual(name, convertToScala(v, t)), e.dataType)) case expressions.InSet(e @ pushableColumn(name), set) => val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType) - collationAwareFilter(sources.In(name, set.toArray.map(toScala)), e.dataType) + Some(collationAwareFilter(sources.In(name, set.toArray.map(toScala)), e.dataType)) // Because we only convert In to InSet in Optimizer when there are more than certain // items. So it is possible we still get an In expression here that needs to be pushed @@ -577,20 +577,20 @@ object DataSourceStrategy case expressions.In(e @ pushableColumn(name), list) if list.forall(_.isInstanceOf[Literal]) => val hSet = list.map(_.eval(EmptyRow)) val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType) - collationAwareFilter(sources.In(name, hSet.toArray.map(toScala)), e.dataType) + Some(collationAwareFilter(sources.In(name, hSet.toArray.map(toScala)), e.dataType)) case expressions.IsNull(pushableColumn(name)) => Some(sources.IsNull(name)) case expressions.IsNotNull(pushableColumn(name)) => Some(sources.IsNotNull(name)) case expressions.StartsWith(e @ pushableColumn(name), Literal(v: UTF8String, StringType)) => - collationAwareFilter(sources.StringStartsWith(name, v.toString), e.dataType) + Some(collationAwareFilter(sources.StringStartsWith(name, v.toString), e.dataType)) case expressions.EndsWith(e @ pushableColumn(name), Literal(v: UTF8String, StringType)) => - collationAwareFilter(sources.StringEndsWith(name, v.toString), e.dataType) + Some(collationAwareFilter(sources.StringEndsWith(name, v.toString), e.dataType)) case expressions.Contains(e @ pushableColumn(name), Literal(v: UTF8String, StringType)) => - collationAwareFilter(sources.StringContains(name, v.toString), e.dataType) + Some(collationAwareFilter(sources.StringContains(name, v.toString), e.dataType)) case expressions.Literal(true, BooleanType) => Some(sources.AlwaysTrue) From 31421215432ecc90117c2b031e33af353b44ce2f Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Thu, 27 Jun 2024 14:27:49 +0200 Subject: [PATCH 14/16] fix scalastyle --- .../src/main/scala/org/apache/spark/sql/sources/filters.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala index 65162cca005f..05e19dec50d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.parseColumnPath import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue, NamedReference} -import org.apache.spark.sql.connector.expressions.filter.{Predicate, AlwaysFalse => V2AlwaysFalse, AlwaysTrue => V2AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or} +import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse => V2AlwaysFalse, AlwaysTrue => V2AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate} import org.apache.spark.sql.types.{DataType, StringType} import org.apache.spark.unsafe.types.UTF8String From 0d3b2f97db29748e1ec263989af700bbc79b03eb Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Thu, 27 Jun 2024 17:36:25 +0200 Subject: [PATCH 15/16] delete old tests --- .../datasources/DataSourceStrategySuite.scala | 32 ------------------- 1 file changed, 32 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala index 834225baf070..9f0396ab60e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala @@ -357,36 +357,4 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { Some(sources.GreaterThanOrEqual("col", "value"))) testTranslateFilter(IsNotNull(colAttr), Some(sources.IsNotNull("col"))) } - - for (collation <- Seq("UTF8_LCASE", "UNICODE")) { - test(s"SPARK-48431: Filter pushdown on columns with $collation collation") { - val colAttr = $"col".string(collation) - - // No pushdown for all comparison based filters. - testTranslateFilter(EqualTo(colAttr, Literal("value")), Some(sources.AlwaysTrue)) - testTranslateFilter(LessThan(colAttr, Literal("value")), Some(sources.AlwaysTrue)) - testTranslateFilter(LessThan(colAttr, Literal("value")), Some(sources.AlwaysTrue)) - testTranslateFilter(LessThanOrEqual(colAttr, Literal("value")), Some(sources.AlwaysTrue)) - testTranslateFilter(GreaterThan(colAttr, Literal("value")), Some(sources.AlwaysTrue)) - testTranslateFilter(GreaterThanOrEqual(colAttr, Literal("value")), Some(sources.AlwaysTrue)) - - // Allow pushdown of Is(Not)Null filter. - testTranslateFilter(IsNotNull(colAttr), Some(sources.IsNotNull("col"))) - testTranslateFilter(IsNull(colAttr), Some(sources.IsNull("col"))) - - // Top level filter splitting at And and Or. - testTranslateFilter(And(EqualTo(colAttr, Literal("value")), IsNotNull(colAttr)), - Some(sources.And(sources.AlwaysTrue, sources.IsNotNull("col")))) - testTranslateFilter(Or(EqualTo(colAttr, Literal("value")), IsNotNull(colAttr)), - Some(sources.Or(sources.AlwaysTrue, sources.IsNotNull("col")))) - - // Different cases involving Not. - testTranslateFilter(Not(EqualTo(colAttr, Literal("value"))), Some(sources.AlwaysTrue)) - testTranslateFilter(And(Not(EqualTo(colAttr, Literal("value"))), IsNotNull(colAttr)), - Some(sources.And(sources.AlwaysTrue, sources.IsNotNull("col")))) - // This filter would work, but we want to keep the translation logic simple. - testTranslateFilter(And(EqualTo(colAttr, Literal("value")), Not(IsNotNull(colAttr))), - Some(sources.And(sources.AlwaysTrue, sources.AlwaysTrue))) - } - } } From 92845559d001eccc92c0885a41faba5c47efa106 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Mon, 1 Jul 2024 11:25:39 +0200 Subject: [PATCH 16/16] add evolving to all apis --- .../scala/org/apache/spark/sql/sources/filters.scala | 11 +++++++++++ .../CollatedFilterPushDownToParquetSuite.scala | 6 +++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala index 05e19dec50d4..88f556130bfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -385,6 +385,7 @@ object AlwaysFalse extends AlwaysFalse { /** * Base class for collation aware string filters. */ +@Evolving abstract class CollatedFilter() extends Filter { /** The corresponding non-collation aware filter. */ @@ -396,60 +397,70 @@ abstract class CollatedFilter() extends Filter { } /** Collation aware equivalent of [[EqualTo]]. */ +@Evolving case class CollatedEqualTo(attribute: String, value: Any, dataType: DataType) extends CollatedFilter { override def correspondingFilter: Filter = EqualTo(attribute, value) } /** Collation aware equivalent of [[EqualNullSafe]]. */ +@Evolving case class CollatedEqualNullSafe(attribute: String, value: Any, dataType: DataType) extends CollatedFilter { override def correspondingFilter: Filter = EqualNullSafe(attribute, value) } /** Collation aware equivalent of [[GreaterThan]]. */ +@Evolving case class CollatedGreaterThan(attribute: String, value: Any, dataType: DataType) extends CollatedFilter { override def correspondingFilter: Filter = GreaterThan(attribute, value) } /** Collation aware equivalent of [[GreaterThanOrEqual]]. */ +@Evolving case class CollatedGreaterThanOrEqual(attribute: String, value: Any, dataType: DataType) extends CollatedFilter { override def correspondingFilter: Filter = GreaterThanOrEqual(attribute, value) } /** Collation aware equivalent of [[LessThan]]. */ +@Evolving case class CollatedLessThan(attribute: String, value: Any, dataType: DataType) extends CollatedFilter { override def correspondingFilter: Filter = LessThan(attribute, value) } /** Collation aware equivalent of [[LessThanOrEqual]]. */ +@Evolving case class CollatedLessThanOrEqual(attribute: String, value: Any, dataType: DataType) extends CollatedFilter { override def correspondingFilter: Filter = LessThanOrEqual(attribute, value) } /** Collation aware equivalent of [[In]]. */ +@Evolving case class CollatedIn(attribute: String, values: Array[Any], dataType: DataType) extends CollatedFilter { override def correspondingFilter: Filter = In(attribute, values) } /** Collation aware equivalent of [[StringStartsWith]]. */ +@Evolving case class CollatedStringStartsWith(attribute: String, value: String, dataType: DataType) extends CollatedFilter { override def correspondingFilter: Filter = StringStartsWith(attribute, value) } /** Collation aware equivalent of [[StringEndsWith]]. */ +@Evolving case class CollatedStringEndsWith(attribute: String, value: String, dataType: DataType) extends CollatedFilter { override def correspondingFilter: Filter = StringEndsWith(attribute, value) } /** Collation aware equivalent of [[StringContains]]. */ +@Evolving case class CollatedStringContains(attribute: String, value: String, dataType: DataType) extends CollatedFilter { override def correspondingFilter: Filter = StringContains(attribute, value) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToParquetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToParquetSuite.scala index 05d1e2948ffc..ab8e82162ce1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToParquetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollatedFilterPushDownToParquetSuite.scala @@ -177,7 +177,7 @@ abstract class CollatedFilterPushDownToParquetSuite extends QueryTest expectedRowCount = 2) } - test("mix or and and") { + test("mix OR and AND") { testPushDown( filterString = s"$collatedCol = 'aaa' AND ($nonCollatedCol = 'aaa' OR $collatedCol = 'aaa')", expectedPushedFilters = Seq(IsNotNull(collatedCol)), @@ -221,7 +221,7 @@ abstract class CollatedFilterPushDownToParquetSuite extends QueryTest } } -class V1CollatedFilterPushDownToParquetSuite extends CollatedFilterPushDownToParquetSuite { +class CollatedFilterPushDownToParquetV1Suite extends CollatedFilterPushDownToParquetSuite { override protected def sparkConf: SparkConf = super .sparkConf @@ -249,7 +249,7 @@ class V1CollatedFilterPushDownToParquetSuite extends CollatedFilterPushDownToPar } } -class V2CollatedFilterPushDownToParquetSuite extends CollatedFilterPushDownToParquetSuite { +class CollatedFilterPushDownToParquetV2Suite extends CollatedFilterPushDownToParquetSuite { override protected def sparkConf: SparkConf = super .sparkConf