Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging {
/**
* Runs this query returning the result as an array.
*/
def executeCollect(): Array[Row] = execute().collect()
def executeCollect(): Array[Row] = execute().map(_.copy()).collect()

protected def buildRow(values: Seq[Any]): Row =
new GenericRow(values.toArray)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,55 @@ import org.apache.spark.sql.{SQLContext, execution}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.{BaseRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.parquet._

private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
self: SQLContext#SparkPlanner =>

/**
* Uses the HashFilteredJoin pattern to find joins where at least some of the predicates can be
* evaluated by matching hash keys.
*/
object HashJoin extends Strategy with PredicateHelper {
var broadcastTables: Seq[String] =
sparkContext.conf.get("spark.sql.hints.broadcastTables", "").split(",").toBuffer

def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
// Find inner joins where at least some predicates can be evaluated by matching hash keys
// using the HashFilteredJoin pattern.

case HashFilteredJoin(
Inner,
leftKeys,
rightKeys,
condition,
left,
right @ PhysicalOperation(_, _, b: BaseRelation))
if broadcastTables.contains(b.tableName)=>

val hashJoin =
execution.BroadcastHashJoin(
leftKeys, rightKeys, BuildRight, planLater(left), planLater(right))(sparkContext)
condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil

case HashFilteredJoin(
Inner,
leftKeys,
rightKeys,
condition,
left @ PhysicalOperation(_, _, b: BaseRelation),
right)
if broadcastTables.contains(b.tableName) =>

val hashJoin =
execution.BroadcastHashJoin(
leftKeys, rightKeys, BuildLeft, planLater(left), planLater(right))(sparkContext)
condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil

case HashFilteredJoin(Inner, leftKeys, rightKeys, condition, left, right) =>
val hashJoin =
execution.HashJoin(leftKeys, rightKeys, BuildRight, planLater(left), planLater(right))
execution.ShuffledHashJoin(
leftKeys, rightKeys, BuildRight, planLater(left), planLater(right))
condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil
case _ => Nil
}
Expand Down
196 changes: 126 additions & 70 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
package org.apache.spark.sql.execution

import scala.collection.mutable.{ArrayBuffer, BitSet}
import scala.concurrent._
import scala.concurrent.duration._
import scala.concurrent.ExecutionContext.Implicits.global

import org.apache.spark.SparkContext

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning}
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning, UnspecifiedDistribution}

@DeveloperApi
sealed abstract class BuildSide
Expand All @@ -35,21 +38,13 @@ case object BuildLeft extends BuildSide
@DeveloperApi
case object BuildRight extends BuildSide

/**
* :: DeveloperApi ::
*/
@DeveloperApi
case class HashJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
buildSide: BuildSide,
left: SparkPlan,
right: SparkPlan) extends BinaryNode {

override def outputPartitioning: Partitioning = left.outputPartitioning

override def requiredChildDistribution =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
trait HashJoin {
val leftKeys: Seq[Expression]
val rightKeys: Seq[Expression]
val buildSide: BuildSide
val left: SparkPlan
val right: SparkPlan

val (buildPlan, streamedPlan) = buildSide match {
case BuildLeft => (left, right)
Expand All @@ -67,79 +62,140 @@ case class HashJoin(
@transient lazy val streamSideKeyGenerator =
() => new MutableProjection(streamedKeys, streamedPlan.output)

def execute() = {

buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
// TODO: Use Spark's HashMap implementation.
val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]()
var currentRow: Row = null

// Create a mapping of buildKeys -> rows
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if(!rowKey.anyNull) {
val existingMatchList = hashTable.get(rowKey)
val matchList = if (existingMatchList == null) {
val newMatchList = new ArrayBuffer[Row]()
hashTable.put(rowKey, newMatchList)
newMatchList
} else {
existingMatchList
}
matchList += currentRow.copy()
def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = {
// TODO: Use Spark's HashMap implementation.

val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]()
var currentRow: Row = null

// Create a mapping of buildKeys -> rows
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if(!rowKey.anyNull) {
val existingMatchList = hashTable.get(rowKey)
val matchList = if (existingMatchList == null) {
val newMatchList = new ArrayBuffer[Row]()
hashTable.put(rowKey, newMatchList)
newMatchList
} else {
existingMatchList
}
matchList += currentRow.copy()
}
}

new Iterator[Row] {
private[this] var currentStreamedRow: Row = _
private[this] var currentHashMatches: ArrayBuffer[Row] = _
private[this] var currentMatchPosition: Int = -1
new Iterator[Row] {
private[this] var currentStreamedRow: Row = _
private[this] var currentHashMatches: ArrayBuffer[Row] = _
private[this] var currentMatchPosition: Int = -1

// Mutable per row objects.
private[this] val joinRow = new JoinedRow
// Mutable per row objects.
private[this] val joinRow = new JoinedRow

private[this] val joinKeys = streamSideKeyGenerator()
private[this] val joinKeys = streamSideKeyGenerator()

override final def hasNext: Boolean =
(currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) ||
override final def hasNext: Boolean =
(currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) ||
(streamIter.hasNext && fetchNext())

override final def next() = {
val ret = joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
currentMatchPosition += 1
ret
}
override final def next() = {
val ret = joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
currentMatchPosition += 1
ret
}

/**
* Searches the streamed iterator for the next row that has at least one match in hashtable.
*
* @return true if the search is successful, and false the streamed iterator runs out of
* tuples.
*/
private final def fetchNext(): Boolean = {
currentHashMatches = null
currentMatchPosition = -1

while (currentHashMatches == null && streamIter.hasNext) {
currentStreamedRow = streamIter.next()
if (!joinKeys(currentStreamedRow).anyNull) {
currentHashMatches = hashTable.get(joinKeys.currentValue)
}
/**
* Searches the streamed iterator for the next row that has at least one match in hashtable.
*
* @return true if the search is successful, and false the streamed iterator runs out of
* tuples.
*/
private final def fetchNext(): Boolean = {
currentHashMatches = null
currentMatchPosition = -1

while (currentHashMatches == null && streamIter.hasNext) {
currentStreamedRow = streamIter.next()
if (!joinKeys(currentStreamedRow).anyNull) {
currentHashMatches = hashTable.get(joinKeys.currentValue)
}
}

if (currentHashMatches == null) {
false
} else {
currentMatchPosition = 0
true
}
if (currentHashMatches == null) {
false
} else {
currentMatchPosition = 0
true
}
}
}
}
}

/**
* :: DeveloperApi ::
* Performs and inner hash join of two child relations by first shuffling the data using the join
* keys.
*/
@DeveloperApi
case class ShuffledHashJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
buildSide: BuildSide,
left: SparkPlan,
right: SparkPlan) extends BinaryNode with HashJoin {

override def outputPartitioning: Partitioning = left.outputPartitioning

override def requiredChildDistribution =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil


def execute() = {
buildPlan.execute().zipPartitions(streamedPlan.execute()) {
(buildIter, streamIter) => joinIterators(buildIter, streamIter)
}
}
}


/**
* :: DeveloperApi ::
* Performs an inner hash join of two child relations. When the operator is constructed, a Spark
* job is asynchronously started to calculate the values for the broadcasted relation. This data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bodies of BroadcastHashJoin and of HashJoin do not strictly reference broadcastFuture, right? If so, the Spark job isn't launched during the constructor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is only run on Line 191 during execute.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, we should update the comment "When the operator is constructed" then.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah i guess it should be when the RDD is constructed.

* is then placed in a Spark broadcast variable. The streamed relation is not shuffled.
*/
@DeveloperApi
case class BroadcastHashJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
buildSide: BuildSide,
left: SparkPlan,
right: SparkPlan)(@transient sc: SparkContext) extends BinaryNode with HashJoin {

override def otherCopyArgs = sc :: Nil

override def outputPartitioning: Partitioning = left.outputPartitioning

override def requiredChildDistribution =
UnspecifiedDistribution :: UnspecifiedDistribution :: Nil

@transient
lazy val broadcastFuture = future {
sc.broadcast(buildPlan.executeCollect())

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, will you plan to clean up broadcast variables after the operation or leave it in the context?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Spark 1.0, with the newly added garbage collection mechanism, when the query plan itself goes out of scope, the broadcast variable should also be cleaned automatically.

Another way we can do this is to have some query context object we pass around the entire physical query plan which tracks the stuff we need to clean up.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Reynold, thanks for the reply. Does spark has a plan to port this PR in to the repo?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We definitely want to merge this PR (assuming you are talking about the broadcast hash join PR).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, the broadcast join. We were experiencing the perf problem when join between a big table with a small table. Look forward to the merge. Do you know when it will approximately be, assuming it goes to 1.1.0?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1.0 is already going through voting now so this won't make it into 1.0. It will be in 1.0.1/1.1; However, if you need this functionality, you can just cherry pick this pull request and do a custom build.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to know. Thanks for the headsup

}

def execute() = {
val broadcastRelation = Await.result(broadcastFuture, 5.minute)

streamedPlan.execute().mapPartitions { streamedIter =>
joinIterators(broadcastRelation.value.iterator, streamedIter)
}
}
}

/**
* :: DeveloperApi ::
*/
Expand Down