Skip to content

Commit 7e83298

Browse files
author
Davies Liu
committed
robust shuffle writer
1 parent 3676d4c commit 7e83298

File tree

9 files changed

+258
-36
lines changed

9 files changed

+258
-36
lines changed

core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
import org.apache.spark.storage.TimeTrackingOutputStream;
5656
import org.apache.spark.unsafe.Platform;
5757
import org.apache.spark.unsafe.memory.TaskMemoryManager;
58+
import org.apache.spark.util.Utils;
5859

5960
@Private
6061
public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
@@ -217,16 +218,18 @@ void closeAndWriteOutput() throws IOException {
217218
final SpillInfo[] spills = sorter.closeAndGetSpills();
218219
sorter = null;
219220
final long[] partitionLengths;
221+
final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
222+
final File tmp = Utils.tempFileWith(output);
220223
try {
221-
partitionLengths = mergeSpills(spills);
224+
partitionLengths = mergeSpills(spills, tmp);
222225
} finally {
223226
for (SpillInfo spill : spills) {
224227
if (spill.file.exists() && ! spill.file.delete()) {
225228
logger.error("Error while deleting spill file {}", spill.file.getPath());
226229
}
227230
}
228231
}
229-
shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
232+
shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
230233
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
231234
}
232235

@@ -259,8 +262,7 @@ void forceSorterToSpill() throws IOException {
259262
*
260263
* @return the partition lengths in the merged file.
261264
*/
262-
private long[] mergeSpills(SpillInfo[] spills) throws IOException {
263-
final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId);
265+
private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException {
264266
final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true);
265267
final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf);
266268
final boolean fastMergeEnabled =

core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@ import java.util.concurrent.atomic.AtomicInteger
2323

2424
import scala.collection.JavaConversions._
2525

26-
import org.apache.spark.{Logging, SparkConf, SparkEnv}
2726
import org.apache.spark.executor.ShuffleWriteMetrics
2827
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
2928
import org.apache.spark.network.netty.SparkTransportConf
3029
import org.apache.spark.serializer.Serializer
3130
import org.apache.spark.shuffle.FileShuffleBlockResolver.ShuffleFileGroup
3231
import org.apache.spark.storage._
33-
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
3432
import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector}
33+
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils}
34+
import org.apache.spark.{Logging, SparkConf, SparkEnv}
3535

3636
/** A group of writers for a ShuffleMapTask, one writer per reducer. */
3737
private[spark] trait ShuffleWriterGroup {
@@ -124,17 +124,8 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)
124124
Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId =>
125125
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
126126
val blockFile = blockManager.diskBlockManager.getFile(blockId)
127-
// Because of previous failures, the shuffle file may already exist on this machine.
128-
// If so, remove it.
129-
if (blockFile.exists) {
130-
if (blockFile.delete()) {
131-
logInfo(s"Removed existing shuffle file $blockFile")
132-
} else {
133-
logWarning(s"Failed to remove existing shuffle file $blockFile")
134-
}
135-
}
136-
blockManager.getDiskWriter(blockId, blockFile, serializerInstance, bufferSize,
137-
writeMetrics)
127+
val tmp = Utils.tempFileWith(blockFile)
128+
blockManager.getDiskWriter(blockId, tmp, serializerInstance, bufferSize, writeMetrics)
138129
}
139130
}
140131
// Creating the file to write to and creating a disk writer both involve interacting with

core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala

Lines changed: 96 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,12 @@ import java.io._
2121

2222
import com.google.common.io.ByteStreams
2323

24-
import org.apache.spark.{SparkConf, SparkEnv}
2524
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
2625
import org.apache.spark.network.netty.SparkTransportConf
26+
import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID
2727
import org.apache.spark.storage._
2828
import org.apache.spark.util.Utils
29-
30-
import IndexShuffleBlockResolver.NOOP_REDUCE_ID
29+
import org.apache.spark.{Logging, SparkConf, SparkEnv}
3130

3231
/**
3332
* Create and maintain the shuffle blocks' mapping between logic block and physical file location.
@@ -40,9 +39,13 @@ import IndexShuffleBlockResolver.NOOP_REDUCE_ID
4039
*/
4140
// Note: Changes to the format in this file should be kept in sync with
4241
// org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getSortBasedShuffleBlockData().
43-
private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleBlockResolver {
42+
private[spark] class IndexShuffleBlockResolver(
43+
conf: SparkConf,
44+
_blockManager: BlockManager = null)
45+
extends ShuffleBlockResolver
46+
with Logging {
4447

45-
private lazy val blockManager = SparkEnv.get.blockManager
48+
private lazy val blockManager = Option(_blockManager).getOrElse(SparkEnv.get.blockManager)
4649

4750
private val transportConf = SparkTransportConf.fromSparkConf(conf)
4851

@@ -69,14 +72,69 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB
6972
}
7073
}
7174

75+
/**
76+
* Check whether the given index and data files match each other.
77+
* If so, return the partition lengths in the data file. Otherwise return null.
78+
*/
79+
private def checkIndexAndDataFile(index: File, data: File, blocks: Int): Array[Long] = {
80+
// the index file should have `block + 1` longs as offset.
81+
if (index.length() != (blocks + 1) * 8) {
82+
return null
83+
}
84+
val lengths = new Array[Long](blocks)
85+
// Read the lengths of blocks
86+
val in = try {
87+
new DataInputStream(new BufferedInputStream(new FileInputStream(index)))
88+
} catch {
89+
case e: IOException =>
90+
return null
91+
}
92+
try {
93+
// Convert the offsets into lengths of each block
94+
var offset = in.readLong()
95+
if (offset != 0L) {
96+
return null
97+
}
98+
var i = 0
99+
while (i < blocks) {
100+
val off = in.readLong()
101+
lengths(i) = off - offset
102+
offset = off
103+
i += 1
104+
}
105+
} catch {
106+
case e: IOException =>
107+
return null
108+
} finally {
109+
in.close()
110+
}
111+
112+
// the size of data file should match with index file
113+
if (data.length() == lengths.sum) {
114+
lengths
115+
} else {
116+
null
117+
}
118+
}
119+
72120
/**
73121
* Write an index file with the offsets of each block, plus a final offset at the end for the
74122
* end of the output file. This will be used by getBlockData to figure out where each block
75123
* begins and ends.
124+
*
125+
* It will commit the data and index file as an atomic operation, use the existing ones, or
126+
* replace them with new ones.
127+
*
128+
* Note: the `lengths` will be updated to match the existing index file if use the existing ones.
76129
* */
77-
def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]): Unit = {
130+
def writeIndexFileAndCommit(
131+
shuffleId: Int,
132+
mapId: Int,
133+
lengths: Array[Long],
134+
dataTmp: File): Unit = {
78135
val indexFile = getIndexFile(shuffleId, mapId)
79-
val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile)))
136+
val indexTmp = Utils.tempFileWith(indexFile)
137+
val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp)))
80138
Utils.tryWithSafeFinally {
81139
// We take in lengths of each block, need to convert it to offsets.
82140
var offset = 0L
@@ -88,6 +146,37 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB
88146
} {
89147
out.close()
90148
}
149+
150+
val dataFile = getDataFile(shuffleId, mapId)
151+
// There is only one IndexShuffleBlockResolver per executor, this synchronization make sure
152+
// the following check and rename are atomic.
153+
synchronized {
154+
val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length)
155+
if (existingLengths != null) {
156+
// Another attempt for the same task has already written our map outputs successfully,
157+
// so just use the existing partition lengths and delete our temporary map outputs.
158+
System.arraycopy(existingLengths, 0, lengths, 0, lengths.length)
159+
if (dataTmp != null && dataTmp.exists()) {
160+
dataTmp.delete()
161+
}
162+
indexTmp.delete()
163+
} else {
164+
// This is the first successful attempt in writing the map outputs for this task,
165+
// so override any existing index and data files with the ones we wrote.
166+
if (indexFile.exists()) {
167+
indexFile.delete()
168+
}
169+
if (dataFile.exists()) {
170+
dataFile.delete()
171+
}
172+
if (!indexTmp.renameTo(indexFile)) {
173+
throw new IOException("fail to rename file " + indexTmp + " to " + indexFile)
174+
}
175+
if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) {
176+
throw new IOException("fail to rename file " + dataTmp + " to " + dataFile)
177+
}
178+
}
179+
}
91180
}
92181

93182
override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {

core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.shuffle.hash
1919

20+
import java.io.IOException
21+
2022
import org.apache.spark._
2123
import org.apache.spark.executor.ShuffleWriteMetrics
2224
import org.apache.spark.scheduler.MapStatus
@@ -106,6 +108,29 @@ private[spark] class HashShuffleWriter[K, V](
106108
writer.commitAndClose()
107109
writer.fileSegment().length
108110
}
111+
// rename all shuffle files to final paths
112+
// Note: there is only one ShuffleBlockResolver in executor
113+
shuffleBlockResolver.synchronized {
114+
shuffle.writers.zipWithIndex.foreach { case (writer, i) =>
115+
val output = blockManager.diskBlockManager.getFile(writer.blockId)
116+
if (sizes(i) > 0) {
117+
if (output.exists()) {
118+
// Use length of existing file and delete our own temporary one
119+
sizes(i) = output.length()
120+
writer.file.delete()
121+
} else {
122+
// Commit by renaming our temporary file to something the fetcher expects
123+
if (!writer.file.renameTo(output)) {
124+
throw new IOException(s"fail to rename ${writer.file} to $output")
125+
}
126+
}
127+
} else {
128+
if (output.exists()) {
129+
output.delete()
130+
}
131+
}
132+
}
133+
}
109134
MapStatus(blockManager.shuffleServerId, sizes)
110135
}
111136

core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ import org.apache.spark._
2121
import org.apache.spark.executor.ShuffleWriteMetrics
2222
import org.apache.spark.scheduler.MapStatus
2323
import org.apache.spark.serializer.Serializer
24-
import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle}
24+
import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter}
2525
import org.apache.spark.storage.ShuffleBlockId
26+
import org.apache.spark.util.Utils
2627
import org.apache.spark.util.collection.ExternalSorter
2728

2829
private[spark] class SortShuffleWriter[K, V, C](
@@ -75,11 +76,11 @@ private[spark] class SortShuffleWriter[K, V, C](
7576
// Don't bother including the time to open the merged output file in the shuffle write time,
7677
// because it just opens a single file, so is typically too fast to measure accurately
7778
// (see SPARK-3570).
78-
val outputFile = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
79+
val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
80+
val tmp = Utils.tempFileWith(output)
7981
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
80-
val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile)
81-
shuffleBlockResolver.writeIndexFile(dep.shuffleId, mapId, partitionLengths)
82-
82+
val partitionLengths = sorter.writePartitionedFile(blockId, context, tmp)
83+
shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
8384
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
8485
}
8586

core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.apache.spark.util.Utils
3535
*/
3636
private[spark] class DiskBlockObjectWriter(
3737
val blockId: BlockId,
38-
file: File,
38+
val file: File,
3939
serializerInstance: SerializerInstance,
4040
bufferSize: Int,
4141
compressStream: OutputStream => OutputStream,

core/src/main/scala/org/apache/spark/util/Utils.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,16 @@ import java.io._
2121
import java.lang.management.ManagementFactory
2222
import java.net._
2323
import java.nio.ByteBuffer
24-
import java.util.{Properties, Locale, Random, UUID}
2524
import java.util.concurrent._
25+
import java.util.{Locale, Properties, Random, UUID}
2626
import javax.net.ssl.HttpsURLConnection
2727

2828
import scala.collection.JavaConversions._
2929
import scala.collection.Map
3030
import scala.collection.mutable.ArrayBuffer
3131
import scala.io.Source
3232
import scala.reflect.ClassTag
33-
import scala.util.{Failure, Success, Try}
33+
import scala.util.Try
3434
import scala.util.control.{ControlThrowable, NonFatal}
3535

3636
import com.google.common.io.{ByteStreams, Files}
@@ -42,7 +42,6 @@ import org.apache.hadoop.security.UserGroupInformation
4242
import org.apache.log4j.PropertyConfigurator
4343
import org.eclipse.jetty.util.MultiException
4444
import org.json4s._
45-
4645
import tachyon.TachyonURI
4746
import tachyon.client.{TachyonFS, TachyonFile}
4847

@@ -2152,6 +2151,12 @@ private[spark] object Utils extends Logging {
21522151
conf.getInt("spark.executor.instances", 0) == 0
21532152
}
21542153

2154+
/**
2155+
* Returns a path of temporary file which is in the same directory with `path`.
2156+
*/
2157+
def tempFileWith(path: File): File = {
2158+
new File(path.getAbsolutePath + "." + UUID.randomUUID())
2159+
}
21552160
}
21562161

21572162
/**

core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,13 @@ public OutputStream answer(InvocationOnMock invocation) throws Throwable {
174174
@Override
175175
public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
176176
partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2];
177+
File tmp = (File) invocationOnMock.getArguments()[3];
178+
mergedOutputFile.delete();
179+
tmp.renameTo(mergedOutputFile);
177180
return null;
178181
}
179-
}).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class));
182+
}).when(shuffleBlockResolver)
183+
.writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), any(File.class));
180184

181185
when(diskBlockManager.createTempShuffleBlock()).thenAnswer(
182186
new Answer<Tuple2<TempShuffleBlockId, File>>() {

0 commit comments

Comments
 (0)