Skip to content

Commit f2e22ae

Browse files
viiryagatorsmile
authored andcommitted
[SPARK-21835][SQL] RewritePredicateSubquery should not produce unresolved query plans
## What changes were proposed in this pull request? Correlated predicate subqueries are rewritten into `Join` by the rule `RewritePredicateSubquery` during optimization. It is possibly that the two sides of the `Join` have conflicting attributes. The query plans produced by `RewritePredicateSubquery` become unresolved and break structural integrity. We should check if there are conflicting attributes in the `Join` and de-duplicate them by adding a `Project`. ## How was this patch tested? Added tests. Author: Liang-Chi Hsieh <[email protected]> Closes #19050 from viirya/SPARK-21835.
1 parent 64936c1 commit f2e22ae

File tree

2 files changed

+98
-4
lines changed

2 files changed

+98
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,33 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
4949
}
5050
}
5151

52+
private def dedupJoin(joinPlan: Join): Join = joinPlan match {
53+
// SPARK-21835: It is possibly that the two sides of the join have conflicting attributes,
54+
// the produced join then becomes unresolved and break structural integrity. We should
55+
// de-duplicate conflicting attributes. We don't use transformation here because we only
56+
// care about the most top join converted from correlated predicate subquery.
57+
case j @ Join(left, right, joinType @ (LeftSemi | LeftAnti), joinCond) =>
58+
val duplicates = right.outputSet.intersect(left.outputSet)
59+
if (duplicates.nonEmpty) {
60+
val aliasMap = AttributeMap(duplicates.map { dup =>
61+
dup -> Alias(dup, dup.toString)()
62+
}.toSeq)
63+
val aliasedExpressions = right.output.map { ref =>
64+
aliasMap.getOrElse(ref, ref)
65+
}
66+
val newRight = Project(aliasedExpressions, right)
67+
val newJoinCond = joinCond.map { condExpr =>
68+
condExpr transform {
69+
case a: Attribute => aliasMap.getOrElse(a, a).toAttribute
70+
}
71+
}
72+
Join(left, newRight, joinType, newJoinCond)
73+
} else {
74+
j
75+
}
76+
case _ => joinPlan
77+
}
78+
5279
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
5380
case Filter(condition, child) =>
5481
val (withSubquery, withoutSubquery) =
@@ -64,14 +91,17 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
6491
withSubquery.foldLeft(newFilter) {
6592
case (p, Exists(sub, conditions, _)) =>
6693
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
67-
Join(outerPlan, sub, LeftSemi, joinCond)
94+
// Deduplicate conflicting attributes if any.
95+
dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond))
6896
case (p, Not(Exists(sub, conditions, _))) =>
6997
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
70-
Join(outerPlan, sub, LeftAnti, joinCond)
98+
// Deduplicate conflicting attributes if any.
99+
dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond))
71100
case (p, In(value, Seq(ListQuery(sub, conditions, _, _)))) =>
72101
val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
73102
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
74-
Join(outerPlan, sub, LeftSemi, joinCond)
103+
// Deduplicate conflicting attributes if any.
104+
dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond))
75105
case (p, Not(In(value, Seq(ListQuery(sub, conditions, _, _))))) =>
76106
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
77107
// Construct the condition. A NULL in one of the conditions is regarded as a positive
@@ -93,7 +123,8 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
93123
// will have the final conditions in the LEFT ANTI as
94124
// (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2)
95125
val pairs = (joinConds.map(c => Or(c, IsNull(c))) ++ conditions).reduceLeft(And)
96-
Join(outerPlan, sub, LeftAnti, Option(pairs))
126+
// Deduplicate conflicting attributes if any.
127+
dedupJoin(Join(outerPlan, sub, LeftAnti, Option(pairs)))
97128
case (p, predicate) =>
98129
val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p)
99130
Project(p.output, Filter(newCond.get, inputPlan))

sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql
1919

20+
import org.apache.spark.sql.catalyst.plans.logical.Join
2021
import org.apache.spark.sql.test.SharedSQLContext
2122

2223
class SubquerySuite extends QueryTest with SharedSQLContext {
@@ -875,4 +876,66 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
875876
assert(e.message.contains("cannot resolve '`a`' given input columns: [t.i, t.j]"))
876877
}
877878
}
879+
880+
test("SPARK-21835: Join in correlated subquery should be duplicateResolved: case 1") {
881+
withTable("t1") {
882+
withTempPath { path =>
883+
Seq(1 -> "a").toDF("i", "j").write.parquet(path.getCanonicalPath)
884+
sql(s"CREATE TABLE t1 USING parquet LOCATION '${path.toURI}'")
885+
886+
val sqlText =
887+
"""
888+
|SELECT * FROM t1
889+
|WHERE
890+
|NOT EXISTS (SELECT * FROM t1)
891+
""".stripMargin
892+
val optimizedPlan = sql(sqlText).queryExecution.optimizedPlan
893+
val join = optimizedPlan.collectFirst { case j: Join => j }.get
894+
assert(join.duplicateResolved)
895+
assert(optimizedPlan.resolved)
896+
}
897+
}
898+
}
899+
900+
test("SPARK-21835: Join in correlated subquery should be duplicateResolved: case 2") {
901+
withTable("t1", "t2", "t3") {
902+
withTempPath { path =>
903+
val data = Seq((1, 1, 1), (2, 0, 2))
904+
905+
data.toDF("t1a", "t1b", "t1c").write.parquet(path.getCanonicalPath + "/t1")
906+
data.toDF("t2a", "t2b", "t2c").write.parquet(path.getCanonicalPath + "/t2")
907+
data.toDF("t3a", "t3b", "t3c").write.parquet(path.getCanonicalPath + "/t3")
908+
909+
sql(s"CREATE TABLE t1 USING parquet LOCATION '${path.toURI}/t1'")
910+
sql(s"CREATE TABLE t2 USING parquet LOCATION '${path.toURI}/t2'")
911+
sql(s"CREATE TABLE t3 USING parquet LOCATION '${path.toURI}/t3'")
912+
913+
val sqlText =
914+
s"""
915+
|SELECT *
916+
|FROM (SELECT *
917+
| FROM t2
918+
| WHERE t2c IN (SELECT t1c
919+
| FROM t1
920+
| WHERE t1a = t2a)
921+
| UNION
922+
| SELECT *
923+
| FROM t3
924+
| WHERE t3a IN (SELECT t2a
925+
| FROM t2
926+
| UNION ALL
927+
| SELECT t1a
928+
| FROM t1
929+
| WHERE t1b > 0)) t4
930+
|WHERE t4.t2b IN (SELECT Min(t3b)
931+
| FROM t3
932+
| WHERE t4.t2a = t3a)
933+
""".stripMargin
934+
val optimizedPlan = sql(sqlText).queryExecution.optimizedPlan
935+
val joinNodes = optimizedPlan.collect { case j: Join => j }
936+
joinNodes.foreach(j => assert(j.duplicateResolved))
937+
assert(optimizedPlan.resolved)
938+
}
939+
}
940+
}
878941
}

0 commit comments

Comments
 (0)