diff --git a/core/src/main/java/org/apache/spark/api/shuffle/MapShuffleLocations.java b/core/src/main/java/org/apache/spark/api/shuffle/MapShuffleLocations.java index b0aed4d08d38..a7b2c3006bb7 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/MapShuffleLocations.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/MapShuffleLocations.java @@ -17,8 +17,10 @@ package org.apache.spark.api.shuffle; import org.apache.spark.annotation.Experimental; +import org.apache.spark.api.java.Optional; import java.io.Serializable; +import java.util.List; /** * Represents metadata about where shuffle blocks were written in a single map task. @@ -35,5 +37,30 @@ public interface MapShuffleLocations extends Serializable { /** * Get the location for a given shuffle block written by this map task. */ - ShuffleLocation getLocationForBlock(int reduceId); + List getLocationsForBlock(int reduceId); + + /** + * Mark a location for a block in this map output as unreachable, and thus partitions can no + * longer be fetched from that location. + *

+ * This is called by the scheduler when it detects that a block could not be fetched from the + * file server located at this host and port. + *

+ * This should return true if there exists a data loss from the removal of this shuffle + * location. Otherwise, if all partitions can still be fetched from alternative locations, + * this should return false. + */ + boolean invalidateShuffleLocation(String host, Optional port); + + /** + * Mark all locations within this MapShuffleLocations with this execId as unreachable. + *

+ * This is called by the scheduler when it detects that an executor cannot be reached to + * fetch file data. + *

+ * This should return true if there exists a data loss from the removal of shuffle locations + * with this execId. Otherwise, if all partitions can still be fetched form alternative locaitons, + * this should return false. + */ + boolean invalidateShuffleLocation(String executorId); } diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java index a312831cb628..da5c74a689d1 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleBlockInfo.java @@ -17,8 +17,6 @@ package org.apache.spark.api.shuffle; -import org.apache.spark.api.java.Optional; - import java.util.Objects; /** @@ -31,10 +29,10 @@ public class ShuffleBlockInfo { private final int mapId; private final int reduceId; private final long length; - private final Optional shuffleLocation; + private final ShuffleLocation[] shuffleLocation; public ShuffleBlockInfo(int shuffleId, int mapId, int reduceId, long length, - Optional shuffleLocation) { + ShuffleLocation[] shuffleLocation) { this.shuffleId = shuffleId; this.mapId = mapId; this.reduceId = reduceId; @@ -58,7 +56,7 @@ public long getLength() { return length; } - public Optional getShuffleLocation() { + public ShuffleLocation[] getShuffleLocation() { return shuffleLocation; } diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java index 6a0ec8d44fd4..4f2a45731bef 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleDriverComponents.java @@ -30,4 +30,12 @@ public interface ShuffleDriverComponents { void cleanupApplication() throws IOException; void removeShuffleData(int shuffleId, boolean blocking) throws IOException; + + /** + * Whether to unregister other map statuses on the same hosts or executors + * when a shuffle task returns a {@link org.apache.spark.FetchFailed}. + */ + default boolean unregisterOtherMapStatusesOnFetchFailure() { + return false; + } } diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java index d06c11b3c01e..552f2888297d 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java @@ -17,8 +17,33 @@ package org.apache.spark.api.shuffle; +import org.apache.spark.api.java.Optional; + /** * Marker interface representing a location of a shuffle block. Implementations of shuffle readers * and writers are expected to cast this down to an implementation-specific representation. */ -public interface ShuffleLocation {} +public abstract class ShuffleLocation { + /** + * The host and port on which the shuffle block is located. + */ + public abstract String host(); + public abstract int port(); + + /** + * The executor on which the ShuffleLocation is located. Returns {@link Optional#empty()} if + * location is not associated with an executor. + */ + public Optional execId() { + return Optional.empty(); + } + + @Override + public String toString() { + String shuffleLocation = String.format("ShuffleLocation %s:%d", host(), port()); + if (execId().isPresent()) { + return String.format("%s (execId: %s)", shuffleLocation, execId().get()); + } + return shuffleLocation; + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/DefaultMapShuffleLocations.java b/core/src/main/java/org/apache/spark/shuffle/sort/DefaultMapShuffleLocations.java index ffd97c0f2660..ef7e4dab9154 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/DefaultMapShuffleLocations.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/DefaultMapShuffleLocations.java @@ -17,17 +17,21 @@ package org.apache.spark.shuffle.sort; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; +import com.google.common.collect.Lists; +import org.apache.spark.api.java.Optional; import org.apache.spark.api.shuffle.MapShuffleLocations; import org.apache.spark.api.shuffle.ShuffleLocation; import org.apache.spark.storage.BlockManagerId; +import java.util.List; import java.util.Objects; -public class DefaultMapShuffleLocations implements MapShuffleLocations, ShuffleLocation { +public class DefaultMapShuffleLocations extends ShuffleLocation implements MapShuffleLocations { /** * We borrow the cache size from the BlockManagerId's cache - around 1MB, which should be @@ -45,9 +49,12 @@ public DefaultMapShuffleLocations load(BlockManagerId blockManagerId) { }); private final BlockManagerId location; + @JsonIgnore + private final List locationsArray; public DefaultMapShuffleLocations(BlockManagerId blockManagerId) { this.location = blockManagerId; + this.locationsArray = Lists.newArrayList(this); } public static DefaultMapShuffleLocations get(BlockManagerId blockManagerId) { @@ -55,8 +62,21 @@ public static DefaultMapShuffleLocations get(BlockManagerId blockManagerId) { } @Override - public ShuffleLocation getLocationForBlock(int reduceId) { - return this; + public List getLocationsForBlock(int reduceId) { + return locationsArray; + } + + @Override + public boolean invalidateShuffleLocation(String host, Optional port) { + if (port.isPresent()) { + return this.host().equals(host) && this.port() == port.get(); + } + return this.host().equals(host); + } + + @Override + public boolean invalidateShuffleLocation(String executorId) { + return location.executorId().equals(executorId); } public BlockManagerId getBlockManagerId() { @@ -73,4 +93,19 @@ public boolean equals(Object other) { public int hashCode() { return Objects.hashCode(location); } + + @Override + public String host() { + return location.host(); + } + + @Override + public int port() { + return location.port(); + } + + @Override + public Optional execId() { + return Optional.of(location.executorId()); + } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java index a3eddc8ec930..d8d229e4dfd7 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/lifecycle/DefaultShuffleDriverComponents.java @@ -46,6 +46,11 @@ public void removeShuffleData(int shuffleId, boolean blocking) throws IOExceptio blockManagerMaster.removeShuffle(shuffleId, blocking); } + @Override + public boolean unregisterOtherMapStatusesOnFetchFailure() { + return true; + } + private void checkInitialized() { if (blockManagerMaster == null) { throw new IllegalStateException("Driver components must be initialized before using"); diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index ebddf5ff6f6e..ecf9d07325a3 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -28,6 +28,7 @@ import scala.concurrent.duration.Duration import scala.reflect.ClassTag import scala.util.control.NonFatal +import org.apache.spark.api.java.Optional import org.apache.spark.api.shuffle.{MapShuffleLocations, ShuffleLocation} import org.apache.spark.broadcast.{Broadcast, BroadcastManager} import org.apache.spark.internal.Logging @@ -102,11 +103,23 @@ private class ShuffleStatus(numPartitions: Int) { * This is a no-op if there is no registered map output or if the registered output is from a * different block manager. */ - def removeMapOutput(mapId: Int, bmAddress: BlockManagerId): Unit = synchronized { - if (mapStatuses(mapId) != null && mapStatuses(mapId).location == bmAddress) { - _numAvailableOutputs -= 1 - mapStatuses(mapId) = null - invalidateSerializedMapOutputStatusCache() + def removeMapOutput(mapId: Int, shuffleLocations: Seq[ShuffleLocation]): Unit = synchronized { + if (mapStatuses(mapId) != null) { + var shouldDelete = false + if (shuffleLocations.isEmpty) { + shouldDelete = true + } else { + shuffleLocations.foreach { location => + shouldDelete = mapStatuses(mapId) + .mapShuffleLocations + .invalidateShuffleLocation(location.host(), Optional.of(location.port())) + } + } + if (shouldDelete) { + _numAvailableOutputs -= 1 + mapStatuses(mapId) = null + invalidateSerializedMapOutputStatusCache() + } } } @@ -115,7 +128,14 @@ private class ShuffleStatus(numPartitions: Int) { * outputs which are served by an external shuffle server (if one exists). */ def removeOutputsOnHost(host: String): Unit = { - removeOutputsByFilter(x => x.host == host) + for (mapId <- 0 until mapStatuses.length) { + if (mapStatuses(mapId) != null && + mapStatuses(mapId).mapShuffleLocations.invalidateShuffleLocation(host, Optional.empty())) { + _numAvailableOutputs -= 1 + mapStatuses(mapId) = null + invalidateSerializedMapOutputStatusCache() + } + } } /** @@ -124,7 +144,14 @@ private class ShuffleStatus(numPartitions: Int) { * still registered with that execId. */ def removeOutputsOnExecutor(execId: String): Unit = synchronized { - removeOutputsByFilter(x => x.executorId == execId) + for (mapId <- 0 until mapStatuses.length) { + if (mapStatuses(mapId) != null && + mapStatuses(mapId).mapShuffleLocations.invalidateShuffleLocation(execId)) { + _numAvailableOutputs -= 1 + mapStatuses(mapId) = null + invalidateSerializedMapOutputStatusCache() + } + } } /** @@ -283,7 +310,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging // For testing def getMapSizesByShuffleLocation(shuffleId: Int, reduceId: Int) - : Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = { + : Iterator[(Seq[ShuffleLocation], Seq[(BlockId, Long)])] = { getMapSizesByShuffleLocation(shuffleId, reduceId, reduceId + 1) } @@ -297,7 +324,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * describing the shuffle blocks that are stored at that block manager. */ def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] + : Iterator[(Seq[ShuffleLocation], Seq[(BlockId, Long)])] /** * Deletes map output status information for the specified shuffle stage. @@ -424,10 +451,10 @@ private[spark] class MapOutputTrackerMaster( } /** Unregister map output information of the given shuffle, mapper and block manager */ - def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { + def unregisterMapOutput(shuffleId: Int, mapId: Int, shuffleLocations: Seq[ShuffleLocation]) { shuffleStatuses.get(shuffleId) match { case Some(shuffleStatus) => - shuffleStatus.removeMapOutput(mapId, bmAddress) + shuffleStatus.removeMapOutput(mapId, shuffleLocations) incrementEpoch() case None => throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID") @@ -647,7 +674,7 @@ private[spark] class MapOutputTrackerMaster( // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. // This method is only called in local-mode. def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = { + : Iterator[(Seq[ShuffleLocation], Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") shuffleStatuses.get(shuffleId) match { case Some (shuffleStatus) => @@ -684,7 +711,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. override def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = { + : Iterator[(Seq[ShuffleLocation], Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") val statuses = getStatuses(shuffleId) try { @@ -873,9 +900,9 @@ private[spark] object MapOutputTracker extends Logging { shuffleId: Int, startPartition: Int, endPartition: Int, - statuses: Array[MapStatus]): Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = { + statuses: Array[MapStatus]): Iterator[(Seq[ShuffleLocation], Seq[(BlockId, Long)])] = { assert (statuses != null) - val splitsByAddress = new HashMap[Option[ShuffleLocation], ListBuffer[(BlockId, Long)]] + val splitsByAddress = new HashMap[Seq[ShuffleLocation], ListBuffer[(BlockId, Long)]] for ((status, mapId) <- statuses.iterator.zipWithIndex) { if (status == null) { val errorMessage = s"Missing an output location for shuffle $shuffleId" @@ -885,12 +912,13 @@ private[spark] object MapOutputTracker extends Logging { for (part <- startPartition until endPartition) { val size = status.getSizeForBlock(part) if (size != 0) { - if (status.mapShuffleLocations == null) { - splitsByAddress.getOrElseUpdate(Option.empty, ListBuffer()) += + if (status.mapShuffleLocations == null + || status.mapShuffleLocations.getLocationsForBlock(part).isEmpty) { + splitsByAddress.getOrElseUpdate(Seq.empty, ListBuffer()) += ((ShuffleBlockId(shuffleId, mapId, part), size)) } else { - val shuffleLoc = status.mapShuffleLocations.getLocationForBlock(part) - splitsByAddress.getOrElseUpdate(Option.apply(shuffleLoc), ListBuffer()) += + val shuffleLocations = status.mapShuffleLocations.getLocationsForBlock(part) + splitsByAddress.getOrElseUpdate(shuffleLocations.asScala, ListBuffer()) += ((ShuffleBlockId(shuffleId, mapId, part), size)) } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 999f180193d8..f9e15fb7f19b 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -308,6 +308,9 @@ class SparkContext(config: SparkConf) extends SafeLogging { _dagScheduler = ds } + private[spark] def shuffleDriverComponents: ShuffleDriverComponents = + _shuffleDriverComponents + /** * A unique identifier for the Spark application. * Its format depends on the scheduler implementation. diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 33901bc8380e..f84b36e80454 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -20,6 +20,7 @@ package org.apache.spark import java.io.{ObjectInputStream, ObjectOutputStream} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.api.shuffle.ShuffleLocation import org.apache.spark.internal.Logging import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.storage.BlockManagerId @@ -81,14 +82,14 @@ case object Resubmitted extends TaskFailedReason { */ @DeveloperApi case class FetchFailed( - bmAddress: BlockManagerId, // Note that bmAddress can be null + shuffleLocation: Seq[ShuffleLocation], // Note that shuffleLocation cannot be null shuffleId: Int, mapId: Int, reduceId: Int, message: String) extends TaskFailedReason { override def toErrorString: String = { - val bmAddressString = if (bmAddress == null) "null" else bmAddress.toString + val bmAddressString = if (shuffleLocation == null) "null" else shuffleLocation.toString s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapId=$mapId, reduceId=$reduceId, " + s"message=\n$message\n)" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index dd1b2595461f..a63a6ed8520b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -43,6 +43,7 @@ import org.apache.spark.network.util.JavaUtils import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.{DeterministicLevel, RDD, RDDCheckpointData} import org.apache.spark.rpc.RpcTimeout +import org.apache.spark.shuffle.sort.io.DefaultShuffleDataIO import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat import org.apache.spark.util._ @@ -228,6 +229,9 @@ private[spark] class DAGScheduler( private[spark] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) + private[spark] val unregisterOtherMapStatusesOnFetchFailure = sc.shuffleDriverComponents + .unregisterOtherMapStatusesOnFetchFailure() + /** * Called by the TaskSetManager to report task's starting. */ @@ -1478,7 +1482,7 @@ private[spark] class DAGScheduler( } } - case FetchFailed(bmAddress, shuffleId, mapId, _, failureMessage) => + case FetchFailed(shuffleLocations, shuffleId, mapId, _, failureMessage) => val failedStage = stageIdToStage(task.stageId) val mapStage = shuffleIdToMapStage(shuffleId) @@ -1511,7 +1515,7 @@ private[spark] class DAGScheduler( mapOutputTracker.unregisterAllMapOutput(shuffleId) } else if (mapId != -1) { // Mark the map whose fetch failed as broken in the map stage - mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) + mapOutputTracker.unregisterMapOutput(shuffleId, mapId, shuffleLocations) } if (failedStage.rdd.isBarrier()) { @@ -1626,22 +1630,39 @@ private[spark] class DAGScheduler( } // TODO: mark the executor as failed only if there were lots of fetch failures on it - if (bmAddress != null) { - val hostToUnregisterOutputs = if (env.blockManager.externalShuffleServiceEnabled && - unRegisterOutputOnHostOnFetchFailure) { - // We had a fetch failure with the external shuffle service, so we - // assume all shuffle data on the node is bad. - Some(bmAddress.host) - } else { - // Unregister shuffle data just for one executor (we don't have any - // reason to believe shuffle data has been lost for the entire host). - None - } - removeExecutorAndUnregisterOutputs( - execId = bmAddress.executorId, - fileLost = true, - hostToUnregisterOutputs = hostToUnregisterOutputs, - maybeEpoch = Some(task.epoch)) + if (unregisterOtherMapStatusesOnFetchFailure && shuffleLocations.nonEmpty) { + val toRemoveHost = + if (env.conf.get(config.SHUFFLE_IO_PLUGIN_CLASS) == + classOf[DefaultShuffleDataIO].getName) { + env.blockManager.externalShuffleServiceEnabled && + unRegisterOutputOnHostOnFetchFailure + } else { + true + } + + shuffleLocations.foreach(location => { + var epochAllowsRemoval = false + // If the location belonged to an executor, remove all outputs on the executor + val maybeExecId = location.execId() + val currentEpoch = Some(task.epoch).getOrElse(mapOutputTracker.getEpoch) + if (maybeExecId.isPresent) { + val execId = maybeExecId.get() + if (!failedEpoch.contains(execId) || failedEpoch(execId) < currentEpoch) { + failedEpoch(execId) = currentEpoch + epochAllowsRemoval = true + blockManagerMaster.removeExecutor(execId) + mapOutputTracker.removeOutputsOnExecutor(execId) + } + } else { + // If the location doesn't belong to an executor, the epoch doesn't matter + epochAllowsRemoval = true + } + + if (toRemoveHost && epochAllowsRemoval) { + mapOutputTracker.removeOutputsOnHost(location.host()) + } + }) + clearCacheLocs() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index ea31fe80ef56..3ae8064c1ed4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -30,6 +30,7 @@ import org.apache.spark.TaskState.TaskState import org.apache.spark.internal.{config, Logging} import org.apache.spark.internal.config._ import org.apache.spark.scheduler.SchedulingMode._ +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.util.{AccumulatorV2, Clock, LongAccumulator, SystemClock, Utils} import org.apache.spark.util.collection.MedianHeap @@ -848,9 +849,17 @@ private[spark] class TaskSetManager( } isZombie = true - if (fetchFailed.bmAddress != null) { - blacklistTracker.foreach(_.updateBlacklistForFetchFailure( - fetchFailed.bmAddress.host, fetchFailed.bmAddress.executorId)) + // Fetches from remote locations shouldn't affect executor scheduling since these remote + // locations shouldn't be running executors, so only fetches using the default Spark + // implementation (DefaultMapShuffleLocations) of fetching from executor disk should result + // in blacklistable executors. + if (fetchFailed.shuffleLocation.nonEmpty) { + fetchFailed.shuffleLocation.foreach(loc => { + if (loc.execId().isPresent) { + blacklistTracker.foreach(_.updateBlacklistForFetchFailure( + loc.host(), loc.execId().get())) + } + }) } None diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 530c3694ad1e..316e810c9bd6 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -69,7 +69,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( block.mapId, block.reduceId, blockInfo._2, - Optional.ofNullable(shuffleLocationInfo._1.orNull)) + shuffleLocationInfo._1.toArray) } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala index 265a8acfa8d6..100e011700b3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala @@ -18,7 +18,7 @@ package org.apache.spark.shuffle import org.apache.spark.{FetchFailed, TaskContext, TaskFailedReason} -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.api.shuffle.ShuffleLocation import org.apache.spark.util.Utils /** @@ -33,7 +33,7 @@ import org.apache.spark.util.Utils * (or risk triggering any other exceptions). See SPARK-19276. */ private[spark] class FetchFailedException( - bmAddress: BlockManagerId, + shuffleLocations: Seq[ShuffleLocation], shuffleId: Int, mapId: Int, reduceId: Int, @@ -42,12 +42,12 @@ private[spark] class FetchFailedException( extends Exception(message, cause) { def this( - bmAddress: BlockManagerId, + shuffleLocations: Seq[ShuffleLocation], shuffleId: Int, mapId: Int, reduceId: Int, cause: Throwable) { - this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause) + this(shuffleLocations, shuffleId, mapId, reduceId, cause.getMessage, cause) } // SPARK-19276. We set the fetch failure in the task context, so that even if there is user-code @@ -56,7 +56,11 @@ private[spark] class FetchFailedException( // because the TaskContext is not defined in some test cases. Option(TaskContext.get()).map(_.setFetchFailed(this)) - def toTaskFailedReason: TaskFailedReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId, + def toTaskFailedReason: TaskFailedReason = FetchFailed( + shuffleLocations, + shuffleId, + mapId, + reduceId, Utils.exceptionString(this)) } @@ -67,4 +71,4 @@ private[spark] class MetadataFetchFailedException( shuffleId: Int, reduceId: Int, message: String) - extends FetchFailedException(null, shuffleId, -1, reduceId, message) + extends FetchFailedException(Seq.empty, shuffleId, -1, reduceId, message) diff --git a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala index 9b9b8508e88a..7a79469497d8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/io/DefaultShuffleReadSupport.scala @@ -95,8 +95,9 @@ private class ShuffleBlockFetcherIterable( blockManager, mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, minReduceId, maxReduceId + 1) .map { shuffleLocationInfo => - val defaultShuffleLocation = shuffleLocationInfo._1 - .get.asInstanceOf[DefaultMapShuffleLocations] + // there should be only one copy of the shuffle data in the default implementation + val defaultShuffleLocation = shuffleLocationInfo._1(0) + .asInstanceOf[DefaultMapShuffleLocations] (defaultShuffleLocation.getBlockManagerId, shuffleLocationInfo._2) }, serializerManager.wrapStream, diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 287ffdd6e10e..5ef25e438158 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -22,6 +22,7 @@ import java.nio.ByteBuffer import java.util.concurrent.LinkedBlockingQueue import javax.annotation.concurrent.GuardedBy +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} @@ -31,6 +32,7 @@ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle._ import org.apache.spark.network.util.TransportConf import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} import org.apache.spark.util.io.ChunkedByteBufferOutputStream @@ -579,7 +581,9 @@ final class ShuffleBlockFetcherIterator( private def throwFetchFailedException(blockId: BlockId, address: BlockManagerId, e: Throwable) = { blockId match { case ShuffleBlockId(shufId, mapId, reduceId) => - throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e) + throw new FetchFailedException( + DefaultMapShuffleLocations.get(address).getLocationsForBlock(reduceId).asScala, + shufId.toInt, mapId.toInt, reduceId, e) case _ => throw new SparkException( "Failed to get block " + blockId + ", which is not a shuffle block", e) diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index d84dd5800ebb..268f7d284823 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -30,11 +30,13 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark._ +import org.apache.spark.api.shuffle.ShuffleLocation import org.apache.spark.executor._ import org.apache.spark.metrics.ExecutorMetricType import org.apache.spark.rdd.RDDOperationScope import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.storage._ /** @@ -407,9 +409,8 @@ private[spark] object JsonProtocol { val reason = Utils.getFormattedClassName(taskEndReason) val json: JObject = taskEndReason match { case fetchFailed: FetchFailed => - val blockManagerAddress = Option(fetchFailed.bmAddress). - map(blockManagerIdToJson).getOrElse(JNothing) - ("Block Manager Address" -> blockManagerAddress) ~ + val blockManagerAddress = shuffleLocationsToJson(fetchFailed.shuffleLocation) + ("Shuffle Locations" -> blockManagerAddress) ~ ("Shuffle ID" -> fetchFailed.shuffleId) ~ ("Map ID" -> fetchFailed.mapId) ~ ("Reduce ID" -> fetchFailed.reduceId) ~ @@ -439,6 +440,24 @@ private[spark] object JsonProtocol { ("Reason" -> reason) ~ json } + def shuffleLocationsToJson(shuffleLocations: Seq[ShuffleLocation]): JValue = { + if (shuffleLocations != null && shuffleLocations.nonEmpty) { + if (shuffleLocations.head.isInstanceOf[DefaultMapShuffleLocations]) { + val array = JArray(shuffleLocations.map(location => { + val blockManagerId = location.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId + blockManagerIdToJson(blockManagerId) + }).toList) + ("type" -> "Default") ~ + ("data" -> array) + } else { + ("type" -> "Custom") ~ + ("data" -> mapper.writeValueAsString(shuffleLocations)) + } + } else { + "type" -> "None" + } + } + def blockManagerIdToJson(blockManagerId: BlockManagerId): JValue = { ("Executor ID" -> blockManagerId.executorId) ~ ("Host" -> blockManagerId.host) ~ @@ -948,12 +967,13 @@ private[spark] object JsonProtocol { case `success` => Success case `resubmitted` => Resubmitted case `fetchFailed` => - val blockManagerAddress = blockManagerIdFromJson(json \ "Block Manager Address") + val locations = shuffleLocationsFromJson( + (json \ "Shuffle Locations")) val shuffleId = (json \ "Shuffle ID").extract[Int] val mapId = (json \ "Map ID").extract[Int] val reduceId = (json \ "Reduce ID").extract[Int] val message = jsonOption(json \ "Message").map(_.extract[String]) - new FetchFailed(blockManagerAddress, shuffleId, mapId, reduceId, + new FetchFailed(locations, shuffleId, mapId, reduceId, message.getOrElse("Unknown reason")) case `exceptionFailure` => val className = (json \ "Class Name").extract[String] @@ -996,6 +1016,26 @@ private[spark] object JsonProtocol { } } + def shuffleLocationsFromJson(json: JValue): Seq[ShuffleLocation] = { + val shuffleType = (json \ "type").extract[String] + if (shuffleType == "Default") { + (json \ "data").children.map(value => { + val block = blockManagerIdFromJson(value) + DefaultMapShuffleLocations.get(block) + }) + } else { + Seq.empty + } + } + + def shuffleLocationsFromString(string: String): Option[Array[ShuffleLocation]] = { + if (string == "None") { + return None + } + Some(mapper.readValue(string, classOf[Array[ShuffleLocation]])) + } + + def blockManagerIdFromJson(json: JValue): BlockManagerId = { // On metadata fetch fail, block manager ID can be null (SPARK-4471) if (json == JNothing) { diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala index 28cbeeda7a88..879bbdf893d2 100644 --- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala @@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} @@ -138,11 +139,12 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { val isFirstStageAttempt = taskContext.taskAttemptId() < numPartitions * 2L if (isFirstStageAttempt) { throw new FetchFailedException( - SparkEnv.get.blockManager.blockManagerId, - sid, - taskContext.partitionId(), - taskContext.partitionId(), - "simulated fetch failure") + shuffleLocations = + Array(DefaultMapShuffleLocations.get(SparkEnv.get.blockManager.blockManagerId)), + shuffleId = sid, + mapId = taskContext.partitionId(), + reduceId = taskContext.partitionId(), + message = "simulated fetch failure") } else { iter } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 8fcbc845d1a7..5191dacf304a 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -69,13 +69,12 @@ class MapOutputTrackerSuite extends SparkFunSuite { tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(10000L, 1000L))) val statuses = tracker.getMapSizesByShuffleLocation(10, 0) - assert(statuses.toSet === - Seq( - (Some(DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000))), - ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), - (Some(DefaultMapShuffleLocations.get(BlockManagerId("b", "hostB", 1000))), - ArrayBuffer((ShuffleBlockId(10, 1, 0), size10000)))) - .toSet) + assert(statuses.toSet === Seq( + (Seq(DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000))), + ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), + (Seq(DefaultMapShuffleLocations.get(BlockManagerId("b", "hostB", 1000))), + ArrayBuffer((ShuffleBlockId(10, 1, 0), size10000)))) + .toSet) assert(0 == tracker.getNumCachedSerializedBroadcast) tracker.stop() rpcEnv.shutdown() @@ -119,8 +118,10 @@ class MapOutputTrackerSuite extends SparkFunSuite { assert(0 == tracker.getNumCachedSerializedBroadcast) // As if we had two simultaneous fetch failures - tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) - tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) + tracker.unregisterMapOutput(10, 0, + Array(DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000)))) + tracker.unregisterMapOutput(10, 0, + Array(DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000)))) // The remaining reduce task might try to grab the output despite the shuffle failure; // this should cause it to fail, and the scheduler will ignore the failure due to the @@ -155,12 +156,13 @@ class MapOutputTrackerSuite extends SparkFunSuite { slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getMapSizesByShuffleLocation(10, 0).toSeq === Seq( - (Some(DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000))), + (Seq(DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000))), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) assert(0 == masterTracker.getNumCachedSerializedBroadcast) val masterTrackerEpochBeforeLossOfMapOutput = masterTracker.getEpoch - masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) + masterTracker.unregisterMapOutput(10, 0, + Array(DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000)))) assert(masterTracker.getEpoch > masterTrackerEpochBeforeLossOfMapOutput) slaveTracker.updateEpoch(masterTracker.getEpoch) intercept[FetchFailedException] { slaveTracker.getMapSizesByShuffleLocation(10, 0) } @@ -325,7 +327,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { Array(size10000, size0, size1000, size0))) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByShuffleLocation(10, 0, 4) - .map(x => (x._1.get, x._2)).toSeq === + .map(x => (x._1(0), x._2)).toSeq === Seq( (DefaultMapShuffleLocations.get(BlockManagerId("b", "hostB", 1000)), Seq((ShuffleBlockId(10, 1, 0), size10000), (ShuffleBlockId(10, 1, 2), size1000))), diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 7a16f7b715e6..98371930710c 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -37,6 +37,7 @@ import org.apache.spark.internal.config._ import org.apache.spark.internal.config.UI._ import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorMetricsUpdate, SparkListenerJobStart, SparkListenerTaskEnd, SparkListenerTaskStart} import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.util.{ThreadUtils, Utils} @@ -689,8 +690,10 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu if (context.stageAttemptNumber == 0) { if (context.partitionId == 0) { // Make the first task in the first stage attempt fail. - throw new FetchFailedException(SparkEnv.get.blockManager.blockManagerId, 0, 0, 0, - new java.io.IOException("fake")) + throw new FetchFailedException( + shuffleLocations = + Seq(DefaultMapShuffleLocations.get(SparkEnv.get.blockManager.blockManagerId)), + 0, 0, 0, new java.io.IOException("fake")) } else { // Make the second task in the first stage attempt sleep to generate a zombie task Thread.sleep(60000) diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 558cd3626ab9..511966a8bded 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -49,6 +49,7 @@ import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcTimeout} import org.apache.spark.scheduler.{FakeTask, ResultTask, Task, TaskDescription} import org.apache.spark.serializer.{JavaSerializer, SerializerManager} import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.storage.{BlockManager, BlockManagerId} import org.apache.spark.util.{LongAccumulator, UninterruptibleThread} @@ -438,7 +439,7 @@ class FetchFailureThrowingRDD(sc: SparkContext) extends RDD[Int](sc, Nil) { override def hasNext: Boolean = true override def next(): Int = { throw new FetchFailedException( - bmAddress = BlockManagerId("1", "hostA", 1234), + Seq(DefaultMapShuffleLocations.get(BlockManagerId("1", "hostA", 1234))), shuffleId = 0, mapId = 0, reduceId = 0, diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerAsyncSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerAsyncSuite.scala new file mode 100644 index 000000000000..7a2a4a5205cf --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerAsyncSuite.scala @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.util + +import scala.collection.JavaConverters._ +import scala.collection.mutable.Buffer +import scala.collection.mutable.Map + +import org.apache.spark.{FetchFailed, HashPartitioner, ShuffleDependency, SparkConf, Success} +import org.apache.spark.api.java.Optional +import org.apache.spark.api.shuffle.{MapShuffleLocations, ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents, ShuffleLocation} +import org.apache.spark.internal.config +import org.apache.spark.rdd.RDD +import org.apache.spark.shuffle.sort.io.DefaultShuffleDataIO +import org.apache.spark.storage.BlockManagerId + +class AsyncShuffleDataIO(sparkConf: SparkConf) extends ShuffleDataIO { + val defaultShuffleDataIO = new DefaultShuffleDataIO(sparkConf) + + override def driver(): ShuffleDriverComponents = + new AsyncShuffleDriverComponents(defaultShuffleDataIO.driver()) + + override def executor(): ShuffleExecutorComponents = defaultShuffleDataIO.executor() +} + +class AsyncShuffleDriverComponents(default: ShuffleDriverComponents) + extends ShuffleDriverComponents { + override def initializeApplication(): util.Map[String, String] = default.initializeApplication() + + override def cleanupApplication(): Unit = default.cleanupApplication() + + override def removeShuffleData(shuffleId: Int, blocking: Boolean): Unit = + default.removeShuffleData(shuffleId, blocking) + + override def unregisterOtherMapStatusesOnFetchFailure(): Boolean = false +} + +class DAGSchedulerAsyncSuite extends DAGSchedulerSuite { + + class AsyncShuffleLocation(hostname: String, portInt: Int, exec: String) extends ShuffleLocation { + override def host(): String = hostname + override def port(): Int = portInt + override def execId(): Optional[String] = Optional.of(exec) + } + + class AsyncMapShuffleLocations(asyncLocation: AsyncShuffleLocation) + extends MapShuffleLocations { + var locations : Buffer[ShuffleLocation] = Buffer(asyncLocation) + + override def getLocationsForBlock(reduceId: Int): util.List[ShuffleLocation] = + locations.asJava + + override def invalidateShuffleLocation(host: String, port: Optional[Integer]): Boolean = { + removeIfPredicate(loc => + loc.host() != host || (port.isPresent && loc.port() != port.get())) + } + + override def invalidateShuffleLocation(executorId: String): Boolean = { + removeIfPredicate(loc => !loc.execId().isPresent || !loc.execId().get().equals(executorId)) + } + + def removeIfPredicate(predicate: ShuffleLocation => Boolean): Boolean = { + var missingPartition = false + if (locations.isEmpty) { + return missingPartition + } + locations = locations.filter(predicate) + if (locations.isEmpty) { + missingPartition = true + } + missingPartition + } + } + + private def setupTest(): (RDD[_], Int) = { + afterEach() + val conf = new SparkConf() + // unregistering all outputs on a host is disabled for async case + conf.set(config.SHUFFLE_IO_PLUGIN_CLASS, classOf[AsyncShuffleDataIO].getName) + init(conf) + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + (reduceRdd, shuffleId) + } + + test("Test async simple shuffle success") { + val (reduceRdd, shuffleId) = setupTest() + submit(reduceRdd, Array(0, 1)) + + // Perform map task + val mapStatus1 = makeAsyncMapStatus("hostA") + val mapStatus2 = makeAsyncMapStatus("hostB") + complete(taskSets(0), Seq((Success, mapStatus1), (Success, mapStatus2))) + assertMapOutputTrackerContains(shuffleId, Seq( + mapStatus1.mapShuffleLocations, mapStatus2.mapShuffleLocations)) + + // perform reduce task + complete(taskSets(1), Seq((Success, 42), (Success, 43))) + assert(results === Map(0 -> 42, 1 -> 43)) + assertDataStructuresEmpty() + } + + test("test async fetch failed - different hosts") { + val (reduceRdd, shuffleId) = setupTest() + submit(reduceRdd, Array(0, 1)) + + // Perform map task + val mapStatus1 = makeAsyncMapStatus("hostA") + val mapStatus2 = makeAsyncMapStatus("hostB") + complete(taskSets(0), Seq((Success, mapStatus1), (Success, mapStatus2))) + + // The 2nd ResultTask reduce task failed. This will remove that shuffle location, + // but the other shuffle block is still available + complete(taskSets(1), Seq( + (Success, 42), + (FetchFailed(Seq(makeAsyncShuffleLocation("hostA")), + shuffleId, 0, 0, "ignored"), null))) + assert(scheduler.failedStages.size > 0) + assert(mapOutputTracker.getNumAvailableOutputs(shuffleId) == 1) + assertMapOutputTrackerContains(shuffleId, Seq(null, mapStatus2.mapShuffleLocations)) + + // submit the mapper once more + scheduler.resubmitFailedStages() + complete(taskSets(2), Seq((Success, mapStatus1))) + assertMapOutputTrackerContains(shuffleId, + Seq(mapStatus1.mapShuffleLocations, mapStatus2.mapShuffleLocations)) + + // submit last reduce task + complete(taskSets(3), Seq((Success, 43))) + assert(results === Map(0 -> 42, 1 -> 43)) + assertDataStructuresEmpty() + } + + test("test async fetch failed - same host, same exec") { + val (reduceRdd, shuffleId) = setupTest() + submit(reduceRdd, Array(0, 1)) + + val mapStatus = makeAsyncMapStatus("hostA") + complete(taskSets(0), Seq((Success, mapStatus), (Success, mapStatus))) + + // the 2nd ResultTask failed. This removes the first executor, but the + // other task is still intact because it was uploaded to the remove dfs + complete(taskSets(1), Seq( + (Success, 42), + (FetchFailed(Seq(makeAsyncShuffleLocation("hostA")), + shuffleId, 0, 0, "ignored"), null))) + assert(scheduler.failedStages.size > 0) + assert(mapOutputTracker.getNumAvailableOutputs(shuffleId) == 1) + assertMapOutputTrackerContains(shuffleId, Seq(null, new AsyncMapShuffleLocations(null))) + + // submit both mappers once more + scheduler.resubmitFailedStages() + complete(taskSets(2), Seq((Success, mapStatus))) + assert(mapOutputTracker.getNumAvailableOutputs(shuffleId) == 2) + assertMapOutputTrackerContains(shuffleId, Seq( + mapStatus.mapShuffleLocations, mapStatus.mapShuffleLocations)) + + // submit last reduce task + complete(taskSets(3), Seq((Success, 43))) + assert(results === Map(0 -> 42, 1 -> 43)) + assertDataStructuresEmpty() + } + + test("test async fetch failed - same host but different execId") { + val (reduceRdd, shuffleId) = setupTest() + submit(reduceRdd, Array(0, 1)) + + val mapStatus1 = makeAsyncMapStatus("hostA") + val mapStatus2 = MapStatus( + BlockManagerId("other-exec", "hostA", 1234), + new AsyncMapShuffleLocations(new AsyncShuffleLocation("hostA", 1234, "other-exec")), + Array.fill[Long](2)(2) + ) + complete(taskSets(0), Seq((Success, mapStatus1), (Success, mapStatus2))) + + // the 2nd ResultTask failed. This only removes the first shuffle location because + // the second location was written by a different executor + complete(taskSets(1), Seq( + (Success, 42), + (FetchFailed(Seq(makeAsyncShuffleLocation("hostA")), + shuffleId, 0, 0, "ignored"), null))) + assert(scheduler.failedStages.size > 0) + assert(mapOutputTracker.getNumAvailableOutputs(shuffleId) == 1) + assertMapOutputTrackerContains(shuffleId, Seq(null, mapStatus2.mapShuffleLocations)) + + // submit the one mapper again + scheduler.resubmitFailedStages() + complete(taskSets(2), Seq((Success, mapStatus1))) + assert(mapOutputTracker.getNumAvailableOutputs(shuffleId) == 2) + assertMapOutputTrackerContains(shuffleId, + Seq(mapStatus1.mapShuffleLocations, mapStatus2.mapShuffleLocations)) + + // submit last reduce task + complete(taskSets(3), Seq((Success, 43))) + assert(results === Map(0 -> 42, 1 -> 43)) + assertDataStructuresEmpty() + } + + def assertMapOutputTrackerContains( + shuffleId: Int, + set: Seq[MapShuffleLocations]): Unit = { + val actualShuffleLocations = mapOutputTracker.shuffleStatuses(shuffleId).mapStatuses + .map(mapStatus => { + if (mapStatus == null) { + return null + } + mapStatus.mapShuffleLocations + }) + assert(set === actualShuffleLocations.toSeq) + } + + def makeAsyncMapStatus( + host: String, + reduces: Int = 2, + execId: Optional[String] = Optional.empty(), + sizes: Byte = 2): MapStatus = { + MapStatus(makeBlockManagerId(host), + new AsyncMapShuffleLocations(makeAsyncShuffleLocation(host)), + Array.fill[Long](reduces)(sizes)) + } + + def makeAsyncShuffleLocation(host: String): AsyncShuffleLocation = { + new AsyncShuffleLocation(host, 12345, "exec-" + host) + } + + def makeBlockManagerId(host: String): BlockManagerId = + BlockManagerId("exec-" + host, host, 12345) + +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerFileServerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerFileServerSuite.scala new file mode 100644 index 000000000000..e58c939db873 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerFileServerSuite.scala @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.util + +import scala.collection.JavaConverters._ +import scala.collection.mutable.Buffer +import scala.collection.mutable.Map +import scala.collection.mutable.Seq + +import org.apache.spark.{FetchFailed, HashPartitioner, ShuffleDependency, SparkConf, Success} +import org.apache.spark.api.java.Optional +import org.apache.spark.api.shuffle.{MapShuffleLocations, ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents, ShuffleLocation} +import org.apache.spark.internal.config +import org.apache.spark.rdd.RDD +import org.apache.spark.shuffle.sort.io.DefaultShuffleDataIO +import org.apache.spark.storage.BlockManagerId + +class FileServerShuffleDriverComponents(driverComponents: ShuffleDriverComponents) + extends ShuffleDriverComponents { + + override def initializeApplication(): util.Map[String, String] = + driverComponents.initializeApplication() + + override def cleanupApplication(): Unit = driverComponents.cleanupApplication() + + override def removeShuffleData(shuffleId: Int, blocking: Boolean): Unit = + driverComponents.removeShuffleData(shuffleId, blocking) + + override def unregisterOtherMapStatusesOnFetchFailure(): Boolean = true +} + +class FileServerShuffleDataIO(sparkConf: SparkConf) extends ShuffleDataIO { + val defaultShuffleDataIO = new DefaultShuffleDataIO(sparkConf) + override def driver(): ShuffleDriverComponents = + new FileServerShuffleDriverComponents(defaultShuffleDataIO.driver()) + + override def executor(): ShuffleExecutorComponents = defaultShuffleDataIO.executor() +} + +object FileServerShuffleDataIO { + def apply(sparkConf: SparkConf): Unit = { + new FileServerShuffleDataIO(sparkConf) + } +} + +class DAGSchedulerFileServerSuite extends DAGSchedulerSuite { + + class FileServerShuffleLocation(hostname: String, portInt: Int) extends ShuffleLocation { + override def host(): String = hostname + override def port(): Int = portInt + } + + class FileServerMapShuffleLocations(mapShuffleLocationsInput: Buffer[Buffer[ShuffleLocation]]) + extends MapShuffleLocations { + val mapShuffleLocations = mapShuffleLocationsInput + override def getLocationsForBlock(reduceId: Int): util.List[ShuffleLocation] = + mapShuffleLocations(reduceId).asJava + + override def invalidateShuffleLocation(host: String, port: Optional[Integer]): Boolean = { + var missingPartition = false + for ((locations, i) <- mapShuffleLocations.zipWithIndex) { + mapShuffleLocations(i) = locations.filter(loc => + loc.host() != host || (port.isPresent && loc.port() != port.get())) + if (mapShuffleLocations(i).isEmpty) { + missingPartition = true + } + } + missingPartition + } + + override def invalidateShuffleLocation(executorId: String): Boolean = false + } + + private def setupTest(): (RDD[_], Int) = { + afterEach() + val conf = new SparkConf() + // unregistering all outputs on a host is enabled for the individual file server case + conf.set(config.SHUFFLE_IO_PLUGIN_CLASS, classOf[FileServerShuffleDataIO].getName) + init(conf) + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + (reduceRdd, shuffleId) + } + + test("Test simple shuffle success") { + val (reduceRdd, shuffleId) = setupTest() + submit(reduceRdd, Array(0, 1)) + + // Perform map task + val mapStatus1 = makeFileServerMapStatus("hostA", "hostB", "hostC", "hostD") + val mapStatus2 = makeFileServerMapStatus("hostA", "hostB", "hostC", "hostE") + complete(taskSets(0), Seq((Success, mapStatus1), (Success, mapStatus2))) + assertMapOutputTrackerContains(shuffleId, + Seq(mapStatus1.mapShuffleLocations, mapStatus2.mapShuffleLocations)) + + // perform reduce task + complete(taskSets(1), Seq((Success, 42), (Success, 43))) + assert(results === Map(0 -> 42, 1 -> 43)) + assertDataStructuresEmpty() + } + + test("Test one mapper failed but another has blocks elsewhere") { + val (reduceRdd, shuffleId) = setupTest() + submit(reduceRdd, Array(0, 1)) + + // Perform map task + val mapStatus1 = makeFileServerMapStatus("hostA", "hostB", "hostC", "hostD") + val mapStatus2 = makeFileServerMapStatus("hostA", "hostE", "hostB", "hostE") + complete(taskSets(0), Seq((Success, mapStatus1), (Success, mapStatus2))) + assertMapOutputTrackerContains(shuffleId, + Seq(mapStatus1.mapShuffleLocations, mapStatus2.mapShuffleLocations)) + + // perform reduce task + complete(taskSets(1), Seq((Success, 42), (FetchFailed( + Seq( + shuffleLocation("hostA"), + shuffleLocation("hostB")), + shuffleId, 0, 0, "ignored"), null))) + assert(scheduler.failedStages.size > 0) + assert(mapOutputTracker.getNumAvailableOutputs(shuffleId) == 1) + assertMapOutputTrackerContains(shuffleId, Seq(null, + new FileServerMapShuffleLocations(Buffer( + shuffleLocationSeq("hostE"), + shuffleLocationSeq("hostE"))))) + + scheduler.resubmitFailedStages() + complete(taskSets(2), Seq((Success, mapStatus1))) + assert(mapOutputTracker.getNumAvailableOutputs(shuffleId) == 2) + assertMapOutputTrackerContains(shuffleId, Seq( + mapStatus1.mapShuffleLocations, mapStatus2.mapShuffleLocations)) + + complete(taskSets(3), Seq((Success, 43))) + assert(results === Map(0 -> 42, 1 -> 43)) + assertDataStructuresEmpty() + } + + test("Test one failed but other only has one partition replicated") { + val (reduceRdd, shuffleId) = setupTest() + submit(reduceRdd, Array(0, 1)) + + // Perform map task + val mapStatus1 = makeFileServerMapStatus("hostA", "hostB", "hostC", "hostD") + val mapStatus2 = makeFileServerMapStatus("hostA", "hostB", "hostC", "hostE") + complete(taskSets(0), Seq((Success, mapStatus1), (Success, mapStatus2))) + assertMapOutputTrackerContains(shuffleId, + Seq(mapStatus1.mapShuffleLocations, mapStatus2.mapShuffleLocations)) + + // perform reduce task + complete(taskSets(1), Seq((Success, 42), (FetchFailed( + Seq( + shuffleLocation("hostA"), + shuffleLocation("hostB"), + shuffleLocation("hostC"), + shuffleLocation("hostD")), + shuffleId, 0, 0, "ignored"), null))) + assert(scheduler.failedStages.size > 0) + assert(mapOutputTracker.getNumAvailableOutputs(shuffleId) == 0) + assertMapOutputTrackerContains(shuffleId, Seq(null, null)) + + scheduler.resubmitFailedStages() + complete(taskSets(2), Seq((Success, mapStatus1), (Success, mapStatus2))) + assert(mapOutputTracker.getNumAvailableOutputs(shuffleId) == 2) + assertMapOutputTrackerContains(shuffleId, Seq( + mapStatus1.mapShuffleLocations, mapStatus2.mapShuffleLocations)) + + complete(taskSets(3), Seq((Success, 43))) + assert(results === Map(0 -> 42, 1 -> 43)) + assertDataStructuresEmpty() + } + + + def assertMapOutputTrackerContains( + shuffleId: Int, + set: Seq[MapShuffleLocations]): Unit = { + val actualShuffleLocations = mapOutputTracker.shuffleStatuses(shuffleId).mapStatuses + .map(mapStatus => { + if (mapStatus == null) { + return null + } + mapStatus.mapShuffleLocations + }) + assert(set === actualShuffleLocations.toSeq) + } + + def makeFileServerMapStatus( + partition1Primary: String, + partition1Secondary: String, + partition2Primary: String, + partition2Secondary: String): MapStatus = { + val partition1List: Buffer[ShuffleLocation] = Buffer( + new FileServerShuffleLocation(partition1Primary, 1234), + new FileServerShuffleLocation(partition1Secondary, 1234)) + val partition2List: Buffer[ShuffleLocation] = Buffer( + new FileServerShuffleLocation(partition2Primary, 1234), + new FileServerShuffleLocation(partition2Secondary, 1234)) + makeFileServerMapStatus(partition1List, partition2List) + } + + def makeFileServerMapStatus(partition1Loc: Buffer[ShuffleLocation], + partition2Loc: Buffer[ShuffleLocation]): MapStatus = { + MapStatus( + makeBlockManagerId("executor-host"), + new FileServerMapShuffleLocations(Buffer(partition1Loc, partition2Loc)), + Array.fill[Long](2)(2) + ) + } + + def shuffleLocationSeq(hosts: String*): Buffer[ShuffleLocation] = { + hosts.map(host => + shuffleLocation(host) + ).toBuffer + } + + def shuffleLocation(host: String): ShuffleLocation = { + new FileServerShuffleLocation(host, 1234) + } + + def makeBlockManagerId(host: String): BlockManagerId = + BlockManagerId("exec-" + host, host, 12345) +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 21b4e56c9e80..ad1bd8d1b1bd 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -29,7 +29,7 @@ import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.time.SpanSugar._ import org.apache.spark._ -import org.apache.spark.api.shuffle.MapShuffleLocations +import org.apache.spark.api.shuffle.{MapShuffleLocations, ShuffleLocation} import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.executor.ExecutorMetrics import org.apache.spark.internal.config @@ -238,7 +238,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi init(new SparkConf()) } - private def init(testConf: SparkConf): Unit = { + def init(testConf: SparkConf): Unit = { sc = new SparkContext("local[2]", "DAGSchedulerSuite", testConf) sparkListener.submittedStageInfos.clear() sparkListener.successfulStages.clear() @@ -308,7 +308,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi it.next.asInstanceOf[Tuple2[_, _]]._1 /** Send the given CompletionEvent messages for the tasks in the TaskSet. */ - private def complete(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) { + def complete(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) { assert(taskSet.tasks.size >= results.size) for ((result, i) <- results.zipWithIndex) { if (i < taskSet.tasks.size) { @@ -334,7 +334,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } /** Submits a job to the scheduler and returns the job id. */ - private def submit( + def submit( rdd: RDD[_], partitions: Array[Int], func: (TaskContext, Iterator[_]) => _ = jobComputeFunc, @@ -430,8 +430,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // reset the test context with the right shuffle service config afterEach() val conf = new SparkConf() - conf.set(config.SHUFFLE_SERVICE_ENABLED.key, "true") conf.set("spark.files.fetchFailure.unRegisterOutputOnHost", "true") + conf.set(config.SHUFFLE_SERVICE_ENABLED.key, "true") init(conf) runEvent(ExecutorAdded("exec-hostA1", "hostA")) runEvent(ExecutorAdded("exec-hostA2", "hostA")) @@ -475,8 +475,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // reduce stage fails with a fetch failure from one host complete(taskSets(2), Seq( - (FetchFailed(BlockManagerId("exec-hostA2", "hostA", 12345), firstShuffleId, 0, 0, "ignored"), - null) + (FetchFailed(makeShuffleLocationSeq("exec-hostA2", "hostA", 12345), + firstShuffleId, 0, 0, "ignored"), null) )) // Here is the main assertion -- make sure that we de-register @@ -703,7 +703,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === - HashSet(makeMaybeShuffleLocation("hostA"), makeMaybeShuffleLocation("hostB"))) + HashSet(makeShuffleLocationSeq("hostA"), makeShuffleLocationSeq("hostB"))) complete(taskSets(1), Seq((Success, 42))) assert(results === Map(0 -> 42)) assertDataStructuresEmpty() @@ -721,7 +721,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // the 2nd ResultTask failed complete(taskSets(1), Seq( (Success, 42), - (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null))) + (FetchFailed(makeShuffleLocationSeq("hostA"), shuffleId, 0, 0, "ignored"), null))) // this will get called // blockManagerMaster.removeExecutor("exec-hostA") // ask the scheduler to try it again @@ -731,7 +731,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // we can see both result blocks now assert(mapOutputTracker .getMapSizesByShuffleLocation(shuffleId, 0) - .map(_._1.get.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) + .map(_._1(0).host) .toSet === HashSet("hostA", "hostB")) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) @@ -774,7 +774,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } } else { assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === - HashSet(makeMaybeShuffleLocation("hostA"), makeMaybeShuffleLocation("hostB"))) + HashSet(makeShuffleLocationSeq("hostA"), makeShuffleLocationSeq("hostB"))) } } } @@ -840,7 +840,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi val stageAttempt = taskSets.last checkStageId(stageId, attemptIdx, stageAttempt) complete(stageAttempt, stageAttempt.tasks.zipWithIndex.map { case (task, idx) => - (FetchFailed(makeBlockManagerId("hostA"), shuffleDep.shuffleId, 0, idx, "ignored"), null) + (FetchFailed(makeShuffleLocationSeq("hostA"), + shuffleDep.shuffleId, 0, idx, "ignored"), null) }.toSeq) } @@ -1069,13 +1070,13 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // The MapOutputTracker should know about both map output locations. assert(mapOutputTracker .getMapSizesByShuffleLocation(shuffleId, 0) - .map(_._1.get.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) + .map(_._1(0).asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) .toSet === HashSet("hostA", "hostB")) // The first result task fails, with a fetch failure for the output from the first mapper. runEvent(makeCompletionEvent( taskSets(1).tasks(0), - FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + FetchFailed(makeShuffleLocationSeq("hostA"), shuffleId, 0, 0, "ignored"), null)) sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(sparkListener.failedStages.contains(1)) @@ -1083,7 +1084,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // The second ResultTask fails, with a fetch failure for the output from the second mapper. runEvent(makeCompletionEvent( taskSets(1).tasks(0), - FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1, 1, "ignored"), + FetchFailed(makeShuffleLocationSeq("hostA"), shuffleId, 1, 1, "ignored"), null)) // The SparkListener should not receive redundant failure events. sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) @@ -1104,7 +1105,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // The first result task fails, with a fetch failure for the output from the first mapper. runEvent(makeCompletionEvent( taskSets(1).tasks(0), - FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + FetchFailed(makeShuffleLocationSeq("hostA"), shuffleId, 0, 0, "ignored"), null)) assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq(0, 1))) @@ -1201,17 +1202,17 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // The MapOutputTracker should know about both map output locations. assert(mapOutputTracker .getMapSizesByShuffleLocation(shuffleId, 0) - .map(_._1.get.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) + .map(_._1(0).asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) .toSet === HashSet("hostA", "hostB")) assert(mapOutputTracker .getMapSizesByShuffleLocation(shuffleId, 1) - .map(_._1.get.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) + .map(_._1(0).asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) .toSet === HashSet("hostA", "hostB")) // The first result task fails, with a fetch failure for the output from the first mapper. runEvent(makeCompletionEvent( taskSets(1).tasks(0), - FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + FetchFailed(makeShuffleLocationSeq("hostA"), shuffleId, 0, 0, "ignored"), null)) sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(sparkListener.failedStages.contains(1)) @@ -1226,7 +1227,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // The second ResultTask fails, with a fetch failure for the output from the second mapper. runEvent(makeCompletionEvent( taskSets(1).tasks(1), - FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"), + FetchFailed(makeShuffleLocationSeq("hostB"), shuffleId, 1, 1, "ignored"), null)) // Another ResubmitFailedStages event should not result in another attempt for the map @@ -1275,7 +1276,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // The first result task fails, with a fetch failure for the output from the first mapper. runEvent(makeCompletionEvent( taskSets(1).tasks(0), - FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + FetchFailed(makeShuffleLocationSeq("hostA"), shuffleId, 0, 0, "ignored"), null)) // Trigger resubmission of the failed map stage and finish the re-started map task. @@ -1291,7 +1292,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // A late FetchFailed arrives from the second task in the original reduce stage. runEvent(makeCompletionEvent( taskSets(1).tasks(1), - FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"), + FetchFailed(makeShuffleLocationSeq("hostB"), shuffleId, 1, 1, "ignored"), null)) // Running ResubmitFailedStages shouldn't result in any more attempts for the map stage, because @@ -1397,7 +1398,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi makeMapStatus("hostA", reduceRdd.partitions.size))) assert(shuffleStage.numAvailableOutputs === 2) assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === - HashSet(makeMaybeShuffleLocation("hostB"), makeMaybeShuffleLocation("hostA"))) + HashSet(makeShuffleLocationSeq("hostB"), makeShuffleLocationSeq("hostA"))) // finish the next stage normally, which completes the job complete(taskSets(1), Seq((Success, 42), (Success, 43))) @@ -1491,7 +1492,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi runEvent(ExecutorLost("exec-hostA", ExecutorKilled)) runEvent(makeCompletionEvent( taskSets(1).tasks(0), - FetchFailed(null, firstShuffleId, 2, 0, "Fetch failed"), + FetchFailed(Seq.empty, firstShuffleId, 2, 0, "Fetch failed"), null)) // so we resubmit stage 0, which completes happily @@ -1751,7 +1752,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // lets say there is a fetch failure in this task set, which makes us go back and // run stage 0, attempt 1 complete(taskSets(1), Seq( - (FetchFailed(makeBlockManagerId("hostA"), shuffleDep1.shuffleId, 0, 0, "ignored"), null))) + (FetchFailed(makeShuffleLocationSeq("hostA"), + shuffleDep1.shuffleId, 0, 0, "ignored"), null))) scheduler.resubmitFailedStages() // stage 0, attempt 1 should have the properties of job2 @@ -1803,7 +1805,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // have hostC complete the resubmitted task complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === - HashSet(makeMaybeShuffleLocation("hostC"), makeMaybeShuffleLocation("hostB"))) + HashSet(makeShuffleLocationSeq("hostC"), makeShuffleLocationSeq("hostB"))) // Make sure that the reduce stage was now submitted. assert(taskSets.size === 3) @@ -1832,7 +1834,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostC", 1)))) // fail the third stage because hostA went down complete(taskSets(2), Seq( - (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null))) + (FetchFailed(makeShuffleLocationSeq("hostA"), + shuffleDepTwo.shuffleId, 0, 0, "ignored"), null))) // TODO assert this: // blockManagerMaster.removeExecutor("exec-hostA") // have DAGScheduler try again @@ -1863,7 +1866,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostB", 1)))) // pretend stage 2 failed because hostA went down complete(taskSets(2), Seq( - (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null))) + (FetchFailed(makeShuffleLocationSeq("hostA"), + shuffleDepTwo.shuffleId, 0, 0, "ignored"), null))) // TODO assert this: // blockManagerMaster.removeExecutor("exec-hostA") // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun. @@ -2066,7 +2070,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === - HashSet(makeMaybeShuffleLocation("hostA"))) + HashSet(makeShuffleLocationSeq("hostA"))) // Reducer should run on the same host that map task ran val reduceTaskSet = taskSets(1) @@ -2112,7 +2116,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === - HashSet(makeMaybeShuffleLocation("hostA"))) + HashSet(makeShuffleLocationSeq("hostA"))) // Reducer should run where RDD 2 has preferences, even though it also has a shuffle dep val reduceTaskSet = taskSets(1) @@ -2224,7 +2228,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi submit(reduceRdd, Array(0, 1)) complete(taskSets(1), Seq( (Success, 42), - (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null))) + (FetchFailed(makeShuffleLocationSeq("hostA"), shuffleId, 0, 0, "ignored"), null))) // Ask the scheduler to try it again; TaskSet 2 will rerun the map task that we couldn't fetch // from, then TaskSet 3 will run the reduce stage scheduler.resubmitFailedStages() @@ -2276,14 +2280,14 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostA", rdd1.partitions.length)), (Success, makeMapStatus("hostB", rdd1.partitions.length)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet === - HashSet(makeMaybeShuffleLocation("hostA"), makeMaybeShuffleLocation("hostB"))) + HashSet(makeShuffleLocationSeq("hostA"), makeShuffleLocationSeq("hostB"))) assert(listener1.results.size === 1) // When attempting the second stage, show a fetch failure assert(taskSets(1).stageId === 1) complete(taskSets(1), Seq( (Success, makeMapStatus("hostA", rdd2.partitions.length)), - (FetchFailed(makeBlockManagerId("hostA"), dep1.shuffleId, 0, 0, "ignored"), null))) + (FetchFailed(makeShuffleLocationSeq("hostA"), dep1.shuffleId, 0, 0, "ignored"), null))) scheduler.resubmitFailedStages() assert(listener2.results.size === 0) // Second stage listener should not have a result yet @@ -2292,7 +2296,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(2), Seq( (Success, makeMapStatus("hostC", rdd2.partitions.length)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet === - HashSet(makeMaybeShuffleLocation("hostC"), makeMaybeShuffleLocation("hostB"))) + HashSet(makeShuffleLocationSeq("hostC"), makeShuffleLocationSeq("hostB"))) assert(listener2.results.size === 0) // Second stage listener should still not have a result @@ -2302,7 +2306,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostB", rdd2.partitions.length)), (Success, makeMapStatus("hostD", rdd2.partitions.length)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(dep2.shuffleId, 0).map(_._1).toSet === - HashSet(makeMaybeShuffleLocation("hostB"), makeMaybeShuffleLocation("hostD"))) + HashSet(makeShuffleLocationSeq("hostB"), makeShuffleLocationSeq("hostD"))) assert(listener2.results.size === 1) // Finally, the reduce job should be running as task set 4; make it see a fetch failure, @@ -2310,7 +2314,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(taskSets(4).stageId === 2) complete(taskSets(4), Seq( (Success, 52), - (FetchFailed(makeBlockManagerId("hostD"), dep2.shuffleId, 0, 0, "ignored"), null))) + (FetchFailed(makeShuffleLocationSeq("hostD"), dep2.shuffleId, 0, 0, "ignored"), null))) scheduler.resubmitFailedStages() // TaskSet 5 will rerun stage 1's lost task, then TaskSet 6 will rerun stage 2 @@ -2341,14 +2345,14 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostA", rdd1.partitions.length)), (Success, makeMapStatus("hostB", rdd1.partitions.length)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet === - HashSet(makeMaybeShuffleLocation("hostA"), makeMaybeShuffleLocation("hostB"))) + HashSet(makeShuffleLocationSeq("hostA"), makeShuffleLocationSeq("hostB"))) assert(listener1.results.size === 1) // When attempting stage1, trigger a fetch failure. assert(taskSets(1).stageId === 1) complete(taskSets(1), Seq( (Success, makeMapStatus("hostC", rdd2.partitions.length)), - (FetchFailed(makeBlockManagerId("hostA"), dep1.shuffleId, 0, 0, "ignored"), null))) + (FetchFailed(makeShuffleLocationSeq("hostA"), dep1.shuffleId, 0, 0, "ignored"), null))) scheduler.resubmitFailedStages() // Stage1 listener should not have a result yet assert(listener2.results.size === 0) @@ -2367,7 +2371,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(2), Seq( (Success, makeMapStatus("hostC", rdd2.partitions.length)))) assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet === - Set(makeMaybeShuffleLocation("hostC"), makeMaybeShuffleLocation("hostB"))) + Set(makeShuffleLocationSeq("hostC"), makeShuffleLocationSeq("hostB"))) // After stage0 is finished, stage1 will be submitted and found there is no missing // partitions in it. Then listener got triggered. @@ -2483,7 +2487,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi rdd1.map { case (x, _) if (x == 1) => throw new FetchFailedException( - BlockManagerId("1", "1", 1), shuffleHandle.shuffleId, 0, 0, "test") + makeShuffleLocationSeq("1", "1", 1), shuffleHandle.shuffleId, 0, 0, "test") case (x, _) => x }.count() } @@ -2496,7 +2500,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi rdd1.map { case (x, _) if (x == 1) && FailThisAttempt._fail.getAndSet(false) => throw new FetchFailedException( - BlockManagerId("1", "1", 1), shuffleHandle.shuffleId, 0, 0, "test") + makeShuffleLocationSeq("1", "1", 1), shuffleHandle.shuffleId, 0, 0, "test") } } @@ -2550,7 +2554,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(taskSets(1).stageId === 1 && taskSets(1).stageAttemptId === 0) runEvent(makeCompletionEvent( taskSets(1).tasks(0), - FetchFailed(makeBlockManagerId("hostA"), shuffleIdA, 0, 0, + FetchFailed(makeShuffleLocationSeq("hostA"), shuffleIdA, 0, 0, "Fetch failure of task: stageId=1, stageAttempt=0, partitionId=0"), result = null)) @@ -2738,7 +2742,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // The first task of the final stage failed with fetch failure runEvent(makeCompletionEvent( taskSets(2).tasks(0), - FetchFailed(makeBlockManagerId("hostC"), shuffleId2, 0, 0, "ignored"), + FetchFailed(makeShuffleLocationSeq("hostC"), shuffleId2, 0, 0, "ignored"), null)) val failedStages = scheduler.failedStages.toSeq @@ -2757,7 +2761,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // The first task of the `shuffleMapRdd2` failed with fetch failure runEvent(makeCompletionEvent( taskSets(3).tasks(0), - FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0, 0, "ignored"), + FetchFailed(makeShuffleLocationSeq("hostA"), shuffleId1, 0, 0, "ignored"), null)) // The job should fail because Spark can't rollback the shuffle map stage. @@ -2782,7 +2786,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // Fail the second task with FetchFailed. runEvent(makeCompletionEvent( taskSets.last.tasks(1), - FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + FetchFailed(makeShuffleLocationSeq("hostA"), shuffleId, 0, 0, "ignored"), null)) // The job should fail because Spark can't rollback the result stage. @@ -2825,7 +2829,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // Fail the second task with FetchFailed. runEvent(makeCompletionEvent( taskSets.last.tasks(1), - FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + FetchFailed(makeShuffleLocationSeq("hostA"), shuffleId, 0, 0, "ignored"), null)) assert(failure == null, "job should not fail") @@ -2871,7 +2875,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } } - private def assertDataStructuresEmpty(): Unit = { + def assertDataStructuresEmpty(): Unit = { assert(scheduler.activeJobs.isEmpty) assert(scheduler.failedStages.isEmpty) assert(scheduler.jobIdToActiveJob.isEmpty) @@ -2924,8 +2928,12 @@ object DAGSchedulerSuite { DefaultMapShuffleLocations.get(makeBlockManagerId(host)) } - def makeMaybeShuffleLocation(host: String): Option[MapShuffleLocations] = { - Some(DefaultMapShuffleLocations.get(makeBlockManagerId(host))) + def makeShuffleLocationSeq(execId: String, host: String, port: Int): Seq[ShuffleLocation] = { + Seq(DefaultMapShuffleLocations.get(makeBlockManagerId(host))) + } + + def makeShuffleLocationSeq(host: String): Seq[ShuffleLocation] = { + Seq(DefaultMapShuffleLocations.get(makeBlockManagerId(host))) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index a560013dba96..a5c3b21e3e61 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -36,6 +36,7 @@ import org.apache.spark._ import org.apache.spark.internal.io.{FileCommitProtocol, HadoopMapRedCommitProtocol, SparkHadoopWriterUtils} import org.apache.spark.rdd.{FakeOutputCommitter, RDD} import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.util.{ThreadUtils, Utils} /** @@ -257,8 +258,10 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { .reduceByKey { case (_, _) => val ctx = TaskContext.get() if (ctx.stageAttemptNumber() == 0) { - throw new FetchFailedException(SparkEnv.get.blockManager.blockManagerId, 1, 1, 1, - new Exception("Failure for test.")) + throw new FetchFailedException( + shuffleLocations = + Seq(DefaultMapShuffleLocations.get(SparkEnv.get.blockManager.blockManagerId)), + shuffleId = 1, mapId = 1, reduceId = 1, cause = new Exception("Failure for test.")) } else { ctx.stageId() } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala index 83305a96e679..d60004b9c1f9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala @@ -627,7 +627,7 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor backend.taskSuccess(taskDescription, DAGSchedulerSuite.makeMapStatus("hostA", 10)) case (1, 0, 0) => val fetchFailed = FetchFailed( - DAGSchedulerSuite.makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored") + DAGSchedulerSuite.makeShuffleLocationSeq("hostA"), shuffleId, 0, 0, "ignored") backend.taskFailed(taskDescription, fetchFailed) case (1, _, partition) => backend.taskSuccess(taskDescription, 42 + partition) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 27369759fad5..29dfb8d43715 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -176,7 +176,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark if (stageAttemptNumber < 2) { // Throw FetchFailedException to explicitly trigger stage resubmission. A normal exception // will only trigger task resubmission in the same stage. - throw new FetchFailedException(null, 0, 0, 0, "Fake") + throw new FetchFailedException(Seq.empty, 0, 0, 0, "Fake") } Seq(stageAttemptNumber).iterator }.collect() diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 1a81f556e061..3a3cf61657c7 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -1137,7 +1137,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(tsm.runningTasks === 10) // fail attempt tsm.handleFailedTask(tsm.taskAttempts.head.head.taskId, TaskState.FAILED, - FetchFailed(null, 0, 0, 0, "fetch failed")) + FetchFailed(Seq.empty, 0, 0, 0, "fetch failed")) // the attempt is a zombie, but the tasks are still running (this could be true even if // we actively killed those tasks, as killing is best-effort) assert(tsm.isZombie) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 60acd3ed4cd4..830a6a77ab39 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config import org.apache.spark.serializer.SerializerInstance +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.{AccumulatorV2, ManualClock, Utils} @@ -1246,7 +1247,8 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // now fail those tasks tsmSpy.handleFailedTask(taskDescs(0).taskId, TaskState.FAILED, - FetchFailed(BlockManagerId(taskDescs(0).executorId, "host1", 12345), 0, 0, 0, "ignored")) + FetchFailed(Seq(DefaultMapShuffleLocations.get( + BlockManagerId(taskDescs(0).executorId, "host1", 12345))), 0, 0, 0, "ignored")) tsmSpy.handleFailedTask(taskDescs(1).taskId, TaskState.FAILED, ExecutorLostFailure(taskDescs(1).executorId, exitCausedByApp = false, reason = None)) tsmSpy.handleFailedTask(taskDescs(2).taskId, TaskState.FAILED, @@ -1286,7 +1288,9 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // Fail the task with fetch failure tsm.handleFailedTask(taskDescs(0).taskId, TaskState.FAILED, - FetchFailed(BlockManagerId(taskDescs(0).executorId, "host1", 12345), 0, 0, 0, "ignored")) + FetchFailed( + Seq(DefaultMapShuffleLocations.get( + BlockManagerId(taskDescs(0).executorId, "host1", 12345))), 0, 0, 0, "ignored")) assert(blacklistTracker.isNodeBlacklisted("host1")) } @@ -1393,7 +1397,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // fail fetch taskSetManager1.handleFailedTask( taskSetManager1.taskAttempts.head.head.taskId, TaskState.FAILED, - FetchFailed(null, 0, 0, 0, "fetch failed")) + FetchFailed(Seq.empty, 0, 0, 0, "fetch failed")) assert(taskSetManager1.isZombie) assert(taskSetManager1.runningTasks === 9) diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 6468914bf318..7f4ebaaddd87 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -116,12 +116,13 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext (shuffleBlockId, byteOutputStream.size().toLong) } val blocksToRetrieve = Seq( - (Option.apply(DefaultMapShuffleLocations.get(localBlockManagerId)), shuffleBlockIdsAndSizes)) + (Seq(DefaultMapShuffleLocations.get(localBlockManagerId).asInstanceOf[ShuffleLocation]), + shuffleBlockIdsAndSizes)) val mapOutputTracker = mock(classOf[MapOutputTracker]) when(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, reduceId, reduceId + 1)) - .thenAnswer(new Answer[Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])]] { + .thenAnswer(new Answer[Iterator[(Seq[ShuffleLocation], Seq[(BlockId, Long)])]] { def answer(invocationOnMock: InvocationOnMock): - Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = { + Iterator[(Seq[ShuffleLocation], Seq[(BlockId, Long)])] = { blocksToRetrieve.iterator } }) diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala index dbb954945a8b..e8372c045860 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleDriverComponentsSuite.scala @@ -22,8 +22,9 @@ import java.util import com.google.common.collect.ImmutableMap import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} -import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents, ShuffleWriteSupport} +import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleDriverComponents, ShuffleExecutorComponents, ShuffleReadSupport, ShuffleWriteSupport} import org.apache.spark.internal.config.SHUFFLE_IO_PLUGIN_CLASS +import org.apache.spark.shuffle.io.DefaultShuffleReadSupport import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport class ShuffleDriverComponentsSuite extends SparkFunSuite with LocalSparkContext { @@ -66,6 +67,13 @@ class TestShuffleExecutorComponents(sparkConf: SparkConf) extends ShuffleExecuto override def writes(): ShuffleWriteSupport = { val blockManager = SparkEnv.get.blockManager val blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager) - new DefaultShuffleWriteSupport(sparkConf, blockResolver) + new DefaultShuffleWriteSupport(sparkConf, blockResolver, blockManager.shuffleServerId) + } + + override def reads(): ShuffleReadSupport = { + val blockManager = SparkEnv.get.blockManager + val mapOutputTracker = SparkEnv.get.mapOutputTracker + val serializerManager = SparkEnv.get.serializerManager + new DefaultShuffleReadSupport(blockManager, mapOutputTracker, serializerManager, sparkConf) } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala index 4f5bb264170d..32df6633886f 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala @@ -197,14 +197,14 @@ object BlockStoreShuffleReaderBenchmark extends BenchmarkBase { } when(mapOutputTracker.getMapSizesByShuffleLocation(0, 0, 1)) - .thenAnswer(new Answer[Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])]] { + .thenAnswer(new Answer[Iterator[(Seq[ShuffleLocation], Seq[(BlockId, Long)])]] { def answer(invocationOnMock: InvocationOnMock): - Iterator[(Option[ShuffleLocation], Seq[(BlockId, Long)])] = { + Iterator[(Seq[ShuffleLocation], Seq[(BlockId, Long)])] = { val shuffleBlockIdsAndSizes = (0 until NUM_MAPS).map { mapId => val shuffleBlockId = ShuffleBlockId(0, mapId, 0) (shuffleBlockId, dataFileLength) } - Seq((Option.apply(DefaultMapShuffleLocations.get(dataBlockId)), shuffleBlockIdsAndSizes)) + Seq((Seq(DefaultMapShuffleLocations.get(dataBlockId)), shuffleBlockIdsAndSizes)) .toIterator } }) diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index b184b74bf3cb..339e7844bd5b 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -43,6 +43,7 @@ import org.apache.spark.internal.config._ import org.apache.spark.internal.config.Status._ import org.apache.spark.internal.config.UI._ import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.status.api.v1.{JacksonMessageWriter, RDDDataDistribution, StageStatus} private[spark] class SparkUICssErrorHandler extends DefaultCssErrorHandler { @@ -315,7 +316,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B val mapId = 0 val reduceId = taskContext.partitionId() val message = "Simulated fetch failure" - throw new FetchFailedException(bmAddress, shuffleId, mapId, reduceId, message) + throw new FetchFailedException( + Seq(DefaultMapShuffleLocations.get(bmAddress)), shuffleId, mapId, reduceId, message) } else { x } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index c3ff379c84ff..1bc27bdf45a4 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -35,6 +35,7 @@ import org.apache.spark.rdd.RDDOperationScope import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.shuffle.MetadataFetchFailedException +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.storage._ class JsonProtocolSuite extends SparkFunSuite { @@ -169,8 +170,9 @@ class JsonProtocolSuite extends SparkFunSuite { testJobResult(jobFailed) // TaskEndReason - val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19, - "Some exception") + val fetchFailed = FetchFailed( + Seq(DefaultMapShuffleLocations.get(BlockManagerId("With or", "without you", 15))), + 17, 18, 19, "Some exception") val fetchMetadataFailed = new MetadataFetchFailedException(17, 19, "metadata Fetch failed exception").toTaskFailedReason val exceptionFailure = new ExceptionFailure(exception, Seq.empty[AccumulableInfo]) @@ -286,12 +288,14 @@ class JsonProtocolSuite extends SparkFunSuite { test("FetchFailed backwards compatibility") { // FetchFailed in Spark 1.1.0 does not have a "Message" property. - val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19, - "ignored") + val fetchFailed = FetchFailed( + Seq(DefaultMapShuffleLocations.get(BlockManagerId("With or", "without you", 15))), + 17, 18, 19, "ignored") val oldEvent = JsonProtocol.taskEndReasonToJson(fetchFailed) .removeField({ _._1 == "Message" }) - val expectedFetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19, - "Unknown reason") + val expectedFetchFailed = FetchFailed( + Seq(DefaultMapShuffleLocations.get(BlockManagerId("With or", "without you", 15))), + 17, 18, 19, "Unknown reason") assert(expectedFetchFailed === JsonProtocol.taskEndReasonFromJson(oldEvent)) } @@ -713,7 +717,7 @@ private[spark] object JsonProtocolSuite extends Assertions { assert(r1.shuffleId === r2.shuffleId) assert(r1.mapId === r2.mapId) assert(r1.reduceId === r2.reduceId) - assert(r1.bmAddress === r2.bmAddress) + assert(r1.shuffleLocation === r2.shuffleLocation) assert(r1.message === r2.message) case (r1: ExceptionFailure, r2: ExceptionFailure) => assert(r1.className === r2.className)