Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,18 @@
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;

import io.netty.channel.Channel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.network.buffer.ManagedBuffer;

import com.google.common.base.Preconditions;

/**
* StreamManager which allows registration of an Iterator<ManagedBuffer>, which are individually
* fetched as chunks by the client. Each registered buffer is one chunk.
Expand All @@ -36,18 +40,21 @@ public class OneForOneStreamManager extends StreamManager {
private final Logger logger = LoggerFactory.getLogger(OneForOneStreamManager.class);

private final AtomicLong nextStreamId;
private final Map<Long, StreamState> streams;
private final ConcurrentHashMap<Long, StreamState> streams;

/** State of a single stream. */
private static class StreamState {
final Iterator<ManagedBuffer> buffers;

// The channel associated to the stream
Channel associatedChannel = null;

// Used to keep track of the index of the buffer that the user has retrieved, just to ensure
// that the caller only requests each chunk one at a time, in order.
int curChunk = 0;

StreamState(Iterator<ManagedBuffer> buffers) {
this.buffers = buffers;
this.buffers = Preconditions.checkNotNull(buffers);
}
}

Expand All @@ -58,6 +65,13 @@ public OneForOneStreamManager() {
streams = new ConcurrentHashMap<Long, StreamState>();
}

@Override
public void registerChannel(Channel channel, long streamId) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can avoid the new field by doing

streams.get(streamId).associatedChannel = channel;

here and

// Close all streams which have been associated with the channel.
Iterator<Map.Entry<Long, StreamState>> streamIterator = streams.iterator()
while (streamIterator.hasNext()) {
  StreamState state = streamIterator.next().getValue()
  if (state.associatedChannel == channel) {
    streamIterator.remove();

    // Release all remaining buffers.
    while (state.buffers.hasNext()) {
      state.buffers.next().release();
    }
  }
}

Allowing the removal of the other connectionTerminated().

if (streams.containsKey(streamId)) {
streams.get(streamId).associatedChannel = channel;
}
}

@Override
public ManagedBuffer getChunk(long streamId, int chunkIndex) {
StreamState state = streams.get(streamId);
Expand All @@ -80,12 +94,17 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) {
}

@Override
public void connectionTerminated(long streamId) {
// Release all remaining buffers.
StreamState state = streams.remove(streamId);
if (state != null && state.buffers != null) {
while (state.buffers.hasNext()) {
state.buffers.next().release();
public void connectionTerminated(Channel channel) {
// Close all streams which have been associated with the channel.
for (Map.Entry<Long, StreamState> entry: streams.entrySet()) {
StreamState state = entry.getValue();
if (state.associatedChannel == channel) {
streams.remove(entry.getKey());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good point, this is safe because our Map is a ConcurrentHashMap (or else you would need to use an iterator to remove it safely). Would you mind making the left-hand type of the declaration of streams a ConcurrentHashMap? This is not the first place where we rely on the semantics of a ConcurrentHashMap over a general Map, and we should use proper style therefore.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making sense. Updated.


// Release all remaining buffers.
while (state.buffers.hasNext()) {
state.buffers.next().release();
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.network.server;

import io.netty.channel.Channel;

import org.apache.spark.network.buffer.ManagedBuffer;

/**
Expand Down Expand Up @@ -44,9 +46,18 @@ public abstract class StreamManager {
public abstract ManagedBuffer getChunk(long streamId, int chunkIndex);

/**
* Indicates that the TCP connection that was tied to the given stream has been terminated. After
* this occurs, we are guaranteed not to read from the stream again, so any state can be cleaned
* up.
* Associates a stream with a single client connection, which is guaranteed to be the only reader
* of the stream. The getChunk() method will be called serially on this connection and once the
* connection is closed, the stream will never be used again, enabling cleanup.
*
* This must be called before the first getChunk() on the stream, but it may be invoked multiple
* times with the same channel and stream id.
*/
public void registerChannel(Channel channel, long streamId) { }

/**
* Indicates that the given channel has been terminated. After this occurs, we are guaranteed not
* to read from the associated streams again, so any state can be cleaned up.
*/
public void connectionTerminated(long streamId) { }
public void connectionTerminated(Channel channel) { }
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@

package org.apache.spark.network.server;

import java.util.Set;

import com.google.common.base.Throwables;
import com.google.common.collect.Sets;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
Expand Down Expand Up @@ -62,9 +59,6 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
/** Returns each chunk part of a stream. */
private final StreamManager streamManager;

/** List of all stream ids that have been read on this handler, used for cleanup. */
private final Set<Long> streamIds;

public TransportRequestHandler(
Channel channel,
TransportClient reverseClient,
Expand All @@ -73,7 +67,6 @@ public TransportRequestHandler(
this.reverseClient = reverseClient;
this.rpcHandler = rpcHandler;
this.streamManager = rpcHandler.getStreamManager();
this.streamIds = Sets.newHashSet();
}

@Override
Expand All @@ -82,10 +75,7 @@ public void exceptionCaught(Throwable cause) {

@Override
public void channelUnregistered() {
// Inform the StreamManager that these streams will no longer be read from.
for (long streamId : streamIds) {
streamManager.connectionTerminated(streamId);
}
streamManager.connectionTerminated(channel);
rpcHandler.connectionTerminated(reverseClient);
}

Expand All @@ -102,12 +92,12 @@ public void handle(RequestMessage request) {

private void processFetchRequest(final ChunkFetchRequest req) {
final String client = NettyUtils.getRemoteAddress(channel);
streamIds.add(req.streamChunkId.streamId);

logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId);

ManagedBuffer buf;
try {
streamManager.registerChannel(channel, req.streamChunkId.streamId);
buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex);
} catch (Exception e) {
logger.error(String.format(
Expand Down