Skip to content

Commit b35d238

Browse files
yifeihbulldozer-bot[bot]
authored andcommitted
[SPARK-25299] shuffle reader API (apache#523)
1 parent 5398f3b commit b35d238

File tree

18 files changed

+448
-103
lines changed

18 files changed

+448
-103
lines changed
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.api.shuffle;
19+
20+
import org.apache.spark.api.java.Optional;
21+
22+
import java.util.Objects;
23+
24+
/**
25+
* :: Experimental ::
26+
* An object defining the shuffle block and length metadata associated with the block.
27+
* @since 3.0.0
28+
*/
29+
public class ShuffleBlockInfo {
30+
private final int shuffleId;
31+
private final int mapId;
32+
private final int reduceId;
33+
private final long length;
34+
private final Optional<ShuffleLocation> shuffleLocation;
35+
36+
public ShuffleBlockInfo(int shuffleId, int mapId, int reduceId, long length,
37+
Optional<ShuffleLocation> shuffleLocation) {
38+
this.shuffleId = shuffleId;
39+
this.mapId = mapId;
40+
this.reduceId = reduceId;
41+
this.length = length;
42+
this.shuffleLocation = shuffleLocation;
43+
}
44+
45+
public int getShuffleId() {
46+
return shuffleId;
47+
}
48+
49+
public int getMapId() {
50+
return mapId;
51+
}
52+
53+
public int getReduceId() {
54+
return reduceId;
55+
}
56+
57+
public long getLength() {
58+
return length;
59+
}
60+
61+
public Optional<ShuffleLocation> getShuffleLocation() {
62+
return shuffleLocation;
63+
}
64+
65+
@Override
66+
public boolean equals(Object other) {
67+
return other instanceof ShuffleBlockInfo
68+
&& shuffleId == ((ShuffleBlockInfo) other).shuffleId
69+
&& mapId == ((ShuffleBlockInfo) other).mapId
70+
&& reduceId == ((ShuffleBlockInfo) other).reduceId
71+
&& length == ((ShuffleBlockInfo) other).length
72+
&& Objects.equals(shuffleLocation, ((ShuffleBlockInfo) other).shuffleLocation);
73+
}
74+
75+
@Override
76+
public int hashCode() {
77+
return Objects.hash(shuffleId, mapId, reduceId, length, shuffleLocation);
78+
}
79+
}

core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,6 @@ public interface ShuffleExecutorComponents {
3030
void initializeExecutor(String appId, String execId);
3131

3232
ShuffleWriteSupport writes();
33+
34+
ShuffleReadSupport reads();
3335
}

core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,4 @@
2121
* Marker interface representing a location of a shuffle block. Implementations of shuffle readers
2222
* and writers are expected to cast this down to an implementation-specific representation.
2323
*/
24-
public interface ShuffleLocation {
25-
}
24+
public interface ShuffleLocation {}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.api.shuffle;
19+
20+
import org.apache.spark.annotation.Experimental;
21+
22+
import java.io.IOException;
23+
import java.io.InputStream;
24+
25+
/**
26+
* :: Experimental ::
27+
* An interface for reading shuffle records.
28+
* @since 3.0.0
29+
*/
30+
@Experimental
31+
public interface ShuffleReadSupport {
32+
/**
33+
* Returns an underlying {@link Iterable<InputStream>} that will iterate
34+
* through shuffle data, given an iterable for the shuffle blocks to fetch.
35+
*/
36+
Iterable<InputStream> getPartitionReaders(Iterable<ShuffleBlockInfo> blockMetadata)
37+
throws IOException;
38+
}

core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,24 @@
1717

1818
package org.apache.spark.shuffle.sort.io;
1919

20+
import org.apache.spark.MapOutputTracker;
2021
import org.apache.spark.SparkConf;
2122
import org.apache.spark.SparkEnv;
2223
import org.apache.spark.api.shuffle.ShuffleExecutorComponents;
24+
import org.apache.spark.api.shuffle.ShuffleReadSupport;
2325
import org.apache.spark.api.shuffle.ShuffleWriteSupport;
26+
import org.apache.spark.serializer.SerializerManager;
2427
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
28+
import org.apache.spark.shuffle.io.DefaultShuffleReadSupport;
2529
import org.apache.spark.storage.BlockManager;
2630

2731
public class DefaultShuffleExecutorComponents implements ShuffleExecutorComponents {
2832

2933
private final SparkConf sparkConf;
3034
private BlockManager blockManager;
3135
private IndexShuffleBlockResolver blockResolver;
36+
private MapOutputTracker mapOutputTracker;
37+
private SerializerManager serializerManager;
3238

3339
public DefaultShuffleExecutorComponents(SparkConf sparkConf) {
3440
this.sparkConf = sparkConf;
@@ -37,15 +43,30 @@ public DefaultShuffleExecutorComponents(SparkConf sparkConf) {
3743
@Override
3844
public void initializeExecutor(String appId, String execId) {
3945
blockManager = SparkEnv.get().blockManager();
46+
mapOutputTracker = SparkEnv.get().mapOutputTracker();
47+
serializerManager = SparkEnv.get().serializerManager();
4048
blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager);
4149
}
4250

4351
@Override
4452
public ShuffleWriteSupport writes() {
53+
checkInitialized();
54+
return new DefaultShuffleWriteSupport(sparkConf, blockResolver, blockManager.shuffleServerId());
55+
}
56+
57+
@Override
58+
public ShuffleReadSupport reads() {
59+
checkInitialized();
60+
return new DefaultShuffleReadSupport(blockManager,
61+
mapOutputTracker,
62+
serializerManager,
63+
sparkConf);
64+
}
65+
66+
private void checkInitialized() {
4567
if (blockResolver == null) {
4668
throw new IllegalStateException(
47-
"Executor components must be initialized before getting writers.");
69+
"Executor components must be initialized before getting writers.");
4870
}
49-
return new DefaultShuffleWriteSupport(sparkConf, blockResolver, blockManager.shuffleServerId());
5071
}
5172
}

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
283283

284284
// For testing
285285
def getMapSizesByShuffleLocation(shuffleId: Int, reduceId: Int)
286-
: Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = {
286+
: Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = {
287287
getMapSizesByShuffleLocation(shuffleId, reduceId, reduceId + 1)
288288
}
289289

@@ -297,7 +297,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
297297
* describing the shuffle blocks that are stored at that block manager.
298298
*/
299299
def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int)
300-
: Iterator[(ShuffleLocation, Seq[(BlockId, Long)])]
300+
: Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])]
301301

302302
/**
303303
* Deletes map output status information for the specified shuffle stage.
@@ -647,7 +647,7 @@ private[spark] class MapOutputTrackerMaster(
647647
// Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result.
648648
// This method is only called in local-mode.
649649
def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int)
650-
: Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = {
650+
: Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = {
651651
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
652652
shuffleStatuses.get(shuffleId) match {
653653
case Some (shuffleStatus) =>
@@ -684,7 +684,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
684684

685685
// Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result.
686686
override def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int)
687-
: Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = {
687+
: Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = {
688688
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
689689
val statuses = getStatuses(shuffleId)
690690
try {
@@ -873,9 +873,9 @@ private[spark] object MapOutputTracker extends Logging {
873873
shuffleId: Int,
874874
startPartition: Int,
875875
endPartition: Int,
876-
statuses: Array[MapStatus]): Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = {
876+
statuses: Array[MapStatus]): Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = {
877877
assert (statuses != null)
878-
val splitsByAddress = new HashMap[ShuffleLocation, ListBuffer[(BlockId, Long)]]
878+
val splitsByAddress = new HashMap[Option[ShuffleLocation], ListBuffer[(BlockId, Long)]]
879879
for ((status, mapId) <- statuses.iterator.zipWithIndex) {
880880
if (status == null) {
881881
val errorMessage = s"Missing an output location for shuffle $shuffleId"
@@ -885,9 +885,14 @@ private[spark] object MapOutputTracker extends Logging {
885885
for (part <- startPartition until endPartition) {
886886
val size = status.getSizeForBlock(part)
887887
if (size != 0) {
888-
val shuffleLoc = status.mapShuffleLocations.getLocationForBlock(part)
889-
splitsByAddress.getOrElseUpdate(shuffleLoc, ListBuffer()) +=
888+
if (status.mapShuffleLocations == null) {
889+
splitsByAddress.getOrElseUpdate(Option.empty, ListBuffer()) +=
890890
((ShuffleBlockId(shuffleId, mapId, part), size))
891+
} else {
892+
val shuffleLoc = status.mapShuffleLocations.getLocationForBlock(part)
893+
splitsByAddress.getOrElseUpdate(Option.apply(shuffleLoc), ListBuffer()) +=
894+
((ShuffleBlockId(shuffleId, mapId, part), size))
895+
}
891896
}
892897
}
893898
}

core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ class TaskMetrics private[spark] () extends Serializable {
5656
private val _diskBytesSpilled = new LongAccumulator
5757
private val _peakExecutionMemory = new LongAccumulator
5858
private val _updatedBlockStatuses = new CollectionAccumulator[(BlockId, BlockStatus)]
59+
private var _decorFunc: TempShuffleReadMetrics => TempShuffleReadMetrics =
60+
Predef.identity[TempShuffleReadMetrics]
5961

6062
/**
6163
* Time taken on the executor to deserialize this task.
@@ -187,11 +189,17 @@ class TaskMetrics private[spark] () extends Serializable {
187189
* be lost.
188190
*/
189191
private[spark] def createTempShuffleReadMetrics(): TempShuffleReadMetrics = synchronized {
190-
val readMetrics = new TempShuffleReadMetrics
191-
tempShuffleReadMetrics += readMetrics
192+
val tempShuffleMetrics = new TempShuffleReadMetrics
193+
val readMetrics = _decorFunc(tempShuffleMetrics)
194+
tempShuffleReadMetrics += tempShuffleMetrics
192195
readMetrics
193196
}
194197

198+
private[spark] def decorateTempShuffleReadMetrics(
199+
decorFunc: TempShuffleReadMetrics => TempShuffleReadMetrics): Unit = synchronized {
200+
_decorFunc = decorFunc
201+
}
202+
195203
/**
196204
* Merge values across all temporary [[ShuffleReadMetrics]] into `_shuffleReadMetrics`.
197205
* This is expected to be called on executor heartbeat and at the end of a task.

core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala

Lines changed: 62 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,18 @@
1717

1818
package org.apache.spark.shuffle
1919

20+
import java.io.InputStream
21+
22+
import scala.collection.JavaConverters._
23+
2024
import org.apache.spark._
25+
import org.apache.spark.api.java.Optional
26+
import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport}
2127
import org.apache.spark.internal.{config, Logging}
28+
import org.apache.spark.io.CompressionCodec
2229
import org.apache.spark.serializer.SerializerManager
23-
import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations
24-
import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator}
30+
import org.apache.spark.shuffle.io.DefaultShuffleReadSupport
31+
import org.apache.spark.storage.{ShuffleBlockFetcherIterator, ShuffleBlockId}
2532
import org.apache.spark.util.CompletionIterator
2633
import org.apache.spark.util.collection.ExternalSorter
2734

@@ -35,40 +42,68 @@ private[spark] class BlockStoreShuffleReader[K, C](
3542
endPartition: Int,
3643
context: TaskContext,
3744
readMetrics: ShuffleReadMetricsReporter,
45+
shuffleReadSupport: ShuffleReadSupport,
3846
serializerManager: SerializerManager = SparkEnv.get.serializerManager,
39-
blockManager: BlockManager = SparkEnv.get.blockManager,
40-
mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
47+
mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker,
48+
sparkConf: SparkConf = SparkEnv.get.conf)
4149
extends ShuffleReader[K, C] with Logging {
4250

4351
private val dep = handle.dependency
4452

53+
private val compressionCodec = CompressionCodec.createCodec(sparkConf)
54+
55+
private val compressShuffle = sparkConf.get(config.SHUFFLE_COMPRESS)
56+
4557
/** Read the combined key-values for this reduce task */
4658
override def read(): Iterator[Product2[K, C]] = {
47-
val wrappedStreams = new ShuffleBlockFetcherIterator(
48-
context,
49-
blockManager.shuffleClient,
50-
blockManager,
51-
mapOutputTracker.getMapSizesByShuffleLocation(handle.shuffleId, startPartition, endPartition)
52-
.map {
53-
case (loc: DefaultMapShuffleLocations, blocks: Seq[(BlockId, Long)]) =>
54-
(loc.getBlockManagerId, blocks)
55-
case _ =>
56-
throw new UnsupportedOperationException("Not allowed to using non-default map shuffle" +
57-
" locations yet.")
58-
},
59-
serializerManager.wrapStream,
60-
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
61-
SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024,
62-
SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT),
63-
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
64-
SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
65-
SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT),
66-
readMetrics).toCompletionIterator
59+
val streamsIterator =
60+
shuffleReadSupport.getPartitionReaders(new Iterable[ShuffleBlockInfo] {
61+
override def iterator: Iterator[ShuffleBlockInfo] = {
62+
mapOutputTracker
63+
.getMapSizesByShuffleLocation(handle.shuffleId, startPartition, endPartition)
64+
.flatMap { shuffleLocationInfo =>
65+
shuffleLocationInfo._2.map { blockInfo =>
66+
val block = blockInfo._1.asInstanceOf[ShuffleBlockId]
67+
new ShuffleBlockInfo(
68+
block.shuffleId,
69+
block.mapId,
70+
block.reduceId,
71+
blockInfo._2,
72+
Optional.ofNullable(shuffleLocationInfo._1.orNull))
73+
}
74+
}
75+
}
76+
}.asJava).iterator()
6777

68-
val serializerInstance = dep.serializer.newInstance()
78+
val retryingWrappedStreams = new Iterator[InputStream] {
79+
override def hasNext: Boolean = streamsIterator.hasNext
6980

70-
// Create a key/value iterator for each stream
71-
val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
81+
override def next(): InputStream = {
82+
var returnStream: InputStream = null
83+
while (streamsIterator.hasNext && returnStream == null) {
84+
if (shuffleReadSupport.isInstanceOf[DefaultShuffleReadSupport]) {
85+
// The default implementation checks for corrupt streams, so it will already have
86+
// decompressed/decrypted the bytes
87+
returnStream = streamsIterator.next()
88+
} else {
89+
val nextStream = streamsIterator.next()
90+
returnStream = if (compressShuffle) {
91+
compressionCodec.compressedInputStream(
92+
serializerManager.wrapForEncryption(nextStream))
93+
} else {
94+
serializerManager.wrapForEncryption(nextStream)
95+
}
96+
}
97+
}
98+
if (returnStream == null) {
99+
throw new IllegalStateException("Expected shuffle reader iterator to return a stream")
100+
}
101+
returnStream
102+
}
103+
}
104+
105+
val serializerInstance = dep.serializer.newInstance()
106+
val recordIter = retryingWrappedStreams.flatMap { wrappedStream =>
72107
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
73108
// NextIterator. The NextIterator makes sure that close() is called on the
74109
// underlying InputStream when all records have been read.

0 commit comments

Comments
 (0)