-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-21835][SQL] RewritePredicateSubquery should not produce unresolved query plans #19050
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,6 +49,30 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { | |
| } | ||
| } | ||
|
|
||
| def dedupJoin(plan: LogicalPlan): LogicalPlan = { | ||
| plan transform { | ||
| case j @ Join(left, right, joinType, joinCond) => | ||
|
||
| val duplicates = right.outputSet.intersect(left.outputSet) | ||
| if (duplicates.nonEmpty) { | ||
| val aliasMap = AttributeMap(duplicates.map { dup => | ||
| dup -> Alias(dup, dup.toString)() | ||
| }.toSeq) | ||
| val aliasedExpressions = right.output.map { ref => | ||
| aliasMap.getOrElse(ref, ref) | ||
| } | ||
| val newRight = Project(aliasedExpressions, right) | ||
| val newJoinCond = joinCond.map { condExpr => | ||
| condExpr transform { | ||
| case a: Attribute => aliasMap.getOrElse(a, a).toAttribute | ||
| } | ||
| } | ||
| Join(left, newRight, joinType, newJoinCond) | ||
| } else { | ||
| j | ||
| } | ||
| } | ||
| } | ||
|
|
||
| def apply(plan: LogicalPlan): LogicalPlan = plan transform { | ||
| case Filter(condition, child) => | ||
| val (withSubquery, withoutSubquery) = | ||
|
|
@@ -61,7 +85,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { | |
| } | ||
|
|
||
| // Filter the plan by applying left semi and left anti joins. | ||
| withSubquery.foldLeft(newFilter) { | ||
| val rewritten = withSubquery.foldLeft(newFilter) { | ||
| case (p, Exists(sub, conditions, _)) => | ||
| val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) | ||
| Join(outerPlan, sub, LeftSemi, joinCond) | ||
|
|
@@ -98,6 +122,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { | |
| val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p) | ||
| Project(p.output, Filter(newCond.get, inputPlan)) | ||
| } | ||
| dedupJoin(rewritten) | ||
|
||
| } | ||
|
|
||
| /** | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
|
|
||
| package org.apache.spark.sql | ||
|
|
||
| import org.apache.spark.sql.catalyst.plans.logical.Join | ||
| import org.apache.spark.sql.test.SharedSQLContext | ||
|
|
||
| class SubquerySuite extends QueryTest with SharedSQLContext { | ||
|
|
@@ -875,4 +876,71 @@ class SubquerySuite extends QueryTest with SharedSQLContext { | |
| assert(e.message.contains("cannot resolve '`a`' given input columns: [t.i, t.j]")) | ||
| } | ||
| } | ||
|
|
||
| test("SPARK-21835: Join in correlated subquery should be duplicateResolved: case 1") { | ||
| withTable("t1") { | ||
| withTempPath { path => | ||
| Seq(1 -> "a").toDF("i", "j").write.parquet(path.getCanonicalPath) | ||
| sql(s"CREATE TABLE t1 USING parquet LOCATION '${path.toURI}'") | ||
|
|
||
| val sqlText = | ||
| """ | ||
| |SELECT * FROM t1 | ||
| |WHERE | ||
| |NOT EXISTS (SELECT * FROM t1) | ||
| """.stripMargin | ||
| val ds = sql(sqlText) | ||
|
||
| val optimizedPlan = sql(sqlText).queryExecution.optimizedPlan | ||
| val join = optimizedPlan.collect { | ||
| case j: Join => j | ||
| }.head.asInstanceOf[Join] | ||
|
||
| assert(join.duplicateResolved) | ||
| assert(optimizedPlan.resolved) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| test("SPARK-21835: Join in correlated subquery should be duplicateResolved: case 2") { | ||
| withTable("t1", "t2", "t3") { | ||
| withTempPath { path => | ||
| val data = Seq((1, 1, 1), (2, 0, 2)) | ||
|
|
||
| data.toDF("t1a", "t1b", "t1c").write.parquet(path.getCanonicalPath + "/t1") | ||
| data.toDF("t2a", "t2b", "t2c").write.parquet(path.getCanonicalPath + "/t2") | ||
| data.toDF("t3a", "t3b", "t3c").write.parquet(path.getCanonicalPath + "/t3") | ||
|
|
||
| sql(s"CREATE TABLE t1 USING parquet LOCATION '${path.toURI}/t1'") | ||
| sql(s"CREATE TABLE t2 USING parquet LOCATION '${path.toURI}/t2'") | ||
| sql(s"CREATE TABLE t3 USING parquet LOCATION '${path.toURI}/t3'") | ||
|
|
||
| val sqlText = | ||
| s""" | ||
| |SELECT * | ||
| |FROM (SELECT * | ||
| | FROM t2 | ||
| | WHERE t2c IN (SELECT t1c | ||
| | FROM t1 | ||
| | WHERE t1a = t2a) | ||
| | UNION | ||
| | SELECT * | ||
| | FROM t3 | ||
| | WHERE t3a IN (SELECT t2a | ||
| | FROM t2 | ||
| | UNION ALL | ||
| | SELECT t1a | ||
| | FROM t1 | ||
| | WHERE t1b > 0)) t4 | ||
| |WHERE t4.t2b IN (SELECT Min(t3b) | ||
| | FROM t3 | ||
| | WHERE t4.t2a = t3a) | ||
| """.stripMargin | ||
| val optimizedPlan = sql(sqlText).queryExecution.optimizedPlan | ||
| val joinNodes = optimizedPlan.collect { | ||
| case j: Join => j | ||
| }.map(_.asInstanceOf[Join]) | ||
| joinNodes.map(j => assert(j.duplicateResolved)) | ||
|
||
| assert(optimizedPlan.resolved) | ||
| } | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-> private def
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok.