Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix Constant Folding Bugs & Add More Unittests
  • Loading branch information
chenghao-intel committed Apr 30, 2014
commit 27ea3d7364017b7f647b04f90f8f3050d236f308
Original file line number Diff line number Diff line change
Expand Up @@ -114,37 +114,37 @@ package object dsl {
def attr = analysis.UnresolvedAttribute(s)

/** Creates a new AttributeReference of type boolean */
def boolean = AttributeReference(s, BooleanType, nullable = false)()
def boolean = AttributeReference(s, BooleanType, nullable = true)()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the rationale for changing all of these?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think "def boolean" / "def string" should return an Attribute with nullability = true, that does not harm for the correctness, otherwise, it may bring a wrong hint for the optimization of rule NullPropagation.

For example:

val row = new GenericRow(Array[Any] ("a", null))
val c1 = 'a.string.at(0) 
val c2 = 'b.string.at(1) // nullable should be true, otherwise does't reflect the real situation.
assert(evaluate(IsNull(c2), row) == true)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point.


/** Creates a new AttributeReference of type byte */
def byte = AttributeReference(s, ByteType, nullable = false)()
def byte = AttributeReference(s, ByteType, nullable = true)()

/** Creates a new AttributeReference of type short */
def short = AttributeReference(s, ShortType, nullable = false)()
def short = AttributeReference(s, ShortType, nullable = true)()

/** Creates a new AttributeReference of type int */
def int = AttributeReference(s, IntegerType, nullable = false)()
def int = AttributeReference(s, IntegerType, nullable = true)()

/** Creates a new AttributeReference of type long */
def long = AttributeReference(s, LongType, nullable = false)()
def long = AttributeReference(s, LongType, nullable = true)()

/** Creates a new AttributeReference of type float */
def float = AttributeReference(s, FloatType, nullable = false)()
def float = AttributeReference(s, FloatType, nullable = true)()

/** Creates a new AttributeReference of type double */
def double = AttributeReference(s, DoubleType, nullable = false)()
def double = AttributeReference(s, DoubleType, nullable = true)()

/** Creates a new AttributeReference of type string */
def string = AttributeReference(s, StringType, nullable = false)()
def string = AttributeReference(s, StringType, nullable = true)()

/** Creates a new AttributeReference of type decimal */
def decimal = AttributeReference(s, DecimalType, nullable = false)()
def decimal = AttributeReference(s, DecimalType, nullable = true)()

/** Creates a new AttributeReference of type timestamp */
def timestamp = AttributeReference(s, TimestampType, nullable = false)()
def timestamp = AttributeReference(s, TimestampType, nullable = true)()

/** Creates a new AttributeReference of type binary */
def binary = AttributeReference(s, BinaryType, nullable = false)()
def binary = AttributeReference(s, BinaryType, nullable = true)()
}

implicit class DslAttribute(a: AttributeReference) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,6 @@ abstract class Expression extends TreeNode[Expression] {
}
}

/**
* Root class for rewritten 2 operands UDF expression. By default, we assume it produces Null if
* either one of its operands is null. Exceptional case requires to update the optimization rule
* at [[optimizer.ConstantFolding ConstantFolding]]
*/
abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know I said this was maybe an okay idea, but looking at this closer I'm worried that the semantics here are too subtle, even with the added documentation. For example, I'm pretty sure we are introducing a bunch of bugs by making this existing class the marker for "null if children are null". This is because BinaryPredicate inherits from BinaryExpression but predicates are not always null if one of their children is null.

Instead maybe in the Rule should more explicitly name classes that can be null simplified. For example, I think it is safe to do it on BinaryArithmetic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true, adding document restrictions for BinaryExpression / UnaryExpression is quite confusing and error-prone. Besides the BinaryArithmetic, also the rewrite UDFs RLike / Cast / BinaryComparison etc.will be considered, too.

Probably enumerate all of the expression types in this rule makes more sense.

self: Product =>

Expand All @@ -243,11 +238,6 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression]
self: Product =>
}

/**
* Root class for rewritten single operand UDF expression. By default, we assume it produces Null
* if its operand is null. Exceptional case requires to update the optimization rule
* at [[optimizer.ConstantFolding ConstantFolding]]
*/
abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same concerns as above.

self: Product =>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,33 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
override def toString = s"$child[$ordinal]"

override def eval(input: Row): Any = {
if (child.dataType.isInstanceOf[ArrayType]) {
val baseValue = child.eval(input).asInstanceOf[Seq[_]]
val o = ordinal.eval(input).asInstanceOf[Int]
if (baseValue == null) {
null
} else if (o >= baseValue.size || o < 0) {
null
} else {
baseValue(o)
}
val value = child.eval(input)
if(value == null) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Space after if, below too.

null
} else {
val baseValue = child.eval(input).asInstanceOf[Map[Any, _]]
val key = ordinal.eval(input)
if (baseValue == null) {
if(key == null) {
null
} else {
baseValue.get(key).orNull
if (child.dataType.isInstanceOf[ArrayType]) {
val baseValue = value.asInstanceOf[Seq[_]]
val o = key.asInstanceOf[Int]
if (baseValue == null) {
null
} else if (o >= baseValue.size || o < 0) {
null
} else {
baseValue(o)
}
} else {
val baseValue = value.asInstanceOf[Map[Any, _]]
val key = ordinal.eval(input)
if (baseValue == null) {
null
} else {
baseValue.get(key).orNull
}
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ object NullPropagation extends Rule[LogicalPlan] {
case q: LogicalPlan => q transformExpressionsUp {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want Up? I believe this means we are going to call evaluate on each foldable node working up instead of just calling it once at the top.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expression.foldable is ok by traveling from top to bottom, while null propagation is opposite. I've put them into different rule objects (ConstantFolding & NullPropagation).

// Skip redundant folding of literals.
case l: Literal => l
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no need to skip literals since none of the conditions below can ever match a raw literal.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking if put the literal matching in the beginning, maybe helpful avoid the further pattern matching of the rest rules. Just a tiny performance optimization for Literal.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By that logic it would be an optimization to skip any class that won't match the cases below. Why is Literal a special case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same as the rule ConstantFolding, NullPropagation won't do any transformation for Literal.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but in the case of ConstantFolding the subsequent pattern will match Literal, since a Literal is technically foldable. Matching the next pattern causes the rule to invoke the expression evaluator and create an identical, wasted object.

In NullPropogation, a Literal will not match any of the later rules. So in essence you are second guessing the code generated by the pattern matcher. While there may be extreme cases where that is required for performance, I don't think this is one of them.

case e @ Count(Literal(null, _)) => Literal(null, e.dataType)
case e @ Count(Literal(null, _)) => Literal(0, e.dataType)
case e @ Sum(Literal(null, _)) => Literal(null, e.dataType)
case e @ Average(Literal(null, _)) => Literal(null, e.dataType)
case e @ IsNull(c @ Rand) => Literal(false, BooleanType)
case e @ IsNotNull(c @ Rand) => Literal(true, BooleanType)
case e @ IsNull(c) if c.nullable == false => Literal(false, BooleanType)
case e @ IsNotNull(c) if c.nullable == false => Literal(true, BooleanType)
case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType)
case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType)
case e @ GetField(Literal(null, _), _) => Literal(null, e.dataType)
Expand All @@ -122,13 +122,32 @@ object NullPropagation extends Rule[LogicalPlan] {
case Literal(candidate, _) if(candidate == v) => true
case _ => false
})) => Literal(true, BooleanType)
// Put exceptional cases(Unary & Binary Expression if it doesn't produce null with constant
// null operand) before here.
case e: UnaryExpression => e.child match {
case e: UnaryMinus => e.child match {
case Literal(null, _) => Literal(null, e.dataType)
case _ => e
}
case e: BinaryExpression => e.children match {
case e: Cast => e.child match {
case Literal(null, _) => Literal(null, e.dataType)
case _ => e
}
case e: Not => e.child match {
case Literal(null, _) => Literal(null, e.dataType)
case _ => e
}
case e: And => e // leave it for BooleanSimplification
case e: Or => e // leave it for BooleanSimplification
// Put exceptional cases above
case e: BinaryArithmetic => e.children match {
case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
case left :: Literal(null, _) :: Nil => Literal(null, e.dataType)
case _ => e
}
case e: BinaryPredicate => e.children match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you match on BinaryComparison instead of BinaryPredicate you won't need to skip And and Or above, which seems a little clearer.

case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
case left :: Literal(null, _) :: Nil => Literal(null, e.dataType)
case _ => e
}
case e: StringRegexExpression => e.children match {
case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
case left :: Literal(null, _) :: Nil => Literal(null, e.dataType)
case _ => e
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class ExpressionEvaluationSuite extends FunSuite {

test("LIKE literal Regular Expression") {
checkEvaluation(Literal(null, StringType).like("a"), null)
checkEvaluation(Literal("a", StringType).like(Literal(null, StringType)), null)
checkEvaluation(Literal(null, StringType).like(Literal(null, StringType)), null)
checkEvaluation("abdef" like "abdef", true)
checkEvaluation("a_%b" like "a\\__b", true)
Expand Down Expand Up @@ -157,9 +158,14 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation("abc" like regEx, true, new GenericRow(Array[Any]("a%")))
checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("b%")))
checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("bc%")))

checkEvaluation(Literal(null, StringType) like regEx, null, new GenericRow(Array[Any]("bc%")))
}

test("RLIKE literal Regular Expression") {
checkEvaluation(Literal(null, StringType) rlike "abdef", null)
checkEvaluation("abdef" rlike Literal(null, StringType), null)
checkEvaluation(Literal(null, StringType) rlike Literal(null, StringType), null)
checkEvaluation("abdef" rlike "abdef", true)
checkEvaluation("abbbbc" rlike "a.*c", true)

Expand Down Expand Up @@ -244,17 +250,19 @@ class ExpressionEvaluationSuite extends FunSuite {

intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)}

assert(("abcdef" cast StringType).nullable === false)
assert(("abcdef" cast BinaryType).nullable === false)
assert(("abcdef" cast BooleanType).nullable === false)
assert(("abcdef" cast TimestampType).nullable === true)
assert(("abcdef" cast LongType).nullable === true)
assert(("abcdef" cast IntegerType).nullable === true)
assert(("abcdef" cast ShortType).nullable === true)
assert(("abcdef" cast ByteType).nullable === true)
assert(("abcdef" cast DecimalType).nullable === true)
assert(("abcdef" cast DoubleType).nullable === true)
assert(("abcdef" cast FloatType).nullable === true)
checkEvaluation(("abcdef" cast StringType).nullable, false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure you want to change these. nullable returns a Boolean, not an Expression. This code is only compiling because the Boolean values are getting implicitly converted to a Literal.

checkEvaluation(("abcdef" cast BinaryType).nullable,false)
checkEvaluation(("abcdef" cast BooleanType).nullable, false)
checkEvaluation(("abcdef" cast TimestampType).nullable, true)
checkEvaluation(("abcdef" cast LongType).nullable, true)
checkEvaluation(("abcdef" cast IntegerType).nullable, true)
checkEvaluation(("abcdef" cast ShortType).nullable, true)
checkEvaluation(("abcdef" cast ByteType).nullable, true)
checkEvaluation(("abcdef" cast DecimalType).nullable, true)
checkEvaluation(("abcdef" cast DoubleType).nullable, true)
checkEvaluation(("abcdef" cast FloatType).nullable, true)

checkEvaluation(Cast(Literal(null, IntegerType), ShortType), null)
}

test("timestamp") {
Expand Down Expand Up @@ -285,5 +293,108 @@ class ExpressionEvaluationSuite extends FunSuite {
// A test for higher precision than millis
checkEvaluation(Cast(Cast(0.00000001, TimestampType), DoubleType), 0.00000001)
}

test("null checking") {
val row = new GenericRow(Array[Any]("^Ba*n", null, true, null))
val c1 = 'a.string.at(0)
val c2 = 'a.string.at(1)
val c3 = 'a.boolean.at(2)
val c4 = 'a.boolean.at(3)

checkEvaluation(IsNull(c1), false, row)
checkEvaluation(IsNotNull(c1), true, row)

checkEvaluation(IsNull(c2), true, row)
checkEvaluation(IsNotNull(c2), false, row)

checkEvaluation(IsNull(Literal(1, ShortType)), false)
checkEvaluation(IsNotNull(Literal(1, ShortType)), true)

checkEvaluation(IsNull(Literal(null, ShortType)), true)
checkEvaluation(IsNotNull(Literal(null, ShortType)), false)

checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row)
checkEvaluation(Coalesce(Literal(null, StringType) :: Nil), null, row)
checkEvaluation(Coalesce(Literal(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row)

checkEvaluation(If(c3, Literal("a", StringType), Literal("b", StringType)), "a", row)
checkEvaluation(If(c3, c1, c2), "^Ba*n", row)
checkEvaluation(If(c4, c2, c1), "^Ba*n", row)
checkEvaluation(If(Literal(null, BooleanType), c2, c1), "^Ba*n", row)
checkEvaluation(If(Literal(true, BooleanType), c1, c2), "^Ba*n", row)
checkEvaluation(If(Literal(false, BooleanType), c2, c1), "^Ba*n", row)
checkEvaluation(If(Literal(false, BooleanType),
Literal("a", StringType), Literal("b", StringType)), "b", row)

checkEvaluation(In(c1, c1 :: c2 :: Nil), true, row)
checkEvaluation(In(Literal("^Ba*n", StringType),
Literal("^Ba*n", StringType) :: Nil), true, row)
checkEvaluation(In(Literal("^Ba*n", StringType),
Literal("^Ba*n", StringType) :: c2 :: Nil), true, row)
}

test("complex type") {
val row = new GenericRow(Array[Any](
"^Ba*n", // 0
null.asInstanceOf[String], // 1
new GenericRow(Array[Any]("aa", "bb")), // 2
Map("aa"->"bb"), // 3
Seq("aa", "bb") // 4
))

val typeS = StructType(
StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil
)
val typeMap = MapType(StringType, StringType)
val typeArray = ArrayType(StringType)

checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()),
Literal("aa")), "bb", row)
checkEvaluation(GetItem(Literal(null, typeMap), Literal("aa")), null, row)
checkEvaluation(GetItem(Literal(null, typeMap), Literal(null, StringType)), null, row)
checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()),
Literal(null, StringType)), null, row)

checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()),
Literal(1)), "bb", row)
checkEvaluation(GetItem(Literal(null, typeArray), Literal(1)), null, row)
checkEvaluation(GetItem(Literal(null, typeArray), Literal(null, IntegerType)), null, row)
checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()),
Literal(null, IntegerType)), null, row)

checkEvaluation(GetField(BoundReference(2, AttributeReference("c", typeS)()), "a"), "aa", row)
checkEvaluation(GetField(Literal(null, typeS), "a"), null, row)
}

test("arithmetic") {
val row = new GenericRow(Array[Any](1, 2, 3, null))
val c1 = 'a.int.at(0)
val c2 = 'a.int.at(1)
val c3 = 'a.int.at(2)
val c4 = 'a.int.at(3)

checkEvaluation(UnaryMinus(c1), -1, row)
checkEvaluation(UnaryMinus(Literal(100, IntegerType)), -100)

checkEvaluation(Add(c1, c4), null, row)
checkEvaluation(Add(c1, c2), 3, row)
checkEvaluation(Add(c1, Literal(null, IntegerType)), null, row)
checkEvaluation(Add(Literal(null, IntegerType), c2), null, row)
checkEvaluation(Add(Literal(null, IntegerType), Literal(null, IntegerType)), null, row)
}

test("BinaryComparison") {
val row = new GenericRow(Array[Any](1, 2, 3, null))
val c1 = 'a.int.at(0)
val c2 = 'a.int.at(1)
val c3 = 'a.int.at(2)
val c4 = 'a.int.at(3)

checkEvaluation(LessThan(c1, c4), null, row)
checkEvaluation(LessThan(c1, c2), true, row)
checkEvaluation(LessThan(c1, Literal(null, IntegerType)), null, row)
checkEvaluation(LessThan(Literal(null, IntegerType), c2), null, row)
checkEvaluation(LessThan(Literal(null, IntegerType), Literal(null, IntegerType)), null, row)
}
}