Skip to content

Commit 4f769b9

Browse files
committed
[SPARK-17296][SQL] Simplify parser join processing.
## What changes were proposed in this pull request? Join processing in the parser relies on the fact that the grammar produces a right nested trees, for instance the parse tree for `select * from a join b join c` is expected to produce a tree similar to `JOIN(a, JOIN(b, c))`. However there are cases in which this (invariant) is violated, like: ```sql SELECT COUNT(1) FROM test T1 CROSS JOIN test T2 JOIN test T3 ON T3.col = T1.col JOIN test T4 ON T4.col = T1.col ``` In this case the parser returns a tree in which Joins are located on both the left and the right sides of the parent join node. This PR introduces a different grammar rule which does not make this assumption. The new rule takes a relation and searches for zero or more joined relations. As a bonus processing is much easier. ## How was this patch tested? Existing tests and I have added a regression test to the plan parser suite. Author: Herman van Hovell <[email protected]> Closes apache#14867 from hvanhovell/SPARK-17296.
1 parent 29cfab3 commit 4f769b9

File tree

4 files changed

+102
-58
lines changed

4 files changed

+102
-58
lines changed

sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -374,11 +374,12 @@ setQuantifier
374374
;
375375

376376
relation
377-
: left=relation
378-
(joinType JOIN right=relation joinCriteria?
379-
| NATURAL joinType JOIN right=relation
380-
) #joinRelation
381-
| relationPrimary #relationDefault
377+
: relationPrimary joinRelation*
378+
;
379+
380+
joinRelation
381+
: (joinType) JOIN right=relationPrimary joinCriteria?
382+
| NATURAL joinType JOIN right=relationPrimary
382383
;
383384

384385
joinType

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 50 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
9292

9393
// Apply CTEs
9494
query.optional(ctx.ctes) {
95-
val ctes = ctx.ctes.namedQuery.asScala.map {
96-
case nCtx =>
97-
val namedQuery = visitNamedQuery(nCtx)
98-
(namedQuery.alias, namedQuery)
95+
val ctes = ctx.ctes.namedQuery.asScala.map { nCtx =>
96+
val namedQuery = visitNamedQuery(nCtx)
97+
(namedQuery.alias, namedQuery)
9998
}
10099
// Check for duplicate names.
101100
checkDuplicateKeys(ctes, ctx)
@@ -401,7 +400,11 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
401400
* separated) relations here, these get converted into a single plan by condition-less inner join.
402401
*/
403402
override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) {
404-
val from = ctx.relation.asScala.map(plan).reduceLeft(Join(_, _, Inner, None))
403+
val from = ctx.relation.asScala.foldLeft(null: LogicalPlan) { (left, relation) =>
404+
val right = plan(relation.relationPrimary)
405+
val join = right.optionalMap(left)(Join(_, _, Inner, None))
406+
withJoinRelations(join, relation)
407+
}
405408
ctx.lateralView.asScala.foldLeft(from)(withGenerate)
406409
}
407410

@@ -532,55 +535,53 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
532535
}
533536

534537
/**
535-
* Create a joins between two or more logical plans.
538+
* Create a single relation referenced in a FROM claused. This method is used when a part of the
539+
* join condition is nested, for example:
540+
* {{{
541+
* select * from t1 join (t2 cross join t3) on col1 = col2
542+
* }}}
536543
*/
537-
override def visitJoinRelation(ctx: JoinRelationContext): LogicalPlan = withOrigin(ctx) {
538-
/** Build a join between two plans. */
539-
def join(ctx: JoinRelationContext, left: LogicalPlan, right: LogicalPlan): Join = {
540-
val baseJoinType = ctx.joinType match {
541-
case null => Inner
542-
case jt if jt.CROSS != null => Cross
543-
case jt if jt.FULL != null => FullOuter
544-
case jt if jt.SEMI != null => LeftSemi
545-
case jt if jt.ANTI != null => LeftAnti
546-
case jt if jt.LEFT != null => LeftOuter
547-
case jt if jt.RIGHT != null => RightOuter
548-
case _ => Inner
549-
}
544+
override def visitRelation(ctx: RelationContext): LogicalPlan = withOrigin(ctx) {
545+
withJoinRelations(plan(ctx.relationPrimary), ctx)
546+
}
550547

551-
// Resolve the join type and join condition
552-
val (joinType, condition) = Option(ctx.joinCriteria) match {
553-
case Some(c) if c.USING != null =>
554-
val columns = c.identifier.asScala.map { column =>
555-
UnresolvedAttribute.quoted(column.getText)
556-
}
557-
(UsingJoin(baseJoinType, columns), None)
558-
case Some(c) if c.booleanExpression != null =>
559-
(baseJoinType, Option(expression(c.booleanExpression)))
560-
case None if ctx.NATURAL != null =>
561-
(NaturalJoin(baseJoinType), None)
562-
case None =>
563-
(baseJoinType, None)
564-
}
565-
Join(left, right, joinType, condition)
566-
}
548+
/**
549+
* Join one more [[LogicalPlan]]s to the current logical plan.
550+
*/
551+
private def withJoinRelations(base: LogicalPlan, ctx: RelationContext): LogicalPlan = {
552+
ctx.joinRelation.asScala.foldLeft(base) { (left, join) =>
553+
withOrigin(join) {
554+
val baseJoinType = join.joinType match {
555+
case null => Inner
556+
case jt if jt.CROSS != null => Cross
557+
case jt if jt.FULL != null => FullOuter
558+
case jt if jt.SEMI != null => LeftSemi
559+
case jt if jt.ANTI != null => LeftAnti
560+
case jt if jt.LEFT != null => LeftOuter
561+
case jt if jt.RIGHT != null => RightOuter
562+
case _ => Inner
563+
}
567564

568-
// Handle all consecutive join clauses. ANTLR produces a right nested tree in which the the
569-
// first join clause is at the top. However fields of previously referenced tables can be used
570-
// in following join clauses. The tree needs to be reversed in order to make this work.
571-
var result = plan(ctx.left)
572-
var current = ctx
573-
while (current != null) {
574-
current.right match {
575-
case right: JoinRelationContext =>
576-
result = join(current, result, plan(right.left))
577-
current = right
578-
case right =>
579-
result = join(current, result, plan(right))
580-
current = null
565+
// Resolve the join type and join condition
566+
val (joinType, condition) = Option(join.joinCriteria) match {
567+
case Some(c) if c.USING != null =>
568+
val columns = c.identifier.asScala.map { column =>
569+
UnresolvedAttribute.quoted(column.getText)
570+
}
571+
(UsingJoin(baseJoinType, columns), None)
572+
case Some(c) if c.booleanExpression != null =>
573+
(baseJoinType, Option(expression(c.booleanExpression)))
574+
case None if join.NATURAL != null =>
575+
if (baseJoinType == Cross) {
576+
throw new ParseException("NATURAL CROSS JOIN is not supported", ctx)
577+
}
578+
(NaturalJoin(baseJoinType), None)
579+
case None =>
580+
(baseJoinType, None)
581+
}
582+
Join(left, plan(join.right), joinType, condition)
581583
}
582584
}
583-
result
584585
}
585586

586587
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.parser
1818

1919
import scala.collection.mutable.StringBuilder
2020

21-
import org.antlr.v4.runtime.{CharStream, ParserRuleContext, Token}
21+
import org.antlr.v4.runtime.{ParserRuleContext, Token}
2222
import org.antlr.v4.runtime.misc.Interval
2323
import org.antlr.v4.runtime.tree.TerminalNode
2424

@@ -189,9 +189,7 @@ object ParserUtils {
189189
* Map a [[LogicalPlan]] to another [[LogicalPlan]] if the passed context exists using the
190190
* passed function. The original plan is returned when the context does not exist.
191191
*/
192-
def optionalMap[C <: ParserRuleContext](
193-
ctx: C)(
194-
f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = {
192+
def optionalMap[C](ctx: C)(f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = {
195193
if (ctx != null) {
196194
f(ctx, plan)
197195
} else {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,10 +360,54 @@ class PlanParserSuite extends PlanTest {
360360
test("left anti join", LeftAnti, testExistence)
361361
test("anti join", LeftAnti, testExistence)
362362

363+
// Test natural cross join
364+
intercept("select * from a natural cross join b")
365+
366+
// Test natural join with a condition
367+
intercept("select * from a natural join b on a.id = b.id")
368+
363369
// Test multiple consecutive joins
364370
assertEqual(
365371
"select * from a join b join c right join d",
366372
table("a").join(table("b")).join(table("c")).join(table("d"), RightOuter).select(star()))
373+
374+
// SPARK-17296
375+
assertEqual(
376+
"select * from t1 cross join t2 join t3 on t3.id = t1.id join t4 on t4.id = t1.id",
377+
table("t1")
378+
.join(table("t2"), Cross)
379+
.join(table("t3"), Inner, Option(Symbol("t3.id") === Symbol("t1.id")))
380+
.join(table("t4"), Inner, Option(Symbol("t4.id") === Symbol("t1.id")))
381+
.select(star()))
382+
383+
// Test multiple on clauses.
384+
intercept("select * from t1 inner join t2 inner join t3 on col3 = col2 on col3 = col1")
385+
386+
// Parenthesis
387+
assertEqual(
388+
"select * from t1 inner join (t2 inner join t3 on col3 = col2) on col3 = col1",
389+
table("t1")
390+
.join(table("t2")
391+
.join(table("t3"), Inner, Option('col3 === 'col2)), Inner, Option('col3 === 'col1))
392+
.select(star()))
393+
assertEqual(
394+
"select * from t1 inner join (t2 inner join t3) on col3 = col2",
395+
table("t1")
396+
.join(table("t2").join(table("t3"), Inner, None), Inner, Option('col3 === 'col2))
397+
.select(star()))
398+
assertEqual(
399+
"select * from t1 inner join (t2 inner join t3 on col3 = col2)",
400+
table("t1")
401+
.join(table("t2").join(table("t3"), Inner, Option('col3 === 'col2)), Inner, None)
402+
.select(star()))
403+
404+
// Implicit joins.
405+
assertEqual(
406+
"select * from t1, t3 join t2 on t1.col1 = t2.col2",
407+
table("t1")
408+
.join(table("t3"))
409+
.join(table("t2"), Inner, Option(Symbol("t1.col1") === Symbol("t2.col2")))
410+
.select(star()))
367411
}
368412

369413
test("sampled relations") {

0 commit comments

Comments
 (0)