Skip to content

Commit 06ea01a

Browse files
committed
Address more comments.
1 parent 2d29404 commit 06ea01a

File tree

6 files changed

+57
-99
lines changed

6 files changed

+57
-99
lines changed

core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.shuffle.api;
1919

2020
import java.io.IOException;
21+
2122
import org.apache.spark.annotation.Private;
2223

2324
/**

core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,10 @@
1818
package org.apache.spark.shuffle.api;
1919

2020
import java.io.IOException;
21+
import java.util.Optional;
2122
import java.io.OutputStream;
22-
import java.nio.channels.Channels;
2323

2424
import org.apache.spark.annotation.Private;
25-
import org.apache.spark.shuffle.sort.io.DefaultWritableByteChannelWrapper;
2625

2726
/**
2827
* :: Private ::
@@ -67,25 +66,24 @@ public interface ShufflePartitionWriter {
6766
* Implementations that intend on combining the bytes for all the partitions written by this
6867
* map task should reuse the same channel instance across all the partition writers provided
6968
* by the parent {@link ShuffleMapOutputWriter}. If one does so, ensure that
70-
* {@link WritableByteChannelWrapper#close()} does not close the resource, since it
69+
* {@link WritableByteChannelWrapper#close()} does not close the resource, since the channel
7170
* will be reused across partition writes. The underlying resources should be cleaned up in
7271
* {@link ShuffleMapOutputWriter#commitAllPartitions()} and
7372
* {@link ShuffleMapOutputWriter#abort(Throwable)}.
7473
* <p>
7574
* This method is primarily for advanced optimizations where bytes can be copied from the input
76-
* spill files to the output channel without copying data into memory.
77-
* <p>
78-
* The default implementation should be sufficient for most situations. Only override this
79-
* method if there is a very specific optimization that needs to be built.
75+
* spill files to the output channel without copying data into memory. If such optimizations are
76+
* not supported, the implementation should return {@link Optional#empty()}. By default, the
77+
* implementation returns {@link Optional#empty()}.
8078
* <p>
8179
* Note that the returned {@link WritableByteChannelWrapper} itself is closed, but not the
8280
* underlying channel that is returned by {@link WritableByteChannelWrapper#channel()}. Ensure
8381
* that the underlying channel is cleaned up in {@link WritableByteChannelWrapper#close()},
8482
* {@link ShuffleMapOutputWriter#commitAllPartitions()}, or
8583
* {@link ShuffleMapOutputWriter#abort(Throwable)}.
8684
*/
87-
default WritableByteChannelWrapper openChannelWrapper() throws IOException {
88-
return new DefaultWritableByteChannelWrapper(Channels.newChannel(openStream()));
85+
default Optional<WritableByteChannelWrapper> openChannelWrapper() throws IOException {
86+
return Optional.empty();
8987
}
9088

9189
/**

core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.io.IOException;
2323
import java.io.OutputStream;
2424
import java.nio.channels.FileChannel;
25+
import java.util.Optional;
2526
import javax.annotation.Nullable;
2627

2728
import scala.None$;
@@ -205,45 +206,23 @@ private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) thro
205206
final File file = partitionWriterSegments[i].file();
206207
ShufflePartitionWriter writer = mapOutputWriter.getPartitionWriter(i);
207208
if (file.exists()) {
208-
boolean copyThrewException = true;
209209
if (transferToEnabled) {
210-
FileInputStream in = new FileInputStream(file);
211210
// Using WritableByteChannelWrapper to make resource closing consistent between
212211
// this implementation and UnsafeShuffleWriter.
213-
try {
214-
WritableByteChannelWrapper outputChannel = writer.openChannelWrapper();
215-
try (FileChannel inputChannel = in.getChannel()) {
216-
Utils.copyFileStreamNIO(
217-
inputChannel, outputChannel.channel(), 0L, inputChannel.size());
218-
copyThrewException = false;
219-
} finally {
220-
Closeables.close(outputChannel, copyThrewException);
221-
}
222-
} finally {
223-
Closeables.close(in, copyThrewException);
212+
Optional<WritableByteChannelWrapper> maybeOutputChannel = writer.openChannelWrapper();
213+
if (maybeOutputChannel.isPresent()) {
214+
writePartitionedDataWithChannel(file, maybeOutputChannel.get());
215+
} else {
216+
writePartitionedDataWithStream(file, writer);
224217
}
225218
} else {
226-
FileInputStream in = new FileInputStream(file);
227-
OutputStream outputStream;
228-
try {
229-
outputStream = writer.openStream();
230-
try {
231-
Utils.copyStream(in, outputStream, false, false);
232-
copyThrewException = false;
233-
} finally {
234-
Closeables.close(outputStream, copyThrewException);
235-
}
236-
} finally {
237-
Closeables.close(in, copyThrewException);
238-
}
219+
writePartitionedDataWithStream(file, writer);
239220
}
240221
if (!file.delete()) {
241222
logger.error("Unable to delete file for partition {}", i);
242223
}
243224
}
244-
long numBytesWritten = writer.getNumBytesWritten();
245-
lengths[i] = numBytesWritten;
246-
writeMetrics.incBytesWritten(numBytesWritten);
225+
lengths[i] = writer.getNumBytesWritten();
247226
}
248227
} finally {
249228
writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
@@ -252,6 +231,41 @@ private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) thro
252231
return lengths;
253232
}
254233

234+
private void writePartitionedDataWithChannel(
235+
File file, WritableByteChannelWrapper outputChannel) throws IOException {
236+
boolean copyThrewException = true;
237+
try {
238+
FileInputStream in = new FileInputStream(file);
239+
try (FileChannel inputChannel = in.getChannel()) {
240+
Utils.copyFileStreamNIO(
241+
inputChannel, outputChannel.channel(), 0L, inputChannel.size());
242+
copyThrewException = false;
243+
} finally {
244+
Closeables.close(in, copyThrewException);
245+
}
246+
} finally {
247+
Closeables.close(outputChannel, copyThrewException);
248+
}
249+
}
250+
251+
private void writePartitionedDataWithStream(File file, ShufflePartitionWriter writer)
252+
throws IOException {
253+
boolean copyThrewException = true;
254+
FileInputStream in = new FileInputStream(file);
255+
OutputStream outputStream;
256+
try {
257+
outputStream = writer.openStream();
258+
try {
259+
Utils.copyStream(in, outputStream, false, false);
260+
copyThrewException = false;
261+
} finally {
262+
Closeables.close(outputStream, copyThrewException);
263+
}
264+
} finally {
265+
Closeables.close(in, copyThrewException);
266+
}
267+
}
268+
255269
@Override
256270
public Option<MapStatus> stop(boolean success) {
257271
if (stopping) {

core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultWritableByteChannelWrapper.java

Lines changed: 0 additions & 49 deletions
This file was deleted.

core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.nio.channels.FileChannel;
2626
import java.nio.channels.WritableByteChannel;
2727

28+
import java.util.Optional;
2829
import org.slf4j.Logger;
2930
import org.slf4j.LoggerFactory;
3031

@@ -164,7 +165,7 @@ public OutputStream openStream() throws IOException {
164165
}
165166

166167
@Override
167-
public WritableByteChannelWrapper openChannelWrapper() throws IOException {
168+
public Optional<WritableByteChannelWrapper> openChannelWrapper() throws IOException {
168169
if (partChannel == null) {
169170
if (partStream != null) {
170171
throw new IllegalStateException("Requested an output stream for a previous write but" +
@@ -174,7 +175,7 @@ public WritableByteChannelWrapper openChannelWrapper() throws IOException {
174175
initChannel();
175176
partChannel = new PartitionWriterChannel(partitionId);
176177
}
177-
return partChannel;
178+
return Optional.of(partChannel);
178179
}
179180

180181
@Override

core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.shuffle.sort.io
1919

20-
import java.io.{File, FileInputStream, FileOutputStream}
20+
import java.io.{File, FileInputStream}
2121
import java.nio.channels.FileChannel
2222
import java.nio.file.Files
2323
import java.util.Arrays
@@ -30,7 +30,6 @@ import org.mockito.MockitoAnnotations
3030
import org.scalatest.BeforeAndAfterEach
3131

3232
import org.apache.spark.{SparkConf, SparkFunSuite}
33-
import org.apache.spark.executor.ShuffleWriteMetrics
3433
import org.apache.spark.shuffle.IndexShuffleBlockResolver
3534
import org.apache.spark.util.Utils
3635

@@ -39,9 +38,6 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA
3938
@Mock(answer = RETURNS_SMART_NULLS)
4039
private var blockResolver: IndexShuffleBlockResolver = _
4140

42-
@Mock(answer = RETURNS_SMART_NULLS)
43-
private var shuffleWriteMetrics: ShuffleWriteMetrics = _
44-
4541
private val NUM_PARTITIONS = 4
4642
private val data: Array[Array[Byte]] = (0 until NUM_PARTITIONS).map { p =>
4743
if (p == 3) {
@@ -93,7 +89,6 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA
9389
0,
9490
0,
9591
NUM_PARTITIONS,
96-
shuffleWriteMetrics,
9792
blockResolver,
9893
conf)
9994
}
@@ -116,13 +111,11 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA
116111
(0 until NUM_PARTITIONS).foreach { p =>
117112
val writer = mapOutputWriter.getPartitionWriter(p)
118113
val outputTempFile = File.createTempFile("channelTemp", "", tempDir)
119-
val outputTempFileStream = new FileOutputStream(outputTempFile)
120-
outputTempFileStream.write(data(p))
121-
outputTempFileStream.close()
114+
Files.write(outputTempFile.toPath, data(p))
122115
val tempFileInput = new FileInputStream(outputTempFile)
123116
val channel = writer.openChannelWrapper()
124117
Utils.tryWithResource(new FileInputStream(outputTempFile)) { tempFileInput =>
125-
Utils.tryWithResource(writer.openChannelWrapper()) { channelWrapper =>
118+
Utils.tryWithResource(writer.openChannelWrapper().get) { channelWrapper =>
126119
assert(channelWrapper.channel().isInstanceOf[FileChannel],
127120
"Underlying channel should be a file channel")
128121
Utils.copyFileStreamNIO(

0 commit comments

Comments
 (0)