Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,14 @@ class EquivalentExpressions {
}

/**
* Adds or removes only expressions which are common in each of given expressions, in a recursive
* way.
* For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e + (c + 1)))`, the common
* Adds or removes expressions only when the expressions:
* 1. are common in exprs.
* For example, given two exprs `(a + (b + (c + 1)))` and `(d + (e + (c + 1)))`, the common
* expression `(c + 1)` will be added into `equivalenceMap`.
* 2. are common between alwaysEvaluateExprs and exprs.
* For example, given alwaysEvaluateExprs `(a + (b + (c + 1)))` and two exprs `e + f`
* and `(d + (e + (c + 1)))`, the common expression `(c + 1)`
* will be added into `equivalenceMap`.
*
* Note that as we don't know in advance if any child node of an expression will be common across
* all given expressions, we compute local equivalence maps for all given expressions and filter
Expand All @@ -97,6 +101,7 @@ class EquivalentExpressions {
* expressions.
*/
private def updateCommonExprs(
alwaysEvaluateExprs: Seq[Expression],
exprs: Seq[Expression],
map: mutable.HashMap[ExpressionEquals, ExpressionStats],
useCount: Int): Unit = {
Expand All @@ -112,6 +117,21 @@ class EquivalentExpressions {
}
}

if (alwaysEvaluateExprs.length > 0) {
val alwaysEvaluateEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats]
alwaysEvaluateExprs.foreach(updateExprTree(_, alwaysEvaluateEquivalenceMap))
// check if part of `exprs` have common expressions with `alwaysEvaluateExprs`.
exprs.filter(e => !alwaysEvaluateExprs.exists(_.semanticEquals(e))).foreach { expr =>
val otherLocalEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats]
updateExprTree(expr, otherLocalEquivalenceMap)
otherLocalEquivalenceMap.foreach { case (key, stats) =>
if (alwaysEvaluateEquivalenceMap.contains(key) && !localEquivalenceMap.contains(key)) {
updateExprTree(stats.expr, localEquivalenceMap)
}
}
}
}

// Start with the highest expression, remove it from `localEquivalenceMap` and add it to `map`.
// The remaining highest expression in `localEquivalenceMap` is also common expression so loop
// until `localEquivalenceMap` is not empty.
Expand All @@ -136,9 +156,10 @@ class EquivalentExpressions {

// For some special expressions we cannot just recurse into all of its children, but we can
// recursively add the common expressions shared between all of its children.
private def commonChildrenToRecurse(expr: Expression): Seq[Seq[Expression]] = expr match {
private def commonChildrenToRecurse(expr: Expression): Seq[(Seq[Expression], Seq[Expression])] =
expr match {
case _: CodegenFallback => Nil
case c: ConditionalExpression => c.branchGroups
case c: ConditionalExpression => c.branchGroups.map((c.alwaysEvaluatedInputs, _))
case _ => Nil
}

Expand Down Expand Up @@ -168,7 +189,11 @@ class EquivalentExpressions {
if (!skip && !updateExprInMap(expr, map, useCount)) {
val uc = useCount.signum
childrenToRecurse(expr).foreach(updateExprTree(_, map, uc))
commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(updateCommonExprs(_, map, uc))
commonChildrenToRecurse(expr)
.filter(_._2.nonEmpty)
.foreach { case (alwaysEvaluatedInputs, branchGroups) =>
updateCommonExprs(alwaysEvaluatedInputs, branchGroups, map, uc)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,12 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
assert(equivalence1.getAllExprStates().filter(_.useCount == 1).exists(_.expr eq condition))

// Repeated `add` is only in one branch, so we don't count it.
val ifExpr2 = If(condition, Add(Literal(1), Literal(3)), Add(add, add))
val ifExpr2 = If(Literal(true), Add(Literal(1), Literal(3)), Add(add, add))
val equivalence2 = new EquivalentExpressions
equivalence2.addExprTree(ifExpr2)

assert(equivalence2.getAllExprStates(1).isEmpty)
assert(equivalence2.getAllExprStates().count(_.useCount == 1) == 3)
assert(equivalence2.getAllExprStates().count(_.useCount == 1) == 1)

val ifExpr3 = If(condition, ifExpr1, ifExpr1)
val equivalence3 = new EquivalentExpressions
Expand Down Expand Up @@ -461,6 +461,104 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
assert(e2.getCommonSubexpressions.size == 1)
assert(e2.getCommonSubexpressions.head == add)
}

test("SubExpr elimination should work when `branchGroups` " +
"has overlap with `alwaysEvaluatedInputs`: If") {
val add = Add(Literal(1), Literal(2))

// `add` is in one branch of `If` and predicate, it should be a common expression.
val ifExpr1 = If(IsNull(add), Literal(null), KnownNotNull(add))
val equivalence1 = new EquivalentExpressions
equivalence1.addExprTree(ifExpr1)
assert(equivalence1.getCommonSubexpressions.size == 1)
assert(equivalence1.getCommonSubexpressions.head == add)
assert(equivalence1.getAllExprStates().count(_.useCount == 2) == 1)
assert(equivalence1.getAllExprStates().filter(_.useCount == 2).head.expr eq add)

// `add` is in both branches of `If` and predicate, it should be a common expression.
val ifExpr2 = If(IsNull(add), add, KnownNotNull(add))
val equivalence2 = new EquivalentExpressions
equivalence2.addExprTree(ifExpr2)
assert(equivalence2.getCommonSubexpressions.size == 1)
assert(equivalence2.getCommonSubexpressions.head == add)
assert(equivalence2.getAllExprStates().count(_.useCount == 2) == 1)
assert(equivalence2.getAllExprStates().filter(_.useCount == 2).head.expr eq add)
}

test("SubExpr elimination should work when `branchGroups` " +
"has overlap with `alwaysEvaluatedInputs`: CaseWhen") {
val add1 = Add(Literal(1), Literal(2))
val add2 = Add(Literal(2), Literal(3))

// `add2` exists in the first and third condition, it should be a common expression
val conditions1 = (GreaterThan(add2, Literal(3)), add1) ::
(GreaterThan(add1, Literal(4)), add1) ::
(GreaterThan(add2, Literal(5)), add1) :: Nil
val caseWhenExpr1 = CaseWhen(conditions1, None)
val equivalence1 = new EquivalentExpressions
equivalence1.addExprTree(caseWhenExpr1)
assert(equivalence1.getCommonSubexpressions.size == 1)
assert(equivalence1.getCommonSubexpressions.head == add2)
assert(equivalence1.getAllExprStates().count(_.useCount == 2) == 1)
assert(equivalence1.getAllExprStates().filter(_.useCount == 2).head.expr eq add2)

// `add2` exists only in the first condition, it should NOT be a common expression
val conditions2 = (GreaterThan(add2, Literal(3)), add1) ::
(GreaterThan(add1, Literal(4)), add1) ::
(GreaterThan(add1, Literal(5)), add1) :: Nil
val caseWhenExpr2 = CaseWhen(conditions2, add1)
val equivalence2 = new EquivalentExpressions
equivalence2.addExprTree(caseWhenExpr2)
assert(equivalence2.getCommonSubexpressions.size == 0)

// `add2` exists only in the last two conditions, it should NOT be a common expression
val conditions3 = (GreaterThan(add2, Literal(3)), add1) ::
(GreaterThan(add1, Literal(4)), add1) ::
(GreaterThan(add1, Literal(5)), add1) :: Nil
val caseWhenExpr3 = CaseWhen(conditions3, add1)
val equivalence3 = new EquivalentExpressions
equivalence3.addExprTree(caseWhenExpr3)
assert(equivalence3.getCommonSubexpressions.size == 0)

// `add2` exists in the first condition and else value, it should be a common expression
val conditions4 = (GreaterThan(add2, Literal(3)), add1) ::
(GreaterThan(add1, Literal(4)), add1) ::
(GreaterThan(add1, Literal(5)), add1) :: Nil
val caseWhenExpr4 = CaseWhen(conditions4, Some(add2))
val equivalence4 = new EquivalentExpressions
equivalence4.addExprTree(caseWhenExpr4)
assert(equivalence4.getCommonSubexpressions.size == 1)
assert(equivalence4.getCommonSubexpressions.head == add2)
assert(equivalence4.getAllExprStates().count(_.useCount == 2) == 1)
assert(equivalence4.getAllExprStates().filter(_.useCount == 2).head.expr eq add2)
}

test("SubExpr elimination should work when `branchGroups` " +
"has overlap with `alwaysEvaluatedInputs`: Coalesce") {
val add1 = Add(Literal(1), Literal(2))
val add2 = Add(Literal(2), Literal(3))

// `add2` is in first and third conditions, it should be a common expression
val conditions1 = GreaterThan(add2, Literal(3)) ::
GreaterThan(add1, Literal(4)) ::
GreaterThan(add2, Literal(5)) :: Nil
val coalesceExpr1 = Coalesce(conditions1)
val equivalence1 = new EquivalentExpressions
equivalence1.addExprTree(coalesceExpr1)
assert(equivalence1.getCommonSubexpressions.size == 1)
assert(equivalence1.getCommonSubexpressions.head == add2)
assert(equivalence1.getAllExprStates().count(_.useCount == 2) == 1)
assert(equivalence1.getAllExprStates().filter(_.useCount == 2).head.expr eq add2)

// `add2` is in the last conditions, it should be a common expression
val conditions2 = GreaterThan(add1, Literal(3)) ::
GreaterThan(add2, Literal(4)) ::
GreaterThan(add2, Literal(5)) :: Nil
val coalesceExpr2 = Coalesce(conditions2)
val equivalence2 = new EquivalentExpressions
equivalence2.addExprTree(coalesceExpr2)
assert(equivalence2.getCommonSubexpressions.size == 0)
}
}

case class CodegenFallbackExpression(child: Expression)
Expand Down