Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,33 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
}
}

private def dedupJoin(joinPlan: Join): Join = joinPlan match {
// SPARK-21835: It is possibly that the two sides of the join have conflicting attributes,
// the produced join then becomes unresolved and break structural integrity. We should
// de-duplicate conflicting attributes. We don't use transformation here because we only
// care about the most top join converted from correlated predicate subquery.
case j @ Join(left, right, joinType @ (LeftSemi | LeftAnti), joinCond) =>
val duplicates = right.outputSet.intersect(left.outputSet)
if (duplicates.nonEmpty) {
val aliasMap = AttributeMap(duplicates.map { dup =>
dup -> Alias(dup, dup.toString)()
}.toSeq)
val aliasedExpressions = right.output.map { ref =>
aliasMap.getOrElse(ref, ref)
}
val newRight = Project(aliasedExpressions, right)
val newJoinCond = joinCond.map { condExpr =>
condExpr transform {
case a: Attribute => aliasMap.getOrElse(a, a).toAttribute
}
}
Join(left, newRight, joinType, newJoinCond)
} else {
j
}
case _ => joinPlan
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Filter(condition, child) =>
val (withSubquery, withoutSubquery) =
Expand All @@ -64,14 +91,17 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
withSubquery.foldLeft(newFilter) {
case (p, Exists(sub, conditions, _)) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
Join(outerPlan, sub, LeftSemi, joinCond)
// Deduplicate conflicting attributes if any.
dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond))
case (p, Not(Exists(sub, conditions, _))) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
Join(outerPlan, sub, LeftAnti, joinCond)
// Deduplicate conflicting attributes if any.
dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond))
case (p, In(value, Seq(ListQuery(sub, conditions, _, _)))) =>
val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
Join(outerPlan, sub, LeftSemi, joinCond)
// Deduplicate conflicting attributes if any.
dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond))
case (p, Not(In(value, Seq(ListQuery(sub, conditions, _, _))))) =>
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
// Construct the condition. A NULL in one of the conditions is regarded as a positive
Expand All @@ -93,7 +123,8 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
// will have the final conditions in the LEFT ANTI as
// (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2)
val pairs = (joinConds.map(c => Or(c, IsNull(c))) ++ conditions).reduceLeft(And)
Join(outerPlan, sub, LeftAnti, Option(pairs))
// Deduplicate conflicting attributes if any.
dedupJoin(Join(outerPlan, sub, LeftAnti, Option(pairs)))
case (p, predicate) =>
val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p)
Project(p.output, Filter(newCond.get, inputPlan))
Expand Down
67 changes: 67 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.test.SharedSQLContext

class SubquerySuite extends QueryTest with SharedSQLContext {
Expand Down Expand Up @@ -875,4 +876,70 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
assert(e.message.contains("cannot resolve '`a`' given input columns: [t.i, t.j]"))
}
}

test("SPARK-21835: Join in correlated subquery should be duplicateResolved: case 1") {
withTable("t1") {
withTempPath { path =>
Seq(1 -> "a").toDF("i", "j").write.parquet(path.getCanonicalPath)
sql(s"CREATE TABLE t1 USING parquet LOCATION '${path.toURI}'")

val sqlText =
"""
|SELECT * FROM t1
|WHERE
|NOT EXISTS (SELECT * FROM t1)
""".stripMargin
val optimizedPlan = sql(sqlText).queryExecution.optimizedPlan
val join = optimizedPlan.collect {
case j: Join => j
}.head.asInstanceOf[Join]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

        val join = optimizedPlan.collectFirst { case j: Join => j }.get

assert(join.duplicateResolved)
assert(optimizedPlan.resolved)
}
}
}

test("SPARK-21835: Join in correlated subquery should be duplicateResolved: case 2") {
withTable("t1", "t2", "t3") {
withTempPath { path =>
val data = Seq((1, 1, 1), (2, 0, 2))

data.toDF("t1a", "t1b", "t1c").write.parquet(path.getCanonicalPath + "/t1")
data.toDF("t2a", "t2b", "t2c").write.parquet(path.getCanonicalPath + "/t2")
data.toDF("t3a", "t3b", "t3c").write.parquet(path.getCanonicalPath + "/t3")

sql(s"CREATE TABLE t1 USING parquet LOCATION '${path.toURI}/t1'")
sql(s"CREATE TABLE t2 USING parquet LOCATION '${path.toURI}/t2'")
sql(s"CREATE TABLE t3 USING parquet LOCATION '${path.toURI}/t3'")

val sqlText =
s"""
|SELECT *
|FROM (SELECT *
| FROM t2
| WHERE t2c IN (SELECT t1c
| FROM t1
| WHERE t1a = t2a)
| UNION
| SELECT *
| FROM t3
| WHERE t3a IN (SELECT t2a
| FROM t2
| UNION ALL
| SELECT t1a
| FROM t1
| WHERE t1b > 0)) t4
|WHERE t4.t2b IN (SELECT Min(t3b)
| FROM t3
| WHERE t4.t2a = t3a)
""".stripMargin
val optimizedPlan = sql(sqlText).queryExecution.optimizedPlan
val joinNodes = optimizedPlan.collect {
case j: Join => j
}.map(_.asInstanceOf[Join])
joinNodes.map(j => assert(j.duplicateResolved))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

        val joinNodes = optimizedPlan.collect { case j: Join => j }
        joinNodes.foreach(j => assert(j.duplicateResolved))

assert(optimizedPlan.resolved)
}
}
}
}