-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-25121][SQL] Supports multi-part table names for broadcast hint resolution #22198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
24de799
a6e4e40
f021770
d434ba7
c138b81
6a202f2
545148b
bc29a11
59e60d4
5b2b272
b6b9f65
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,7 @@ import java.util.Locale | |
|
|
||
| import org.apache.spark.sql.AnalysisException | ||
| import org.apache.spark.sql.catalyst.IdentifierWithDatabase | ||
| import org.apache.spark.sql.catalyst.catalog.SessionCatalog | ||
| import org.apache.spark.sql.catalyst.expressions.IntegerLiteral | ||
| import org.apache.spark.sql.catalyst.plans.logical._ | ||
| import org.apache.spark.sql.catalyst.rules.Rule | ||
|
|
@@ -48,16 +49,25 @@ object ResolveHints { | |
| * | ||
| * This rule must happen before common table expressions. | ||
| */ | ||
| class ResolveBroadcastHints(conf: SQLConf) extends Rule[LogicalPlan] { | ||
| class ResolveBroadcastHints(conf: SQLConf, catalog: SessionCatalog) extends Rule[LogicalPlan] { | ||
| private val BROADCAST_HINT_NAMES = Set("BROADCAST", "BROADCASTJOIN", "MAPJOIN") | ||
|
|
||
| def resolver: Resolver = conf.resolver | ||
|
|
||
| private def namePartsWithDatabase(nameParts: Seq[String]): Seq[String] = { | ||
| if (nameParts.size == 1) { | ||
| catalog.getCurrentDatabase +: nameParts | ||
| } else { | ||
| nameParts | ||
| } | ||
| } | ||
|
|
||
| private def matchedTableIdentifier( | ||
| nameParts: Seq[String], | ||
| tableIdent: IdentifierWithDatabase): Boolean = { | ||
| val identifierList = tableIdent.database.map(_ :: Nil).getOrElse(Nil) :+ tableIdent.identifier | ||
| nameParts.corresponds(identifierList)(resolver) | ||
| val identifierList = | ||
| tableIdent.database.getOrElse(catalog.getCurrentDatabase) :: tableIdent.identifier :: Nil | ||
| namePartsWithDatabase(nameParts).corresponds(identifierList)(resolver) | ||
|
||
| } | ||
|
|
||
| private def applyBroadcastHint( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -196,33 +196,42 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { | |
| } | ||
|
|
||
| test("SPARK-25121 Supports multi-part names for broadcast hint resolution") { | ||
| val tableName = "t" | ||
| val (table1Name, table2Name) = ("t1", "t2") | ||
| withTempDatabase { dbName => | ||
| withTable(tableName) { | ||
| withTable(table1Name, table2Name) { | ||
| withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { | ||
| spark.range(100).write.saveAsTable(s"$dbName.$tableName") | ||
| spark.range(50).write.saveAsTable(s"$dbName.$table1Name") | ||
| spark.range(100).write.saveAsTable(s"$dbName.$table2Name") | ||
| // First, makes sure a join is not broadcastable | ||
| val plan1 = spark.range(3) | ||
| .join(spark.table(s"$dbName.$tableName"), "id") | ||
| val plan = sql(s"SELECT * FROM $dbName.$table1Name, $dbName.$table2Name " + | ||
| s"WHERE $table1Name.id = $table2Name.id") | ||
| .queryExecution.executedPlan | ||
| assert(plan1.collect { case p: BroadcastHashJoinExec => p }.size == 0) | ||
| assert(plan.collect { case p: BroadcastHashJoinExec => p }.size == 0) | ||
|
|
||
| // Uses multi-part table names for broadcast hints | ||
| val plan2 = spark.range(3) | ||
| .join(spark.table(s"$dbName.$tableName"), "id") | ||
| .hint("broadcast", s"$dbName.$tableName") | ||
| .queryExecution.executedPlan | ||
| val broadcastHashJoin = plan2.collect { case p: BroadcastHashJoinExec => p } | ||
| assert(broadcastHashJoin.size == 1) | ||
| val broadcastExchange = broadcastHashJoin.head.collect { | ||
| case p: BroadcastExchangeExec => p | ||
| } | ||
| assert(broadcastExchange.size == 1) | ||
| val table = broadcastExchange.head.collect { | ||
| case FileSourceScanExec(_, _, _, _, _, _, Some(tableIdent)) => tableIdent | ||
| def checkIfHintApplied(tableName: String, hintTableName: String): Unit = { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yea, I'll fix. |
||
| val p = sql(s"SELECT /*+ BROADCASTJOIN($tableName) */ * " + | ||
| s"FROM $tableName, $dbName.$table2Name " + | ||
| s"WHERE $tableName.id = $table2Name.id") | ||
| .queryExecution.executedPlan | ||
| val broadcastHashJoin = p.collect { case p: BroadcastHashJoinExec => p } | ||
| assert(broadcastHashJoin.size == 1) | ||
| val broadcastExchange = broadcastHashJoin.head.collect { | ||
| case p: BroadcastExchangeExec => p | ||
| } | ||
| assert(broadcastExchange.size == 1) | ||
| val table = broadcastExchange.head.collect { | ||
| case FileSourceScanExec(_, _, _, _, _, _, Some(tableIdent)) => tableIdent | ||
| } | ||
| assert(table.size == 1) | ||
| assert(table.head === TableIdentifier(table1Name, Some(dbName))) | ||
| } | ||
| assert(table.size == 1) | ||
| assert(table.head === TableIdentifier(tableName, Some(dbName))) | ||
|
|
||
| sql(s"USE $dbName") | ||
| checkIfHintApplied(table1Name, table1Name) | ||
| checkIfHintApplied(s"$dbName.$table1Name", s"$dbName.$table1Name") | ||
| checkIfHintApplied(table1Name, s"$dbName.$table1Name") | ||
| checkIfHintApplied(s"$dbName.$table1Name", table1Name) | ||
| } | ||
| } | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accordingly, we can use String instead of SessionCatalog.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can't use
Stringthere becausecurrentDatabasemight be updatable by others?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can instead use
getCurrentDatabase: () => String?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ya. Right, please ignore this. We need
catalogto lookupglobal_temp, too.