Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,14 @@ case class FilterExec(condition: Expression, child: SparkPlan)

// Split out all the IsNotNulls from condition.
private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition {
case IsNotNull(a: NullIntolerant) if a.references.subsetOf(child.outputSet) => true
case IsNotNull(a) => isNullIntolerant(a) && a.references.subsetOf(child.outputSet)
case _ => false
}

// One expression is null intolerant iff it and its children are null intolerant
private def isNullIntolerant(expr: Expression): Boolean = expr match {
case e: NullIntolerant =>
if (e.isInstanceOf[LeafExpression]) true else e.children.forall(isNullIntolerant)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this!

This change is too conservative. Actually we only need to consider a non NullIntolerant expression when it contains the attributes in the output. I think we can do more aggressive way. E.g.,

// Split out all the IsNotNulls from condition.
private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition {
  case IsNotNull(a: NullIntolerant) =>
    isNullIntolerant(a) && a.references.subsetOf(child.outputSet)
  case _ => false
}

private def isNullIntolerant(expr: Expression): Boolean = {
  expr.find { e =>
    !e.isInstanceOf[NullIntolerant] && e.references.subsetOf(child.outputSet)
  }.isEmpty
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized the original code was from your PR. Then, in your above code, why you still need to keep a.references.subsetOf(child.outputSet)? It looks confusing to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even a passed the check of isNullIntolerant, i.e., it has not non NullIntolerant which wraps output attributes. If it doesn't refer to any output attributes, we don't need it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you show me an example?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IsNotNull(Rand() > 0.5)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uh, I see.

First, we definitely need test cases to cover each positive and negative scenario. Previously, we did not have any test case to check the validity of nullability changes. Second, the code needs more comments when the variable/function names are not able to explain the codes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: How about something like this for better readability:

  private def isNullIntolerant(expr: Expression): Boolean = expr match {
    case e: NullIntolerant with LeafExpression => true
    case e: NullIntolerant => e.children.forall(isNullIntolerant)
    case _ => false
  }

Copy link
Member Author

@gatorsmile gatorsmile Nov 3, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forall will return true if the children is empty. Thus, we can remove the first case. Now it becomes simpler. : )

  private def isNullIntolerant(expr: Expression): Boolean = expr match {
    case e: NullIntolerant => e.children.forall(isNullIntolerant)
    case _ => false
  }

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

case _ => false
}

Expand Down
48 changes: 46 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ import org.scalatest.Matchers._

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project, Union}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Project, Union}
import org.apache.spark.sql.execution.{FilterExec, QueryExecution}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange}
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -1615,4 +1615,48 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
qe.assertAnalyzed()
}
}

private def verifyNullabilityInFilterExec(expr: String, isNullIntolerant: Boolean): Unit = {
val df = sparkContext.parallelize(Seq(
null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3),
new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer],
new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF("a", "b")

val dfWithFilter = df.where(s"isnotnull($expr)").selectExpr(expr)
dfWithFilter.queryExecution.optimizedPlan.collect {
// In the logical plan, all the output columns are nullable
case e: Filter => assert(e.output.forall(_.nullable))
}

dfWithFilter.queryExecution.executedPlan.collect {
// When the child expression in isnotnull is null-intolerant (i.e. any null input will
// result in null output), the columns are converted to not nullable; Otherwise, no change
// should be made.
case e: FilterExec =>
assert(e.output.forall(o => if (isNullIntolerant) !o.nullable else o.nullable))
}
}

test("SPARK-17957: no change on nullability in FilterExec output") {
verifyNullabilityInFilterExec("coalesce(a, b)", isNullIntolerant = false)

verifyNullabilityInFilterExec(
"cast(coalesce(cast(coalesce(a, b) as double), 0.0) as int)", isNullIntolerant = false)
}

test("SPARK-17957: set nullability to false in FilterExec output") {
verifyNullabilityInFilterExec("a + b * 3", isNullIntolerant = true)

verifyNullabilityInFilterExec("a + b", isNullIntolerant = true)

verifyNullabilityInFilterExec("cast((a + b) as boolean)", isNullIntolerant = true)
}

test("SPARK-17957: outer join + na.fill") {
val df1 = Seq((1, 2), (2, 3)).toDF("a", "b")
val df2 = Seq((2, 5), (3, 4)).toDF("a", "c")
val joinedDf = df1.join(df2, Seq("a"), "outer").na.fill(0)
val df3 = Seq((3, 1)).toDF("a", "d")
checkAnswer(joinedDf.join(df3, "a"), Row(3, 0, 4, 1))
}
}