Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 65 additions & 71 deletions core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import scala.util.Random

import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.ByteBufferInputStream
import org.apache.spark.util.io.ByteArrayChunkOutputStream
Expand All @@ -46,14 +47,12 @@ import org.apache.spark.util.io.ByteArrayChunkOutputStream
* This prevents the driver from being the bottleneck in sending out multiple copies of the
* broadcast data (one per executor) as done by the [[org.apache.spark.broadcast.HttpBroadcast]].
*
* When initialized, TorrentBroadcast objects read SparkEnv.get.conf.
*
* @param obj object to broadcast
* @param isLocal whether Spark is running in local mode (single JVM process).
* @param id A unique identifier for the broadcast variable.
*/
private[spark] class TorrentBroadcast[T: ClassTag](
obj : T,
@transient private val isLocal: Boolean,
id: Long)
private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
extends Broadcast[T](id) with Logging with Serializable {

/**
Expand All @@ -62,6 +61,20 @@ private[spark] class TorrentBroadcast[T: ClassTag](
* blocks from the driver and/or other executors.
*/
@transient private var _value: T = obj
/** The compression codec to use, or None if compression is disabled */
@transient private var compressionCodec: Option[CompressionCodec] = _
/** Size of each block. Default value is 4MB. This value is only read by the broadcaster. */
@transient private var blockSize: Int = _
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about move these two as part of Constructor? Reading the Conf in TorrentBroadcastFactor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this and agree that it might be cleaner, but this will require more refactoring of other code. One design goal here was to minimize the serialized size of TorrentBroadcast objects, so we can't serialize the SparkConf or CompressionCodec instances (which contain SparkConfs). SparkEnv.conf determines these values anyways.


private def setConf(conf: SparkConf) {
compressionCodec = if (conf.getBoolean("spark.broadcast.compress", true)) {
Some(CompressionCodec.createCodec(conf))
} else {
None
}
blockSize = conf.getInt("spark.broadcast.blockSize", 4096) * 1024
}
setConf(SparkEnv.get.conf)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update the javadoc for this class to make it very obvious that at init time, this class reads configuration from SparkEnv.get.conf

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


private val broadcastId = BroadcastBlockId(id)

Expand All @@ -76,23 +89,20 @@ private[spark] class TorrentBroadcast[T: ClassTag](
* @return number of blocks this broadcast variable is divided into
*/
private def writeBlocks(): Int = {
// For local mode, just put the object in the BlockManager so we can find it later.
SparkEnv.get.blockManager.putSingle(
broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false)

if (!isLocal) {
val blocks = TorrentBroadcast.blockifyObject(_value)
blocks.zipWithIndex.foreach { case (block, i) =>
SparkEnv.get.blockManager.putBytes(
BroadcastBlockId(id, "piece" + i),
block,
StorageLevel.MEMORY_AND_DISK_SER,
tellMaster = true)
}
blocks.length
} else {
0
// Store a copy of the broadcast variable in the driver so that tasks run on the driver
// do not create a duplicate copy of the broadcast variable's value.
SparkEnv.get.blockManager.putSingle(broadcastId, _value, StorageLevel.MEMORY_AND_DISK,
tellMaster = false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder that store a serialized copy in local mode will not help anything. If it failed to fetch the original copy of value from blockManager, it will also can not fetch the serialized copy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for this store is to avoid creating two copies of _value in the driver. If we serialize and deserialize a broadcast variable on the driver and then attempt to access its value, then without this code we will end up going through the regular de-chunking code path, which will cause us to deserialize the serialized copy of _value and waste memory.

I believe that this serialization and deserialization can take place when tasks are run in local mode, since we still serialize tasks in order to help users be aware of serialization issues that would impact them if they moved to a cluster. This complexity is another reason why I'm in favor of just scrapping all local-mode special-casing and configuring Spark to use a dummy LocalBroadcastFactory for local mode instead of whichever setting the user specified. That would be a larger, more-invasive change, which is why I opted for the simpler fix here.

val blocks =
TorrentBroadcast.blockifyObject(_value, blockSize, SparkEnv.get.serializer, compressionCodec)
blocks.zipWithIndex.foreach { case (block, i) =>
SparkEnv.get.blockManager.putBytes(
BroadcastBlockId(id, "piece" + i),
block,
StorageLevel.MEMORY_AND_DISK_SER,
tellMaster = true)
}
blocks.length
}

/** Fetch torrent blocks from the driver and/or other executors. */
Expand All @@ -104,29 +114,24 @@ private[spark] class TorrentBroadcast[T: ClassTag](

for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
val pieceId = BroadcastBlockId(id, "piece" + pid)

// First try getLocalBytes because there is a chance that previous attempts to fetch the
logDebug(s"Reading piece $pieceId of $broadcastId")
// First try getLocalBytes because there is a chance that previous attempts to fetch the
// broadcast blocks have already fetched some of the blocks. In that case, some blocks
// would be available locally (on this executor).
var blockOpt = bm.getLocalBytes(pieceId)
if (!blockOpt.isDefined) {
blockOpt = bm.getRemoteBytes(pieceId)
blockOpt match {
case Some(block) =>
// If we found the block from remote executors/driver's BlockManager, put the block
// in this executor's BlockManager.
SparkEnv.get.blockManager.putBytes(
pieceId,
block,
StorageLevel.MEMORY_AND_DISK_SER,
tellMaster = true)

case None =>
throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
}
def getLocal: Option[ByteBuffer] = bm.getLocalBytes(pieceId)
def getRemote: Option[ByteBuffer] = bm.getRemoteBytes(pieceId).map { block =>
// If we found the block from remote executors/driver's BlockManager, put the block
// in this executor's BlockManager.
SparkEnv.get.blockManager.putBytes(
pieceId,
block,
StorageLevel.MEMORY_AND_DISK_SER,
tellMaster = true)
block
}
// If we get here, the option is defined.
blocks(pid) = blockOpt.get
val block: ByteBuffer = getLocal.orElse(getRemote).getOrElse(
throw new SparkException(s"Failed to get $pieceId of $broadcastId"))
blocks(pid) = block
}
blocks
}
Expand Down Expand Up @@ -156,6 +161,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
TorrentBroadcast.synchronized {
setConf(SparkEnv.get.conf)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks wired, how can we make sure that this conf is equals to the one used when create the Broadcast?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conf is application-scoped. The same conf should be present on this application's executors, where this task will be deserialized. This assumption is used elsewhere, too.

SparkEnv.get.blockManager.getLocal(broadcastId).map(_.data.next()) match {
case Some(x) =>
_value = x.asInstanceOf[T]
Expand All @@ -167,7 +173,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](
val time = (System.nanoTime() - start) / 1e9
logInfo("Reading broadcast variable " + id + " took " + time + " s")

_value = TorrentBroadcast.unBlockifyObject[T](blocks)
_value =
TorrentBroadcast.unBlockifyObject[T](blocks, SparkEnv.get.serializer, compressionCodec)
// Store the merged copy in BlockManager so other tasks on this executor don't
// need to re-fetch it.
SparkEnv.get.blockManager.putSingle(
Expand All @@ -179,43 +186,29 @@ private[spark] class TorrentBroadcast[T: ClassTag](


private object TorrentBroadcast extends Logging {
/** Size of each block. Default value is 4MB. */
private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024
private var initialized = false
private var conf: SparkConf = null
private var compress: Boolean = false
private var compressionCodec: CompressionCodec = null

def initialize(_isDriver: Boolean, conf: SparkConf) {
TorrentBroadcast.conf = conf // TODO: we might have to fix it in tests
synchronized {
if (!initialized) {
compress = conf.getBoolean("spark.broadcast.compress", true)
compressionCodec = CompressionCodec.createCodec(conf)
initialized = true
}
}
}

def stop() {
initialized = false
}

def blockifyObject[T: ClassTag](obj: T): Array[ByteBuffer] = {
val bos = new ByteArrayChunkOutputStream(BLOCK_SIZE)
val out: OutputStream = if (compress) compressionCodec.compressedOutputStream(bos) else bos
val ser = SparkEnv.get.serializer.newInstance()
def blockifyObject[T: ClassTag](
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conf has been moved into class Broadcast, maybe blockifyObject and unblockify also should be moved.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two methods, blockifyObject and unBlockifyObject, now accept all of their dependencies directly, which makes it easier to unit-test them.

obj: T,
blockSize: Int,
serializer: Serializer,
compressionCodec: Option[CompressionCodec]): Array[ByteBuffer] = {
val bos = new ByteArrayChunkOutputStream(blockSize)
val out: OutputStream = compressionCodec.map(c => c.compressedOutputStream(bos)).getOrElse(bos)
val ser = serializer.newInstance()
val serOut = ser.serializeStream(out)
serOut.writeObject[T](obj).close()
bos.toArrays.map(ByteBuffer.wrap)
}

def unBlockifyObject[T: ClassTag](blocks: Array[ByteBuffer]): T = {
def unBlockifyObject[T: ClassTag](
blocks: Array[ByteBuffer],
serializer: Serializer,
compressionCodec: Option[CompressionCodec]): T = {
require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks")
val is = new SequenceInputStream(
asJavaEnumeration(blocks.iterator.map(block => new ByteBufferInputStream(block))))
val in: InputStream = if (compress) compressionCodec.compressedInputStream(is) else is

val ser = SparkEnv.get.serializer.newInstance()
val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is)
val ser = serializer.newInstance()
val serIn = ser.deserializeStream(in)
val obj = serIn.readObject[T]()
serIn.close()
Expand All @@ -227,6 +220,7 @@ private object TorrentBroadcast extends Logging {
* If removeFromDriver is true, also remove these persisted blocks on the driver.
*/
def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = {
logDebug(s"Unpersisting TorrentBroadcast $id")
SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,13 @@ import org.apache.spark.{SecurityManager, SparkConf}
*/
class TorrentBroadcastFactory extends BroadcastFactory {

override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
TorrentBroadcast.initialize(isDriver, conf)
}
override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { }

override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) =
new TorrentBroadcast[T](value_, isLocal, id)
override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = {
new TorrentBroadcast[T](value_, id)
}

override def stop() { TorrentBroadcast.stop() }
override def stop() { }

/**
* Remove all persisted state associated with the torrent broadcast with the given ID.
Expand Down
42 changes: 27 additions & 15 deletions core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@

package org.apache.spark.broadcast

import scala.util.Random

import org.scalatest.FunSuite

import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException}
import org.apache.spark.io.SnappyCompressionCodec
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.storage._


class BroadcastSuite extends FunSuite with LocalSparkContext {

private val httpConf = broadcastConf("HttpBroadcastFactory")
Expand Down Expand Up @@ -84,6 +87,24 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
}

test("TorrentBroadcast's blockifyObject and unblockifyObject are inverses") {
import org.apache.spark.broadcast.TorrentBroadcast._
val blockSize = 1024
val conf = new SparkConf()
val compressionCodec = Some(new SnappyCompressionCodec(conf))
val serializer = new JavaSerializer(conf)
val seed = 42
val rand = new Random(seed)
for (trial <- 1 to 100) {
val size = 1 + rand.nextInt(1024 * 10)
val data: Array[Byte] = new Array[Byte](size)
rand.nextBytes(data)
val blocks = blockifyObject(data, blockSize, serializer, compressionCodec)
val unblockified = unBlockifyObject[Array[Byte]](blocks, serializer, compressionCodec)
assert(unblockified === data)
}
}

test("Unpersisting HttpBroadcast on executors only in local mode") {
testUnpersistHttpBroadcast(distributed = false, removeFromDriver = false)
}
Expand Down Expand Up @@ -193,26 +214,17 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {

blockId = BroadcastBlockId(broadcastId, "piece0")
statuses = bmm.getBlockStatus(blockId, askSlaves = true)
assert(statuses.size === (if (distributed) 1 else 0))
assert(statuses.size === 1)
}

// Verify that blocks are persisted in both the executors and the driver
def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster) {
var blockId = BroadcastBlockId(broadcastId)
var statuses = bmm.getBlockStatus(blockId, askSlaves = true)
if (distributed) {
assert(statuses.size === numSlaves + 1)
} else {
assert(statuses.size === 1)
}
val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
assert(statuses.size === numSlaves + 1)

blockId = BroadcastBlockId(broadcastId, "piece0")
statuses = bmm.getBlockStatus(blockId, askSlaves = true)
if (distributed) {
assert(statuses.size === numSlaves + 1)
} else {
assert(statuses.size === 0)
}
assert(statuses.size === numSlaves + 1)
}

// Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
Expand All @@ -224,7 +236,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
assert(statuses.size === expectedNumBlocks)

blockId = BroadcastBlockId(broadcastId, "piece0")
expectedNumBlocks = if (removeFromDriver || !distributed) 0 else 1
expectedNumBlocks = if (removeFromDriver) 0 else 1
statuses = bmm.getBlockStatus(blockId, askSlaves = true)
assert(statuses.size === expectedNumBlocks)
}
Expand Down