Skip to content
Prev Previous commit
Next Next commit
replace a single map instead of two in createFetchShuffleBlocksMsgAnd…
…BuildBlockIds
  • Loading branch information
yuhaiyang authored and yuhaiyang committed Mar 2, 2021
commit 575ca5e8d5ce5e96f8661f0ef702d820a3c2e683
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.HashMap;

import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs;
Expand Down Expand Up @@ -116,51 +116,40 @@ private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
int shuffleId = Integer.parseInt(firstBlock[1]);
boolean batchFetchEnabled = firstBlock.length == 5;

LinkedHashMap<Long, ArrayList<Integer>> mapIdToReduceIds = new LinkedHashMap<>();
LinkedHashMap<Long, ArrayList<String>> mapIdToBlockIds = new LinkedHashMap<>();
HashMap<Long, BlocksInfo> mapIdToBlocksInfo = new HashMap<>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we use linked hash map to eliminate the randomness of map id array? It's not a hard requirement but it seems better if the block ids are the same with and without the new shuffle protocol.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to be better.

for (String blockId : blockIds) {
String[] blockIdParts = splitBlockId(blockId);
if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
", got:" + blockId);
}
long mapId = Long.parseLong(blockIdParts[2]);
if (!mapIdToReduceIds.containsKey(mapId)) {
mapIdToReduceIds.put(mapId, new ArrayList<>());
if (!mapIdToBlocksInfo.containsKey(mapId)) {
mapIdToBlocksInfo.put(mapId, new BlocksInfo(new ArrayList<>(), new ArrayList<>()));
}
if (!mapIdToBlockIds.containsKey(mapId)) {
mapIdToBlockIds.put(mapId, new ArrayList<>());
}
mapIdToReduceIds.get(mapId).add(Integer.parseInt(blockIdParts[3]));
mapIdToBlockIds.get(mapId).add(blockId);
BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
blocksInfoByMapId.blockIds.add(blockId);
blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
if (batchFetchEnabled) {
// When we read continuous shuffle blocks in batch, we will reuse reduceIds in
// FetchShuffleBlocks to store the start and end reduce id for range
// [startReduceId, endReduceId).
assert(blockIdParts.length == 5);
mapIdToReduceIds.get(mapId).add(Integer.parseInt(blockIdParts[4]));
blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
}
}
long[] mapIds = Longs.toArray(mapIdToReduceIds.keySet());
long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
int[][] reduceIdArr = new int[mapIds.length][];
int blockIdIndex = 0;
for (int i = 0; i < mapIds.length; i++) {
reduceIdArr[i] = Ints.toArray(mapIdToReduceIds.get(mapIds[i]));
}
BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]);
reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds);

// Fill internal `blockIds` by the read order in using `FetchShuffleBlocks`
long[] blockMapIds = Longs.toArray(mapIdToBlockIds.keySet());
assert(mapIds.length == blockMapIds.length);
int blockIdIndex = 0;
for (int i = 0; i < blockMapIds.length; i++) {
// The `blockIds`'s order must be same with the read order specified in in FetchShuffleBlocks
// because the shuffle data's return order should match the `blockIds`'s order to ensure
// blockId and data match.
long blockMapId = blockMapIds[i];
// The keys in `mapIdToReduceIds` and `mapIdToBlockIds` should be same order.
assert(blockMapId == mapIds[i]);
ArrayList<String> blockIdsByMapId = mapIdToBlockIds.get(blockMapId);
for (int j = 0; j < blockIdsByMapId.size(); j++) {
this.blockIds[blockIdIndex++] = blockIdsByMapId.get(i);
for (int j = 0; j < blocksInfoByMapId.blockIds.size(); j++) {
this.blockIds[blockIdIndex++] = blocksInfoByMapId.blockIds.get(j);
}
}
assert(blockIdIndex == this.blockIds.length);
Expand All @@ -181,6 +170,18 @@ private String[] splitBlockId(String blockId) {
return blockIdParts;
}

/** The reduceIds and blocks in a single mapId */
private class BlocksInfo {

ArrayList<Integer> reduceIds;
ArrayList<String> blockIds;

public BlocksInfo(ArrayList<Integer> reduceIds, ArrayList<String> blockIds) {
this.reduceIds = reduceIds;
this.blockIds = blockIds;
}
}

/** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */
private class ChunkCallback implements ChunkReceivedCallback {
@Override
Expand Down