Skip to content

Commit 6ac57fd

Browse files
aokolnychyigatorsmile
authored andcommitted
[SPARK-21417][SQL] Infer join conditions using propagated constraints
## What changes were proposed in this pull request? This PR adds an optimization rule that infers join conditions using propagated constraints. For instance, if there is a join, where the left relation has 'a = 1' and the right relation has 'b = 1', then the rule infers 'a = b' as a join predicate. Only semantically new predicates are appended to the existing join condition. Refer to the corresponding ticket and tests for more details. ## How was this patch tested? This patch comes with a new test suite to cover the implemented logic. Author: aokolnychyi <[email protected]> Closes #18692 from aokolnychyi/spark-21417.
1 parent 999ec13 commit 6ac57fd

File tree

6 files changed

+423
-0
lines changed

6 files changed

+423
-0
lines changed
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
import scala.collection.mutable
21+
22+
import org.apache.spark.sql.catalyst.expressions.EquivalentExpressionMap.SemanticallyEqualExpr
23+
24+
/**
25+
* A class that allows you to map an expression into a set of equivalent expressions. The keys are
26+
* handled based on their semantic meaning and ignoring cosmetic differences. The values are
27+
* represented as [[ExpressionSet]]s.
28+
*
29+
* The underlying representation of keys depends on the [[Expression.semanticHash]] and
30+
* [[Expression.semanticEquals]] methods.
31+
*
32+
* {{{
33+
* val map = new EquivalentExpressionMap()
34+
*
35+
* map.put(1 + 2, a)
36+
* map.put(rand(), b)
37+
*
38+
* map.get(2 + 1) => Set(a) // 1 + 2 and 2 + 1 are semantically equivalent
39+
* map.get(1 + 2) => Set(a) // 1 + 2 and 2 + 1 are semantically equivalent
40+
* map.get(rand()) => Set() // non-deterministic expressions are not equivalent
41+
* }}}
42+
*/
43+
class EquivalentExpressionMap {
44+
45+
private val equivalenceMap = mutable.HashMap.empty[SemanticallyEqualExpr, ExpressionSet]
46+
47+
def put(expression: Expression, equivalentExpression: Expression): Unit = {
48+
val equivalentExpressions = equivalenceMap.getOrElseUpdate(expression, ExpressionSet.empty)
49+
equivalenceMap(expression) = equivalentExpressions + equivalentExpression
50+
}
51+
52+
def get(expression: Expression): Set[Expression] =
53+
equivalenceMap.getOrElse(expression, ExpressionSet.empty)
54+
}
55+
56+
object EquivalentExpressionMap {
57+
58+
private implicit class SemanticallyEqualExpr(val expr: Expression) {
59+
override def equals(obj: Any): Boolean = obj match {
60+
case other: SemanticallyEqualExpr => expr.semanticEquals(other.expr)
61+
case _ => false
62+
}
63+
64+
override def hashCode: Int = expr.semanticHash()
65+
}
66+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ object ExpressionSet {
2727
expressions.foreach(set.add)
2828
set
2929
}
30+
31+
val empty: ExpressionSet = ExpressionSet(Nil)
3032
}
3133

3234
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
8787
PushProjectionThroughUnion,
8888
ReorderJoin,
8989
EliminateOuterJoin,
90+
EliminateCrossJoin,
9091
InferFiltersFromConstraints,
9192
BooleanSimplification,
9293
PushPredicateThroughJoin,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.optimizer
1919

2020
import scala.annotation.tailrec
21+
import scala.collection.mutable
2122

2223
import org.apache.spark.sql.catalyst.expressions._
2324
import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins
@@ -152,3 +153,62 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper {
152153
if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType))
153154
}
154155
}
156+
157+
/**
158+
* A rule that eliminates CROSS joins by inferring join conditions from propagated constraints.
159+
*
160+
* The optimization is applicable only to CROSS joins. For other join types, adding inferred join
161+
* conditions would potentially shuffle children as child node's partitioning won't satisfy the JOIN
162+
* node's requirements which otherwise could have.
163+
*
164+
* For instance, given a CROSS join with the constraint 'a = 1' from the left child and the
165+
* constraint 'b = 1' from the right child, this rule infers a new join predicate 'a = b' and
166+
* converts it to an Inner join.
167+
*/
168+
object EliminateCrossJoin extends Rule[LogicalPlan] with PredicateHelper {
169+
170+
def apply(plan: LogicalPlan): LogicalPlan = {
171+
if (SQLConf.get.constraintPropagationEnabled) {
172+
eliminateCrossJoin(plan)
173+
} else {
174+
plan
175+
}
176+
}
177+
178+
private def eliminateCrossJoin(plan: LogicalPlan): LogicalPlan = plan transform {
179+
case join @ Join(leftPlan, rightPlan, Cross, None) =>
180+
val leftConstraints = join.constraints.filter(_.references.subsetOf(leftPlan.outputSet))
181+
val rightConstraints = join.constraints.filter(_.references.subsetOf(rightPlan.outputSet))
182+
val inferredJoinPredicates = inferJoinPredicates(leftConstraints, rightConstraints)
183+
val joinConditionOpt = inferredJoinPredicates.reduceOption(And)
184+
if (joinConditionOpt.isDefined) Join(leftPlan, rightPlan, Inner, joinConditionOpt) else join
185+
}
186+
187+
private def inferJoinPredicates(
188+
leftConstraints: Set[Expression],
189+
rightConstraints: Set[Expression]): mutable.Set[EqualTo] = {
190+
191+
val equivalentExpressionMap = new EquivalentExpressionMap()
192+
193+
leftConstraints.foreach {
194+
case EqualTo(attr: Attribute, expr: Expression) =>
195+
equivalentExpressionMap.put(expr, attr)
196+
case EqualTo(expr: Expression, attr: Attribute) =>
197+
equivalentExpressionMap.put(expr, attr)
198+
case _ =>
199+
}
200+
201+
val joinConditions = mutable.Set.empty[EqualTo]
202+
203+
rightConstraints.foreach {
204+
case EqualTo(attr: Attribute, expr: Expression) =>
205+
joinConditions ++= equivalentExpressionMap.get(expr).map(EqualTo(attr, _))
206+
case EqualTo(expr: Expression, attr: Attribute) =>
207+
joinConditions ++= equivalentExpressionMap.get(expr).map(EqualTo(attr, _))
208+
case _ =>
209+
}
210+
211+
joinConditions
212+
}
213+
214+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.sql.catalyst.dsl.expressions._
22+
23+
class EquivalentExpressionMapSuite extends SparkFunSuite {
24+
25+
private val onePlusTwo = Literal(1) + Literal(2)
26+
private val twoPlusOne = Literal(2) + Literal(1)
27+
private val rand = Rand(10)
28+
29+
test("behaviour of the equivalent expression map") {
30+
val equivalentExpressionMap = new EquivalentExpressionMap()
31+
equivalentExpressionMap.put(onePlusTwo, 'a)
32+
equivalentExpressionMap.put(Literal(1) + Literal(3), 'b)
33+
equivalentExpressionMap.put(rand, 'c)
34+
35+
// 1 + 2 should be equivalent to 2 + 1
36+
assertResult(ExpressionSet(Seq('a)))(equivalentExpressionMap.get(twoPlusOne))
37+
// non-deterministic expressions should not be equivalent
38+
assertResult(ExpressionSet.empty)(equivalentExpressionMap.get(rand))
39+
40+
// if the same (key, value) is added several times, the map still returns only one entry
41+
equivalentExpressionMap.put(onePlusTwo, 'a)
42+
equivalentExpressionMap.put(twoPlusOne, 'a)
43+
assertResult(ExpressionSet(Seq('a)))(equivalentExpressionMap.get(twoPlusOne))
44+
45+
// get several equivalent attributes
46+
equivalentExpressionMap.put(onePlusTwo, 'e)
47+
assertResult(ExpressionSet(Seq('a, 'e)))(equivalentExpressionMap.get(onePlusTwo))
48+
assertResult(2)(equivalentExpressionMap.get(onePlusTwo).size)
49+
50+
// several non-deterministic expressions should not be equivalent
51+
equivalentExpressionMap.put(rand, 'd)
52+
assertResult(ExpressionSet.empty)(equivalentExpressionMap.get(rand))
53+
assertResult(0)(equivalentExpressionMap.get(rand).size)
54+
}
55+
56+
}

0 commit comments

Comments
 (0)