Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Fix
  • Loading branch information
maropu committed Feb 13, 2019
commit b6b9f656ea04d7adb7036a7e276ce997efdf446e
Original file line number Diff line number Diff line change
Expand Up @@ -54,30 +54,23 @@ object ResolveHints {

def resolver: Resolver = conf.resolver

private def namePartsWithDatabase(nameParts: Seq[String], database: String): Seq[String] = {
if (nameParts.size == 1) {
database +: nameParts
} else {
nameParts
}
}

// Name resolution in hints follows three rules below:
//
// 1. table name matches if the hint table name only has one part
// 2. table name and database name both match if the hint table name has two parts
// 3. no match happens if the hint table name has more than three parts
//
// This means, `SELECT /* BROADCAST(t) */ * FROM db1.t JOIN db2.t` will match both tables, and
// `SELECT /* BROADCAST(default.t) */ * FROM t` match no table.
private def matchedTableIdentifier(
nameParts: Seq[String],
tableIdent: IdentifierWithDatabase): Boolean = {
tableIdent.database match {
case Some(db) if resolver(catalog.globalTempViewManager.database, db) =>
val identifierList = db :: tableIdent.identifier :: Nil
namePartsWithDatabase(nameParts, catalog.globalTempViewManager.database)
.corresponds(identifierList)(resolver)
case None if catalog.getTempView(tableIdent.identifier).isDefined =>
nameParts.size == 1 && resolver(nameParts.head, tableIdent.identifier)
case _ =>
val db = tableIdent.database.getOrElse(catalog.getCurrentDatabase)
val identifierList = db :: tableIdent.identifier :: Nil
namePartsWithDatabase(nameParts, catalog.getCurrentDatabase)
.corresponds(identifierList)(resolver)
}
tableIdent: IdentifierWithDatabase): Boolean = nameParts match {
case Seq(tableName) =>
resolver(tableIdent.identifier, tableName)
case Seq(dbName, tableName) if tableIdent.database.isDefined =>
resolver(tableIdent.database.get, dbName) && resolver(tableIdent.identifier, tableName)
case _ =>
false
}

private def applyBroadcastHint(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class ResolveHintsSuite extends AnalysisTest {
Seq(errMsgRepa))
}

test("Supports multi-part table names for broadcast hint resolution") {
test("supports multi-part table names for broadcast hint resolution") {
// local temp table
checkAnalysis(
UnresolvedHint("MAPJOIN", Seq("table", "table2"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ package org.apache.spark.sql

import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint}
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -195,43 +194,78 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
assert(plan2.collect { case p: BroadcastHashJoinExec => p }.size == 1)
}

test("SPARK-25121 Supports multi-part names for broadcast hint resolution") {
test("SPARK-25121 supports multi-part names for broadcast hint resolution") {
val (table1Name, table2Name) = ("t1", "t2")

withTempDatabase { dbName =>
withTable(table1Name, table2Name) {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
spark.range(50).write.saveAsTable(s"$dbName.$table1Name")
spark.range(100).write.saveAsTable(s"$dbName.$table2Name")

// First, makes sure a join is not broadcastable
val plan = sql(s"SELECT * FROM $dbName.$table1Name, $dbName.$table2Name " +
s"WHERE $table1Name.id = $table2Name.id")
.queryExecution.executedPlan
assert(plan.collect { case p: BroadcastHashJoinExec => p }.size == 0)
assert(plan.collect { case p: BroadcastHashJoinExec => p }.isEmpty)

// Uses multi-part table names for broadcast hints
def checkIfHintApplied(tableName: String, hintTableName: String): Unit = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hintTableName is never used in this func?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, I'll fix.

val p = sql(s"SELECT /*+ BROADCASTJOIN($hintTableName) */ * " +
s"FROM $tableName, $dbName.$table2Name " +
s"WHERE $tableName.id = $table2Name.id")
.queryExecution.executedPlan
val broadcastHashJoin = p.collect { case p: BroadcastHashJoinExec => p }
assert(broadcastHashJoin.size == 1)
val broadcastExchange = broadcastHashJoin.head.collect {
val broadcastHashJoins = p.collect { case p: BroadcastHashJoinExec => p }
assert(broadcastHashJoins.size == 1)
val broadcastExchanges = broadcastHashJoins.head.collect {
case p: BroadcastExchangeExec => p
}
assert(broadcastExchange.size == 1)
val table = broadcastExchange.head.collect {
assert(broadcastExchanges.size == 1)
val tables = broadcastExchanges.head.collect {
case FileSourceScanExec(_, _, _, _, _, _, Some(tableIdent)) => tableIdent
}
assert(table.size == 1)
assert(table.head === TableIdentifier(table1Name, Some(dbName)))
assert(tables.size == 1)
assert(tables.head === TableIdentifier(table1Name, Some(dbName)))
}

def checkIfHintNotApplied(tableName: String, hintTableName: String): Unit = {
val p = sql(s"SELECT /*+ BROADCASTJOIN($hintTableName) */ * " +
s"FROM $tableName, $dbName.$table2Name " +
s"WHERE $tableName.id = $table2Name.id")
.queryExecution.executedPlan
val broadcastHashJoins = p.collect { case p: BroadcastHashJoinExec => p }
assert(broadcastHashJoins.isEmpty)
}

sql(s"USE $dbName")
checkIfHintApplied(table1Name, table1Name)
checkIfHintApplied(s"$dbName.$table1Name", s"$dbName.$table1Name")
checkIfHintApplied(table1Name, s"$dbName.$table1Name")
checkIfHintApplied(s"$dbName.$table1Name", table1Name)
checkIfHintNotApplied(table1Name, s"$dbName.$table1Name")
checkIfHintNotApplied(s"$dbName.$table1Name", s"$dbName.$table1Name.id")
}
}
}
}

test("SPARK-25121 the same table name exists in two databases for broadcast hint resolution") {
val (db1Name, db2Name) = ("db1", "db2")

withDatabase(db1Name, db2Name) {
withTable("t") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
sql(s"CREATE DATABASE $db1Name")
sql(s"CREATE DATABASE $db2Name")
spark.range(1).write.saveAsTable(s"$db1Name.t")
spark.range(1).write.saveAsTable(s"$db2Name.t")

// Checks if a broadcast hint applied in both sides
val statement = s"SELECT /*+ BROADCASTJOIN(t) */ * FROM $db1Name.t, $db2Name.t " +
s"WHERE $db1Name.t.id = $db2Name.t.id"
sql(statement).queryExecution.optimizedPlan match {
case Join(_, _, _, _, JoinHint(Some(leftHint), Some(rightHint))) =>
assert(leftHint.broadcast && rightHint.broadcast)
case _ => fail("broadcast hint not found in both tables")
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalog.Table
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.plans.logical.{Join, ResolvedHint}
import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -169,9 +169,10 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext {
"SELECT /*+ MAPJOIN(v1) */ * FROM global_temp.v1, v2 WHERE v1.id = v2.id",
"SELECT /*+ MAPJOIN(global_temp.v1) */ * FROM global_temp.v1, v2 WHERE v1.id = v2.id"
).foreach { statement =>
val plan = sql(statement).queryExecution.optimizedPlan
assert(plan.asInstanceOf[Join].left.isInstanceOf[ResolvedHint])
assert(!plan.asInstanceOf[Join].right.isInstanceOf[ResolvedHint])
sql(statement).queryExecution.optimizedPlan match {
case Join(_, _, _, _, JoinHint(Some(leftHint), None)) => assert(leftHint.broadcast)
case _ => fail("broadcast hint not found in a left-side table")
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -738,8 +738,8 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils {
assert(broadcastData.head.identifier === "tv")

val sparkPlan = df.queryExecution.executedPlan
val broadcastHashJoin = sparkPlan.collect { case p: BroadcastHashJoinExec => p }
assert(broadcastHashJoin.size == 1)
val broadcastHashJoins = sparkPlan.collect { case p: BroadcastHashJoinExec => p }
assert(broadcastHashJoins.size == 1)
}
}
}
Expand Down