Skip to content

Commit 318bc03

Browse files
committed
Merge pull request apache#14426 from dongjoon-hyun/SPARK-16475-HINT
[SPARK-16475][SQL] Broadcast Hint for SQL Queries
2 parents 457850e + 539782d commit 318bc03

File tree

10 files changed

+435
-4
lines changed

10 files changed

+435
-4
lines changed

sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ querySpecification
365365
(RECORDREADER recordReader=STRING)?
366366
fromClause?
367367
(WHERE where=booleanExpression)?)
368-
| ((kind=SELECT setQuantifier? namedExpressionSeq fromClause?
368+
| ((kind=SELECT hint? setQuantifier? namedExpressionSeq fromClause?
369369
| fromClause (kind=SELECT setQuantifier? namedExpressionSeq)?)
370370
lateralView*
371371
(WHERE where=booleanExpression)?
@@ -374,6 +374,16 @@ querySpecification
374374
windows?)
375375
;
376376

377+
hint
378+
: '/*+' hintStatement '*/'
379+
;
380+
381+
hintStatement
382+
: hintName=identifier
383+
| hintName=identifier '(' parameters+=identifier parameters+=identifier ')'
384+
| hintName=identifier '(' parameters+=identifier (',' parameters+=identifier)* ')'
385+
;
386+
377387
fromClause
378388
: FROM relation (',' relation)* lateralView*
379389
;
@@ -1002,8 +1012,12 @@ SIMPLE_COMMENT
10021012
: '--' ~[\r\n]* '\r'? '\n'? -> channel(HIDDEN)
10031013
;
10041014

1015+
BRACKETED_EMPTY_COMMENT
1016+
: '/**/' -> channel(HIDDEN)
1017+
;
1018+
10051019
BRACKETED_COMMENT
1006-
: '/*' .*? '*/' -> channel(HIDDEN)
1020+
: '/*' ~[+] .*? '*/' -> channel(HIDDEN)
10071021
;
10081022

10091023
WS

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ class Analyzer(
118118
CTESubstitution,
119119
WindowsSubstitution,
120120
EliminateUnions,
121-
new SubstituteUnresolvedOrdinals(conf)),
121+
new SubstituteUnresolvedOrdinals(conf),
122+
SubstituteHints),
122123
Batch("Resolution", fixedPoint,
123124
ResolveTableValuedFunctions ::
124125
ResolveRelations ::
@@ -2087,6 +2088,63 @@ class Analyzer(
20872088
}
20882089
}
20892090

2091+
/**
2092+
* Substitute Hints.
2093+
* - BROADCAST/BROADCASTJOIN/MAPJOIN match the closest table with the given name parameters.
2094+
*
2095+
* This rule substitutes `UnresolvedRelation`s in `Substitute` batch before `ResolveRelations`
2096+
* rule is applied. Here are two reasons.
2097+
* - To support `MetastoreRelation` in Hive module.
2098+
* - To reduce the effect of `Hint` on the other rules.
2099+
*
2100+
* After this rule, it is guaranteed that there exists no unknown `Hint` in the plan.
2101+
* All new `Hint`s should be transformed into concrete Hint classes `BroadcastHint` here.
2102+
*/
2103+
object SubstituteHints extends Rule[LogicalPlan] {
2104+
val BROADCAST_HINT_NAMES = Set("BROADCAST", "BROADCASTJOIN", "MAPJOIN")
2105+
2106+
import scala.collection.mutable.Set
2107+
private def appendAllDescendant(set: Set[LogicalPlan], plan: LogicalPlan): Unit = {
2108+
set += plan
2109+
plan.children.foreach { child => appendAllDescendant(set, child) }
2110+
}
2111+
2112+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
2113+
case logical: LogicalPlan => logical transformDown {
2114+
case h @ Hint(name, parameters, child) if BROADCAST_HINT_NAMES.contains(name.toUpperCase) =>
2115+
var resolvedChild = child
2116+
for (table <- parameters) {
2117+
var stop = false
2118+
val skipNodeSet = scala.collection.mutable.Set.empty[LogicalPlan]
2119+
resolvedChild = resolvedChild.transformDown {
2120+
case n if skipNodeSet.contains(n) =>
2121+
skipNodeSet -= n
2122+
n
2123+
case p @ Project(_, _) if p != resolvedChild =>
2124+
appendAllDescendant(skipNodeSet, p)
2125+
skipNodeSet -= p
2126+
p
2127+
case r @ BroadcastHint(UnresolvedRelation(t, _))
2128+
if !stop && resolver(t.table, table) =>
2129+
stop = true
2130+
r
2131+
case r @ UnresolvedRelation(t, alias) if !stop && resolver(t.table, table) =>
2132+
stop = true
2133+
if (alias.isDefined) {
2134+
SubqueryAlias(alias.get, BroadcastHint(r.copy(alias = None)), None)
2135+
} else {
2136+
BroadcastHint(r)
2137+
}
2138+
}
2139+
}
2140+
resolvedChild
2141+
2142+
// Remove unrecognized hints
2143+
case Hint(name, _, child) => child
2144+
}
2145+
}
2146+
}
2147+
20902148
/**
20912149
* Check and add proper window frames for all window functions.
20922150
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,10 @@ trait CheckAnalysis extends PredicateHelper {
387387
|in operator ${operator.simpleString}
388388
""".stripMargin)
389389

390+
case Hint(_, _, _) =>
391+
throw new IllegalStateException(
392+
"logical hint operator should have been removed by analyzer")
393+
390394
case _ => // Analysis successful!
391395
}
392396
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,10 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
380380
}
381381

382382
// Window
383-
withDistinct.optionalMap(windows)(withWindows)
383+
val withWindow = withDistinct.optionalMap(windows)(withWindows)
384+
385+
// Hint
386+
withWindow.optionalMap(ctx.hint)(withHints)
384387
}
385388
}
386389

@@ -505,6 +508,16 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
505508
}
506509
}
507510

511+
/**
512+
* Add a Hint to a logical plan.
513+
*/
514+
private def withHints(
515+
ctx: HintContext,
516+
query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
517+
val stmt = ctx.hintStatement
518+
Hint(stmt.hintName.getText, stmt.parameters.asScala.map(_.getText), query)
519+
}
520+
508521
/**
509522
* Add a [[Generate]] (Lateral View) to a logical plan.
510523
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,29 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode {
362362
super.computeStats(conf).copy(isBroadcastable = true)
363363
}
364364

365+
/**
366+
* A general hint for the child. This node will be eliminated post analysis.
367+
* A pair of (name, parameters).
368+
*/
369+
case class Hint(name: String, parameters: Seq[String], child: LogicalPlan) extends UnaryNode {
370+
override lazy val resolved: Boolean = false
371+
override def output: Seq[Attribute] = child.output
372+
}
373+
374+
/**
375+
* Options for writing new data into a table.
376+
*
377+
* @param enabled whether to overwrite existing data in the table.
378+
* @param specificPartition only data in the specified partition will be overwritten.
379+
*/
380+
case class OverwriteOptions(
381+
enabled: Boolean,
382+
specificPartition: Option[CatalogTypes.TablePartitionSpec] = None) {
383+
if (specificPartition.isDefined) {
384+
assert(enabled, "Overwrite must be enabled when specifying a partition to overwrite.")
385+
}
386+
}
387+
365388
/**
366389
* Insert some data into a table. Note that this plan is unresolved and has to be replaced by the
367390
* concrete implementations during analysis.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ trait AnalysisTest extends PlanTest {
3232
val conf = new SimpleCatalystConf(caseSensitive)
3333
val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
3434
catalog.createTempView("TaBlE", TestRelations.testRelation, overrideIfExists = true)
35+
catalog.createTempView("TaBlE2", TestRelations.testRelation2, overrideIfExists = true)
3536
new Analyzer(catalog, conf) {
3637
override val extendedResolutionRules = EliminateSubqueryAliases :: Nil
3738
}
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.analysis
19+
20+
import org.apache.spark.sql.catalyst.dsl.expressions._
21+
import org.apache.spark.sql.catalyst.dsl.plans._
22+
import org.apache.spark.sql.catalyst.plans.logical._
23+
24+
class SubstituteHintsSuite extends AnalysisTest {
25+
import org.apache.spark.sql.catalyst.analysis.TestRelations._
26+
27+
val a = testRelation.output(0)
28+
val b = testRelation2.output(0)
29+
30+
test("case-sensitive or insensitive parameters") {
31+
checkAnalysis(
32+
Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE")),
33+
BroadcastHint(testRelation),
34+
caseSensitive = false)
35+
36+
checkAnalysis(
37+
Hint("MAPJOIN", Seq("table"), table("TaBlE")),
38+
BroadcastHint(testRelation),
39+
caseSensitive = false)
40+
41+
checkAnalysis(
42+
Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE")),
43+
BroadcastHint(testRelation))
44+
45+
checkAnalysis(
46+
Hint("MAPJOIN", Seq("table"), table("TaBlE")),
47+
testRelation)
48+
}
49+
50+
test("single hint") {
51+
checkAnalysis(
52+
Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE").select(a)),
53+
BroadcastHint(testRelation).select(a))
54+
55+
checkAnalysis(
56+
Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE").as("t").join(table("TaBlE2").as("u")).select(a)),
57+
BroadcastHint(testRelation).join(testRelation2).select(a))
58+
59+
checkAnalysis(
60+
Hint("MAPJOIN", Seq("TaBlE2"),
61+
table("TaBlE").as("t").join(table("TaBlE2").as("u")).select(a)),
62+
testRelation.join(BroadcastHint(testRelation2)).select(a))
63+
}
64+
65+
test("single hint with multiple parameters") {
66+
checkAnalysis(
67+
Hint("MAPJOIN", Seq("TaBlE", "TaBlE"),
68+
table("TaBlE").as("t").join(table("TaBlE2").as("u")).select(a)),
69+
BroadcastHint(testRelation).join(testRelation2).select(a))
70+
71+
checkAnalysis(
72+
Hint("MAPJOIN", Seq("TaBlE", "TaBlE2"),
73+
table("TaBlE").as("t").join(table("TaBlE2").as("u")).select(a)),
74+
BroadcastHint(testRelation).join(BroadcastHint(testRelation2)).select(a))
75+
}
76+
77+
test("duplicated nested hints are transformed into one") {
78+
checkAnalysis(
79+
Hint("MAPJOIN", Seq("TaBlE"),
80+
Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE").as("t").select('a))
81+
.join(table("TaBlE2").as("u")).select(a)),
82+
BroadcastHint(testRelation).select(a).join(testRelation2).select(a))
83+
84+
checkAnalysis(
85+
Hint("MAPJOIN", Seq("TaBlE2"),
86+
table("TaBlE").as("t").select(a)
87+
.join(Hint("MAPJOIN", Seq("TaBlE2"), table("TaBlE2").as("u").select(b))).select(a)),
88+
testRelation.select(a).join(BroadcastHint(testRelation2).select(b)).select(a))
89+
}
90+
91+
test("distinct nested two hints are handled separately") {
92+
checkAnalysis(
93+
Hint("MAPJOIN", Seq("TaBlE2"),
94+
Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE").as("t").select(a))
95+
.join(table("TaBlE2").as("u")).select(a)),
96+
BroadcastHint(testRelation).select(a).join(BroadcastHint(testRelation2)).select(a))
97+
98+
checkAnalysis(
99+
Hint("MAPJOIN", Seq("TaBlE"),
100+
table("TaBlE").as("t")
101+
.join(Hint("MAPJOIN", Seq("TaBlE2"), table("TaBlE2").as("u").select(b))).select(a)),
102+
BroadcastHint(testRelation).join(BroadcastHint(testRelation2).select(b)).select(a))
103+
}
104+
105+
test("deep self join") {
106+
checkAnalysis(
107+
Hint("MAPJOIN", Seq("TaBlE"),
108+
table("TaBlE").join(table("TaBlE")).join(table("TaBlE")).join(table("TaBlE")).select(a)),
109+
BroadcastHint(testRelation).join(testRelation).join(testRelation).join(testRelation)
110+
.select(a))
111+
}
112+
113+
test("subquery should be ignored") {
114+
checkAnalysis(
115+
Hint("MAPJOIN", Seq("TaBlE"),
116+
table("TaBlE").select(a).as("x").join(table("TaBlE")).select(a)),
117+
testRelation.select(a).join(BroadcastHint(testRelation)).select(a))
118+
119+
checkAnalysis(
120+
Hint("MAPJOIN", Seq("TaBlE"),
121+
table("TaBlE").as("t").select(a).as("x")
122+
.join(table("TaBlE2").as("t2")).select(a)),
123+
testRelation.select(a).join(testRelation2).select(a))
124+
}
125+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,4 +493,46 @@ class PlanParserSuite extends PlanTest {
493493
assertEqual("select a, b from db.c where x !> 1",
494494
table("db", "c").where('x <= 1).select('a, 'b))
495495
}
496+
497+
test("select hint syntax") {
498+
// Hive compatibility: Missing parameter raises ParseException.
499+
val m = intercept[ParseException] {
500+
parsePlan("SELECT /*+ HINT() */ * FROM t")
501+
}.getMessage
502+
assert(m.contains("no viable alternative at input"))
503+
504+
// Hive compatibility: No database.
505+
val m2 = intercept[ParseException] {
506+
parsePlan("SELECT /*+ MAPJOIN(default.t) */ * from default.t")
507+
}.getMessage
508+
assert(m2.contains("no viable alternative at input"))
509+
510+
comparePlans(
511+
parsePlan("SELECT /*+ HINT */ * FROM t"),
512+
Hint("HINT", Seq.empty, table("t").select(star())))
513+
514+
comparePlans(
515+
parsePlan("SELECT /*+ BROADCASTJOIN(u) */ * FROM t"),
516+
Hint("BROADCASTJOIN", Seq("u"), table("t").select(star())))
517+
518+
comparePlans(
519+
parsePlan("SELECT /*+ MAPJOIN(u) */ * FROM t"),
520+
Hint("MAPJOIN", Seq("u"), table("t").select(star())))
521+
522+
comparePlans(
523+
parsePlan("SELECT /*+ STREAMTABLE(a,b,c) */ * FROM t"),
524+
Hint("STREAMTABLE", Seq("a", "b", "c"), table("t").select(star())))
525+
526+
comparePlans(
527+
parsePlan("SELECT /*+ INDEX(t emp_job_ix) */ * FROM t"),
528+
Hint("INDEX", Seq("t", "emp_job_ix"), table("t").select(star())))
529+
530+
comparePlans(
531+
parsePlan("SELECT /*+ MAPJOIN(`default.t`) */ * from `default.t`"),
532+
Hint("MAPJOIN", Seq("default.t"), table("default.t").select(star())))
533+
534+
comparePlans(
535+
parsePlan("SELECT /*+ MAPJOIN(t) */ a from t where true group by a order by a"),
536+
Hint("MAPJOIN", Seq("t"), table("t").where(Literal(true)).groupBy('a)('a)).orderBy('a.asc))
537+
}
496538
}

0 commit comments

Comments
 (0)