Skip to content

Commit 3ae3e1b

Browse files
ho3rexqjcloud-fan
authored andcommitted
[SPARK-22986][CORE] Use a cache to avoid instantiating multiple instances of broadcast variable values
When resources happen to be constrained on an executor the first time a broadcast variable is instantiated it is persisted to disk by the BlockManager. Consequently, every subsequent call to TorrentBroadcast::readBroadcastBlock from other instances of that broadcast variable spawns another instance of the underlying value. That is, broadcast variables are spawned once per executor **unless** memory is constrained, in which case every instance of a broadcast variable is provided with a unique copy of the underlying value. This patch fixes the above by explicitly caching the underlying values using weak references in a ReferenceMap. Author: ho3rexqj <[email protected]> Closes #20183 from ho3rexqj/fix/cache-broadcast-values. (cherry picked from commit cbe7c6f) Signed-off-by: Wenchen Fan <[email protected]>
1 parent 55695c7 commit 3ae3e1b

File tree

3 files changed

+83
-29
lines changed

3 files changed

+83
-29
lines changed

core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import java.util.concurrent.atomic.AtomicLong
2121

2222
import scala.reflect.ClassTag
2323

24+
import org.apache.commons.collections.map.{AbstractReferenceMap, ReferenceMap}
25+
2426
import org.apache.spark.{SecurityManager, SparkConf}
2527
import org.apache.spark.internal.Logging
2628

@@ -52,6 +54,10 @@ private[spark] class BroadcastManager(
5254

5355
private val nextBroadcastId = new AtomicLong(0)
5456

57+
private[broadcast] val cachedValues = {
58+
new ReferenceMap(AbstractReferenceMap.HARD, AbstractReferenceMap.WEAK)
59+
}
60+
5561
def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = {
5662
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
5763
}

core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -206,36 +206,50 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
206206

207207
private def readBroadcastBlock(): T = Utils.tryOrIOException {
208208
TorrentBroadcast.synchronized {
209-
setConf(SparkEnv.get.conf)
210-
val blockManager = SparkEnv.get.blockManager
211-
blockManager.getLocalValues(broadcastId) match {
212-
case Some(blockResult) =>
213-
if (blockResult.data.hasNext) {
214-
val x = blockResult.data.next().asInstanceOf[T]
215-
releaseLock(broadcastId)
216-
x
217-
} else {
218-
throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId")
219-
}
220-
case None =>
221-
logInfo("Started reading broadcast variable " + id)
222-
val startTimeMs = System.currentTimeMillis()
223-
val blocks = readBlocks()
224-
logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs))
225-
226-
try {
227-
val obj = TorrentBroadcast.unBlockifyObject[T](
228-
blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec)
229-
// Store the merged copy in BlockManager so other tasks on this executor don't
230-
// need to re-fetch it.
231-
val storageLevel = StorageLevel.MEMORY_AND_DISK
232-
if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
233-
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
209+
val broadcastCache = SparkEnv.get.broadcastManager.cachedValues
210+
211+
Option(broadcastCache.get(broadcastId)).map(_.asInstanceOf[T]).getOrElse {
212+
setConf(SparkEnv.get.conf)
213+
val blockManager = SparkEnv.get.blockManager
214+
blockManager.getLocalValues(broadcastId) match {
215+
case Some(blockResult) =>
216+
if (blockResult.data.hasNext) {
217+
val x = blockResult.data.next().asInstanceOf[T]
218+
releaseLock(broadcastId)
219+
220+
if (x != null) {
221+
broadcastCache.put(broadcastId, x)
222+
}
223+
224+
x
225+
} else {
226+
throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId")
234227
}
235-
obj
236-
} finally {
237-
blocks.foreach(_.dispose())
238-
}
228+
case None =>
229+
logInfo("Started reading broadcast variable " + id)
230+
val startTimeMs = System.currentTimeMillis()
231+
val blocks = readBlocks()
232+
logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs))
233+
234+
try {
235+
val obj = TorrentBroadcast.unBlockifyObject[T](
236+
blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec)
237+
// Store the merged copy in BlockManager so other tasks on this executor don't
238+
// need to re-fetch it.
239+
val storageLevel = StorageLevel.MEMORY_AND_DISK
240+
if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
241+
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
242+
}
243+
244+
if (obj != null) {
245+
broadcastCache.put(broadcastId, obj)
246+
}
247+
248+
obj
249+
} finally {
250+
blocks.foreach(_.dispose())
251+
}
252+
}
239253
}
240254
}
241255
}

core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,40 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio
153153
assert(broadcast.value.sum === 10)
154154
}
155155

156+
test("One broadcast value instance per executor") {
157+
val conf = new SparkConf()
158+
.setMaster("local[4]")
159+
.setAppName("test")
160+
161+
sc = new SparkContext(conf)
162+
val list = List[Int](1, 2, 3, 4)
163+
val broadcast = sc.broadcast(list)
164+
val instances = sc.parallelize(1 to 10)
165+
.map(x => System.identityHashCode(broadcast.value))
166+
.collect()
167+
.toSet
168+
169+
assert(instances.size === 1)
170+
}
171+
172+
test("One broadcast value instance per executor when memory is constrained") {
173+
val conf = new SparkConf()
174+
.setMaster("local[4]")
175+
.setAppName("test")
176+
.set("spark.memory.useLegacyMode", "true")
177+
.set("spark.storage.memoryFraction", "0.0")
178+
179+
sc = new SparkContext(conf)
180+
val list = List[Int](1, 2, 3, 4)
181+
val broadcast = sc.broadcast(list)
182+
val instances = sc.parallelize(1 to 10)
183+
.map(x => System.identityHashCode(broadcast.value))
184+
.collect()
185+
.toSet
186+
187+
assert(instances.size === 1)
188+
}
189+
156190
/**
157191
* Verify the persistence of state associated with a TorrentBroadcast in a local-cluster.
158192
*

0 commit comments

Comments
 (0)