Skip to content
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
aa7120e
Initial Broadcast design
hvanhovell Feb 4, 2016
c2b7533
Fix Exchange and initial code gen attempt.
hvanhovell Feb 4, 2016
6a5568a
Move broadcast retreval to SparkPlan
hvanhovell Feb 6, 2016
9adecdd
Merge remote-tracking branch 'spark/master' into SPARK-13136
hvanhovell Feb 6, 2016
d0194fb
Fix Codegen & Add other broadcast joins.
hvanhovell Feb 6, 2016
02a61b8
Minor touchup
hvanhovell Feb 6, 2016
c12c8e6
Move broadcast relation retrieval.
hvanhovell Feb 7, 2016
c7dd7ae
Remove codegen from broadcast.
hvanhovell Feb 8, 2016
e847383
Merge remote-tracking branch 'spark/master' into SPARK-13136
hvanhovell Feb 10, 2016
d73f11c
Merge remote-tracking branch 'apache-github/master' into SPARK-13136
hvanhovell Feb 12, 2016
9c0f4bf
Remove closure passing.
hvanhovell Feb 14, 2016
da4a966
Merge remote-tracking branch 'apache-github/master' into SPARK-13136
hvanhovell Feb 14, 2016
681f347
Move transform into BroadcastMode
hvanhovell Feb 15, 2016
7db240a
Clean-up
hvanhovell Feb 15, 2016
3ad839d
Code Review.
hvanhovell Feb 16, 2016
1116768
No newline at EOF :(
hvanhovell Feb 16, 2016
a5501cf
Rename exchanges and merge Broadcast.scala into exchange.scala.
hvanhovell Feb 17, 2016
c7429bb
Merge remote-tracking branch 'apache-github/master' into SPARK-13136
hvanhovell Feb 17, 2016
b12bbc2
Merge remote-tracking branch 'apache-github/master' into SPARK-13136
hvanhovell Feb 20, 2016
9d52650
Revert renaming of variabels in LeftSemiJoinBNL.
hvanhovell Feb 20, 2016
54b558d
Revert renaming of variabels in LeftSemiJoinBNL.
hvanhovell Feb 20, 2016
f33d2cb
Move all exchange related operators into the exchange package.
hvanhovell Feb 21, 2016
28363c8
CR
hvanhovell Feb 21, 2016
f812a31
Merge remote-tracking branch 'apache-github/master' into SPARK-13136
hvanhovell Feb 21, 2016
4b5978b
put broadcast mode in a separate file.
hvanhovell Feb 21, 2016
c8c175e
Fix style in sqlcontext.
hvanhovell Feb 21, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.plans.physical

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{DataType, IntegerType}

Expand Down Expand Up @@ -75,6 +76,12 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution {
def clustering: Set[Expression] = ordering.map(_.child).toSet
}

/**
* Represents data where tuples are broadcasted to every node. It is quite common that the
* entire set of tuples is transformed into different data structure.
*/
case class BroadcastDistribution(f: Iterable[InternalRow] => Any = identity) extends Distribution
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm thinking maybe it's better to just declare that we want a hashed broadcast distribution, and then don't take a closure. The reason it is bad to take a closure is that this won't work if we want to whole-stage codegen the building of the hash table, or if we want to change the internal engine to a push-based model.


/**
* Describes how an operator's output is split across partitions. The `compatibleWith`,
* `guarantees`, and `satisfies` methods describe relationships between child partitionings,
Expand Down Expand Up @@ -213,7 +220,10 @@ case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning {
case object SinglePartition extends Partitioning {
val numPartitions = 1

override def satisfies(required: Distribution): Boolean = true
override def satisfies(required: Distribution): Boolean = required match {
case _: BroadcastDistribution => false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this is ok for now, but technically we don't need to introduce an exchange if both sides of the join have only one partition. i guess this framework does not currently handle that.

case _ => true
}

override def compatibleWith(other: Partitioning): Boolean = other.numPartitions == 1

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution

import scala.concurrent._
import scala.concurrent.duration._

import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.util.ThreadUtils

/**
* A broadcast collects, transforms and finally broadcasts the result of a transformed SparkPlan.
*/
case class Broadcast(f: Iterable[InternalRow] => Any, child: SparkPlan) extends UnaryNode {

override def output: Seq[Attribute] = child.output

override private[sql] lazy val metrics = Map(
"numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")
)

val timeout: Duration = {
val timeoutValue = sqlContext.conf.broadcastTimeout
if (timeoutValue < 0) {
Duration.Inf
} else {
timeoutValue.seconds
}
}

@transient
private lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
val numBuildRows = longMetric("numRows")

// broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here.
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
Future {
// This will run in another thread. Set the execution id so that we can connect these jobs
// with the correct execution.
SQLExecution.withExecutionId(sparkContext, executionId) {
// Note that we use .execute().collect() because we don't want to convert data to Scala
// types
val input: Array[InternalRow] = child.execute().map { row =>
numBuildRows += 1
row.copy()
}.collect()

// Construct and broadcast the relation.
sparkContext.broadcast(f(input))
}
}(Broadcast.executionContext)
}

override protected def doPrepare(): Unit = {
// Materialize the future.
relationFuture
}

override protected def doExecute(): RDD[InternalRow] = {
child.execute() // TODO throw an Exception here?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Throw an UnsupportedOperationException?

}

override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
val result = Await.result(relationFuture, timeout)
result.asInstanceOf[broadcast.Broadcast[T]]
}
}

object Broadcast {
private[execution] val executionContext = ExecutionContext.fromExecutorService(
ThreadUtils.newDaemonCachedThreadPool("build-broadcast", 128))
}
Original file line number Diff line number Diff line change
Expand Up @@ -395,18 +395,31 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
assert(requiredChildOrderings.length == children.length)

// Ensure that the operator's children satisfy their output distribution requirements:
children = children.zip(requiredChildDistributions).map { case (child, distribution) =>
if (child.outputPartitioning.satisfies(distribution)) {
children = children.zip(requiredChildDistributions).map {
case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
child
} else {
case (child, BroadcastDistribution(f1)) =>
child match {
// The child is broadcasting the same variable: keep the child.
case Broadcast(f2, _) if f1 == f2 => child
// The child is broadcasting a different variable: replace the child.
case Broadcast(f2, src) => Broadcast(f1, src)
// Create a broadcast on top of the child.
case _ => Broadcast(f1, child)
}
case (child, distribution) =>
Exchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child)
}
}

// If the operator has multiple children and specifies child output distributions (e.g. join),
// then the children's output partitionings must be compatible:
def requireCompatiblePartitioning(distribution: Distribution): Boolean = distribution match {
case UnspecifiedDistribution => false
case BroadcastDistribution(_) => false
case _ => true
}
if (children.length > 1
&& requiredChildDistributions.toSet != Set(UnspecifiedDistribution)
&& requiredChildDistributions.exists(requireCompatiblePartitioning)
&& !Partitioning.allCompatible(children.map(_.outputPartitioning))) {

// First check if the existing partitions of the children all match. This means they are
Expand Down Expand Up @@ -443,8 +456,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[

children.zip(requiredChildDistributions).map {
case (child, distribution) => {
val targetPartitioning =
createPartitioning(distribution, numPartitions)
val targetPartitioning = createPartitioning(distribution, numPartitions)
if (child.outputPartitioning.guarantees(targetPartitioning)) {
child
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.Logging
import org.apache.spark.broadcast
import org.apache.spark.rdd.{RDD, RDDOperationScope}
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
Expand Down Expand Up @@ -98,14 +99,29 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil)

/**
* Returns the result of this query as an RDD[InternalRow] by delegating to doExecute
* after adding query plan information to created RDDs for visualization.
* Concrete implementations of SparkPlan should override doExecute instead.
* Returns the result of this query as an RDD[InternalRow] by delegating to doExecute after
* preparations. Concrete implementations of SparkPlan should override doExecute.
*/
final def execute(): RDD[InternalRow] = {
final def execute(): RDD[InternalRow] = executeQuery {
doExecute()
}

/**
* Returns the result of this query as a broadcast variable by delegating to doBroadcast after
* preparations. Concrete implementations of SparkPlan should override doBroadcast.
*/
final def executeBroadcast[T](): broadcast.Broadcast[T] = executeQuery {
doExecuteBroadcast()
}

/**
* Execute a query after preparing the query and adding query plan information to created RDDs
* for visualization.
*/
private final def executeQuery[T](query: => T): T = {
RDDOperationScope.withScope(sparkContext, nodeName, false, true) {
prepare()
doExecute()
query
}
}

Expand Down Expand Up @@ -135,6 +151,14 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
*/
protected def doExecute(): RDD[InternalRow]

/**
* Overridden by concrete implementations of SparkPlan.
* Produces the result of the query as a broadcast variable.
*/
protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
throw new NotImplementedError(s"$nodeName does not implement doExecuteBroadcast")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UnsupportedOperationException ?

}

/**
* Runs this query returning the result as an array.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -360,6 +361,9 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
// the generated code will be huge if there are too many columns
val haveManyColumns = plan.output.length > 200
!willFallback && !haveManyColumns
// Collapse a broadcast into the stage - it should not contain any code that can be
// codegenerated.
case _: Broadcast => true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also not needed.

case _ => false
}

Expand All @@ -370,10 +374,10 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
var inputs = ArrayBuffer[SparkPlan]()
val combined = plan.transform {
// The build side can't be compiled together
case b @ BroadcastHashJoin(_, _, BuildLeft, _, left, right) =>
b.copy(left = apply(left))
case b @ BroadcastHashJoin(_, _, BuildRight, _, left, right) =>
b.copy(right = apply(right))
case b @ BroadcastHashJoin(_, _, BuildLeft, _, Broadcast(f, left), _) =>
b.copy(left = Broadcast(f, apply(left)))
case b @ BroadcastHashJoin(_, _, BuildRight, _, _, Broadcast(f, right)) =>
b.copy(right = Broadcast(f, apply(right)))
case p if !supportCodegen(p) =>
val input = apply(p) // collapse them recursively
inputs += input
Expand Down
Loading