Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ public void onComplete(String streamId) throws IOException {
callback.onSuccess(ByteBuffer.allocate(0));
} catch (Exception ex) {
IOException ioExc = new IOException("Failure post-processing complete stream;" +
" failing this rpc and leaving channel active");
" failing this rpc and leaving channel active", ex);
callback.onFailure(ioExc);
streamHandler.onFailure(streamId, ioExc);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public abstract class BlockTransferMessage implements Encodable {
/** Preceding every serialized message is its type, which allows us to deserialize it. */
public enum Type {
OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4),
HEARTBEAT(5);
HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6);

private final byte id;

Expand All @@ -67,6 +67,7 @@ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) {
case 3: return StreamHandle.decode(buf);
case 4: return RegisterDriver.decode(buf);
case 5: return ShuffleServiceHeartbeat.decode(buf);
case 6: return UploadBlockStream.decode(buf);
default: throw new IllegalArgumentException("Unknown message type: " + type);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network.shuffle.protocol;

import java.util.Arrays;

import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;

import org.apache.spark.network.protocol.Encoders;

// Needed by ScalaDoc. See SPARK-7726
import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;

/**
* A request to Upload a block, which the destination should receive as a stream.
*
* The actual block data is not contained here. It will be passed to the StreamCallbackWithID
* that is returned from RpcHandler.receiveStream()
*/
public class UploadBlockStream extends BlockTransferMessage {
public final String blockId;
public final byte[] metadata;

public UploadBlockStream(String blockId, byte[] metadata) {
this.blockId = blockId;
this.metadata = metadata;
}

@Override
protected Type type() { return Type.UPLOAD_BLOCK_STREAM; }

@Override
public int hashCode() {
int objectsHashCode = Objects.hashCode(blockId);
return objectsHashCode * 41 + Arrays.hashCode(metadata);
}

@Override
public String toString() {
return Objects.toStringHelper(this)
.add("blockId", blockId)
.add("metadata size", metadata.length)
.toString();
}

@Override
public boolean equals(Object other) {
if (other != null && other instanceof UploadBlockStream) {
UploadBlockStream o = (UploadBlockStream) other;
return Objects.equal(blockId, o.blockId)
&& Arrays.equals(metadata, o.metadata);
}
return false;
}

@Override
public int encodedLength() {
return Encoders.Strings.encodedLength(blockId)
+ Encoders.ByteArrays.encodedLength(metadata);
}

@Override
public void encode(ByteBuf buf) {
Encoders.Strings.encode(buf, blockId);
Encoders.ByteArrays.encode(buf, metadata);
}

public static UploadBlockStream decode(ByteBuf buf) {
String blockId = Encoders.Strings.decode(buf);
byte[] metadata = Encoders.ByteArrays.decode(buf);
return new UploadBlockStream(blockId, metadata);
}
}
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -363,14 +363,14 @@ private[spark] class Executor(
threadMXBean.getCurrentThreadCpuTime
} else 0L
var threwException = true
val value = try {
val value = Utils.tryWithSafeFinally {
val res = task.run(
taskAttemptId = taskId,
attemptNumber = taskDescription.attemptNumber,
metricsSystem = env.metricsSystem)
threwException = false
res
} finally {
} {
val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId)
val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,13 @@ package object config {
.checkValue(v => v > 0, "The value should be a positive integer.")
.createWithDefault(2000)

private[spark] val MEMORY_MAP_LIMIT_FOR_TESTS =
ConfigBuilder("spark.storage.memoryMapLimitForTests")
.internal()
.doc("For testing only, controls the size of chunks when memory mapping a file")
.bytesConf(ByteUnit.BYTE)
.createWithDefault(Int.MaxValue)

private[spark] val BARRIER_SYNC_TIMEOUT =
ConfigBuilder("spark.barrier.sync.timeout")
.doc("The timeout in seconds for each barrier() call from a barrier task. If the " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.network
import scala.reflect.ClassTag

import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.client.StreamCallbackWithID
import org.apache.spark.storage.{BlockId, StorageLevel}

private[spark]
Expand All @@ -43,6 +44,17 @@ trait BlockDataManager {
level: StorageLevel,
classTag: ClassTag[_]): Boolean

/**
* Put the given block that will be received as a stream.
*
* When this method is called, the block data itself is not available -- it will be passed to the
* returned StreamCallbackWithID.
*/
def putBlockDataAsStream(
blockId: BlockId,
level: StorageLevel,
classTag: ClassTag[_]): StreamCallbackWithID

/**
* Release locks acquired by [[putBlockData()]] and [[getBlockData()]].
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ import scala.reflect.ClassTag
import org.apache.spark.internal.Logging
import org.apache.spark.network.BlockDataManager
import org.apache.spark.network.buffer.NioManagedBuffer
import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
import org.apache.spark.network.client.{RpcResponseCallback, StreamCallbackWithID, TransportClient}
import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager}
import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, OpenBlocks, StreamHandle, UploadBlock}
import org.apache.spark.network.shuffle.protocol._
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.{BlockId, StorageLevel}

Expand Down Expand Up @@ -73,10 +73,32 @@ class NettyBlockRpcServer(
}
val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData))
val blockId = BlockId(uploadBlock.blockId)
logDebug(s"Receiving replicated block $blockId with level ${level} " +
s"from ${client.getSocketAddress}")
blockManager.putBlockData(blockId, data, level, classTag)
responseContext.onSuccess(ByteBuffer.allocate(0))
}
}

override def receiveStream(
client: TransportClient,
messageHeader: ByteBuffer,
responseContext: RpcResponseCallback): StreamCallbackWithID = {
val message =
BlockTransferMessage.Decoder.fromByteBuffer(messageHeader).asInstanceOf[UploadBlockStream]
val (level: StorageLevel, classTag: ClassTag[_]) = {
serializer
.newInstance()
.deserialize(ByteBuffer.wrap(message.metadata))
.asInstanceOf[(StorageLevel, ClassTag[_])]
}
val blockId = BlockId(message.blockId)
logDebug(s"Receiving replicated block $blockId with level ${level} as stream " +
s"from ${client.getSocketAddress}")
// This will return immediately, but will setup a callback on streamData which will still
// do all the processing in the netty thread.
blockManager.putBlockDataAsStream(blockId, level, classTag)
}

override def getStreamManager(): StreamManager = streamManager
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@ import scala.reflect.ClassTag
import com.codahale.metrics.{Metric, MetricSet}

import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.internal.config
import org.apache.spark.network._
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory}
import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap}
import org.apache.spark.network.server._
import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher, TempFileManager}
import org.apache.spark.network.shuffle.protocol.UploadBlock
import org.apache.spark.network.shuffle.protocol.{UploadBlock, UploadBlockStream}
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.storage.{BlockId, StorageLevel}
Expand Down Expand Up @@ -148,20 +149,28 @@ private[spark] class NettyBlockTransferService(
// Everything else is encoded using our binary protocol.
val metadata = JavaUtils.bufferToArray(serializer.newInstance().serialize((level, classTag)))

// Convert or copy nio buffer into array in order to serialize it.
val array = JavaUtils.bufferToArray(blockData.nioByteBuffer())
val asStream = blockData.size() > conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM)
val callback = new RpcResponseCallback {
override def onSuccess(response: ByteBuffer): Unit = {
logTrace(s"Successfully uploaded block $blockId${if (asStream) " as stream" else ""}")
result.success((): Unit)
}

client.sendRpc(new UploadBlock(appId, execId, blockId.name, metadata, array).toByteBuffer,
new RpcResponseCallback {
override def onSuccess(response: ByteBuffer): Unit = {
logTrace(s"Successfully uploaded block $blockId")
result.success((): Unit)
}
override def onFailure(e: Throwable): Unit = {
logError(s"Error while uploading block $blockId", e)
result.failure(e)
}
})
override def onFailure(e: Throwable): Unit = {
logError(s"Error while uploading $blockId${if (asStream) " as stream" else ""}", e)
result.failure(e)
}
}
if (asStream) {
val streamHeader = new UploadBlockStream(blockId.name, metadata).toByteBuffer
client.uploadStream(new NioManagedBuffer(streamHeader), blockData, callback)
} else {
// Convert or copy nio buffer into array in order to serialize it.
val array = JavaUtils.bufferToArray(blockData.nioByteBuffer())

client.sendRpc(new UploadBlock(appId, execId, blockId.name, metadata, array).toByteBuffer,
callback)
}

result.future
}
Expand Down
66 changes: 64 additions & 2 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import org.apache.spark.memory.{MemoryManager, MemoryMode}
import org.apache.spark.metrics.source.Source
import org.apache.spark.network._
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.client.StreamCallbackWithID
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.shuffle.{ExternalShuffleClient, TempFileManager}
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
Expand Down Expand Up @@ -406,6 +407,63 @@ private[spark] class BlockManager(
putBytes(blockId, new ChunkedByteBuffer(data.nioByteBuffer()), level)(classTag)
}

override def putBlockDataAsStream(
blockId: BlockId,
level: StorageLevel,
classTag: ClassTag[_]): StreamCallbackWithID = {
// TODO if we're going to only put the data in the disk store, we should just write it directly
// to the final location, but that would require a deeper refactor of this code. So instead
// we just write to a temp file, and call putBytes on the data in that file.
val tmpFile = diskBlockManager.createTempLocalBlock()._2
val channel = new CountingWritableChannel(
Channels.newChannel(serializerManager.wrapForEncryption(new FileOutputStream(tmpFile))))
logTrace(s"Streaming block $blockId to tmp file $tmpFile")
new StreamCallbackWithID {

override def getID: String = blockId.name

override def onData(streamId: String, buf: ByteBuffer): Unit = {
while (buf.hasRemaining) {
channel.write(buf)
}
}

override def onComplete(streamId: String): Unit = {
logTrace(s"Done receiving block $blockId, now putting into local blockManager")
// Read the contents of the downloaded file as a buffer to put into the blockManager.
// Note this is all happening inside the netty thread as soon as it reads the end of the
// stream.
channel.close()
// TODO SPARK-25035 Even if we're only going to write the data to disk after this, we end up
// using a lot of memory here. With encryption, we'll read the whole file into a regular
// byte buffer and OOM. Without encryption, we'll memory map the file and won't get a jvm
// OOM, but might get killed by the OS / cluster manager. We could at least read the tmp
// file as a stream in both cases.
val buffer = securityManager.getIOEncryptionKey() match {
case Some(key) =>
// we need to pass in the size of the unencrypted block
val blockSize = channel.getCount
val allocator = level.memoryMode match {
case MemoryMode.ON_HEAP => ByteBuffer.allocate _
case MemoryMode.OFF_HEAP => Platform.allocateDirectBuffer _
}
new EncryptedBlockData(tmpFile, blockSize, conf, key).toChunkedByteBuffer(allocator)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

toChunkedByteBuffer is also pretty memory-hungry, right? You'll end up needing enough memory to hold the entire file in memory, if I read the code right.

This is probably ok for now, but should probably mention it in your TODO above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, you store the entire file in memory (after decrypting). its not memory mapped either, so it'll probably be a regular OOM (depending on memory mode). updated the comment


case None =>
ChunkedByteBuffer.map(tmpFile, conf.get(config.MEMORY_MAP_LIMIT_FOR_TESTS).toInt)
}
putBytes(blockId, buffer, level)(classTag)
tmpFile.delete()
}

override def onFailure(streamId: String, cause: Throwable): Unit = {
// the framework handles the connection itself, we just need to do local cleanup
channel.close()
tmpFile.delete()
}
}
}

/**
* Get the BlockStatus for the block identified by the given ID, if it exists.
* NOTE: This is mainly for testing.
Expand Down Expand Up @@ -667,7 +725,7 @@ private[spark] class BlockManager(
// TODO if we change this method to return the ManagedBuffer, then getRemoteValues
// could just use the inputStream on the temp file, rather than memory-mapping the file.
// Until then, replication can cause the process to use too much memory and get killed
// by the OS / cluster manager (not a java OOM, since its a memory-mapped file) even though
// by the OS / cluster manager (not a java OOM, since it's a memory-mapped file) even though
// we've read the data to disk.
logDebug(s"Getting remote block $blockId")
require(blockId != null, "BlockId is null")
Expand Down Expand Up @@ -1358,12 +1416,16 @@ private[spark] class BlockManager(
try {
val onePeerStartTime = System.nanoTime
logTrace(s"Trying to replicate $blockId of ${data.size} bytes to $peer")
// This thread keeps a lock on the block, so we do not want the netty thread to unlock
// block when it finishes sending the message.
val buffer = new BlockManagerManagedBuffer(blockInfoManager, blockId, data, false,
unlockOnDeallocate = false)
blockTransferService.uploadBlockSync(
peer.host,
peer.port,
peer.executorId,
blockId,
new BlockManagerManagedBuffer(blockInfoManager, blockId, data, false),
buffer,
tLevel,
classTag)
logTrace(s"Replicated $blockId of ${data.size} bytes to $peer" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ private[storage] class BlockManagerManagedBuffer(
blockInfoManager: BlockInfoManager,
blockId: BlockId,
data: BlockData,
dispose: Boolean) extends ManagedBuffer {
dispose: Boolean,
unlockOnDeallocate: Boolean = true) extends ManagedBuffer {

private val refCount = new AtomicInteger(1)

Expand All @@ -58,7 +59,9 @@ private[storage] class BlockManagerManagedBuffer(
}

override def release(): ManagedBuffer = {
blockInfoManager.unlock(blockId)
if (unlockOnDeallocate) {
blockInfoManager.unlock(blockId)
}
if (refCount.decrementAndGet() == 0 && dispose) {
data.dispose()
}
Expand Down
Loading