Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;

import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs;
Expand Down Expand Up @@ -81,7 +81,6 @@ public OneForOneBlockFetcher(
TransportConf transportConf,
DownloadFileManager downloadFileManager) {
this.client = client;
this.blockIds = blockIds;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to remove this from the constructor, line 78 and line 69?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or, we cannot change it because it's a public class?

Copy link
Author

@seayoun seayoun Mar 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The blockIds will be used to create OpenBlocks or ShuffleFetchBlocks later in constructor.

this.listener = listener;
this.chunkCallback = new ChunkCallback();
this.transportConf = transportConf;
Expand All @@ -90,8 +89,10 @@ public OneForOneBlockFetcher(
throw new IllegalArgumentException("Zero-sized blockIds array");
}
if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
this.message = createFetchShuffleBlocksMsg(appId, execId, blockIds);
this.blockIds = new String[blockIds.length];
this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
} else {
this.blockIds = blockIds;
this.message = new OpenBlocks(appId, execId, blockIds);
}
}
Expand All @@ -106,41 +107,53 @@ private boolean isShuffleBlocks(String[] blockIds) {
}

/**
* Analyze the pass in blockIds and create FetchShuffleBlocks message.
* The blockIds has been sorted by mapId and reduceId. It's produced in
* org.apache.spark.MapOutputTracker.convertMapStatuses.
* Create FetchShuffleBlocks message and rebuild internal blockIds by
* analyzing the pass in blockIds.
*/
private FetchShuffleBlocks createFetchShuffleBlocksMsg(
private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
String appId, String execId, String[] blockIds) {
String[] firstBlock = splitBlockId(blockIds[0]);
int shuffleId = Integer.parseInt(firstBlock[1]);
boolean batchFetchEnabled = firstBlock.length == 5;

HashMap<Long, ArrayList<Integer>> mapIdToReduceIds = new HashMap<>();
LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
for (String blockId : blockIds) {
String[] blockIdParts = splitBlockId(blockId);
if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
", got:" + blockId);
}
long mapId = Long.parseLong(blockIdParts[2]);
if (!mapIdToReduceIds.containsKey(mapId)) {
mapIdToReduceIds.put(mapId, new ArrayList<>());
if (!mapIdToBlocksInfo.containsKey(mapId)) {
mapIdToBlocksInfo.put(mapId, new BlocksInfo());
}
mapIdToReduceIds.get(mapId).add(Integer.parseInt(blockIdParts[3]));
BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
blocksInfoByMapId.blockIds.add(blockId);
blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
if (batchFetchEnabled) {
// When we read continuous shuffle blocks in batch, we will reuse reduceIds in
// FetchShuffleBlocks to store the start and end reduce id for range
// [startReduceId, endReduceId).
assert(blockIdParts.length == 5);
mapIdToReduceIds.get(mapId).add(Integer.parseInt(blockIdParts[4]));
blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
}
}
long[] mapIds = Longs.toArray(mapIdToReduceIds.keySet());
long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
int[][] reduceIdArr = new int[mapIds.length][];
int blockIdIndex = 0;
for (int i = 0; i < mapIds.length; i++) {
reduceIdArr[i] = Ints.toArray(mapIdToReduceIds.get(mapIds[i]));
BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]);
reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds);

// The `blockIds`'s order must be same with the read order specified in in FetchShuffleBlocks
// because the shuffle data's return order should match the `blockIds`'s order to ensure
// blockId and data match.
for (int j = 0; j < blocksInfoByMapId.blockIds.size(); j++) {
this.blockIds[blockIdIndex++] = blocksInfoByMapId.blockIds.get(j);
}
}
assert(blockIdIndex == this.blockIds.length);

return new FetchShuffleBlocks(
appId, execId, shuffleId, mapIds, reduceIdArr, batchFetchEnabled);
}
Expand All @@ -157,6 +170,18 @@ private String[] splitBlockId(String blockId) {
return blockIdParts;
}

/** The reduceIds and blocks in a single mapId */
private class BlocksInfo {

final ArrayList<Integer> reduceIds;
final ArrayList<String> blockIds;

BlocksInfo() {
this.reduceIds = new ArrayList<>();
this.blockIds = new ArrayList<>();
}
}

/** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */
private class ChunkCallback implements ChunkReceivedCallback {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,48 @@ public void testEmptyBlockFetch() {
}
}

@Test
public void testFetchShuffleBlocksOrder() {
LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
blocks.put("shuffle_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[1])));
blocks.put("shuffle_0_2_1", new NioManagedBuffer(ByteBuffer.wrap(new byte[2])));
blocks.put("shuffle_0_10_2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[3])));
String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);

BlockFetchingListener listener = fetchBlocks(
blocks,
blockIds,
new FetchShuffleBlocks("app-id", "exec-id", 0,
new long[]{0, 2, 10}, new int[][]{{0}, {1}, {2}}, false),
conf);

for (int chunkIndex = 0; chunkIndex < blockIds.length; chunkIndex++) {
String blockId = blockIds[chunkIndex];
verify(listener).onBlockFetchSuccess(blockId, blocks.get(blockId));
}
}

@Test
public void testBatchFetchShuffleBlocksOrder() {
LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
blocks.put("shuffle_0_0_1_2", new NioManagedBuffer(ByteBuffer.wrap(new byte[1])));
blocks.put("shuffle_0_2_2_3", new NioManagedBuffer(ByteBuffer.wrap(new byte[2])));
blocks.put("shuffle_0_10_3_4", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[3])));
String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);

BlockFetchingListener listener = fetchBlocks(
blocks,
blockIds,
new FetchShuffleBlocks("app-id", "exec-id", 0,
new long[]{0, 2, 10}, new int[][]{{1, 2}, {2, 3}, {3, 4}}, true),
conf);

for (int chunkIndex = 0; chunkIndex < blockIds.length; chunkIndex++) {
String blockId = blockIds[chunkIndex];
verify(listener).onBlockFetchSuccess(blockId, blocks.get(blockId));
}
}

/**
* Begins a fetch on the given set of blocks by mocking out the server side of the RPC which
* simply returns the given (BlockId, Block) pairs.
Expand Down