Skip to content

Commit cae0af3

Browse files
committed
[SPARK-2521] Broadcast RDD object (instead of sending it along with every task).
Currently (as of Spark 1.0.1), Spark sends RDD object (which contains closures) using Akka along with the task itself to the executors. This is inefficient because all tasks in the same stage use the same RDD object, but we have to send RDD object multiple times to the executors. This is especially bad when a closure references some variable that is very large. The current design led to users having to explicitly broadcast large variables. The patch uses broadcast to send RDD objects and the closures to executors, and use Akka to only send a reference to the broadcast RDD/closure along with the partition specific information for the task. For those of you who know more about the internals, Spark already relies on broadcast to send the Hadoop JobConf every time it uses the Hadoop input, because the JobConf is large. The user-facing impact of the change include: 1. Users won't need to decide what to broadcast anymore, unless they would want to use a large object multiple times in different operations 2. Task size will get smaller, resulting in faster scheduling and higher task dispatch throughput. In addition, the change will simplify some internals of Spark, eliminating the need to maintain task caches and the complex logic to broadcast JobConf (which also led to a deadlock recently). A simple way to test this: ```scala val a = new Array[Byte](1000*1000); scala.util.Random.nextBytes(a); sc.parallelize(1 to 1000, 1000).map { x => a; x }.groupBy { x => a; x }.count ``` Numbers on 3 r3.8xlarge instances on EC2 ``` master branch: 5.648436068 s, 4.715361895 s, 5.360161877 s with this change: 3.416348793 s, 1.477846558 s, 1.553432156 s ``` Author: Reynold Xin <rxin@apache.org> Closes #1452 from rxin/broadcast-task and squashes the following commits: 762e0be [Reynold Xin] Warn large broadcasts. ade6eac [Reynold Xin] Log broadcast size. c3b6f11 [Reynold Xin] Added a unit test for clean up. 754085f [Reynold Xin] Explain why broadcasting serialized copy of the task. 04b17f0 [Reynold Xin] [SPARK-2521] Broadcast RDD object once per TaskSet (instead of sending it for every task). (cherry picked from commit 7b8cd17) Signed-off-by: Reynold Xin <rxin@apache.org>
1 parent 86534d0 commit cae0af3

File tree

8 files changed

+137
-251
lines changed

8 files changed

+137
-251
lines changed

core/src/main/scala/org/apache/spark/Dependency.scala

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ import org.apache.spark.shuffle.ShuffleHandle
2727
* Base class for dependencies.
2828
*/
2929
@DeveloperApi
30-
abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
30+
abstract class Dependency[T] extends Serializable {
31+
def rdd: RDD[T]
32+
}
3133

3234

3335
/**
@@ -36,41 +38,47 @@ abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
3638
* partition of the child RDD. Narrow dependencies allow for pipelined execution.
3739
*/
3840
@DeveloperApi
39-
abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
41+
abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] {
4042
/**
4143
* Get the parent partitions for a child partition.
4244
* @param partitionId a partition of the child RDD
4345
* @return the partitions of the parent RDD that the child partition depends upon
4446
*/
4547
def getParents(partitionId: Int): Seq[Int]
48+
49+
override def rdd: RDD[T] = _rdd
4650
}
4751

4852

4953
/**
5054
* :: DeveloperApi ::
51-
* Represents a dependency on the output of a shuffle stage.
52-
* @param rdd the parent RDD
55+
* Represents a dependency on the output of a shuffle stage. Note that in the case of shuffle,
56+
* the RDD is transient since we don't need it on the executor side.
57+
*
58+
* @param _rdd the parent RDD
5359
* @param partitioner partitioner used to partition the shuffle output
5460
* @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None,
5561
* the default serializer, as specified by `spark.serializer` config option, will
5662
* be used.
5763
*/
5864
@DeveloperApi
5965
class ShuffleDependency[K, V, C](
60-
@transient rdd: RDD[_ <: Product2[K, V]],
66+
@transient _rdd: RDD[_ <: Product2[K, V]],
6167
val partitioner: Partitioner,
6268
val serializer: Option[Serializer] = None,
6369
val keyOrdering: Option[Ordering[K]] = None,
6470
val aggregator: Option[Aggregator[K, V, C]] = None,
6571
val mapSideCombine: Boolean = false)
66-
extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {
72+
extends Dependency[Product2[K, V]] {
73+
74+
override def rdd = _rdd.asInstanceOf[RDD[Product2[K, V]]]
6775

68-
val shuffleId: Int = rdd.context.newShuffleId()
76+
val shuffleId: Int = _rdd.context.newShuffleId()
6977

70-
val shuffleHandle: ShuffleHandle = rdd.context.env.shuffleManager.registerShuffle(
71-
shuffleId, rdd.partitions.size, this)
78+
val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle(
79+
shuffleId, _rdd.partitions.size, this)
7280

73-
rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
81+
_rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
7482
}
7583

7684

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -997,8 +997,6 @@ class SparkContext(config: SparkConf) extends Logging {
997997
// TODO: Cache.stop()?
998998
env.stop()
999999
SparkEnv.set(null)
1000-
ShuffleMapTask.clearCache()
1001-
ResultTask.clearCache()
10021000
listenerBus.stop()
10031001
eventLogger.foreach(_.stop())
10041002
logInfo("Successfully stopped SparkContext")

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,13 @@ import org.apache.spark.Partitioner._
3535
import org.apache.spark.SparkContext._
3636
import org.apache.spark.annotation.{DeveloperApi, Experimental}
3737
import org.apache.spark.api.java.JavaRDD
38+
import org.apache.spark.broadcast.Broadcast
3839
import org.apache.spark.partial.BoundedDouble
3940
import org.apache.spark.partial.CountEvaluator
4041
import org.apache.spark.partial.GroupedCountEvaluator
4142
import org.apache.spark.partial.PartialResult
4243
import org.apache.spark.storage.StorageLevel
43-
import org.apache.spark.util.{BoundedPriorityQueue, CallSite, Utils}
44+
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
4445
import org.apache.spark.util.collection.OpenHashMap
4546
import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils}
4647

@@ -1206,21 +1207,36 @@ abstract class RDD[T: ClassTag](
12061207
/**
12071208
* Return whether this RDD has been checkpointed or not
12081209
*/
1209-
def isCheckpointed: Boolean = {
1210-
checkpointData.map(_.isCheckpointed).getOrElse(false)
1211-
}
1210+
def isCheckpointed: Boolean = checkpointData.exists(_.isCheckpointed)
12121211

12131212
/**
12141213
* Gets the name of the file to which this RDD was checkpointed
12151214
*/
1216-
def getCheckpointFile: Option[String] = {
1217-
checkpointData.flatMap(_.getCheckpointFile)
1218-
}
1215+
def getCheckpointFile: Option[String] = checkpointData.flatMap(_.getCheckpointFile)
12191216

12201217
// =======================================================================
12211218
// Other internal methods and fields
12221219
// =======================================================================
12231220

1221+
/**
1222+
* Broadcasted copy of this RDD, used to dispatch tasks to executors. Note that we broadcast
1223+
* the serialized copy of the RDD and for each task we will deserialize it, which means each
1224+
* task gets a different copy of the RDD. This provides stronger isolation between tasks that
1225+
* might modify state of objects referenced in their closures. This is necessary in Hadoop
1226+
* where the JobConf/Configuration object is not thread-safe.
1227+
*/
1228+
@transient private[spark] lazy val broadcasted: Broadcast[Array[Byte]] = {
1229+
val ser = SparkEnv.get.closureSerializer.newInstance()
1230+
val bytes = ser.serialize(this).array()
1231+
val size = Utils.bytesToString(bytes.length)
1232+
if (bytes.length > (1L << 20)) {
1233+
logWarning(s"Broadcasting RDD $id ($size), which contains large objects")
1234+
} else {
1235+
logDebug(s"Broadcasting RDD $id ($size)")
1236+
}
1237+
sc.broadcast(bytes)
1238+
}
1239+
12241240
private var storageLevel: StorageLevel = StorageLevel.NONE
12251241

12261242
/** User code that created this RDD (e.g. `textFile`, `parallelize`). */

core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
106106
cpRDD = Some(newRDD)
107107
rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions
108108
cpState = Checkpointed
109-
RDDCheckpointData.clearTaskCaches()
110109
}
111110
logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id)
112111
}
@@ -131,9 +130,5 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
131130
}
132131
}
133132

134-
private[spark] object RDDCheckpointData {
135-
def clearTaskCaches() {
136-
ShuffleMapTask.clearCache()
137-
ResultTask.clearCache()
138-
}
139-
}
133+
// Used for synchronization
134+
private[spark] object RDDCheckpointData

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -361,9 +361,6 @@ class DAGScheduler(
361361
// data structures based on StageId
362362
stageIdToStage -= stageId
363363

364-
ShuffleMapTask.removeStage(stageId)
365-
ResultTask.removeStage(stageId)
366-
367364
logDebug("After removal of stage %d, remaining stages = %d"
368365
.format(stageId, stageIdToStage.size))
369366
}
@@ -691,7 +688,6 @@ class DAGScheduler(
691688
}
692689
}
693690

694-
695691
/** Called when stage's parents are available and we can now do its task. */
696692
private def submitMissingTasks(stage: Stage, jobId: Int) {
697693
logDebug("submitMissingTasks(" + stage + ")")

core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala

Lines changed: 31 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -17,134 +17,68 @@
1717

1818
package org.apache.spark.scheduler
1919

20-
import scala.language.existentials
20+
import java.nio.ByteBuffer
2121

2222
import java.io._
23-
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
24-
25-
import scala.collection.mutable.HashMap
2623

2724
import org.apache.spark._
28-
import org.apache.spark.rdd.{RDD, RDDCheckpointData}
29-
30-
private[spark] object ResultTask {
31-
32-
// A simple map between the stage id to the serialized byte array of a task.
33-
// Served as a cache for task serialization because serialization can be
34-
// expensive on the master node if it needs to launch thousands of tasks.
35-
private val serializedInfoCache = new HashMap[Int, Array[Byte]]
36-
37-
def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] =
38-
{
39-
synchronized {
40-
val old = serializedInfoCache.get(stageId).orNull
41-
if (old != null) {
42-
old
43-
} else {
44-
val out = new ByteArrayOutputStream
45-
val ser = SparkEnv.get.closureSerializer.newInstance()
46-
val objOut = ser.serializeStream(new GZIPOutputStream(out))
47-
objOut.writeObject(rdd)
48-
objOut.writeObject(func)
49-
objOut.close()
50-
val bytes = out.toByteArray
51-
serializedInfoCache.put(stageId, bytes)
52-
bytes
53-
}
54-
}
55-
}
56-
57-
def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) =
58-
{
59-
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
60-
val ser = SparkEnv.get.closureSerializer.newInstance()
61-
val objIn = ser.deserializeStream(in)
62-
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
63-
val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _]
64-
(rdd, func)
65-
}
66-
67-
def removeStage(stageId: Int) {
68-
serializedInfoCache.remove(stageId)
69-
}
70-
71-
def clearCache() {
72-
synchronized {
73-
serializedInfoCache.clear()
74-
}
75-
}
76-
}
77-
25+
import org.apache.spark.broadcast.Broadcast
26+
import org.apache.spark.rdd.RDD
7827

7928
/**
8029
* A task that sends back the output to the driver application.
8130
*
82-
* See [[org.apache.spark.scheduler.Task]] for more information.
31+
* See [[Task]] for more information.
8332
*
8433
* @param stageId id of the stage this task belongs to
85-
* @param rdd input to func
34+
* @param rddBinary broadcast version of of the serialized RDD
8635
* @param func a function to apply on a partition of the RDD
87-
* @param _partitionId index of the number in the RDD
36+
* @param partition partition of the RDD this task is associated with
8837
* @param locs preferred task execution locations for locality scheduling
8938
* @param outputId index of the task in this job (a job can launch tasks on only a subset of the
9039
* input RDD's partitions).
9140
*/
9241
private[spark] class ResultTask[T, U](
9342
stageId: Int,
94-
var rdd: RDD[T],
95-
var func: (TaskContext, Iterator[T]) => U,
96-
_partitionId: Int,
43+
val rddBinary: Broadcast[Array[Byte]],
44+
val func: (TaskContext, Iterator[T]) => U,
45+
val partition: Partition,
9746
@transient locs: Seq[TaskLocation],
98-
var outputId: Int)
99-
extends Task[U](stageId, _partitionId) with Externalizable {
100-
101-
def this() = this(0, null, null, 0, null, 0)
102-
103-
var split = if (rdd == null) null else rdd.partitions(partitionId)
47+
val outputId: Int)
48+
extends Task[U](stageId, partition.index) with Serializable {
49+
50+
// TODO: Should we also broadcast func? For that we would need a place to
51+
// keep a reference to it (perhaps in DAGScheduler's job object).
52+
53+
def this(
54+
stageId: Int,
55+
rdd: RDD[T],
56+
func: (TaskContext, Iterator[T]) => U,
57+
partitionId: Int,
58+
locs: Seq[TaskLocation],
59+
outputId: Int) = {
60+
this(stageId, rdd.broadcasted, func, rdd.partitions(partitionId), locs, outputId)
61+
}
10462

105-
@transient private val preferredLocs: Seq[TaskLocation] = {
63+
@transient private[this] val preferredLocs: Seq[TaskLocation] = {
10664
if (locs == null) Nil else locs.toSet.toSeq
10765
}
10866

10967
override def runTask(context: TaskContext): U = {
68+
// Deserialize the RDD using the broadcast variable.
69+
val ser = SparkEnv.get.closureSerializer.newInstance()
70+
val rdd = ser.deserialize[RDD[T]](ByteBuffer.wrap(rddBinary.value),
71+
Thread.currentThread.getContextClassLoader)
11072
metrics = Some(context.taskMetrics)
11173
try {
112-
func(context, rdd.iterator(split, context))
74+
func(context, rdd.iterator(partition, context))
11375
} finally {
11476
context.executeOnCompleteCallbacks()
11577
}
11678
}
11779

80+
// This is only callable on the driver side.
11881
override def preferredLocations: Seq[TaskLocation] = preferredLocs
11982

12083
override def toString = "ResultTask(" + stageId + ", " + partitionId + ")"
121-
122-
override def writeExternal(out: ObjectOutput) {
123-
RDDCheckpointData.synchronized {
124-
split = rdd.partitions(partitionId)
125-
out.writeInt(stageId)
126-
val bytes = ResultTask.serializeInfo(
127-
stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _])
128-
out.writeInt(bytes.length)
129-
out.write(bytes)
130-
out.writeInt(partitionId)
131-
out.writeInt(outputId)
132-
out.writeLong(epoch)
133-
out.writeObject(split)
134-
}
135-
}
136-
137-
override def readExternal(in: ObjectInput) {
138-
val stageId = in.readInt()
139-
val numBytes = in.readInt()
140-
val bytes = new Array[Byte](numBytes)
141-
in.readFully(bytes)
142-
val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes)
143-
rdd = rdd_.asInstanceOf[RDD[T]]
144-
func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U]
145-
partitionId = in.readInt()
146-
outputId = in.readInt()
147-
epoch = in.readLong()
148-
split = in.readObject().asInstanceOf[Partition]
149-
}
15084
}

0 commit comments

Comments
 (0)