Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix basicPhysicalOperators usage.
  • Loading branch information
JoshRosen committed Jun 1, 2019
commit 33b579cc6f22e854c6d03f3f695e05aa56a68efd
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,10 @@ trait PredicateHelper {
}

/**
* Given a sequence of attributes and a filter condition, update attributes' nullability
* when the condition implies that attributes cannot be null.
* Given an IsNotNull expression, returns the IDs of expressions whose not-nullness
* is implied by the IsNotNull expressions.
*/
protected def updateAttributeNullabilityFromNonNullConstraints(
attributes: Seq[Attribute],
condition: Expression): Seq[Attribute] = {
protected def getImpliedNotNullExprIds(isNotNullExpr: IsNotNull): Set[ExprId] = {
// This logic is a little tricky, so we'll use an example to build some intuition.
// Consider the expression IsNotNull(f(g(x), y)). By definition, its child is not null:
// f(g(x), y) is not null
Expand All @@ -142,38 +140,21 @@ trait PredicateHelper {
// By recursively applying this logic, if g is NullIntolerant then x is not null.
// However, if g is NOT NullIntolerant (e.g. if g(null) is non-null) then we cannot
// conclude anything about x's nullability.
def getNonNullAttributes(isNotNull: IsNotNull): Set[ExprId] = {
def getExprIdIfNamed(expr: Expression): Set[ExprId] = expr match {
case ne: NamedExpression => Set(ne.toAttribute.exprId)
case _ => Set.empty
}
// Recurse through the IsNotNull expression's descendants, stopping
// once we encounter a null-tolerant expression.
def getNotNullDescendants(expr: Expression): Set[ExprId] = {
expr.children.map {
case child: NullIntolerant =>
getExprIdIfNamed(child) ++ getNotNullDescendants(child)
case child =>
getExprIdIfNamed(child)
}.reduce(_ ++ _)
}
getExprIdIfNamed(isNotNull) ++ getNotNullDescendants(isNotNull)
}

val notNullAttributes: Set[ExprId] = {
splitConjunctivePredicates(condition)
.collect { case isNotNull: IsNotNull => isNotNull }
.map(getNonNullAttributes)
.reduce(_ ++ _)
def getExprIdIfNamed(expr: Expression): Set[ExprId] = expr match {
case ne: NamedExpression => Set(ne.toAttribute.exprId)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this should be AttributeReference? I couldn't remember offhand how to get ExprIds from arbitrary expressions, hence this hack.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use AttributeSet?

case _ => Set.empty
}

attributes.map { a =>
if (a.nullable && notNullAttributes.contains(a.exprId)) {
a.withNullability(false)
} else {
a
}
// Recurse through the IsNotNull expression's descendants, stopping
// once we encounter a null-tolerant expression.
def getNotNullDescendants(expr: Expression): Set[ExprId] = {
expr.children.map {
case child: NullIntolerant =>
getExprIdIfNamed(child) ++ getNotNullDescendants(child)
case child =>
getExprIdIfNamed(child)
}.reduce(_ ++ _)
}
getExprIdIfNamed(isNotNullExpr) ++ getNotNullDescendants(isNotNullExpr)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,21 @@ case class Generate(
case class Filter(condition: Expression, child: LogicalPlan)
extends OrderPreservingUnaryNode with PredicateHelper {

private val impliedNotNullExprIds: Set[ExprId] = {
splitConjunctivePredicates(condition)
.collect { case isNotNull: IsNotNull => isNotNull }
.map(getImpliedNotNullExprIds)
.reduce(_ ++ _)
}

override def output: Seq[Attribute] = {
updateAttributeNullabilityFromNonNullConstraints(child.output, condition)
child.output.map { a =>
if (a.nullable && impliedNotNullExprIds.contains(a.exprId)) {
a.withNullability(false)
} else {
a
}
}
}

override def maxRows: Option[Long] = child.maxRows
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,23 @@ case class FilterExec(condition: Expression, child: SparkPlan)
// all the variables at the beginning to take advantage of short circuiting.
override def usedInputs: AttributeSet = AttributeSet.empty

// Split out all the IsNotNulls from condition.
private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found the old code here to be slightly confusing because it seemed to be using notNullPreds for two different purposes:

  1. If we see IsNotNull conjuncts in the filter then evaluate them first / earlier because (a) these expressions are cheap to evaluate and may allow for short-circuiting and skipping more expensive expressions, and (b) evaluating these earlier allows other expressions to omit null checks (for example, if we have IsNotNull(x) and x * 100 < 10 then we already implicitly need to null-check x as part of the second expression so we might as well do the explicit null check expression first).
  2. Given that tuples have successfully passed through the filter, we can rely on the presence of IsNotNull checks to default subsequent expressions' null checks to false. For example, let's say we had a .filter().select() which gets compiled into a single whole stage codegen: after tuples have passed through the filter we know that certain fields cannot possibly be null, so we can elide null checks at codegen time by just setting nullable = false in subsequent code.

There might be some subtleties related in (1) related to non-deterministic expressions, but I think that's accounted for further down at the place where we're actually generating the checks.

In the old code, the (notNullPreds, otherPreds) on this line was being used for both purposes: for (1) I think we could simply collect all IsNotNull expressions, but the existing implementation of (2) relied on the additional nullIntolerant / a.references checks in order to be correct.

In this PR, I've separated these two usages: the "update nullability for downstream operators" now uses the more precise condition implemented in getImpliedNotNullExprIds, while the "optimize short-circuiting" simply checks for IsNotNull and ignores child attributes.

case IsNotNull(_) => true
case _ => false
}

private val impliedNotNullExprIds: Set[ExprId] =
notNullPreds.map { case n: IsNotNull => getImpliedNotNullExprIds(n) }.reduce(_ ++ _)

override def output: Seq[Attribute] = {
updateAttributeNullabilityFromNonNullConstraints(child.output, condition)
child.output.map { a =>
if (a.nullable && impliedNotNullExprIds.contains(a.exprId)) {
a.withNullability(false)
} else {
a
}
}
}

override lazy val metrics = Map(
Expand Down Expand Up @@ -133,7 +148,7 @@ case class FilterExec(condition: Expression, child: SparkPlan)

// To generate the predicates we will follow this algorithm.
// For each predicate that is not IsNotNull, we will generate them one by one loading attributes
// as necessary. For each of both attributes, if there is an IsNotNull predicate we will
// as necessary. For each attribute, if there is an IsNotNull predicate we will
// generate that check *before* the predicate. After all of these predicates, we will generate
// the remaining IsNotNull checks that were not part of other predicates.
// This has the property of not doing redundant IsNotNull checks and taking better advantage of
Expand Down Expand Up @@ -172,7 +187,7 @@ case class FilterExec(condition: Expression, child: SparkPlan)
// Reset the isNull to false for the not-null columns, then the followed operators could
// generate better code (remove dead branches).
val resultVars = input.zipWithIndex.map { case (ev, i) =>
if (notNullAttributes.contains(child.output(i).exprId)) {
if (impliedNotNullExprIds.contains(child.output(i).exprId)) {
ev.isNull = FalseLiteral
}
ev
Expand Down