-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-21417][SQL] Infer join conditions using propagated constraints #18692
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
e0e6ad3
b69185c
3e090f9
0e5a9f2
9ab91a1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -152,3 +152,99 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { | |
| if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType)) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * A rule that eliminates CROSS joins by inferring join conditions from propagated constraints. | ||
| * | ||
| * The optimization is applicable only to CROSS joins. For other join types, adding inferred join | ||
| * conditions would potentially shuffle children as child node's partitioning won't satisfy the JOIN | ||
| * node's requirements which otherwise could have. | ||
| * | ||
| * For instance, if there is a CROSS join, where the left relation has 'a = 1' and the right | ||
| * relation has 'b = 1', the rule infers 'a = b' as a join predicate. | ||
|
||
| */ | ||
| object EliminateCrossJoin extends Rule[LogicalPlan] with PredicateHelper { | ||
|
|
||
| def apply(plan: LogicalPlan): LogicalPlan = { | ||
| if (SQLConf.get.constraintPropagationEnabled) { | ||
| eliminateCrossJoin(plan) | ||
| } else { | ||
| plan | ||
| } | ||
| } | ||
|
|
||
| private def eliminateCrossJoin(plan: LogicalPlan): LogicalPlan = plan transform { | ||
| case join@Join(leftPlan, rightPlan, Cross, None) => | ||
|
||
| val leftConstraints = join.constraints.filter(_.references.subsetOf(leftPlan.outputSet)) | ||
| val rightConstraints = join.constraints.filter(_.references.subsetOf(rightPlan.outputSet)) | ||
| val inferredJoinPredicates = inferJoinPredicates(leftConstraints, rightConstraints) | ||
| val joinConditionOpt = inferredJoinPredicates.reduceOption(And) | ||
| if (joinConditionOpt.isDefined) Join(leftPlan, rightPlan, Inner, joinConditionOpt) else join | ||
| } | ||
|
|
||
| private def inferJoinPredicates( | ||
| leftConstraints: Set[Expression], | ||
| rightConstraints: Set[Expression]): Set[EqualTo] = { | ||
|
|
||
| // iterate through the left constraints and build a hash map that points semantically | ||
| // equivalent expressions into attributes | ||
| val emptyEquivalenceMap = Map.empty[SemanticExpression, Set[Attribute]] | ||
| val equivalenceMap = leftConstraints.foldLeft(emptyEquivalenceMap) { case (map, constraint) => | ||
| constraint match { | ||
| case EqualTo(attr: Attribute, expr: Expression) => | ||
| updateEquivalenceMap(map, attr, expr) | ||
| case EqualTo(expr: Expression, attr: Attribute) => | ||
| updateEquivalenceMap(map, attr, expr) | ||
| case _ => map | ||
| } | ||
| } | ||
|
|
||
| // iterate through the right constraints and infer join conditions using the equivalence map | ||
| rightConstraints.foldLeft(Set.empty[EqualTo]) { case (joinConditions, constraint) => | ||
| constraint match { | ||
| case EqualTo(attr: Attribute, expr: Expression) => | ||
| appendJoinConditions(attr, expr, equivalenceMap, joinConditions) | ||
| case EqualTo(expr: Expression, attr: Attribute) => | ||
| appendJoinConditions(attr, expr, equivalenceMap, joinConditions) | ||
| case _ => joinConditions | ||
| } | ||
| } | ||
| } | ||
|
|
||
| private def updateEquivalenceMap( | ||
| equivalenceMap: Map[SemanticExpression, Set[Attribute]], | ||
| attr: Attribute, | ||
| expr: Expression): Map[SemanticExpression, Set[Attribute]] = { | ||
|
|
||
| val equivalentAttrs = equivalenceMap.getOrElse(expr, Set.empty[Attribute]) | ||
| if (equivalentAttrs.contains(attr)) { | ||
| equivalenceMap | ||
| } else { | ||
| equivalenceMap.updated(expr, equivalentAttrs + attr) | ||
| } | ||
| } | ||
|
|
||
| private def appendJoinConditions( | ||
| attr: Attribute, | ||
| expr: Expression, | ||
| equivalenceMap: Map[SemanticExpression, Set[Attribute]], | ||
| joinConditions: Set[EqualTo]): Set[EqualTo] = { | ||
|
|
||
| equivalenceMap.get(expr) match { | ||
| case Some(equivalentAttrs) => joinConditions ++ equivalentAttrs.map(EqualTo(attr, _)) | ||
| case None => joinConditions | ||
| } | ||
| } | ||
|
|
||
| // the purpose of this class is to treat 'a === 1 and 1 === 'a as the same expressions | ||
| implicit class SemanticExpression(private val expr: Expression) { | ||
|
||
|
|
||
| override def hashCode(): Int = expr.semanticHash() | ||
|
|
||
| override def equals(other: Any): Boolean = other match { | ||
| case other: SemanticExpression => expr.semanticEquals(other.expr) | ||
| case _ => false | ||
| } | ||
| } | ||
|
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,210 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.catalyst.optimizer | ||
|
|
||
| import org.apache.spark.sql.catalyst.dsl.expressions._ | ||
| import org.apache.spark.sql.catalyst.dsl.plans._ | ||
| import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal, Not} | ||
| import org.apache.spark.sql.catalyst.plans.{Cross, Inner, JoinType, PlanTest} | ||
| import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} | ||
| import org.apache.spark.sql.catalyst.rules.RuleExecutor | ||
| import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED | ||
| import org.apache.spark.sql.types.IntegerType | ||
|
|
||
| class EliminateCrossJoinSuite extends PlanTest { | ||
|
|
||
| object Optimize extends RuleExecutor[LogicalPlan] { | ||
| val batches = | ||
| Batch("Eliminate cross joins", FixedPoint(10), | ||
| EliminateCrossJoin, | ||
| PushPredicateThroughJoin) :: Nil | ||
| } | ||
|
|
||
| val testRelation1 = LocalRelation('a.int, 'b.int) | ||
| val testRelation2 = LocalRelation('c.int, 'd.int) | ||
|
|
||
| test("successful elimination of cross joins (1)") { | ||
| checkJoinOptimization( | ||
| originalFilter = 'a === 1 && 'c === 1 && 'd === 1, | ||
| originalJoinType = Cross, | ||
| originalJoinCondition = None, | ||
| expectedLeftRelationFilter = 'a === 1, | ||
| expectedRightRelationFilter = 'c === 1 && 'd === 1, | ||
| expectedJoinType = Inner, | ||
| expectedJoinCondition = Some('a === 'c && 'a === 'd)) | ||
| } | ||
|
|
||
| test("successful elimination of cross joins (2)") { | ||
| checkJoinOptimization( | ||
| originalFilter = 'a === 1 && 'b === 2 && 'd === 1, | ||
| originalJoinType = Cross, | ||
| originalJoinCondition = None, | ||
| expectedLeftRelationFilter = 'a === 1 && 'b === 2, | ||
| expectedRightRelationFilter = 'd === 1, | ||
| expectedJoinType = Inner, | ||
| expectedJoinCondition = Some('a === 'd)) | ||
| } | ||
|
|
||
| test("successful elimination of cross joins (3)") { | ||
| // PushPredicateThroughJoin will push 'd === 'a into the join condition | ||
| // EliminateCrossJoin will NOT apply because the condition will be already present | ||
| // therefore, the join type will stay the same (i.e., CROSS) | ||
| checkJoinOptimization( | ||
| originalFilter = 'a === 1 && Literal(1) === 'd && 'd === 'a, | ||
| originalJoinType = Cross, | ||
| originalJoinCondition = None, | ||
| expectedLeftRelationFilter = 'a === 1, | ||
| expectedRightRelationFilter = Literal(1) === 'd, | ||
| expectedJoinType = Cross, | ||
| expectedJoinCondition = Some('a === 'd)) | ||
| } | ||
|
|
||
| test("successful elimination of cross joins (4)") { | ||
| // Literal(1) * Literal(2) and Literal(2) * Literal(1) are semantically equal | ||
| checkJoinOptimization( | ||
| originalFilter = 'a === Literal(1) * Literal(2) && Literal(2) * Literal(1) === 'c, | ||
| originalJoinType = Cross, | ||
| originalJoinCondition = None, | ||
| expectedLeftRelationFilter = 'a === Literal(1) * Literal(2), | ||
| expectedRightRelationFilter = Literal(2) * Literal(1) === 'c, | ||
| expectedJoinType = Inner, | ||
| expectedJoinCondition = Some('a === 'c)) | ||
| } | ||
|
|
||
| test("successful elimination of cross joins (5)") { | ||
| checkJoinOptimization( | ||
| originalFilter = 'a === 1 && Literal(1) === 'a && 'c === 1, | ||
| originalJoinType = Cross, | ||
| originalJoinCondition = None, | ||
| expectedLeftRelationFilter = 'a === 1 && Literal(1) === 'a, | ||
| expectedRightRelationFilter = 'c === 1, | ||
| expectedJoinType = Inner, | ||
| expectedJoinCondition = Some('a === 'c)) | ||
| } | ||
|
|
||
| test("successful elimination of cross joins (6)") { | ||
| checkJoinOptimization( | ||
| originalFilter = 'a === Cast("1", IntegerType) && 'c === Cast("1", IntegerType) && 'd === 1, | ||
| originalJoinType = Cross, | ||
| originalJoinCondition = None, | ||
| expectedLeftRelationFilter = 'a === Cast("1", IntegerType), | ||
| expectedRightRelationFilter = 'c === Cast("1", IntegerType) && 'd === 1, | ||
| expectedJoinType = Inner, | ||
| expectedJoinCondition = Some('a === 'c)) | ||
| } | ||
|
|
||
| test("successful elimination of cross joins (7)") { | ||
| // The join condition appears due to PushPredicateThroughJoin | ||
| checkJoinOptimization( | ||
| originalFilter = (('a >= 1 && 'c === 1) || 'd === 10) && 'b === 10 && 'c === 1, | ||
| originalJoinType = Cross, | ||
| originalJoinCondition = None, | ||
| expectedLeftRelationFilter = 'b === 10, | ||
| expectedRightRelationFilter = 'c === 1, | ||
| expectedJoinType = Cross, | ||
| expectedJoinCondition = Some(('a >= 1 && 'c === 1) || 'd === 10)) | ||
| } | ||
|
|
||
| test("successful elimination of cross joins (8)") { | ||
| checkJoinOptimization( | ||
| originalFilter = 'a === 1 && 'c === 1 && Literal(1) === 'a && Literal(1) === 'c, | ||
| originalJoinType = Cross, | ||
| originalJoinCondition = None, | ||
| expectedLeftRelationFilter = 'a === 1 && Literal(1) === 'a, | ||
| expectedRightRelationFilter = 'c === 1 && Literal(1) === 'c, | ||
| expectedJoinType = Inner, | ||
| expectedJoinCondition = Some('a === 'c)) | ||
| } | ||
|
|
||
| test("inability to detect join conditions when constant propagation is disabled") { | ||
| withSQLConf(CONSTRAINT_PROPAGATION_ENABLED.key -> "false") { | ||
| checkJoinOptimization( | ||
| originalFilter = 'a === 1 && 'c === 1 && 'd === 1, | ||
| originalJoinType = Cross, | ||
| originalJoinCondition = None, | ||
| expectedLeftRelationFilter = 'a === 1, | ||
| expectedRightRelationFilter = 'c === 1 && 'd === 1, | ||
| expectedJoinType = Cross, | ||
| expectedJoinCondition = None) | ||
| } | ||
| } | ||
|
|
||
| test("inability to detect join conditions (1)") { | ||
| checkJoinOptimization( | ||
| originalFilter = 'a >= 1 && 'c === 1 && 'd >= 1, | ||
| originalJoinType = Cross, | ||
| originalJoinCondition = None, | ||
| expectedLeftRelationFilter = 'a >= 1, | ||
| expectedRightRelationFilter = 'c === 1 && 'd >= 1, | ||
| expectedJoinType = Cross, | ||
| expectedJoinCondition = None) | ||
| } | ||
|
|
||
| test("inability to detect join conditions (2)") { | ||
| checkJoinOptimization( | ||
| originalFilter = Literal(1) === 'b && ('c === 1 || 'd === 1), | ||
| originalJoinType = Cross, | ||
| originalJoinCondition = None, | ||
| expectedLeftRelationFilter = Literal(1) === 'b, | ||
| expectedRightRelationFilter = 'c === 1 || 'd === 1, | ||
| expectedJoinType = Cross, | ||
| expectedJoinCondition = None) | ||
| } | ||
|
|
||
| test("inability to detect join conditions (3)") { | ||
| checkJoinOptimization( | ||
| originalFilter = Literal(1) === 'b && 'c === 1, | ||
| originalJoinType = Cross, | ||
| originalJoinCondition = Some('c === 'b), | ||
| expectedLeftRelationFilter = Literal(1) === 'b, | ||
| expectedRightRelationFilter = 'c === 1, | ||
| expectedJoinType = Cross, | ||
| expectedJoinCondition = Some('c === 'b)) | ||
| } | ||
|
|
||
| test("inability to detect join conditions (4)") { | ||
| checkJoinOptimization( | ||
| originalFilter = Not('a === 1) && 'd === 1, | ||
| originalJoinType = Cross, | ||
| originalJoinCondition = None, | ||
| expectedLeftRelationFilter = Not('a === 1), | ||
| expectedRightRelationFilter = 'd === 1, | ||
| expectedJoinType = Cross, | ||
| expectedJoinCondition = None) | ||
| } | ||
|
|
||
| private def checkJoinOptimization( | ||
| originalFilter: Expression, | ||
| originalJoinType: JoinType, | ||
| originalJoinCondition: Option[Expression], | ||
| expectedLeftRelationFilter: Expression, | ||
| expectedRightRelationFilter: Expression, | ||
| expectedJoinType: JoinType, | ||
| expectedJoinCondition: Option[Expression]): Unit = { | ||
|
|
||
| val originalQuery = testRelation1 | ||
| .join(testRelation2, originalJoinType, originalJoinCondition) | ||
| .where(originalFilter) | ||
| val optimizedQuery = Optimize.execute(originalQuery.analyze) | ||
|
|
||
| val left = testRelation1.where(expectedLeftRelationFilter) | ||
| val right = testRelation2.where(expectedRightRelationFilter) | ||
| val expectedQuery = left.join(right, expectedJoinType, expectedJoinCondition).analyze | ||
| comparePlans(optimizedQuery, expectedQuery) | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we apply this optimization to all joins after #19054?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It sounds promising.