Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
initial working version
  • Loading branch information
stefankandic committed Jun 25, 2024
commit b698defbb640f11e2a9b9368f1be4e7c97d104f8
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,18 @@ case class StringContains(attribute: String, value: String) extends Filter {
Array(toV2Column(attribute), LiteralValue(UTF8String.fromString(value), StringType)))
}

/**
* A.
* @param filter a.
* @param fullyTranslated a.
*/
@Evolving
case class TranslatedFilter(filter: Filter, fullyTranslated: Boolean) {
def withFilter(newFilter: Filter): TranslatedFilter = {
copy(filter = newFilter)
}
}

/**
* A filter that always evaluates to `true`.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ trait FileSourceScanLike extends DataSourceScanExec {
scalarSubqueryReplaced.filterNot(_.references.exists {
case FileSourceConstantMetadataAttribute(_) => true
case _ => false
}).flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown))
}).flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown).map(_.filter))
}

// This field may execute subquery expressions and should not be accessed during planning.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ import org.apache.spark.sql.execution.streaming.StreamingRelation
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.{PartitioningUtils => CatalystPartitioningUtils}
import org.apache.spark.sql.util.{CaseInsensitiveStringMap, SchemaUtils}
import org.apache.spark.sql.util.{CaseInsensitiveStringMap}
import org.apache.spark.unsafe.types.UTF8String

/**
Expand Down Expand Up @@ -573,10 +573,11 @@ object DataSourceStrategy
/**
* Tries to translate a Catalyst [[Expression]] into data source [[Filter]].
*
* @return a `Some[Filter]` if the input [[Expression]] is convertible, otherwise a `None`.
* @return a `Some[TranslatedFilter]` if the input [[Expression]] is at least partially
* convertible, otherwise a `None`.
*/
protected[sql] def translateFilter(
predicate: Expression, supportNestedPredicatePushdown: Boolean): Option[Filter] = {
predicate: Expression, supportNestedPredicatePushdown: Boolean): Option[TranslatedFilter] = {
translateFilterWithMapping(predicate, None, supportNestedPredicatePushdown)
}

Expand All @@ -588,19 +589,24 @@ object DataSourceStrategy
* translated [[Filter]]. The map is used for rebuilding
* [[Expression]] from [[Filter]].
* @param nestedPredicatePushdownEnabled Whether nested predicate pushdown is enabled.
* @return a `Some[Filter]` if the input [[Expression]] is convertible, otherwise a `None`.
* @return a `Some[TranslatedFilter]` if the input [[Expression]] is at least partially
* convertible, otherwise a `None`.
*/
protected[sql] def translateFilterWithMapping(
predicate: Expression,
translatedFilterToExpr: Option[mutable.HashMap[sources.Filter, Expression]],
translatedFilterToExpr: Option[mutable.HashMap[Filter, Expression]],
nestedPredicatePushdownEnabled: Boolean)
: Option[Filter] = {
: Option[TranslatedFilter] = {

def translateAndRecordLeafNodeFilter(filter: Expression): Option[Filter] = {
def translateAndRecordLeafNodeFilter(filter: Expression, canBeFullyTranslated: Boolean = true)
: Option[TranslatedFilter] = {
val translatedFilter =
translateLeafNodeFilter(filter, PushableColumn(nestedPredicatePushdownEnabled))
translateLeafNodeFilter(filter, PushableColumn(nestedPredicatePushdownEnabled)) match {
case Some(f) => Some(TranslatedFilter(f, canBeFullyTranslated))
case None => None
}
if (translatedFilter.isDefined && translatedFilterToExpr.isDefined) {
translatedFilterToExpr.get(translatedFilter.get) = predicate
translatedFilterToExpr.get(translatedFilter.get.filter) = predicate
}
translatedFilter
}
Expand All @@ -621,32 +627,37 @@ object DataSourceStrategy
left, translatedFilterToExpr, nestedPredicatePushdownEnabled)
rightFilter <- translateFilterWithMapping(
right, translatedFilterToExpr, nestedPredicatePushdownEnabled)
} yield sources.And(leftFilter, rightFilter)
} yield TranslatedFilter(sources.And(leftFilter.filter, rightFilter.filter),
leftFilter.fullyTranslated && rightFilter.fullyTranslated)

case expressions.Or(left, right) =>
for {
leftFilter <- translateFilterWithMapping(
left, translatedFilterToExpr, nestedPredicatePushdownEnabled)
rightFilter <- translateFilterWithMapping(
right, translatedFilterToExpr, nestedPredicatePushdownEnabled)
} yield sources.Or(leftFilter, rightFilter)
} yield TranslatedFilter(sources.Or(leftFilter.filter, rightFilter.filter),
leftFilter.fullyTranslated && rightFilter.fullyTranslated)

case notNull @ expressions.IsNotNull(_: AttributeReference) =>
case notNull @ expressions.IsNotNull(_: AttributeReference | _: GetStructField) =>
// Not null filters on attribute references can always be pushed, also for collated columns.
translateAndRecordLeafNodeFilter(notNull)

case isNull @ expressions.IsNull(_: AttributeReference) =>
case isNull @ expressions.IsNull(_: AttributeReference | _: GetStructField) =>
// Is null filters on attribute references can always be pushed, also for collated columns.
translateAndRecordLeafNodeFilter(isNull)

case p if p.references.exists(ref => SchemaUtils.hasNonUTF8BinaryCollation(ref.dataType)) =>
// The filter cannot be pushed and we widen it to be AlwaysTrue(). This is only valid if
// the result of the filter is not negated by a Not expression it is wrapped in.
translateAndRecordLeafNodeFilter(Literal.TrueLiteral)
case p if DataSourceUtils.hasNonUTF8BinaryCollation(p) =>
// The filter cannot be pushed and we widen it to be AlwaysTrue() and set it
// as partially translated so it has to get evaluated by the engine as well.
// This is only valid if the result of the filter is not negated by a
// Not expression it is wrapped in.
translateAndRecordLeafNodeFilter(Literal.TrueLiteral, canBeFullyTranslated = false)

case expressions.Not(child) =>
translateFilterWithMapping(child, translatedFilterToExpr, nestedPredicatePushdownEnabled)
.map(sources.Not)
.map(translatedFilter =>
translatedFilter.withFilter(sources.Not(translatedFilter.filter)))

case other =>
translateAndRecordLeafNodeFilter(other)
Expand Down Expand Up @@ -693,21 +704,22 @@ object DataSourceStrategy
// If a predicate is not in this map, it means it cannot be pushed down.
val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation)
// SPARK-41636: we keep the order of the predicates to avoid CodeGenerator cache misses
val translatedMap: Map[Expression, Filter] = ListMap(predicates.flatMap { p =>
val translatedMap: Map[Expression, TranslatedFilter] = ListMap(predicates.flatMap { p =>
translateFilter(p, supportNestedPredicatePushdown).map(f => p -> f)
}: _*)

val pushedFilters: Seq[Filter] = translatedMap.values.toSeq
val pushedFilters: Seq[Filter] = translatedMap.values.map(_.filter).toSeq

// Catalyst predicate expressions that cannot be converted to data source filters.
val nonconvertiblePredicates = predicates.filterNot(translatedMap.contains)
// Catalyst predicate expressions that cannot be fully converted to data source filters.
val nonconvertiblePredicates = predicates.filter(predicate =>
!translatedMap.contains(predicate) || !translatedMap(predicate).fullyTranslated)

// Data source filters that cannot be handled by `relation`. An unhandled filter means
// the data source cannot guarantee the rows returned can pass the filter.
// As a result we must return it so Spark can plan an extra filter operator.
val unhandledFilters = relation.unhandledFilters(translatedMap.values.toArray).toSet
val unhandledFilters = relation.unhandledFilters(pushedFilters.toArray).toSet
val unhandledPredicates = translatedMap.filter { case (p, f) =>
unhandledFilters.contains(f)
unhandledFilters.contains(f.filter)
}.keys
val handledFilters = pushedFilters.toSet -- unhandledFilters

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.json4s.jackson.Serialization
import org.apache.spark.{SparkException, SparkUpgradeException}
import org.apache.spark.sql.{SPARK_LEGACY_DATETIME_METADATA_KEY, SPARK_LEGACY_INT96_METADATA_KEY, SPARK_TIMEZONE_METADATA_KEY, SPARK_VERSION_METADATA_KEY}
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogUtils}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, ExpressionSet, GetStructField, PredicateHelper}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExpressionSet, PredicateHelper}
import org.apache.spark.sql.catalyst.util.RebaseDateTime
import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
Expand Down Expand Up @@ -281,22 +281,9 @@ object DataSourceUtils extends PredicateHelper {
}

/**
* Determines whether a filter should be pushed down to the data source or not.
*
* @param expression The filter expression to be evaluated.
* @param isCollationPushDownSupported Whether the data source supports collation push down.
* @return A boolean indicating whether the filter should be pushed down or not.
* Determines whether a filter references any columns with non-UTF8 binary collation.
*/
def shouldPushFilter(expression: Expression, isCollationPushDownSupported: Boolean): Boolean = {
if (!expression.deterministic) return false

isCollationPushDownSupported || !expression.exists {
case childExpression @ (_: Attribute | _: GetStructField) =>
// don't push down filters for types with non-binary sortable collation
// as it could lead to incorrect results
SchemaUtils.hasNonUTF8BinaryCollation(childExpression.dataType)

case _ => false
}
def hasNonUTF8BinaryCollation(expression: Expression): Boolean = {
expression.references.exists(ref => SchemaUtils.hasNonUTF8BinaryCollation(ref.dataType))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -223,12 +223,6 @@ trait FileFormat {
*/
def fileConstantMetadataExtractors: Map[String, PartitionedFile => Any] =
FileFormat.BASE_METADATA_EXTRACTORS

/**
* Returns whether the file format supports filter push down
* for non utf8 binary collated columns.
*/
def supportsCollationPushDown: Boolean = false
}

object FileFormat {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,8 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging {
// - filters that need to be evaluated again after the scan
val filterSet = ExpressionSet(filters)

val filtersToPush = filters.filter(f =>
DataSourceUtils.shouldPushFilter(f, fsRelation.fileFormat.supportsCollationPushDown))

val normalizedFilters = DataSourceStrategy.normalizeExprs(
filtersToPush, l.output)
filters.filter(_.deterministic), l.output)

val partitionColumns =
l.resolve(
Expand Down Expand Up @@ -206,6 +203,7 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging {
DataSourceUtils.supportNestedPredicatePushdown(fsRelation)
val pushedFilters = dataFilters
.flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown))
.map(_.filter)
logInfo(log"Pushed Filters: ${MDC(PUSHED_FILTERS, pushedFilters.mkString(","))}")

// Predicates with both partition keys and attributes need to be evaluated after the scan.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
_))
if filters.nonEmpty && fsRelation.partitionSchema.nonEmpty =>
val normalizedFilters = DataSourceStrategy.normalizeExprs(
filters.filter(f => !SubqueryExpression.hasSubquery(f) &&
DataSourceUtils.shouldPushFilter(f, fsRelation.fileFormat.supportsCollationPushDown)),
filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)),
logicalRelation.output)
val (partitionKeyFilters, _) = DataSourceUtils
.getPartitionFiltersAndDataFilters(partitionSchema, normalizedFilters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,21 +70,20 @@ abstract class FileScanBuilder(
}

override def pushFilters(filters: Seq[Expression]): Seq[Expression] = {
val (filtersToPush, filtersToRemain) = filters.partition(
f => DataSourceUtils.shouldPushFilter(f, supportsCollationPushDown))
val (deterministicFilters, nonDeterminsticFilters) = filters.partition(_.deterministic)
val (partitionFilters, dataFilters) =
DataSourceUtils.getPartitionFiltersAndDataFilters(partitionSchema, filtersToPush)
DataSourceUtils.getPartitionFiltersAndDataFilters(partitionSchema, deterministicFilters)
this.partitionFilters = partitionFilters
this.dataFilters = dataFilters
val translatedFilters = mutable.ArrayBuffer.empty[sources.Filter]
for (filterExpr <- dataFilters) {
val translated = DataSourceStrategy.translateFilter(filterExpr, true)
if (translated.nonEmpty) {
translatedFilters += translated.get
translatedFilters += translated.get.filter
}
}
pushedDataFilters = pushDataFilters(translatedFilters.toArray)
dataFilters ++ filtersToRemain
dataFilters ++ nonDeterminsticFilters
}

override def pushedFilters: Array[Predicate] = pushedDataFilters.map(_.toV2)
Expand All @@ -96,12 +95,6 @@ abstract class FileScanBuilder(
*/
protected def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = Array.empty[Filter]

/**
* Returns whether the file scan builder supports filter pushdown
* for non utf8 binary collated columns.
*/
protected def supportsCollationPushDown: Boolean = false

private def createRequiredNameSet(): Set[String] =
requiredSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ object PushDownUtils {
if (translated.isEmpty) {
untranslatableExprs += filterExpr
} else {
translatedFilters += translated.get
translatedFilters += translated.get.filter
if (!translated.get.fullyTranslated) {
untranslatableExprs += filterExpr
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1265,36 +1265,52 @@ class FileBasedDataSourceSuite extends QueryTest
df.write.format(format).save(path.getAbsolutePath)

// filter and expected result
val filters = Seq(
val filterTypes = Seq(
("==", Seq(Row("aaa"), Row("AAA"))),
("!=", Seq(Row("bbb"))),
("<", Seq()),
("<=", Seq(Row("aaa"), Row("AAA"))),
(">", Seq(Row("bbb"))),
(">=", Seq(Row("aaa"), Row("AAA"), Row("bbb"))))

filters.foreach { filter =>
val readback = spark.read
.format(format)
.load(path.getAbsolutePath)
.where(s"c1 ${filter._1} collate('aaa', $collation)")
.where(s"str ${filter._1} struct(collate('aaa', $collation))")
.where(s"namedstr.f1.f2 ${filter._1} collate('aaa', $collation)")
.where(s"arr ${filter._1} array(collate('aaa', $collation))")
.where(s"map_keys(map1) ${filter._1} array(collate('aaa', $collation))")
.where(s"map_values(map2) ${filter._1} array(collate('aaa', $collation))")
.select("c1")

val explain = readback.queryExecution.explainString(
ExplainMode.fromString("extended"))
assert(explain.contains("PushedFilters: []"))
checkAnswer(readback, filter._2)
filterTypes.foreach { filterType =>
Seq(
s"c1 $filterType collate('aaa', $collation)",
s"str $filterType struct(collate('aaa', $collation))",
s"namedstr.f1.f2 $filterType collate('aaa', $collation)",
s"arr $filterType array(collate('aaa', $collation))",
s"map_keys(map1) $filterType array(collate('aaa', $collation))",
s"map_values(map2) $filterType array(collate('aaa', $collation))",
).foreach { filterString =>

val readback = spark.read
.format(format)
.load(path.getAbsolutePath)
.where(filterString)
.select("c1")

val pus = getPushedFilters(readback)
getPushedFilters(readback).foreach { filter =>
assert(filter === "AlwaysTrue()" || filter.startsWith("IsNotNull"))
}
checkAnswer(readback, filterType._2)
}
}
}
}
}
}
}

def getPushedFilters(df: DataFrame): Set[String] = {
val explain = df.queryExecution.explainString(ExplainMode.fromString("extended"))
assert(explain.contains("PushedFilters:"))

// Regular expression to extract text inside the brackets
val pattern = "PushedFilters: \\[(.*?)]".r

pattern.findFirstMatchIn(explain).get.group(1).split(", ").toSet
}
}

object TestingUDT {
Expand Down
Loading