Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Deduplicate join output for correlated predicate subquery.
  • Loading branch information
viirya committed Aug 28, 2017
commit 82b5bacce80064df1feb087293f7d13ad72334ca
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,30 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
}
}

def dedupJoin(plan: LogicalPlan): LogicalPlan = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> private def

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok.

plan transform {
case j @ Join(left, right, joinType, joinCond) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All join types?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Be safe, this only makes sense for LeftAnti and LeftSemi

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

val duplicates = right.outputSet.intersect(left.outputSet)
if (duplicates.nonEmpty) {
val aliasMap = AttributeMap(duplicates.map { dup =>
dup -> Alias(dup, dup.toString)()
}.toSeq)
val aliasedExpressions = right.output.map { ref =>
aliasMap.getOrElse(ref, ref)
}
val newRight = Project(aliasedExpressions, right)
val newJoinCond = joinCond.map { condExpr =>
condExpr transform {
case a: Attribute => aliasMap.getOrElse(a, a).toAttribute
}
}
Join(left, newRight, joinType, newJoinCond)
} else {
j
}
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Filter(condition, child) =>
val (withSubquery, withoutSubquery) =
Expand All @@ -61,7 +85,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
}

// Filter the plan by applying left semi and left anti joins.
withSubquery.foldLeft(newFilter) {
val rewritten = withSubquery.foldLeft(newFilter) {
case (p, Exists(sub, conditions, _)) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
Join(outerPlan, sub, LeftSemi, joinCond)
Expand Down Expand Up @@ -98,6 +122,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p)
Project(p.output, Filter(newCond.get, inputPlan))
}
dedupJoin(rewritten)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment above this line to explain it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After rethinking it, we can be more conservative. Instead of doing a dedup at the end, we should do it when we convert it to the Join.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point. Will follow it.

}

/**
Expand Down
68 changes: 68 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.test.SharedSQLContext

class SubquerySuite extends QueryTest with SharedSQLContext {
Expand Down Expand Up @@ -875,4 +876,71 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
assert(e.message.contains("cannot resolve '`a`' given input columns: [t.i, t.j]"))
}
}

test("SPARK-21835: Join in correlated subquery should be duplicateResolved: case 1") {
withTable("t1") {
withTempPath { path =>
Seq(1 -> "a").toDF("i", "j").write.parquet(path.getCanonicalPath)
sql(s"CREATE TABLE t1 USING parquet LOCATION '${path.toURI}'")

val sqlText =
"""
|SELECT * FROM t1
|WHERE
|NOT EXISTS (SELECT * FROM t1)
""".stripMargin
val ds = sql(sqlText)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

useless ds?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, missing this. I'll remove it.

val optimizedPlan = sql(sqlText).queryExecution.optimizedPlan
val join = optimizedPlan.collect {
case j: Join => j
}.head.asInstanceOf[Join]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

        val join = optimizedPlan.collectFirst { case j: Join => j }.get

assert(join.duplicateResolved)
assert(optimizedPlan.resolved)
}
}
}

test("SPARK-21835: Join in correlated subquery should be duplicateResolved: case 2") {
withTable("t1", "t2", "t3") {
withTempPath { path =>
val data = Seq((1, 1, 1), (2, 0, 2))

data.toDF("t1a", "t1b", "t1c").write.parquet(path.getCanonicalPath + "/t1")
data.toDF("t2a", "t2b", "t2c").write.parquet(path.getCanonicalPath + "/t2")
data.toDF("t3a", "t3b", "t3c").write.parquet(path.getCanonicalPath + "/t3")

sql(s"CREATE TABLE t1 USING parquet LOCATION '${path.toURI}/t1'")
sql(s"CREATE TABLE t2 USING parquet LOCATION '${path.toURI}/t2'")
sql(s"CREATE TABLE t3 USING parquet LOCATION '${path.toURI}/t3'")

val sqlText =
s"""
|SELECT *
|FROM (SELECT *
| FROM t2
| WHERE t2c IN (SELECT t1c
| FROM t1
| WHERE t1a = t2a)
| UNION
| SELECT *
| FROM t3
| WHERE t3a IN (SELECT t2a
| FROM t2
| UNION ALL
| SELECT t1a
| FROM t1
| WHERE t1b > 0)) t4
|WHERE t4.t2b IN (SELECT Min(t3b)
| FROM t3
| WHERE t4.t2a = t3a)
""".stripMargin
val optimizedPlan = sql(sqlText).queryExecution.optimizedPlan
val joinNodes = optimizedPlan.collect {
case j: Join => j
}.map(_.asInstanceOf[Join])
joinNodes.map(j => assert(j.duplicateResolved))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

        val joinNodes = optimizedPlan.collect { case j: Join => j }
        joinNodes.foreach(j => assert(j.duplicateResolved))

assert(optimizedPlan.resolved)
}
}
}
}