Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,11 @@
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.storage.*;
Expand Down Expand Up @@ -79,7 +78,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
private final int numPartitions;
private final BlockManager blockManager;
private final Partitioner partitioner;
private final ShuffleWriteMetrics writeMetrics;
private final ShuffleWriteMetricsReporter writeMetrics;
private final int shuffleId;
private final int mapId;
private final Serializer serializer;
Expand All @@ -103,8 +102,8 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
IndexShuffleBlockResolver shuffleBlockResolver,
BypassMergeSortShuffleHandle<K, V> handle,
int mapId,
TaskContext taskContext,
SparkConf conf) {
SparkConf conf,
ShuffleWriteMetricsReporter writeMetrics) {
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
Expand All @@ -114,7 +113,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
this.shuffleId = dep.shuffleId();
this.partitioner = dep.partitioner();
this.numPartitions = partitioner.numPartitions();
this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics();
this.writeMetrics = writeMetrics;
this.serializer = dep.serializer();
this.shuffleBlockResolver = shuffleBlockResolver;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.apache.spark.memory.TooLargePageException;
import org.apache.spark.serializer.DummySerializerInstance;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.DiskBlockObjectWriter;
import org.apache.spark.storage.FileSegment;
Expand Down Expand Up @@ -75,7 +76,7 @@ final class ShuffleExternalSorter extends MemoryConsumer {
private final TaskMemoryManager taskMemoryManager;
private final BlockManager blockManager;
private final TaskContext taskContext;
private final ShuffleWriteMetrics writeMetrics;
private final ShuffleWriteMetricsReporter writeMetrics;

/**
* Force this sorter to spill when there are this many elements in memory.
Expand Down Expand Up @@ -113,7 +114,7 @@ final class ShuffleExternalSorter extends MemoryConsumer {
int initialSize,
int numPartitions,
SparkConf conf,
ShuffleWriteMetrics writeMetrics) {
ShuffleWriteMetricsReporter writeMetrics) {
super(memoryManager,
(int) Math.min(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, memoryManager.pageSizeBytes()),
memoryManager.getTungstenMemoryMode());
Expand Down Expand Up @@ -144,7 +145,7 @@ final class ShuffleExternalSorter extends MemoryConsumer {
*/
private void writeSortedFile(boolean isLastFile) {

final ShuffleWriteMetrics writeMetricsToUse;
final ShuffleWriteMetricsReporter writeMetricsToUse;

if (isLastFile) {
// We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes.
Expand Down Expand Up @@ -241,9 +242,14 @@ private void writeSortedFile(boolean isLastFile) {
//
// Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`.
// Consistent with ExternalSorter, we do not count this IO towards shuffle write time.
// This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this.
writeMetrics.incRecordsWritten(writeMetricsToUse.recordsWritten());
taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.bytesWritten());
// SPARK-3577 tracks the spill time separately.

// This is guaranteed to be a ShuffleWriteMetrics based on the if check in the beginning
// of this method.
writeMetrics.incRecordsWritten(
((ShuffleWriteMetrics)writeMetricsToUse).recordsWritten());
taskContext.taskMetrics().incDiskBytesSpilled(
((ShuffleWriteMetrics)writeMetricsToUse).bytesWritten());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@

import org.apache.spark.*;
import org.apache.spark.annotation.Private;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.io.CompressionCodec;
import org.apache.spark.io.CompressionCodec$;
import org.apache.spark.io.NioBufferedFileInputStream;
Expand All @@ -47,6 +46,7 @@
import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.serializer.SerializationStream;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
Expand All @@ -73,7 +73,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
private final TaskMemoryManager memoryManager;
private final SerializerInstance serializer;
private final Partitioner partitioner;
private final ShuffleWriteMetrics writeMetrics;
private final ShuffleWriteMetricsReporter writeMetrics;
private final int shuffleId;
private final int mapId;
private final TaskContext taskContext;
Expand Down Expand Up @@ -122,7 +122,8 @@ public UnsafeShuffleWriter(
SerializedShuffleHandle<K, V> handle,
int mapId,
TaskContext taskContext,
SparkConf sparkConf) throws IOException {
SparkConf sparkConf,
ShuffleWriteMetricsReporter writeMetrics) throws IOException {
final int numPartitions = handle.dependency().partitioner().numPartitions();
if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) {
throw new IllegalArgumentException(
Expand All @@ -138,7 +139,7 @@ public UnsafeShuffleWriter(
this.shuffleId = dep.shuffleId();
this.serializer = dep.serializer().newInstance();
this.partitioner = dep.partitioner();
this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics();
this.writeMetrics = writeMetrics;
this.taskContext = taskContext;
this.sparkConf = sparkConf;
this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import java.io.OutputStream;

import org.apache.spark.annotation.Private;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;

/**
* Intercepts write calls and tracks total time spent writing in order to update shuffle write
Expand All @@ -30,10 +30,11 @@
@Private
public final class TimeTrackingOutputStream extends OutputStream {

private final ShuffleWriteMetrics writeMetrics;
private final ShuffleWriteMetricsReporter writeMetrics;
private final OutputStream outputStream;

public TimeTrackingOutputStream(ShuffleWriteMetrics writeMetrics, OutputStream outputStream) {
public TimeTrackingOutputStream(
ShuffleWriteMetricsReporter writeMetrics, OutputStream outputStream) {
this.writeMetrics = writeMetrics;
this.outputStream = outputStream;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.executor

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
import org.apache.spark.util.LongAccumulator


Expand All @@ -27,7 +28,7 @@ import org.apache.spark.util.LongAccumulator
* Operations are not thread-safe.
*/
@DeveloperApi
class ShuffleWriteMetrics private[spark] () extends Serializable {
class ShuffleWriteMetrics private[spark] () extends ShuffleWriteMetricsReporter with Serializable {
private[executor] val _bytesWritten = new LongAccumulator
private[executor] val _recordsWritten = new LongAccumulator
private[executor] val _writeTime = new LongAccumulator
Expand All @@ -47,13 +48,13 @@ class ShuffleWriteMetrics private[spark] () extends Serializable {
*/
def writeTime: Long = _writeTime.sum

private[spark] def incBytesWritten(v: Long): Unit = _bytesWritten.add(v)
private[spark] def incRecordsWritten(v: Long): Unit = _recordsWritten.add(v)
private[spark] def incWriteTime(v: Long): Unit = _writeTime.add(v)
private[spark] def decBytesWritten(v: Long): Unit = {
private[spark] override def incBytesWritten(v: Long): Unit = _bytesWritten.add(v)
private[spark] override def incRecordsWritten(v: Long): Unit = _recordsWritten.add(v)
private[spark] override def incWriteTime(v: Long): Unit = _writeTime.add(v)
private[spark] override def decBytesWritten(v: Long): Unit = {
_bytesWritten.setValue(bytesWritten - v)
}
private[spark] def decRecordsWritten(v: Long): Unit = {
private[spark] override def decRecordsWritten(v: Long): Unit = {
_recordsWritten.setValue(recordsWritten - v)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ private[spark] class ShuffleMapTask(
var writer: ShuffleWriter[Any, Any] = null
try {
val manager = SparkEnv.get.shuffleManager
writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
writer = manager.getWriter[Any, Any](
dep.shuffleHandle, partitionId, context, context.taskMetrics().shuffleWriteMetrics)
writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
writer.stop(success = true).get
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ private[spark] trait ShuffleManager {
dependency: ShuffleDependency[K, V, C]): ShuffleHandle

/** Get a writer for a given partition. Called on executors by map tasks. */
def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext): ShuffleWriter[K, V]
def getWriter[K, V](
handle: ShuffleHandle,
mapId: Int,
context: TaskContext,
metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V]

/**
* Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
override def getWriter[K, V](
handle: ShuffleHandle,
mapId: Int,
context: TaskContext): ShuffleWriter[K, V] = {
context: TaskContext,
metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
numMapsForShuffle.putIfAbsent(
handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps)
val env = SparkEnv.get
Expand All @@ -138,15 +139,16 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
unsafeShuffleHandle,
mapId,
context,
env.conf)
env.conf,
metrics)
case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
new BypassMergeSortShuffleWriter(
env.blockManager,
shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
bypassMergeSortHandle,
mapId,
context,
env.conf)
env.conf,
metrics)
case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@ import scala.util.Random
import scala.util.control.NonFatal

import com.codahale.metrics.{MetricRegistry, MetricSet}
import com.google.common.io.CountingOutputStream

import org.apache.spark._
import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics}
import org.apache.spark.executor.DataReadMethod
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.memory.{MemoryManager, MemoryMode}
import org.apache.spark.metrics.source.Source
Expand All @@ -50,7 +49,7 @@ import org.apache.spark.network.util.TransportConf
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.serializer.{SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.shuffle.{ShuffleManager, ShuffleWriteMetricsReporter}
import org.apache.spark.storage.memory._
import org.apache.spark.unsafe.Platform
import org.apache.spark.util._
Expand Down Expand Up @@ -932,7 +931,7 @@ private[spark] class BlockManager(
file: File,
serializerInstance: SerializerInstance,
bufferSize: Int,
writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = {
writeMetrics: ShuffleWriteMetricsReporter): DiskBlockObjectWriter = {
val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
new DiskBlockObjectWriter(file, serializerManager, serializerInstance, bufferSize,
syncWrites, writeMetrics, blockId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ package org.apache.spark.storage
import java.io.{BufferedOutputStream, File, FileOutputStream, OutputStream}
import java.nio.channels.FileChannel

import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
import org.apache.spark.util.Utils

/**
Expand All @@ -43,7 +43,7 @@ private[spark] class DiskBlockObjectWriter(
syncWrites: Boolean,
// These write metrics concurrently shared with other active DiskBlockObjectWriters who
// are themselves performing writes. All updates must be relative.
writeMetrics: ShuffleWriteMetrics,
writeMetrics: ShuffleWriteMetricsReporter,
val blockId: BlockId = null)
extends OutputStream
with Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -793,8 +793,8 @@ private[spark] class ExternalSorter[K, V, C](

def nextPartition(): Int = cur._1._1
}
logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " +
s" it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory")
logInfo(s"Task ${TaskContext.get().taskAttemptId} force spilling in-memory map to disk " +
s"and it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory")
val spillFile = spillMemoryIteratorToDisk(inMemoryIterator)
forceSpillFiles += spillFile
val spillReader = new SpillReader(spillFile)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ private UnsafeShuffleWriter<Object, Object> createWriter(
new SerializedShuffleHandle<>(0, 1, shuffleDep),
0, // map id
taskContext,
conf
conf,
taskContext.taskMetrics().shuffleWriteMetrics()
);
}

Expand Down Expand Up @@ -521,7 +522,8 @@ public void testPeakMemoryUsed() throws Exception {
new SerializedShuffleHandle<>(0, 1, shuffleDep),
0, // map id
taskContext,
conf);
conf,
taskContext.taskMetrics().shuffleWriteMetrics());

// Peak memory should be monotonically increasing. More specifically, every time
// we allocate a new page it should increase by exactly the size of the page.
Expand Down
12 changes: 8 additions & 4 deletions core/src/test/scala/org/apache/spark/ShuffleSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -362,15 +362,19 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
mapTrackerMaster.registerShuffle(0, 1)

// first attempt -- its successful
val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0,
new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem))
val context1 =
new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem)
val writer1 = manager.getWriter[Int, Int](
shuffleHandle, 0, context1, context1.taskMetrics.shuffleWriteMetrics)
val data1 = (1 to 10).map { x => x -> x}

// second attempt -- also successful. We'll write out different data,
// just to simulate the fact that the records may get written differently
// depending on what gets spilled, what gets combined, etc.
val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0,
new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem))
val context2 =
new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem)
val writer2 = manager.getWriter[Int, Int](
shuffleHandle, 0, context2, context2.taskMetrics.shuffleWriteMetrics)
val data2 = (11 to 20).map { x => x -> x}

// interleave writes of both attempts -- we want to test that both attempts can occur
Expand Down
Loading