diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java index 0e2355646465..9363efc58d7c 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java @@ -30,7 +30,6 @@ import java.util.Arrays; import java.util.Collection; import java.util.Iterator; -import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentMap; @@ -445,9 +444,9 @@ static class PushBlockStreamCallback implements StreamCallbackWithID { private final AppShufflePartitionInfo partitionInfo; private int length = 0; // This indicates that this stream got the opportunity to write the blocks to the merged file. - // Once this is set to true and the stream encounters a failure then it will take necessary - // action to overwrite any partial written data. This is reset to false when the stream - // completes without any failures. + // Once this is set to true and the stream encounters a failure then it will unset the + // currentMapId of the partition so that another stream can start merging the blocks to the + // partition. This is reset to false when the stream completes. private boolean isWriting = false; // Use on-heap instead of direct ByteBuffer since these buffers will be GC'ed very quickly private List deferredBufs; @@ -477,16 +476,11 @@ public String getID() { */ private void writeBuf(ByteBuffer buf) throws IOException { while (buf.hasRemaining()) { - if (partitionInfo.isEncounteredFailure()) { - long updatedPos = partitionInfo.getDataFilePos() + length; - logger.debug( - "{} shuffleId {} reduceId {} encountered failure current pos {} updated pos {}", - partitionInfo.appShuffleId.appId, partitionInfo.appShuffleId.shuffleId, - partitionInfo.reduceId, partitionInfo.getDataFilePos(), updatedPos); - length += partitionInfo.dataChannel.write(buf, updatedPos); - } else { - length += partitionInfo.dataChannel.write(buf); - } + long updatedPos = partitionInfo.getDataFilePos() + length; + logger.debug("{} shuffleId {} reduceId {} current pos {} updated pos {}", + partitionInfo.appShuffleId.appId, partitionInfo.appShuffleId.shuffleId, + partitionInfo.reduceId, partitionInfo.getDataFilePos(), updatedPos); + length += partitionInfo.dataChannel.write(buf, updatedPos); } } @@ -581,7 +575,6 @@ public void onData(String streamId, ByteBuffer buf) throws IOException { } // Check whether we can write to disk if (allowedToWrite()) { - isWriting = true; // Identify duplicate block generated by speculative tasks. We respond success to // the client in cases of duplicate even though no data is written. if (isDuplicateBlock()) { @@ -598,6 +591,7 @@ public void onData(String streamId, ByteBuffer buf) throws IOException { // If we got here, it's safe to write the block data to the merged shuffle file. We // first write any deferred block. + isWriting = true; try { if (deferredBufs != null && !deferredBufs.isEmpty()) { writeDeferredBufs(); @@ -609,16 +603,6 @@ public void onData(String streamId, ByteBuffer buf) throws IOException { // back to the client so the block could be retried. throw ioe; } - // If we got here, it means we successfully write the current chunk of block to merged - // shuffle file. If we encountered failure while writing the previous block, we should - // reset the file channel position and the status of partitionInfo to indicate that we - // have recovered from previous disk write failure. However, we do not update the - // position tracked by partitionInfo here. That is only updated while the entire block - // is successfully written to merged shuffle file. - if (partitionInfo.isEncounteredFailure()) { - partitionInfo.dataChannel.position(partitionInfo.getDataFilePos() + length); - partitionInfo.setEncounteredFailure(false); - } } else { logger.trace("{} shuffleId {} reduceId {} onData deferred", partitionInfo.appShuffleId.appId, partitionInfo.appShuffleId.shuffleId, @@ -639,7 +623,7 @@ public void onData(String streamId, ByteBuffer buf) throws IOException { // written to disk due to this reason. We thus decide to optimize for server // throughput and memory usage. if (deferredBufs == null) { - deferredBufs = new LinkedList<>(); + deferredBufs = new ArrayList<>(); } // Write the buffer to the in-memory deferred cache. Since buf is a slice of a larger // byte buffer, we cache only the relevant bytes not the entire large buffer to save @@ -670,7 +654,6 @@ public void onComplete(String streamId) throws IOException { } // Check if we can commit this block if (allowedToWrite()) { - isWriting = true; // Identify duplicate block generated by speculative tasks. We respond success to // the client in cases of duplicate even though no data is written. if (isDuplicateBlock()) { @@ -681,6 +664,7 @@ public void onComplete(String streamId) throws IOException { try { if (deferredBufs != null && !deferredBufs.isEmpty()) { abortIfNecessary(); + isWriting = true; writeDeferredBufs(); } } catch (IOException ioe) { @@ -738,14 +722,14 @@ public void onFailure(String streamId, Throwable throwable) throws IOException { Map shufflePartitions = mergeManager.partitions.get(partitionInfo.appShuffleId); if (shufflePartitions != null && shufflePartitions.containsKey(partitionInfo.reduceId)) { - logger.debug("{} shuffleId {} reduceId {} set encountered failure", + logger.debug("{} shuffleId {} reduceId {} encountered failure", partitionInfo.appShuffleId.appId, partitionInfo.appShuffleId.shuffleId, partitionInfo.reduceId); partitionInfo.setCurrentMapIndex(-1); - partitionInfo.setEncounteredFailure(true); } } } + isWriting = false; } @VisibleForTesting @@ -802,8 +786,6 @@ public static class AppShufflePartitionInfo { public FileChannel dataChannel; // Location offset of the last successfully merged block for this shuffle partition private long dataFilePos; - // Indicating whether failure was encountered when merging the previous block - private boolean encounteredFailure; // Track the map index whose block is being merged for this shuffle partition private int currentMapIndex; // Bitmap tracking which mapper's blocks have been merged for this shuffle partition @@ -836,7 +818,6 @@ public static class AppShufflePartitionInfo { // Writing 0 offset so that we can reuse ShuffleIndexInformation.getIndex() updateChunkInfo(0L, -1); this.dataFilePos = 0; - this.encounteredFailure = false; this.mapTracker = new RoaringBitmap(); this.chunkTracker = new RoaringBitmap(); } @@ -851,14 +832,6 @@ public void setDataFilePos(long dataFilePos) { this.dataFilePos = dataFilePos; } - boolean isEncounteredFailure() { - return encounteredFailure; - } - - void setEncounteredFailure(boolean encounteredFailure) { - this.encounteredFailure = encounteredFailure; - } - int getCurrentMapIndex() { return currentMapIndex; } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java index 8c6f7434748e..565d433ff320 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java @@ -28,6 +28,7 @@ import java.nio.file.Paths; import java.util.Arrays; import java.util.concurrent.Semaphore; +import java.util.concurrent.ThreadLocalRandom; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; @@ -292,18 +293,32 @@ public void testTooLateArrival() throws IOException { @Test public void testIncompleteStreamsAreOverwritten() throws IOException { registerExecutor(TEST_APP, prepareLocalDirs(localDirs)); + byte[] expectedBytes = new byte[4]; + ThreadLocalRandom.current().nextBytes(expectedBytes); + StreamCallbackWithID stream1 = pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 0, 0, 0)); - stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[4])); + byte[] data = new byte[10]; + ThreadLocalRandom.current().nextBytes(data); + stream1.onData(stream1.getID(), ByteBuffer.wrap(data)); // There is a failure stream1.onFailure(stream1.getID(), new RuntimeException("forced error")); StreamCallbackWithID stream2 = pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 1, 0, 0)); - stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[5])); + ByteBuffer nextBuf= ByteBuffer.wrap(expectedBytes, 0, 2); + stream2.onData(stream2.getID(), nextBuf); stream2.onComplete(stream2.getID()); + StreamCallbackWithID stream3 = + pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 2, 0, 0)); + nextBuf = ByteBuffer.wrap(expectedBytes, 2, 2); + stream3.onData(stream3.getID(), nextBuf); + stream3.onComplete(stream3.getID()); pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, 0)); MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); - validateChunks(TEST_APP, 0, 0, blockMeta, new int[]{5}, new int[][]{{1}}); + validateChunks(TEST_APP, 0, 0, blockMeta, new int[]{4}, new int[][]{{1, 2}}); + FileSegmentManagedBuffer mb = + (FileSegmentManagedBuffer) pushResolver.getMergedBlockData(TEST_APP, 0, 0, 0); + assertArrayEquals(expectedBytes, mb.nioByteBuffer().array()); } @Test (expected = RuntimeException.class) @@ -740,6 +755,72 @@ public void testFailureWhileTruncatingFiles() throws IOException { validateChunks(TEST_APP, 0, 1, meta, new int[]{5, 3}, new int[][]{{0},{1}}); } + @Test + public void testOnFailureInvokedMoreThanOncePerBlock() throws IOException { + StreamCallbackWithID stream1 = + pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 0, 0, 0)); + stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); + stream1.onFailure(stream1.getID(), new RuntimeException("forced error")); + StreamCallbackWithID stream2 = + pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 1, 0, 0)); + stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[5])); + // On failure on stream1 gets invoked again and should cause no interference + stream1.onFailure(stream1.getID(), new RuntimeException("2nd forced error")); + StreamCallbackWithID stream3 = + pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 3, 0, 0)); + // This should be deferred as stream 2 is still the active stream + stream3.onData(stream3.getID(), ByteBuffer.wrap(new byte[2])); + // Stream 2 writes more and completes + stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[4])); + stream2.onComplete(stream2.getID()); + stream3.onComplete(stream3.getID()); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, 0)); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); + validateChunks(TEST_APP, 0, 0, blockMeta, new int[] {9, 2}, new int[][] {{1},{3}}); + removeApplication(TEST_APP); + } + + @Test (expected = RuntimeException.class) + public void testFailureAfterDuplicateBlockDoesNotInterfereActiveStream() throws IOException { + StreamCallbackWithID stream1 = + pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 0, 0, 0)); + StreamCallbackWithID stream1Duplicate = + pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 0, 0, 0)); + stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); + stream1.onComplete(stream1.getID()); + stream1Duplicate.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); + + StreamCallbackWithID stream2 = + pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 1, 0, 0)); + stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[5])); + // Should not change the current map id of the reduce partition + stream1Duplicate.onFailure(stream2.getID(), new RuntimeException("forced error")); + + StreamCallbackWithID stream3 = + pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 2, 0, 0)); + // This should be deferred as stream 2 is still the active stream + stream3.onData(stream3.getID(), ByteBuffer.wrap(new byte[2])); + RuntimeException failedEx = null; + try { + stream3.onComplete(stream3.getID()); + } catch (RuntimeException re) { + assertEquals( + "Couldn't find an opportunity to write block shufflePush_0_2_0 to merged shuffle", + re.getMessage()); + failedEx = re; + } + // Stream 2 writes more and completes + stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[4])); + stream2.onComplete(stream2.getID()); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, 0)); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); + validateChunks(TEST_APP, 0, 0, blockMeta, new int[] {11}, new int[][] {{0, 1}}); + removeApplication(TEST_APP); + if (failedEx != null) { + throw failedEx; + } + } + private void useTestFiles(boolean useTestIndexFile, boolean useTestMetaFile) throws IOException { pushResolver = new RemoteBlockPushResolver(conf) { @Override