Skip to content

Commit 66a99f4

Browse files
gatorsmilehvanhovell
authored andcommitted
[SPARK-17981][SPARK-17957][SQL] Fix Incorrect Nullability Setting to False in FilterExec
### What changes were proposed in this pull request? When `FilterExec` contains `isNotNull`, which could be inferred and pushed down or users specified, we convert the nullability of the involved columns if the top-layer expression is null-intolerant. However, this is not correct, if the top-layer expression is not a leaf expression, it could still tolerate the null when it has null-tolerant child expressions. For example, `cast(coalesce(a#5, a#15) as double)`. Although `cast` is a null-intolerant expression, but obviously`coalesce` is null-tolerant. Thus, it could eat null. When the nullability is wrong, we could generate incorrect results in different cases. For example, ``` Scala val df1 = Seq((1, 2), (2, 3)).toDF("a", "b") val df2 = Seq((2, 5), (3, 4)).toDF("a", "c") val joinedDf = df1.join(df2, Seq("a"), "outer").na.fill(0) val df3 = Seq((3, 1)).toDF("a", "d") joinedDf.join(df3, "a").show ``` The optimized plan is like ``` Project [a#29, b#30, c#31, d#42] +- Join Inner, (a#29 = a#41) :- Project [cast(coalesce(cast(coalesce(a#5, a#15) as double), 0.0) as int) AS a#29, cast(coalesce(cast(b#6 as double), 0.0) as int) AS b#30, cast(coalesce(cast(c#16 as double), 0.0) as int) AS c#31] : +- Filter isnotnull(cast(coalesce(cast(coalesce(a#5, a#15) as double), 0.0) as int)) : +- Join FullOuter, (a#5 = a#15) : :- LocalRelation [a#5, b#6] : +- LocalRelation [a#15, c#16] +- LocalRelation [a#41, d#42] ``` Without the fix, it returns an empty result. With the fix, it can return a correct answer: ``` +---+---+---+---+ | a| b| c| d| +---+---+---+---+ | 3| 0| 4| 1| +---+---+---+---+ ``` ### How was this patch tested? Added test cases to verify the nullability changes in FilterExec. Also added a test case for verifying the reported incorrect result. Author: gatorsmile <[email protected]> Closes #15523 from gatorsmile/nullabilityFilterExec.
1 parent 9dc9f9a commit 66a99f4

File tree

2 files changed

+79
-3
lines changed

2 files changed

+79
-3
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,13 @@ case class FilterExec(condition: Expression, child: SparkPlan)
9090

9191
// Split out all the IsNotNulls from condition.
9292
private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition {
93-
case IsNotNull(a: NullIntolerant) if a.references.subsetOf(child.outputSet) => true
93+
case IsNotNull(a) => isNullIntolerant(a) && a.references.subsetOf(child.outputSet)
94+
case _ => false
95+
}
96+
97+
// If one expression and its children are null intolerant, it is null intolerant.
98+
private def isNullIntolerant(expr: Expression): Boolean = expr match {
99+
case e: NullIntolerant => e.children.forall(isNullIntolerant)
94100
case _ => false
95101
}
96102

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ import org.scalatest.Matchers._
2828

2929
import org.apache.spark.SparkException
3030
import org.apache.spark.sql.catalyst.TableIdentifier
31-
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project, Union}
32-
import org.apache.spark.sql.execution.QueryExecution
31+
import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Project, Union}
32+
import org.apache.spark.sql.execution.{FilterExec, QueryExecution}
3333
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
3434
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange}
3535
import org.apache.spark.sql.functions._
@@ -1635,6 +1635,76 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
16351635
}
16361636
}
16371637

1638+
private def verifyNullabilityInFilterExec(
1639+
df: DataFrame,
1640+
expr: String,
1641+
expectedNonNullableColumns: Seq[String]): Unit = {
1642+
val dfWithFilter = df.where(s"isnotnull($expr)").selectExpr(expr)
1643+
// In the logical plan, all the output columns of input dataframe are nullable
1644+
dfWithFilter.queryExecution.optimizedPlan.collect {
1645+
case e: Filter => assert(e.output.forall(_.nullable))
1646+
}
1647+
1648+
dfWithFilter.queryExecution.executedPlan.collect {
1649+
// When the child expression in isnotnull is null-intolerant (i.e. any null input will
1650+
// result in null output), the involved columns are converted to not nullable;
1651+
// otherwise, no change should be made.
1652+
case e: FilterExec =>
1653+
assert(e.output.forall { o =>
1654+
if (expectedNonNullableColumns.contains(o.name)) !o.nullable else o.nullable
1655+
})
1656+
}
1657+
}
1658+
1659+
test("SPARK-17957: no change on nullability in FilterExec output") {
1660+
val df = sparkContext.parallelize(Seq(
1661+
null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3),
1662+
new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer],
1663+
new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF()
1664+
1665+
verifyNullabilityInFilterExec(df,
1666+
expr = "Rand()", expectedNonNullableColumns = Seq.empty[String])
1667+
verifyNullabilityInFilterExec(df,
1668+
expr = "coalesce(_1, _2)", expectedNonNullableColumns = Seq.empty[String])
1669+
verifyNullabilityInFilterExec(df,
1670+
expr = "coalesce(_1, 0) + Rand()", expectedNonNullableColumns = Seq.empty[String])
1671+
verifyNullabilityInFilterExec(df,
1672+
expr = "cast(coalesce(cast(coalesce(_1, _2) as double), 0.0) as int)",
1673+
expectedNonNullableColumns = Seq.empty[String])
1674+
}
1675+
1676+
test("SPARK-17957: set nullability to false in FilterExec output") {
1677+
val df = sparkContext.parallelize(Seq(
1678+
null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3),
1679+
new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer],
1680+
new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF()
1681+
1682+
verifyNullabilityInFilterExec(df,
1683+
expr = "_1 + _2 * 3", expectedNonNullableColumns = Seq("_1", "_2"))
1684+
verifyNullabilityInFilterExec(df,
1685+
expr = "_1 + _2", expectedNonNullableColumns = Seq("_1", "_2"))
1686+
verifyNullabilityInFilterExec(df,
1687+
expr = "_1", expectedNonNullableColumns = Seq("_1"))
1688+
// `constructIsNotNullConstraints` infers the IsNotNull(_2) from IsNotNull(_2 + Rand())
1689+
// Thus, we are able to set nullability of _2 to false.
1690+
// If IsNotNull(_2) is not given from `constructIsNotNullConstraints`, the impl of
1691+
// isNullIntolerant in `FilterExec` needs an update for more advanced inference.
1692+
verifyNullabilityInFilterExec(df,
1693+
expr = "_2 + Rand()", expectedNonNullableColumns = Seq("_2"))
1694+
verifyNullabilityInFilterExec(df,
1695+
expr = "_2 * 3 + coalesce(_1, 0)", expectedNonNullableColumns = Seq("_2"))
1696+
verifyNullabilityInFilterExec(df,
1697+
expr = "cast((_1 + _2) as boolean)", expectedNonNullableColumns = Seq("_1", "_2"))
1698+
}
1699+
1700+
test("SPARK-17957: outer join + na.fill") {
1701+
val df1 = Seq((1, 2), (2, 3)).toDF("a", "b")
1702+
val df2 = Seq((2, 5), (3, 4)).toDF("a", "c")
1703+
val joinedDf = df1.join(df2, Seq("a"), "outer").na.fill(0)
1704+
val df3 = Seq((3, 1)).toDF("a", "d")
1705+
checkAnswer(joinedDf.join(df3, "a"), Row(3, 0, 4, 1))
1706+
}
1707+
16381708
test("SPARK-17123: Performing set operations that combine non-scala native types") {
16391709
val dates = Seq(
16401710
(new Date(0), BigDecimal.valueOf(1), new Timestamp(2)),

0 commit comments

Comments
 (0)