Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
[SPARK-41162][SQL] Fix anti- and semi-join for self-join with aggrega…
…tions

Rule `PushDownLeftSemiAntiJoin` should not push an anti-join below an `Aggregate` when the join condition references an attribute that exists in its right plan and its left plan's child. This usually happens when the anti-join / semi-join is a self-join while `DeduplicateRelations` cannot deduplicate those attributes (in this example due to the projection of `value` to `id`).

This behaviour already exists for `Project` and `Union`, but `Aggregate` lacks this safety guard.

Without this change, the optimizer creates an incorrect plan.

This example fails with `distinct()` (an aggregation), and succeeds without `distinct()`, but both queries are identical:
```scala
val ids = Seq(1, 2, 3).toDF("id").distinct()
val result = ids.withColumn("id", $"id" + 1).join(ids, Seq("id"), "left_anti").collect()
assert(result.length == 1)
```
With `distinct()`, rule `PushDownLeftSemiAntiJoin` creates a join condition `(value#907 + 1) = value#907`, which can never be true. This effectively removes the anti-join.

**Before this PR:**
The anti-join is fully removed from the plan.
```
== Physical Plan ==
AdaptiveSparkPlan (16)
+- == Final Plan ==
   LocalTableScan (1)

(16) AdaptiveSparkPlan
Output [1]: [id#900]
Arguments: isFinalPlan=true
```

This is caused by `PushDownLeftSemiAntiJoin` adding join condition `(value#907 + 1) = value#907`, which is wrong as because `id#910` in `(id#910 + 1) AS id#912` exists in the right child of the join as well as in the left grandchild:
```
=== Applying Rule org.apache.spark.sql.catalyst.optimizer.PushDownLeftSemiAntiJoin ===
!Join LeftAnti, (id#912 = id#910)                  Aggregate [id#910], [(id#910 + 1) AS id#912]
!:- Aggregate [id#910], [(id#910 + 1) AS id#912]   +- Project [value#907 AS id#910]
!:  +- Project [value#907 AS id#910]                  +- Join LeftAnti, ((value#907 + 1) = value#907)
!:     +- LocalRelation [value#907]                      :- LocalRelation [value#907]
!+- Aggregate [id#910], [id#910]                         +- Aggregate [id#910], [id#910]
!   +- Project [value#914 AS id#910]                        +- Project [value#914 AS id#910]
!      +- LocalRelation [value#914]                            +- LocalRelation [value#914]
```

The right child of the join and in the left grandchild would become the children of the pushed-down join, which creates an invalid join condition.

**After this PR:**
Join condition `(id#910 + 1) AS id#912` is understood to become ambiguous as both sides of the prospect join contain `id#910`. Hence, the join is not pushed down. The rule is then not applied any more.

The final plan contains the anti-join:
```
== Physical Plan ==
AdaptiveSparkPlan (24)
+- == Final Plan ==
   * BroadcastHashJoin LeftSemi BuildRight (14)
   :- * HashAggregate (7)
   :  +- AQEShuffleRead (6)
   :     +- ShuffleQueryStage (5), Statistics(sizeInBytes=48.0 B, rowCount=3)
   :        +- Exchange (4)
   :           +- * HashAggregate (3)
   :              +- * Project (2)
   :                 +- * LocalTableScan (1)
   +- BroadcastQueryStage (13), Statistics(sizeInBytes=1024.0 KiB, rowCount=3)
      +- BroadcastExchange (12)
         +- * HashAggregate (11)
            +- AQEShuffleRead (10)
               +- ShuffleQueryStage (9), Statistics(sizeInBytes=48.0 B, rowCount=3)
                  +- ReusedExchange (8)

(8) ReusedExchange [Reuses operator id: 4]
Output [1]: [id#898]

(24) AdaptiveSparkPlan
Output [1]: [id#900]
Arguments: isFinalPlan=true
```

It fixes correctness.

Unit tests in `DataFrameJoinSuite` and `LeftSemiAntiJoinPushDownSuite`.

Closes #39131 from EnricoMi/branch-antijoin-selfjoin-fix.

Authored-by: Enrico Minack <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
EnricoMi committed Jan 5, 2023
commit b8b22a7d1766a6ba984897091214d75b6ea834bc
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan]
}

// LeftSemi/LeftAnti over Aggregate, only push down if join can be planned as broadcast join.
case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(_), _, _)
case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(_), joinCond, _)
if agg.aggregateExpressions.forall(_.deterministic) && agg.groupingExpressions.nonEmpty &&
!agg.aggregateExpressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) &&
canPushThroughCondition(agg.children, joinCond, rightOp) &&
canPlanAsBroadcastHashJoin(join, conf) =>
val aliasMap = getAliasMap(agg)
val canPushDownPredicate = (predicate: Expression) => {
Expand Down Expand Up @@ -105,11 +106,11 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan]
}

/**
* Check if we can safely push a join through a project or union by making sure that attributes
* referred in join condition do not contain the same attributes as the plan they are moved
* into. This can happen when both sides of join refers to the same source (self join). This
* function makes sure that the join condition refers to attributes that are not ambiguous (i.e
* present in both the legs of the join) or else the resultant plan will be invalid.
* Check if we can safely push a join through a project, aggregate, or union by making sure that
* attributes referred in join condition do not contain the same attributes as the plan they are
* moved into. This can happen when both sides of join refers to the same source (self join).
* This function makes sure that the join condition refers to attributes that are not ambiguous
* (i.e present in both the legs of the join) or else the resultant plan will be invalid.
*/
private def canPushThroughCondition(
plans: Seq[LogicalPlan],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.IntegerType

class LeftSemiPushdownSuite extends PlanTest {
class LeftSemiAntiJoinPushDownSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Expand All @@ -46,7 +46,7 @@ class LeftSemiPushdownSuite extends PlanTest {
val testRelation1 = LocalRelation('d.int)
val testRelation2 = LocalRelation('e.int)

test("Project: LeftSemiAnti join pushdown") {
test("Project: LeftSemi join pushdown") {
val originalQuery = testRelation
.select(star())
.join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
Expand All @@ -59,7 +59,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Project: LeftSemiAnti join no pushdown because of non-deterministic proj exprs") {
test("Project: LeftSemi join no pushdown - non-deterministic proj exprs") {
val originalQuery = testRelation
.select(Rand(1), 'b, 'c)
.join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
Expand All @@ -68,7 +68,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, originalQuery.analyze)
}

test("Project: LeftSemiAnti join non correlated scalar subq") {
test("Project: LeftSemi join pushdown - non-correlated scalar subq") {
val subq = ScalarSubquery(testRelation.groupBy('b)(sum('c).as("sum")).analyze)
val originalQuery = testRelation
.select(subq.as("sum"))
Expand All @@ -83,7 +83,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Project: LeftSemiAnti join no pushdown - correlated scalar subq in projection list") {
test("Project: LeftSemi join no pushdown - correlated scalar subq in projection list") {
val testRelation2 = LocalRelation('e.int, 'f.int)
val subqPlan = testRelation2.groupBy('e)(sum('f).as("sum")).where('e === 'a)
val subqExpr = ScalarSubquery(subqPlan)
Expand All @@ -95,7 +95,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, originalQuery.analyze)
}

test("Aggregate: LeftSemiAnti join pushdown") {
test("Aggregate: LeftSemi join pushdown") {
val originalQuery = testRelation
.groupBy('b)('b, sum('c))
.join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
Expand All @@ -109,7 +109,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Aggregate: LeftSemiAnti join no pushdown due to non-deterministic aggr expressions") {
test("Aggregate: LeftSemi join no pushdown - non-deterministic aggr expressions") {
val originalQuery = testRelation
.groupBy('b)('b, Rand(10).as('c))
.join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
Expand Down Expand Up @@ -142,7 +142,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, originalQuery.analyze)
}

test("LeftSemiAnti join over aggregate - no pushdown") {
test("Aggregate: LeftSemi join no pushdown") {
val originalQuery = testRelation
.groupBy('b)('b, sum('c).as('sum))
.join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'sum === 'd))
Expand All @@ -151,7 +151,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, originalQuery.analyze)
}

test("Aggregate: LeftSemiAnti join non-correlated scalar subq aggr exprs") {
test("Aggregate: LeftSemi join pushdown - non-correlated scalar subq aggr exprs") {
val subq = ScalarSubquery(testRelation.groupBy('b)(sum('c).as("sum")).analyze)
val originalQuery = testRelation
.groupBy('a) ('a, subq.as("sum"))
Expand All @@ -166,7 +166,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("LeftSemiAnti join over Window") {
test("Window: LeftSemi join pushdown") {
val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame))

val originalQuery = testRelation
Expand All @@ -184,7 +184,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Window: LeftSemi partial pushdown") {
test("Window: LeftSemi join partial pushdown") {
// Attributes from join condition which does not refer to the window partition spec
// are kept up in the plan as a Filter operator above Window.
val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame))
Expand Down Expand Up @@ -224,7 +224,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Union: LeftSemiAnti join pushdown") {
test("Union: LeftSemi join pushdown") {
val testRelation2 = LocalRelation('x.int, 'y.int, 'z.int)

val originalQuery = Union(Seq(testRelation, testRelation2))
Expand All @@ -240,7 +240,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Union: LeftSemiAnti join pushdown in self join scenario") {
test("Union: LeftSemi join pushdown in self join scenario") {
val testRelation2 = LocalRelation('x.int, 'y.int, 'z.int)
val attrX = testRelation2.output.head

Expand All @@ -259,7 +259,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Unary: LeftSemiAnti join pushdown") {
test("Unary: LeftSemi join pushdown") {
val originalQuery = testRelation
.select(star())
.repartition(1)
Expand All @@ -274,7 +274,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Unary: LeftSemiAnti join pushdown - empty join condition") {
test("Unary: LeftSemi join pushdown - empty join condition") {
val originalQuery = testRelation
.select(star())
.repartition(1)
Expand All @@ -289,7 +289,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Unary: LeftSemi join pushdown - partial pushdown") {
test("Unary: LeftSemi join partial pushdown") {
val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType))
val originalQuery = testRelationWithArrayType
.generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col"))
Expand All @@ -305,7 +305,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Unary: LeftAnti join pushdown - no pushdown") {
test("Unary: LeftAnti join no pushdown") {
val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType))
val originalQuery = testRelationWithArrayType
.generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col"))
Expand All @@ -315,7 +315,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, originalQuery.analyze)
}

test("Unary: LeftSemiAnti join pushdown - no pushdown") {
test("Unary: LeftSemi join - no pushdown") {
val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType))
val originalQuery = testRelationWithArrayType
.generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col"))
Expand All @@ -325,7 +325,7 @@ class LeftSemiPushdownSuite extends PlanTest {
comparePlans(optimized, originalQuery.analyze)
}

test("Unary: LeftSemi join push down through Expand") {
test("Unary: LeftSemi join pushdown through Expand") {
val expand = Expand(Seq(Seq('a, 'b, "null"), Seq('a, "null", 'c)),
Seq('a, 'b, 'c), testRelation)
val originalQuery = expand
Expand Down Expand Up @@ -431,6 +431,25 @@ class LeftSemiPushdownSuite extends PlanTest {
}
}

Seq(LeftSemi, LeftAnti).foreach { case jt =>
test(s"Aggregate: $jt join no pushdown - join condition refers left leg and right leg child") {
val aggregation = testRelation
.select('b.as("id"), 'c)
.groupBy('id)('id, sum('c).as("sum"))

// reference "b" exists in left leg, and the children of the right leg of the join
val originalQuery = aggregation.select(('id + 1).as("id_plus_1"), 'sum)
.join(aggregation, joinType = jt, condition = Some('id === 'id_plus_1))
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
.select('b.as("id"), 'c)
.groupBy('id)(('id + 1).as("id_plus_1"), sum('c).as("sum"))
.join(aggregation, joinType = jt, condition = Some('id === 'id_plus_1))
.analyze
comparePlans(optimized, correctAnswer)
}
}

Seq(LeftSemi, LeftAnti).foreach { case outerJT =>
Seq(Inner, LeftOuter, RightOuter, Cross).foreach { case innerJT =>
test(s"$outerJT no pushdown - join condition refers none of the leg - join type $innerJT") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,24 @@ class DataFrameJoinSuite extends QueryTest
}
}

Seq("left_semi", "left_anti").foreach { joinType =>
test(s"SPARK-41162: $joinType self-joined aggregated dataframe") {
// aggregated dataframe
val ids = Seq(1, 2, 3).toDF("id").distinct()

// self-joined via joinType
val result = ids.withColumn("id", $"id" + 1)
.join(ids, usingColumns = Seq("id"), joinType = joinType).collect()

val expected = joinType match {
case "left_semi" => 2
case "left_anti" => 1
case _ => -1 // unsupported test type, test will always fail
}
assert(result.length == expected)
}
}

def extractLeftDeepInnerJoins(plan: LogicalPlan): Seq[LogicalPlan] = plan match {
case j @ Join(left, right, _: InnerLike, _, _) => right +: extractLeftDeepInnerJoins(left)
case Filter(_, child) => extractLeftDeepInnerJoins(child)
Expand Down