diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 7302b63646d6..97be71b8dbc2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -202,3 +202,21 @@ object Unions { } } } + +/** + * A pattern that finds the original expression from a sequence of casts. + */ +object Casts { + def unapply(expr: Expression): Option[Expression] = expr match { + case c: Cast => Some(collectCasts(expr)) + case _ => None + } + + private def collectCasts(e: Expression): Expression = { + if (e.isInstanceOf[Cast]) { + collectCasts(e.children(0)) + } else { + e + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 6e2a5aa4f97c..e198e44cbb39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -21,6 +21,7 @@ import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer} +import org.apache.spark.sql.catalyst.planning.Casts import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics} import org.apache.spark.sql.types.LongType @@ -80,12 +81,21 @@ case class Filter(condition: Expression, child: SparkPlan) // Split out all the IsNotNulls from condition. private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition { case IsNotNull(a) if child.output.contains(a) => true + case IsNotNull(a) => + a match { + case Casts(a) if child.output.contains(a) => true + case _ => false + } case _ => false } - // The columns that will filtered out by `IsNotNull` could be considered as not nullable. + // The columns that will be filtered out by `IsNotNull` could be considered as not nullable. private val notNullAttributes = notNullPreds.flatMap(_.references) + // only the attributes those will be filtered out by `IsNotNull` should be evaluated + // before this plan, otherwise we could defer the evaluation until filtering out nulls. + override def usedInputs: AttributeSet = AttributeSet(notNullAttributes) + override def output: Seq[Attribute] = { child.output.map { a => if (a.nullable && notNullAttributes.contains(a)) { @@ -110,6 +120,9 @@ case class Filter(condition: Expression, child: SparkPlan) override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = { val numOutput = metricTerm(ctx, "numOutputRows") + val evaluated = + evaluateRequiredVariables(child.output, input, references -- usedInputs) + // filter out the nulls val filterOutNull = notNullAttributes.map { a => val idx = child.output.indexOf(a) @@ -142,6 +155,7 @@ case class Filter(condition: Expression, child: SparkPlan) } s""" |$filterOutNull + |$evaluated |$predicates |$numOutput.add(1); |${consume(ctx, resultVars)}