-
Notifications
You must be signed in to change notification settings - Fork 29k
[WIP][Spark-SQL] Optimize the Constant Folding for Expression #482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
2645d4f
3c045c7
536c005
9cf0396
543ef9d
9ccefdb
b28e03a
27ea3d7
80f9f18
50444cc
29c8166
68b9fad
2f14b50
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -222,11 +222,6 @@ abstract class Expression extends TreeNode[Expression] { | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * Root class for rewritten 2 operands UDF expression. By default, we assume it produces Null if | ||
| * either one of its operands is null. Exceptional case requires to update the optimization rule | ||
| * at [[optimizer.ConstantFolding ConstantFolding]] | ||
| */ | ||
| abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { | ||
|
||
| self: Product => | ||
|
|
||
|
|
@@ -243,11 +238,6 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] | |
| self: Product => | ||
| } | ||
|
|
||
| /** | ||
| * Root class for rewritten single operand UDF expression. By default, we assume it produces Null | ||
| * if its operand is null. Exceptional case requires to update the optimization rule | ||
| * at [[optimizer.ConstantFolding ConstantFolding]] | ||
| */ | ||
| abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { | ||
|
||
| self: Product => | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,23 +41,33 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { | |
| override def toString = s"$child[$ordinal]" | ||
|
|
||
| override def eval(input: Row): Any = { | ||
| if (child.dataType.isInstanceOf[ArrayType]) { | ||
| val baseValue = child.eval(input).asInstanceOf[Seq[_]] | ||
| val o = ordinal.eval(input).asInstanceOf[Int] | ||
| if (baseValue == null) { | ||
| null | ||
| } else if (o >= baseValue.size || o < 0) { | ||
| null | ||
| } else { | ||
| baseValue(o) | ||
| } | ||
| val value = child.eval(input) | ||
| if(value == null) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Space after |
||
| null | ||
| } else { | ||
| val baseValue = child.eval(input).asInstanceOf[Map[Any, _]] | ||
| val key = ordinal.eval(input) | ||
| if (baseValue == null) { | ||
| if(key == null) { | ||
| null | ||
| } else { | ||
| baseValue.get(key).orNull | ||
| if (child.dataType.isInstanceOf[ArrayType]) { | ||
| val baseValue = value.asInstanceOf[Seq[_]] | ||
| val o = key.asInstanceOf[Int] | ||
| if (baseValue == null) { | ||
| null | ||
| } else if (o >= baseValue.size || o < 0) { | ||
| null | ||
| } else { | ||
| baseValue(o) | ||
| } | ||
| } else { | ||
| val baseValue = value.asInstanceOf[Map[Any, _]] | ||
| val key = ordinal.eval(input) | ||
| if (baseValue == null) { | ||
| null | ||
| } else { | ||
| baseValue.get(key).orNull | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -96,11 +96,11 @@ object NullPropagation extends Rule[LogicalPlan] { | |
| case q: LogicalPlan => q transformExpressionsUp { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| // Skip redundant folding of literals. | ||
| case l: Literal => l | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is no need to skip literals since none of the conditions below can ever match a raw literal.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was thinking if put the literal matching in the beginning, maybe helpful avoid the further pattern matching of the rest rules. Just a tiny performance optimization for Literal.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By that logic it would be an optimization to skip any class that won't match the cases below. Why is Literal a special case?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The same as the rule ConstantFolding, NullPropagation won't do any transformation for Literal.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, but in the case of In |
||
| case e @ Count(Literal(null, _)) => Literal(null, e.dataType) | ||
| case e @ Count(Literal(null, _)) => Literal(0, e.dataType) | ||
| case e @ Sum(Literal(null, _)) => Literal(null, e.dataType) | ||
| case e @ Average(Literal(null, _)) => Literal(null, e.dataType) | ||
| case e @ IsNull(c @ Rand) => Literal(false, BooleanType) | ||
| case e @ IsNotNull(c @ Rand) => Literal(true, BooleanType) | ||
| case e @ IsNull(c) if c.nullable == false => Literal(false, BooleanType) | ||
| case e @ IsNotNull(c) if c.nullable == false => Literal(true, BooleanType) | ||
| case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType) | ||
| case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType) | ||
| case e @ GetField(Literal(null, _), _) => Literal(null, e.dataType) | ||
|
|
@@ -122,13 +122,32 @@ object NullPropagation extends Rule[LogicalPlan] { | |
| case Literal(candidate, _) if(candidate == v) => true | ||
| case _ => false | ||
| })) => Literal(true, BooleanType) | ||
| // Put exceptional cases(Unary & Binary Expression if it doesn't produce null with constant | ||
| // null operand) before here. | ||
| case e: UnaryExpression => e.child match { | ||
| case e: UnaryMinus => e.child match { | ||
| case Literal(null, _) => Literal(null, e.dataType) | ||
| case _ => e | ||
| } | ||
| case e: BinaryExpression => e.children match { | ||
| case e: Cast => e.child match { | ||
| case Literal(null, _) => Literal(null, e.dataType) | ||
| case _ => e | ||
| } | ||
| case e: Not => e.child match { | ||
| case Literal(null, _) => Literal(null, e.dataType) | ||
| case _ => e | ||
| } | ||
| case e: And => e // leave it for BooleanSimplification | ||
| case e: Or => e // leave it for BooleanSimplification | ||
| // Put exceptional cases above | ||
| case e: BinaryArithmetic => e.children match { | ||
| case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) | ||
| case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) | ||
| case _ => e | ||
| } | ||
| case e: BinaryPredicate => e.children match { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you match on |
||
| case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) | ||
| case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) | ||
| case _ => e | ||
| } | ||
| case e: StringRegexExpression => e.children match { | ||
| case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) | ||
| case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) | ||
| case _ => e | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -129,6 +129,7 @@ class ExpressionEvaluationSuite extends FunSuite { | |
|
|
||
| test("LIKE literal Regular Expression") { | ||
| checkEvaluation(Literal(null, StringType).like("a"), null) | ||
| checkEvaluation(Literal("a", StringType).like(Literal(null, StringType)), null) | ||
| checkEvaluation(Literal(null, StringType).like(Literal(null, StringType)), null) | ||
| checkEvaluation("abdef" like "abdef", true) | ||
| checkEvaluation("a_%b" like "a\\__b", true) | ||
|
|
@@ -157,9 +158,14 @@ class ExpressionEvaluationSuite extends FunSuite { | |
| checkEvaluation("abc" like regEx, true, new GenericRow(Array[Any]("a%"))) | ||
| checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("b%"))) | ||
| checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("bc%"))) | ||
|
|
||
| checkEvaluation(Literal(null, StringType) like regEx, null, new GenericRow(Array[Any]("bc%"))) | ||
| } | ||
|
|
||
| test("RLIKE literal Regular Expression") { | ||
| checkEvaluation(Literal(null, StringType) rlike "abdef", null) | ||
| checkEvaluation("abdef" rlike Literal(null, StringType), null) | ||
| checkEvaluation(Literal(null, StringType) rlike Literal(null, StringType), null) | ||
| checkEvaluation("abdef" rlike "abdef", true) | ||
| checkEvaluation("abbbbc" rlike "a.*c", true) | ||
|
|
||
|
|
@@ -244,17 +250,19 @@ class ExpressionEvaluationSuite extends FunSuite { | |
|
|
||
| intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)} | ||
|
|
||
| assert(("abcdef" cast StringType).nullable === false) | ||
| assert(("abcdef" cast BinaryType).nullable === false) | ||
| assert(("abcdef" cast BooleanType).nullable === false) | ||
| assert(("abcdef" cast TimestampType).nullable === true) | ||
| assert(("abcdef" cast LongType).nullable === true) | ||
| assert(("abcdef" cast IntegerType).nullable === true) | ||
| assert(("abcdef" cast ShortType).nullable === true) | ||
| assert(("abcdef" cast ByteType).nullable === true) | ||
| assert(("abcdef" cast DecimalType).nullable === true) | ||
| assert(("abcdef" cast DoubleType).nullable === true) | ||
| assert(("abcdef" cast FloatType).nullable === true) | ||
| checkEvaluation(("abcdef" cast StringType).nullable, false) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure you want to change these. |
||
| checkEvaluation(("abcdef" cast BinaryType).nullable,false) | ||
| checkEvaluation(("abcdef" cast BooleanType).nullable, false) | ||
| checkEvaluation(("abcdef" cast TimestampType).nullable, true) | ||
| checkEvaluation(("abcdef" cast LongType).nullable, true) | ||
| checkEvaluation(("abcdef" cast IntegerType).nullable, true) | ||
| checkEvaluation(("abcdef" cast ShortType).nullable, true) | ||
| checkEvaluation(("abcdef" cast ByteType).nullable, true) | ||
| checkEvaluation(("abcdef" cast DecimalType).nullable, true) | ||
| checkEvaluation(("abcdef" cast DoubleType).nullable, true) | ||
| checkEvaluation(("abcdef" cast FloatType).nullable, true) | ||
|
|
||
| checkEvaluation(Cast(Literal(null, IntegerType), ShortType), null) | ||
| } | ||
|
|
||
| test("timestamp") { | ||
|
|
@@ -285,5 +293,108 @@ class ExpressionEvaluationSuite extends FunSuite { | |
| // A test for higher precision than millis | ||
| checkEvaluation(Cast(Cast(0.00000001, TimestampType), DoubleType), 0.00000001) | ||
| } | ||
|
|
||
| test("null checking") { | ||
| val row = new GenericRow(Array[Any]("^Ba*n", null, true, null)) | ||
| val c1 = 'a.string.at(0) | ||
| val c2 = 'a.string.at(1) | ||
| val c3 = 'a.boolean.at(2) | ||
| val c4 = 'a.boolean.at(3) | ||
|
|
||
| checkEvaluation(IsNull(c1), false, row) | ||
| checkEvaluation(IsNotNull(c1), true, row) | ||
|
|
||
| checkEvaluation(IsNull(c2), true, row) | ||
| checkEvaluation(IsNotNull(c2), false, row) | ||
|
|
||
| checkEvaluation(IsNull(Literal(1, ShortType)), false) | ||
| checkEvaluation(IsNotNull(Literal(1, ShortType)), true) | ||
|
|
||
| checkEvaluation(IsNull(Literal(null, ShortType)), true) | ||
| checkEvaluation(IsNotNull(Literal(null, ShortType)), false) | ||
|
|
||
| checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row) | ||
| checkEvaluation(Coalesce(Literal(null, StringType) :: Nil), null, row) | ||
| checkEvaluation(Coalesce(Literal(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row) | ||
|
|
||
| checkEvaluation(If(c3, Literal("a", StringType), Literal("b", StringType)), "a", row) | ||
| checkEvaluation(If(c3, c1, c2), "^Ba*n", row) | ||
| checkEvaluation(If(c4, c2, c1), "^Ba*n", row) | ||
| checkEvaluation(If(Literal(null, BooleanType), c2, c1), "^Ba*n", row) | ||
| checkEvaluation(If(Literal(true, BooleanType), c1, c2), "^Ba*n", row) | ||
| checkEvaluation(If(Literal(false, BooleanType), c2, c1), "^Ba*n", row) | ||
| checkEvaluation(If(Literal(false, BooleanType), | ||
| Literal("a", StringType), Literal("b", StringType)), "b", row) | ||
|
|
||
| checkEvaluation(In(c1, c1 :: c2 :: Nil), true, row) | ||
| checkEvaluation(In(Literal("^Ba*n", StringType), | ||
| Literal("^Ba*n", StringType) :: Nil), true, row) | ||
| checkEvaluation(In(Literal("^Ba*n", StringType), | ||
| Literal("^Ba*n", StringType) :: c2 :: Nil), true, row) | ||
| } | ||
|
|
||
| test("complex type") { | ||
| val row = new GenericRow(Array[Any]( | ||
| "^Ba*n", // 0 | ||
| null.asInstanceOf[String], // 1 | ||
| new GenericRow(Array[Any]("aa", "bb")), // 2 | ||
| Map("aa"->"bb"), // 3 | ||
| Seq("aa", "bb") // 4 | ||
| )) | ||
|
|
||
| val typeS = StructType( | ||
| StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil | ||
| ) | ||
| val typeMap = MapType(StringType, StringType) | ||
| val typeArray = ArrayType(StringType) | ||
|
|
||
| checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()), | ||
| Literal("aa")), "bb", row) | ||
| checkEvaluation(GetItem(Literal(null, typeMap), Literal("aa")), null, row) | ||
| checkEvaluation(GetItem(Literal(null, typeMap), Literal(null, StringType)), null, row) | ||
| checkEvaluation(GetItem(BoundReference(3, AttributeReference("c", typeMap)()), | ||
| Literal(null, StringType)), null, row) | ||
|
|
||
| checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()), | ||
| Literal(1)), "bb", row) | ||
| checkEvaluation(GetItem(Literal(null, typeArray), Literal(1)), null, row) | ||
| checkEvaluation(GetItem(Literal(null, typeArray), Literal(null, IntegerType)), null, row) | ||
| checkEvaluation(GetItem(BoundReference(4, AttributeReference("c", typeArray)()), | ||
| Literal(null, IntegerType)), null, row) | ||
|
|
||
| checkEvaluation(GetField(BoundReference(2, AttributeReference("c", typeS)()), "a"), "aa", row) | ||
| checkEvaluation(GetField(Literal(null, typeS), "a"), null, row) | ||
| } | ||
|
|
||
| test("arithmetic") { | ||
| val row = new GenericRow(Array[Any](1, 2, 3, null)) | ||
| val c1 = 'a.int.at(0) | ||
| val c2 = 'a.int.at(1) | ||
| val c3 = 'a.int.at(2) | ||
| val c4 = 'a.int.at(3) | ||
|
|
||
| checkEvaluation(UnaryMinus(c1), -1, row) | ||
| checkEvaluation(UnaryMinus(Literal(100, IntegerType)), -100) | ||
|
|
||
| checkEvaluation(Add(c1, c4), null, row) | ||
| checkEvaluation(Add(c1, c2), 3, row) | ||
| checkEvaluation(Add(c1, Literal(null, IntegerType)), null, row) | ||
| checkEvaluation(Add(Literal(null, IntegerType), c2), null, row) | ||
| checkEvaluation(Add(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) | ||
| } | ||
|
|
||
| test("BinaryComparison") { | ||
| val row = new GenericRow(Array[Any](1, 2, 3, null)) | ||
| val c1 = 'a.int.at(0) | ||
| val c2 = 'a.int.at(1) | ||
| val c3 = 'a.int.at(2) | ||
| val c4 = 'a.int.at(3) | ||
|
|
||
| checkEvaluation(LessThan(c1, c4), null, row) | ||
| checkEvaluation(LessThan(c1, c2), true, row) | ||
| checkEvaluation(LessThan(c1, Literal(null, IntegerType)), null, row) | ||
| checkEvaluation(LessThan(Literal(null, IntegerType), c2), null, row) | ||
| checkEvaluation(LessThan(Literal(null, IntegerType), Literal(null, IntegerType)), null, row) | ||
| } | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the rationale for changing all of these?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think "def boolean" / "def string" should return an Attribute with nullability = true, that does not harm for the correctness, otherwise, it may bring a wrong hint for the optimization of rule NullPropagation.
For example:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point.