Skip to content
Prev Previous commit
Next Next commit
Add explanation / derivation; refine implementation to handle more ca…
…ses.
  • Loading branch information
JoshRosen committed Jun 1, 2019
commit b950474eccea18353f4dee9652c4ea129ef4b0c4
Original file line number Diff line number Diff line change
Expand Up @@ -127,23 +127,46 @@ trait PredicateHelper {
protected def updateAttributeNullabilityFromNonNullConstraints(
attributes: Seq[Attribute],
condition: Expression): Seq[Attribute] = {
val attributeSet = AttributeSet(attributes)

// If one expression and its children are null intolerant, it is null intolerant.
def isNullIntolerant(expr: Expression): Boolean = expr match {
case e: NullIntolerant => e.children.forall(isNullIntolerant)
case _ => false
// This logic is a little tricky, so we'll use an example to build some intuition.
// Consider the expression IsNotNull(f(g(x), y)). By definition, its child is not null:
// f(g(x), y) is not null
// In addition, if `f` is NullIntolerant then it would be null if either child was null:
// g(x) is null => f(g(x), y) is null
// y is null => f(g(x), y) is null
// Via A => B <=> !B || A, we have:
// g(x) is not null || f(g(x), y) is null
// y is not null || f(g(x), y) is null
// Since we know that f(g(x), y) is not null, we must therefore conclude that
// g(x) is not null
// y is not null
// By recursively applying this logic, if g is NullIntolerant then x is not null.
// However, if g is NOT NullIntolerant (e.g. if g(null) is non-null) then we cannot
// conclude anything about x's nullability.
def getNonNullAttributes(isNotNull: IsNotNull): Set[ExprId] = {
def getExprIdIfNamed(expr: Expression): Set[ExprId] = expr match {
case ne: NamedExpression => Set(ne.toAttribute.exprId)
case _ => Set.empty
}
// Recurse through the IsNotNull expression's descendants, stopping
// once we encounter a null-tolerant expression.
def getNotNullDescendants(expr: Expression): Set[ExprId] = {
expr.children.map {
case child: NullIntolerant =>
getExprIdIfNamed(child) ++ getNotNullDescendants(child)
case child =>
getExprIdIfNamed(child)
}.reduce(_ ++ _)
}
getExprIdIfNamed(isNotNull) ++ getNotNullDescendants(isNotNull)
}

// Split out all the IsNotNulls from condition.
val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition {
case IsNotNull(a) => isNullIntolerant(a) && a.references.subsetOf(attributeSet)
case _ => false
val notNullAttributes: Set[ExprId] = {
splitConjunctivePredicates(condition)
.collect { case isNotNull: IsNotNull => isNotNull }
.map(getNonNullAttributes)
.reduce(_ ++ _)
}

// The columns that will filtered out by `IsNotNull` could be considered as not nullable.
val notNullAttributes = notNullPreds.flatMap(_.references).distinct.map(_.exprId)

attributes.map { a =>
if (a.nullable && notNullAttributes.contains(a.exprId)) {
a.withNullability(false)
Expand Down