Skip to content
Closed
Prev Previous commit
Next Next commit
Use trait for null intolerant expression.
  • Loading branch information
viirya committed Mar 23, 2016
commit 56ca15fa348d0488ca689f9fec2dd912d0625fc4
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ object Cast {
}

/** Cast the child expression to the target data type. */
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with NullIntolerant {

override def toString: String = s"cast($child as ${dataType.simpleString})"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,6 @@ abstract class Expression extends TreeNode[Expression] {

def nullable: Boolean

/**
* Indicates whether this expression is null intolerant. If this is true,
* then any null input will result in null output).
*/
def nullIntolerant: Boolean = true

def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator))

/** Returns the result of evaluating this expression on a given input Row */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval


case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes {
case class UnaryMinus(child: Expression) extends UnaryExpression
with ExpectsInputTypes with NullIntolerant {

override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)

Expand Down Expand Up @@ -58,7 +59,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp
override def sql: String = s"(-${child.sql})"
}

case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes {
case class UnaryPositive(child: Expression)
extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
override def prettyName: String = "positive"

override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
Expand All @@ -79,7 +81,8 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the absolute value of the numeric value",
extended = "> SELECT _FUNC_('-1');\n1")
case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes {
case class Abs(child: Expression)
extends UnaryExpression with ExpectsInputTypes with NullIntolerant {

override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)

Expand Down Expand Up @@ -123,7 +126,7 @@ private[sql] object BinaryArithmetic {
def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right))
}

case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
case class Add(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant {

override def inputType: AbstractDataType = TypeCollection.NumericAndInterval

Expand Down Expand Up @@ -152,7 +155,8 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
}
}

case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
case class Subtract(left: Expression, right: Expression)
extends BinaryArithmetic with NullIntolerant {

override def inputType: AbstractDataType = TypeCollection.NumericAndInterval

Expand Down Expand Up @@ -181,7 +185,8 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
}
}

case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
case class Multiply(left: Expression, right: Expression)
extends BinaryArithmetic with NullIntolerant {

override def inputType: AbstractDataType = NumericType

Expand All @@ -193,7 +198,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)
}

case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
case class Divide(left: Expression, right: Expression)
extends BinaryArithmetic with NullIntolerant {

override def inputType: AbstractDataType = NumericType

Expand Down Expand Up @@ -255,7 +261,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
}
}

case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
case class Remainder(left: Expression, right: Expression)
extends BinaryArithmetic with NullIntolerant {

override def inputType: AbstractDataType = NumericType

Expand Down Expand Up @@ -429,7 +436,7 @@ case class MinOf(left: Expression, right: Expression)
override def symbol: String = "min"
}

case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant {

override def toString: String = s"pmod($left, $right)"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@ case class IsNaN(child: Expression) extends UnaryExpression

override def nullable: Boolean = false

override def nullIntolerant: Boolean = false

override def eval(input: InternalRow): Any = {
val value = child.eval(input)
if (value == null) {
Expand Down Expand Up @@ -185,8 +183,6 @@ case class NaNvl(left: Expression, right: Expression)
case class IsNull(child: Expression) extends UnaryExpression with Predicate {
override def nullable: Boolean = false

override def nullIntolerant: Boolean = false

override def eval(input: InternalRow): Any = {
child.eval(input) == null
}
Expand All @@ -208,8 +204,6 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate {
case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
override def nullable: Boolean = false

override def nullIntolerant: Boolean = false

override def eval(input: InternalRow): Any = {
child.eval(input) != null
}
Expand All @@ -230,7 +224,6 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
*/
case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate {
override def nullable: Boolean = false
override def nullIntolerant: Boolean = false
override def foldable: Boolean = children.forall(_.foldable)
override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,11 @@ package object expressions {
StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable)))
}
}

/**
* When an expression inherits this, meaning the expression is null intolerant (i.e. any null
* input will result in null output). We will use this information during constructing IsNotNull
* constraints.
*/
trait NullIntolerant
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ trait PredicateHelper {


case class Not(child: Expression)
extends UnaryExpression with Predicate with ImplicitCastInputTypes {
extends UnaryExpression with Predicate with ImplicitCastInputTypes with NullIntolerant {

override def toString: String = s"NOT $child"

Expand Down Expand Up @@ -376,7 +376,8 @@ private[sql] object Equality {
}


case class EqualTo(left: Expression, right: Expression) extends BinaryComparison {
case class EqualTo(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {

override def inputType: AbstractDataType = AnyDataType

Expand Down Expand Up @@ -408,8 +409,6 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp

override def nullable: Boolean = false

override def nullIntolerant: Boolean = false

override def eval(input: InternalRow): Any = {
val input1 = left.eval(input)
val input2 = right.eval(input)
Expand Down Expand Up @@ -443,7 +442,8 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
}


case class LessThan(left: Expression, right: Expression) extends BinaryComparison {
case class LessThan(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {

override def inputType: AbstractDataType = TypeCollection.Ordered

Expand All @@ -455,7 +455,8 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso
}


case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
case class LessThanOrEqual(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {

override def inputType: AbstractDataType = TypeCollection.Ordered

Expand All @@ -467,7 +468,8 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo
}


case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison {
case class GreaterThan(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {

override def inputType: AbstractDataType = TypeCollection.Ordered

Expand All @@ -479,7 +481,8 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar
}


case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
case class GreaterThanOrEqual(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {

override def inputType: AbstractDataType = TypeCollection.Ordered

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,17 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = {
// Currently we only propagate constraints if the condition consists of equality
// and ranges. For all other cases, we return an empty set of constraints
constraints.map(scanNullIntolerantExpr)
.foldLeft(Set.empty[Expression])(_ union _.toSet)
constraints.flatMap(scanNullIntolerantExpr).map(IsNotNull(_))
}

private def scanNullIntolerantExpr(expr: Expression): Set[Expression] = expr match {
case a: Attribute => Set(IsNotNull(a))
case IsNotNull(e) =>
// IsNotNull is null tolerant, but we need to explore for the attributes not null.
scanNullIntolerantExpr(e)
case e: Expression if e.nullIntolerant => e.children.flatMap(scanNullIntolerantExpr).toSet
case _ => Set.empty[Expression]
/**
* Recursively explores the expressions which are null intolerant and returns all attributes
* in these expressions.
*/
private def scanNullIntolerantExpr(expr: Expression): Seq[Attribute] = expr match {
case a: Attribute => Seq(a)
case _: NullIntolerant | _: IsNotNull => expr.children.flatMap(scanNullIntolerantExpr)
case _ => Seq.empty[Attribute]
}

/**
Expand Down