Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add more test cases
  • Loading branch information
gatorsmile committed Oct 19, 2016
commit ce418f9ff0ddcc2312f338084899fd261a7875ee
58 changes: 40 additions & 18 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1616,40 +1616,62 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
}

private def verifyNullabilityInFilterExec(expr: String, isNullIntolerant: Boolean): Unit = {
val df = sparkContext.parallelize(Seq(
null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3),
new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer],
new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF("a", "b")

private def verifyNullabilityInFilterExec(
df: DataFrame,
expr: String,
expectedNullableColumns: Seq[String]): Unit = {
val dfWithFilter = df.where(s"isnotnull($expr)").selectExpr(expr)
// In the logical plan, all the output columns of input dataframe are nullable
dfWithFilter.queryExecution.optimizedPlan.collect {
// In the logical plan, all the output columns are nullable
case e: Filter => assert(e.output.forall(_.nullable))
}

dfWithFilter.queryExecution.executedPlan.collect {
// When the child expression in isnotnull is null-intolerant (i.e. any null input will
// result in null output), the columns are converted to not nullable; Otherwise, no change
// should be made.
// result in null output), the involved columns are converted to not nullable;
// otherwise, no change should be made.
case e: FilterExec =>
assert(e.output.forall(o => if (isNullIntolerant) !o.nullable else o.nullable))
assert(e.output.forall { o =>
if (expectedNullableColumns.contains(o.name)) !o.nullable else o.nullable
})
}
}

test("SPARK-17957: no change on nullability in FilterExec output") {
verifyNullabilityInFilterExec("coalesce(a, b)", isNullIntolerant = false)
val df = sparkContext.parallelize(Seq(
null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3),
new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer],
new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF()

verifyNullabilityInFilterExec(
"cast(coalesce(cast(coalesce(a, b) as double), 0.0) as int)", isNullIntolerant = false)
verifyNullabilityInFilterExec(df,
expr = "Rand()", expectedNullableColumns = Seq.empty[String])
verifyNullabilityInFilterExec(df,
expr = "coalesce(_1, _2)", expectedNullableColumns = Seq.empty[String])
verifyNullabilityInFilterExec(df,
expr = "coalesce(_1, 0) + Rand()", expectedNullableColumns = Seq.empty[String])
verifyNullabilityInFilterExec(df,
expr = "cast(coalesce(cast(coalesce(_1, _2) as double), 0.0) as int)",
expectedNullableColumns = Seq.empty[String])
}

test("SPARK-17957: set nullability to false in FilterExec output") {
verifyNullabilityInFilterExec("a + b * 3", isNullIntolerant = true)

verifyNullabilityInFilterExec("a + b", isNullIntolerant = true)

verifyNullabilityInFilterExec("cast((a + b) as boolean)", isNullIntolerant = true)
val df = sparkContext.parallelize(Seq(
null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3),
new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer],
new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF()

verifyNullabilityInFilterExec(df,
expr = "_1 + _2 * 3", expectedNullableColumns = Seq("_1", "_2"))
verifyNullabilityInFilterExec(df,
expr = "_1 + _2", expectedNullableColumns = Seq("_1", "_2"))
verifyNullabilityInFilterExec(df,
expr = "_1", expectedNullableColumns = Seq("_1"))
verifyNullabilityInFilterExec(df,
expr = "_2 + Rand()", expectedNullableColumns = Seq("_2"))
verifyNullabilityInFilterExec(df,
expr = "_2 * 3 + coalesce(_1, 0)", expectedNullableColumns = Seq("_2"))
verifyNullabilityInFilterExec(df,
expr = "cast((_1 + _2) as boolean)", expectedNullableColumns = Seq("_1", "_2"))
}

test("SPARK-17957: outer join + na.fill") {
Expand Down