Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,13 @@ case class FilterExec(condition: Expression, child: SparkPlan)

// Split out all the IsNotNulls from condition.
private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition {
case IsNotNull(a: NullIntolerant) if a.references.subsetOf(child.outputSet) => true
case IsNotNull(a) => isNullIntolerant(a) && a.references.subsetOf(child.outputSet)
case _ => false
}

// If one expression and its children are null intolerant, it is null intolerant.
private def isNullIntolerant(expr: Expression): Boolean = expr match {
case e: NullIntolerant => e.children.forall(isNullIntolerant)
case _ => false
}

Expand Down
74 changes: 72 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ import org.scalatest.Matchers._

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project, Union}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Project, Union}
import org.apache.spark.sql.execution.{FilterExec, QueryExecution}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange}
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -1617,6 +1617,76 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
}

private def verifyNullabilityInFilterExec(
df: DataFrame,
expr: String,
expectedNonNullableColumns: Seq[String]): Unit = {
val dfWithFilter = df.where(s"isnotnull($expr)").selectExpr(expr)
// In the logical plan, all the output columns of input dataframe are nullable
dfWithFilter.queryExecution.optimizedPlan.collect {
case e: Filter => assert(e.output.forall(_.nullable))
}

dfWithFilter.queryExecution.executedPlan.collect {
// When the child expression in isnotnull is null-intolerant (i.e. any null input will
// result in null output), the involved columns are converted to not nullable;
// otherwise, no change should be made.
case e: FilterExec =>
assert(e.output.forall { o =>
if (expectedNonNullableColumns.contains(o.name)) !o.nullable else o.nullable
})
}
}

test("SPARK-17957: no change on nullability in FilterExec output") {
val df = sparkContext.parallelize(Seq(
null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3),
new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer],
new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF()

verifyNullabilityInFilterExec(df,
expr = "Rand()", expectedNonNullableColumns = Seq.empty[String])
verifyNullabilityInFilterExec(df,
expr = "coalesce(_1, _2)", expectedNonNullableColumns = Seq.empty[String])
verifyNullabilityInFilterExec(df,
expr = "coalesce(_1, 0) + Rand()", expectedNonNullableColumns = Seq.empty[String])
verifyNullabilityInFilterExec(df,
expr = "cast(coalesce(cast(coalesce(_1, _2) as double), 0.0) as int)",
expectedNonNullableColumns = Seq.empty[String])
}

test("SPARK-17957: set nullability to false in FilterExec output") {
val df = sparkContext.parallelize(Seq(
null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3),
new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer],
new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF()

verifyNullabilityInFilterExec(df,
expr = "_1 + _2 * 3", expectedNonNullableColumns = Seq("_1", "_2"))
verifyNullabilityInFilterExec(df,
expr = "_1 + _2", expectedNonNullableColumns = Seq("_1", "_2"))
verifyNullabilityInFilterExec(df,
expr = "_1", expectedNonNullableColumns = Seq("_1"))
// `constructIsNotNullConstraints` infers the IsNotNull(_2) from IsNotNull(_2 + Rand())
// Thus, we are able to set nullability of _2 to false.
// If IsNotNull(_2) is not given from `constructIsNotNullConstraints`, the impl of
// isNullIntolerant in `FilterExec` needs an update for more advanced inference.
verifyNullabilityInFilterExec(df,
expr = "_2 + Rand()", expectedNonNullableColumns = Seq("_2"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not work in current approach. It works now because we infer redundant IsNotNull constraints. E.g., if Filter has a constraint IsNotNull(_2 + Rand()), we will infer another IsNotNull(_2) from it. Your approach is working on IsNotNull(_2) to decide _2 is non-nullable column, not IsNotNull(_2 + Rand()).

I submitted another PR #15653 for redundant IsNotNull constraints. But I am not sure if we want to fix it since it doesn't affect correctness. I left that to @cloud-fan or @hvanhovell to decide it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I already explained why my current solution works in my previous statement. Personally, I like simple code, which is easy to understand and maintain, especially when it can cover all the cases. Result correctness and code maintainability are alwasy more important.

If constructIsNotNullConstraints is changed by somebody else (i.e., it does not provide the expected IsNotNull constraints), the test cases added by this PR will fail. Then, we can modify the codes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, so I said I will left that to @cloud-fan or others to decide...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with the simplicity argument but can you please add a comment here explaining why this particular case is working due to the null inference rule?

verifyNullabilityInFilterExec(df,
expr = "_2 * 3 + coalesce(_1, 0)", expectedNonNullableColumns = Seq("_2"))
verifyNullabilityInFilterExec(df,
expr = "cast((_1 + _2) as boolean)", expectedNonNullableColumns = Seq("_1", "_2"))
}

test("SPARK-17957: outer join + na.fill") {
val df1 = Seq((1, 2), (2, 3)).toDF("a", "b")
val df2 = Seq((2, 5), (3, 4)).toDF("a", "c")
val joinedDf = df1.join(df2, Seq("a"), "outer").na.fill(0)
val df3 = Seq((3, 1)).toDF("a", "d")
checkAnswer(joinedDf.join(df3, "a"), Row(3, 0, 4, 1))
}

test("SPARK-17123: Performing set operations that combine non-scala native types") {
val dates = Seq(
(new Date(0), BigDecimal.valueOf(1), new Timestamp(2)),
Expand Down