Skip to content

Commit ca499e9

Browse files
maropudongjoon-hyun
authored andcommitted
[SPARK-25121][SQL] Supports multi-part table names for broadcast hint resolution
### What changes were proposed in this pull request? This pr fixed code to respect a database name for broadcast table hint resolution. Currently, spark ignores a database name in multi-part names; ``` scala> sql("CREATE DATABASE testDb") scala> spark.range(10).write.saveAsTable("testDb.t") // without this patch scala> spark.range(10).join(spark.table("testDb.t"), "id").hint("broadcast", "testDb.t").explain == Physical Plan == *(2) Project [id#24L] +- *(2) BroadcastHashJoin [id#24L], [id#26L], Inner, BuildLeft :- BroadcastExchange HashedRelationBroadcastMode(List(input[0, bigint, false])) : +- *(1) Range (0, 10, step=1, splits=4) +- *(2) Project [id#26L] +- *(2) Filter isnotnull(id#26L) +- *(2) FileScan parquet testdb.t[id#26L] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/Users/maropu/Repositories/spark/spark-2.3.1-bin-hadoop2.7/spark-warehouse..., PartitionFilters: [], PushedFilters: [IsNotNull(id)], ReadSchema: struct<id:bigint> // with this patch scala> spark.range(10).join(spark.table("testDb.t"), "id").hint("broadcast", "testDb.t").explain == Physical Plan == *(2) Project [id#3L] +- *(2) BroadcastHashJoin [id#3L], [id#5L], Inner, BuildRight :- *(2) Range (0, 10, step=1, splits=4) +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, bigint, true])) +- *(1) Project [id#5L] +- *(1) Filter isnotnull(id#5L) +- *(1) FileScan parquet testdb.t[id#5L] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/Users/maropu/Repositories/spark/spark-master/spark-warehouse/testdb.db/t], PartitionFilters: [], PushedFilters: [IsNotNull(id)], ReadSchema: struct<id:bigint> ``` This PR comes from #22198 ### Why are the changes needed? For better usability. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Added unit tests. Closes #27935 from maropu/SPARK-25121-2. Authored-by: Takeshi Yamamuro <yamamuro@apache.org> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
1 parent c6a6d5e commit ca499e9

File tree

9 files changed

+230
-27
lines changed

9 files changed

+230
-27
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HintErrorLogger.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,16 @@ import org.apache.spark.sql.catalyst.plans.logical.{HintErrorHandler, HintInfo}
2424
* The hint error handler that logs warnings for each hint error.
2525
*/
2626
object HintErrorLogger extends HintErrorHandler with Logging {
27+
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
2728

2829
override def hintNotRecognized(name: String, parameters: Seq[Any]): Unit = {
2930
logWarning(s"Unrecognized hint: ${hintToPrettyString(name, parameters)}")
3031
}
3132

3233
override def hintRelationsNotFound(
33-
name: String, parameters: Seq[Any], invalidRelations: Set[String]): Unit = {
34-
invalidRelations.foreach { n =>
35-
logWarning(s"Count not find relation '$n' specified in hint " +
34+
name: String, parameters: Seq[Any], invalidRelations: Set[Seq[String]]): Unit = {
35+
invalidRelations.foreach { ident =>
36+
logWarning(s"Count not find relation '${ident.quoted}' specified in hint " +
3637
s"'${hintToPrettyString(name, parameters)}'.")
3738
}
3839
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala

Lines changed: 51 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -64,31 +64,59 @@ object ResolveHints {
6464
_.toUpperCase(Locale.ROOT)).contains(hintName.toUpperCase(Locale.ROOT))))
6565
}
6666

67+
// This method checks if given multi-part identifiers are matched with each other.
68+
// The [[ResolveJoinStrategyHints]] rule is applied before the resolution batch
69+
// in the analyzer and we cannot semantically compare them at this stage.
70+
// Therefore, we follow a simple rule; they match if an identifier in a hint
71+
// is a tail of an identifier in a relation. This process is independent of a session
72+
// catalog (`currentDb` in [[SessionCatalog]]) and it just compares them literally.
73+
//
74+
// For example,
75+
// * in a query `SELECT /*+ BROADCAST(t) */ * FROM db1.t JOIN t`,
76+
// the broadcast hint will match both tables, `db1.t` and `t`,
77+
// even when the current db is `db2`.
78+
// * in a query `SELECT /*+ BROADCAST(default.t) */ * FROM default.t JOIN t`,
79+
// the broadcast hint will match the left-side table only, `default.t`.
80+
private def matchedIdentifier(identInHint: Seq[String], identInQuery: Seq[String]): Boolean = {
81+
if (identInHint.length <= identInQuery.length) {
82+
identInHint.zip(identInQuery.takeRight(identInHint.length))
83+
.forall { case (i1, i2) => resolver(i1, i2) }
84+
} else {
85+
false
86+
}
87+
}
88+
89+
private def extractIdentifier(r: SubqueryAlias): Seq[String] = {
90+
r.identifier.qualifier :+ r.identifier.name
91+
}
92+
6793
private def applyJoinStrategyHint(
6894
plan: LogicalPlan,
69-
relations: mutable.HashSet[String],
95+
relationsInHint: Set[Seq[String]],
96+
relationsInHintWithMatch: mutable.HashSet[Seq[String]],
7097
hintName: String): LogicalPlan = {
7198
// Whether to continue recursing down the tree
7299
var recurse = true
73100

101+
def matchedIdentifierInHint(identInQuery: Seq[String]): Boolean = {
102+
relationsInHint.find(matchedIdentifier(_, identInQuery))
103+
.map(relationsInHintWithMatch.add).nonEmpty
104+
}
105+
74106
val newNode = CurrentOrigin.withOrigin(plan.origin) {
75107
plan match {
76108
case ResolvedHint(u @ UnresolvedRelation(ident), hint)
77-
if relations.exists(resolver(_, ident.last)) =>
78-
relations.remove(ident.last)
109+
if matchedIdentifierInHint(ident) =>
79110
ResolvedHint(u, createHintInfo(hintName).merge(hint, hintErrorHandler))
80111

81112
case ResolvedHint(r: SubqueryAlias, hint)
82-
if relations.exists(resolver(_, r.alias)) =>
83-
relations.remove(r.alias)
113+
if matchedIdentifierInHint(extractIdentifier(r)) =>
84114
ResolvedHint(r, createHintInfo(hintName).merge(hint, hintErrorHandler))
85115

86-
case u @ UnresolvedRelation(ident) if relations.exists(resolver(_, ident.last)) =>
87-
relations.remove(ident.last)
116+
case UnresolvedRelation(ident) if matchedIdentifierInHint(ident) =>
88117
ResolvedHint(plan, createHintInfo(hintName))
89118

90-
case r: SubqueryAlias if relations.exists(resolver(_, r.alias)) =>
91-
relations.remove(r.alias)
119+
case r: SubqueryAlias if matchedIdentifierInHint(extractIdentifier(r)) =>
92120
ResolvedHint(plan, createHintInfo(hintName))
93121

94122
case _: ResolvedHint | _: View | _: With | _: SubqueryAlias =>
@@ -107,7 +135,9 @@ object ResolveHints {
107135
}
108136

109137
if ((plan fastEquals newNode) && recurse) {
110-
newNode.mapChildren(child => applyJoinStrategyHint(child, relations, hintName))
138+
newNode.mapChildren { child =>
139+
applyJoinStrategyHint(child, relationsInHint, relationsInHintWithMatch, hintName)
140+
}
111141
} else {
112142
newNode
113143
}
@@ -120,17 +150,19 @@ object ResolveHints {
120150
ResolvedHint(h.child, createHintInfo(h.name))
121151
} else {
122152
// Otherwise, find within the subtree query plans to apply the hint.
123-
val relationNames = h.parameters.map {
124-
case tableName: String => tableName
125-
case tableId: UnresolvedAttribute => tableId.name
153+
val relationNamesInHint = h.parameters.map {
154+
case tableName: String => UnresolvedAttribute.parseAttributeName(tableName)
155+
case tableId: UnresolvedAttribute => tableId.nameParts
126156
case unsupported => throw new AnalysisException("Join strategy hint parameter " +
127157
s"should be an identifier or string but was $unsupported (${unsupported.getClass}")
128-
}
129-
val relationNameSet = new mutable.HashSet[String]
130-
relationNames.foreach(relationNameSet.add)
131-
132-
val applied = applyJoinStrategyHint(h.child, relationNameSet, h.name)
133-
hintErrorHandler.hintRelationsNotFound(h.name, h.parameters, relationNameSet.toSet)
158+
}.toSet
159+
val relationsInHintWithMatch = new mutable.HashSet[Seq[String]]
160+
val applied = applyJoinStrategyHint(
161+
h.child, relationNamesInHint, relationsInHintWithMatch, h.name)
162+
163+
// Filters unmatched relation identifiers in the hint
164+
val unmatchedIdents = relationNamesInHint -- relationsInHintWithMatch
165+
hintErrorHandler.hintRelationsNotFound(h.name, h.parameters, unmatchedIdents)
134166
applied
135167
}
136168
}
@@ -246,5 +278,4 @@ object ResolveHints {
246278
h.child
247279
}
248280
}
249-
250281
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ package object expressions {
196196
// For example, consider an example where "cat" is the catalog name, "db1" is the database
197197
// name, "a" is the table name and "b" is the column name and "c" is the struct field name.
198198
// If the name parts is cat.db1.a.b.c, then Attribute will match
199-
// Attribute(b, qualifier("cat", "db1, "a")) and List("c") will be the second element
199+
// Attribute(b, qualifier("cat", "db1", "a")) and List("c") will be the second element
200200
var matches: (Seq[Attribute], Seq[String]) = nameParts match {
201201
case catalogPart +: dbPart +: tblPart +: name +: nestedFields =>
202202
val key = (catalogPart.toLowerCase(Locale.ROOT), dbPart.toLowerCase(Locale.ROOT),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,8 @@ trait HintErrorHandler {
186186
* @param parameters the hint parameters
187187
* @param invalidRelations the set of relation names that cannot be associated
188188
*/
189-
def hintRelationsNotFound(name: String, parameters: Seq[Any], invalidRelations: Set[String]): Unit
189+
def hintRelationsNotFound(
190+
name: String, parameters: Seq[Any], invalidRelations: Set[Seq[String]]): Unit
190191

191192
/**
192193
* Callback for a join hint specified on a relation that is not part of a join.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ trait AnalysisTest extends PlanTest {
4545
catalog.createTempView("TaBlE", TestRelations.testRelation, overrideIfExists = true)
4646
catalog.createTempView("TaBlE2", TestRelations.testRelation2, overrideIfExists = true)
4747
catalog.createTempView("TaBlE3", TestRelations.testRelation3, overrideIfExists = true)
48+
catalog.createGlobalTempView("TaBlE4", TestRelations.testRelation4, overrideIfExists = true)
49+
catalog.createGlobalTempView("TaBlE5", TestRelations.testRelation5, overrideIfExists = true)
4850
new Analyzer(catalog, conf) {
4951
override val extendedResolutionRules = EliminateSubqueryAliases +: extendedAnalysisRules
5052
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,4 +241,52 @@ class ResolveHintsSuite extends AnalysisTest {
241241
Project(testRelation.output, testRelation),
242242
caseSensitive = false)
243243
}
244+
245+
test("Supports multi-part table names for broadcast hint resolution") {
246+
// local temp table (single-part identifier case)
247+
checkAnalysis(
248+
UnresolvedHint("MAPJOIN", Seq("table", "table2"),
249+
table("TaBlE").join(table("TaBlE2"))),
250+
Join(
251+
ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))),
252+
ResolvedHint(testRelation2, HintInfo(strategy = Some(BROADCAST))),
253+
Inner,
254+
None,
255+
JoinHint.NONE),
256+
caseSensitive = false)
257+
258+
checkAnalysis(
259+
UnresolvedHint("MAPJOIN", Seq("TaBlE", "table2"),
260+
table("TaBlE").join(table("TaBlE2"))),
261+
Join(
262+
ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))),
263+
testRelation2,
264+
Inner,
265+
None,
266+
JoinHint.NONE),
267+
caseSensitive = true)
268+
269+
// global temp table (multi-part identifier case)
270+
checkAnalysis(
271+
UnresolvedHint("MAPJOIN", Seq("GlOBal_TeMP.table4", "table5"),
272+
table("global_temp", "table4").join(table("global_temp", "table5"))),
273+
Join(
274+
ResolvedHint(testRelation4, HintInfo(strategy = Some(BROADCAST))),
275+
ResolvedHint(testRelation5, HintInfo(strategy = Some(BROADCAST))),
276+
Inner,
277+
None,
278+
JoinHint.NONE),
279+
caseSensitive = false)
280+
281+
checkAnalysis(
282+
UnresolvedHint("MAPJOIN", Seq("global_temp.TaBlE4", "table5"),
283+
table("global_temp", "TaBlE4").join(table("global_temp", "TaBlE5"))),
284+
Join(
285+
ResolvedHint(testRelation4, HintInfo(strategy = Some(BROADCAST))),
286+
testRelation5,
287+
Inner,
288+
None,
289+
JoinHint.NONE),
290+
caseSensitive = true)
291+
}
244292
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ object TestRelations {
4444
AttributeReference("g", StringType)(),
4545
AttributeReference("h", MapType(IntegerType, IntegerType))())
4646

47+
val testRelation5 = LocalRelation(AttributeReference("i", StringType)())
48+
4749
val nestedRelation = LocalRelation(
4850
AttributeReference("top", StructType(
4951
StructField("duplicateField", StringType) ::

sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@
1717

1818
package org.apache.spark.sql
1919

20+
import org.apache.spark.sql.catalyst.TableIdentifier
2021
import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike, LeftOuter, RightOuter}
21-
import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan, Project}
22+
import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, Filter, HintInfo, Join, JoinHint, LogicalPlan, Project}
23+
import org.apache.spark.sql.connector.catalog.CatalogManager
24+
import org.apache.spark.sql.execution.FileSourceScanExec
2225
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
2326
import org.apache.spark.sql.execution.datasources.LogicalRelation
27+
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
2428
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
2529
import org.apache.spark.sql.functions._
2630
import org.apache.spark.sql.internal.SQLConf
@@ -322,4 +326,96 @@ class DataFrameJoinSuite extends QueryTest
322326
}
323327
}
324328
}
329+
330+
test("Supports multi-part names for broadcast hint resolution") {
331+
val (table1Name, table2Name) = ("t1", "t2")
332+
333+
withTempDatabase { dbName =>
334+
withTable(table1Name, table2Name) {
335+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
336+
spark.range(50).write.saveAsTable(s"$dbName.$table1Name")
337+
spark.range(100).write.saveAsTable(s"$dbName.$table2Name")
338+
339+
def checkIfHintApplied(df: DataFrame): Unit = {
340+
val sparkPlan = df.queryExecution.executedPlan
341+
val broadcastHashJoins = sparkPlan.collect { case p: BroadcastHashJoinExec => p }
342+
assert(broadcastHashJoins.size == 1)
343+
val broadcastExchanges = broadcastHashJoins.head.collect {
344+
case p: BroadcastExchangeExec => p
345+
}
346+
assert(broadcastExchanges.size == 1)
347+
val tables = broadcastExchanges.head.collect {
348+
case FileSourceScanExec(_, _, _, _, _, _, Some(tableIdent)) => tableIdent
349+
}
350+
assert(tables.size == 1)
351+
assert(tables.head === TableIdentifier(table1Name, Some(dbName)))
352+
}
353+
354+
def checkIfHintNotApplied(df: DataFrame): Unit = {
355+
val sparkPlan = df.queryExecution.executedPlan
356+
val broadcastHashJoins = sparkPlan.collect { case p: BroadcastHashJoinExec => p }
357+
assert(broadcastHashJoins.isEmpty)
358+
}
359+
360+
def sqlTemplate(tableName: String, hintTableName: String): DataFrame = {
361+
sql(s"SELECT /*+ BROADCASTJOIN($hintTableName) */ * " +
362+
s"FROM $tableName, $dbName.$table2Name " +
363+
s"WHERE $tableName.id = $table2Name.id")
364+
}
365+
366+
def dfTemplate(tableName: String, hintTableName: String): DataFrame = {
367+
spark.table(tableName).join(spark.table(s"$dbName.$table2Name"), "id")
368+
.hint("broadcast", hintTableName)
369+
}
370+
371+
sql(s"USE $dbName")
372+
373+
checkIfHintApplied(sqlTemplate(table1Name, table1Name))
374+
checkIfHintApplied(sqlTemplate(s"$dbName.$table1Name", s"$dbName.$table1Name"))
375+
checkIfHintApplied(sqlTemplate(s"$dbName.$table1Name", table1Name))
376+
checkIfHintNotApplied(sqlTemplate(table1Name, s"$dbName.$table1Name"))
377+
378+
checkIfHintApplied(dfTemplate(table1Name, table1Name))
379+
checkIfHintApplied(dfTemplate(s"$dbName.$table1Name", s"$dbName.$table1Name"))
380+
checkIfHintApplied(dfTemplate(s"$dbName.$table1Name", table1Name))
381+
checkIfHintApplied(dfTemplate(table1Name, s"$dbName.$table1Name"))
382+
checkIfHintApplied(dfTemplate(table1Name,
383+
s"${CatalogManager.SESSION_CATALOG_NAME}.$dbName.$table1Name"))
384+
385+
withView("tv") {
386+
sql(s"CREATE VIEW tv AS SELECT * FROM $dbName.$table1Name")
387+
checkIfHintApplied(sqlTemplate("tv", "tv"))
388+
checkIfHintNotApplied(sqlTemplate("tv", s"$dbName.tv"))
389+
390+
checkIfHintApplied(dfTemplate("tv", "tv"))
391+
checkIfHintApplied(dfTemplate("tv", s"$dbName.tv"))
392+
}
393+
}
394+
}
395+
}
396+
}
397+
398+
test("The same table name exists in two databases for broadcast hint resolution") {
399+
val (db1Name, db2Name) = ("db1", "db2")
400+
401+
withDatabase(db1Name, db2Name) {
402+
withTable("t") {
403+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
404+
sql(s"CREATE DATABASE $db1Name")
405+
sql(s"CREATE DATABASE $db2Name")
406+
spark.range(1).write.saveAsTable(s"$db1Name.t")
407+
spark.range(1).write.saveAsTable(s"$db2Name.t")
408+
409+
// Checks if a broadcast hint applied in both sides
410+
val statement = s"SELECT /*+ BROADCASTJOIN(t) */ * FROM $db1Name.t, $db2Name.t " +
411+
s"WHERE $db1Name.t.id = $db2Name.t.id"
412+
sql(statement).queryExecution.optimizedPlan match {
413+
case Join(_, _, _, _, JoinHint(Some(HintInfo(Some(BROADCAST))),
414+
Some(HintInfo(Some(BROADCAST))))) =>
415+
case _ => fail("broadcast hint not found in both tables")
416+
}
417+
}
418+
}
419+
}
420+
}
325421
}

sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ package org.apache.spark.sql.execution
2020
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
2121
import org.apache.spark.sql.catalog.Table
2222
import org.apache.spark.sql.catalyst.TableIdentifier
23-
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
23+
import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, Join, JoinHint}
24+
import org.apache.spark.sql.internal.SQLConf
2425
import org.apache.spark.sql.test.SharedSparkSession
2526
import org.apache.spark.sql.types.StructType
2627

@@ -170,4 +171,25 @@ class GlobalTempViewSuite extends QueryTest with SharedSparkSession {
170171
isTemporary = true).toString)
171172
}
172173
}
174+
175+
test("broadcast hint on global temp view") {
176+
withGlobalTempView("v1") {
177+
spark.range(10).createGlobalTempView("v1")
178+
withTempView("v2") {
179+
spark.range(10).createTempView("v2")
180+
181+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
182+
Seq(
183+
"SELECT /*+ MAPJOIN(v1) */ * FROM global_temp.v1, v2 WHERE v1.id = v2.id",
184+
"SELECT /*+ MAPJOIN(global_temp.v1) */ * FROM global_temp.v1, v2 WHERE v1.id = v2.id"
185+
).foreach { statement =>
186+
sql(statement).queryExecution.optimizedPlan match {
187+
case Join(_, _, _, _, JoinHint(Some(HintInfo(Some(BROADCAST))), None)) =>
188+
case _ => fail("broadcast hint not found in a left-side table")
189+
}
190+
}
191+
}
192+
}
193+
}
194+
}
173195
}

0 commit comments

Comments
 (0)