diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala new file mode 100644 index 0000000000000..5e546c694e8d9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -0,0 +1,235 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.util.{Timer, TimerTask} +import java.util.concurrent.ConcurrentHashMap +import java.util.function.{Consumer, Function} + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted} + +/** + * For each barrier stage attempt, only at most one barrier() call can be active at any time, thus + * we can use (stageId, stageAttemptId) to identify the stage attempt where the barrier() call is + * from. + */ +private case class ContextBarrierId(stageId: Int, stageAttemptId: Int) { + override def toString: String = s"Stage $stageId (Attempt $stageAttemptId)" +} + +/** + * A coordinator that handles all global sync requests from BarrierTaskContext. Each global sync + * request is generated by `BarrierTaskContext.barrier()`, and identified by + * stageId + stageAttemptId + barrierEpoch. Reply all the blocking global sync requests upon + * all the requests for a group of `barrier()` calls are received. If the coordinator is unable to + * collect enough global sync requests within a configured time, fail all the requests and return + * an Exception with timeout message. + */ +private[spark] class BarrierCoordinator( + timeoutInSecs: Long, + listenerBus: LiveListenerBus, + override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { + + // TODO SPARK-25030 Create a Timer() in the mainClass submitted to SparkSubmit makes it unable to + // fetch result, we shall fix the issue. + private lazy val timer = new Timer("BarrierCoordinator barrier epoch increment timer") + + // Listen to StageCompleted event, clear corresponding ContextBarrierState. + private val listener = new SparkListener { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { + val stageInfo = stageCompleted.stageInfo + val barrierId = ContextBarrierId(stageInfo.stageId, stageInfo.attemptNumber) + // Clear ContextBarrierState from a finished stage attempt. + cleanupBarrierStage(barrierId) + } + } + + // Record all active stage attempts that make barrier() call(s), and the corresponding internal + // state. + private val states = new ConcurrentHashMap[ContextBarrierId, ContextBarrierState] + + override def onStart(): Unit = { + super.onStart() + listenerBus.addToStatusQueue(listener) + } + + override def onStop(): Unit = { + try { + states.forEachValue(1, clearStateConsumer) + states.clear() + listenerBus.removeListener(listener) + } finally { + super.onStop() + } + } + + /** + * Provide the current state of a barrier() call. A state is created when a new stage attempt + * sends out a barrier() call, and recycled on stage completed. + * + * @param barrierId Identifier of the barrier stage that make a barrier() call. + * @param numTasks Number of tasks of the barrier stage, all barrier() calls from the stage shall + * collect `numTasks` requests to succeed. + */ + private class ContextBarrierState( + val barrierId: ContextBarrierId, + val numTasks: Int) { + + // There may be multiple barrier() calls from a barrier stage attempt, `barrierEpoch` is used + // to identify each barrier() call. It shall get increased when a barrier() call succeeds, or + // reset when a barrier() call fails due to timeout. + private var barrierEpoch: Int = 0 + + // An array of RPCCallContexts for barrier tasks that are waiting for reply of a barrier() + // call. + private val requesters: ArrayBuffer[RpcCallContext] = new ArrayBuffer[RpcCallContext](numTasks) + + // A timer task that ensures we may timeout for a barrier() call. + private var timerTask: TimerTask = null + + // Init a TimerTask for a barrier() call. + private def initTimerTask(): Unit = { + timerTask = new TimerTask { + override def run(): Unit = synchronized { + // Timeout current barrier() call, fail all the sync requests. + requesters.foreach(_.sendFailure(new SparkException("The coordinator didn't get all " + + s"barrier sync requests for barrier epoch $barrierEpoch from $barrierId within " + + s"$timeoutInSecs second(s)."))) + cleanupBarrierStage(barrierId) + } + } + } + + // Cancel the current active TimerTask and release resources. + private def cancelTimerTask(): Unit = { + if (timerTask != null) { + timerTask.cancel() + timerTask = null + } + } + + // Process the global sync request. The barrier() call succeed if collected enough requests + // within a configured time, otherwise fail all the pending requests. + def handleRequest(requester: RpcCallContext, request: RequestToSync): Unit = synchronized { + val taskId = request.taskAttemptId + val epoch = request.barrierEpoch + + // Require the number of tasks is correctly set from the BarrierTaskContext. + require(request.numTasks == numTasks, s"Number of tasks of $barrierId is " + + s"${request.numTasks} from Task $taskId, previously it was $numTasks.") + + // Check whether the epoch from the barrier tasks matches current barrierEpoch. + logInfo(s"Current barrier epoch for $barrierId is $barrierEpoch.") + if (epoch != barrierEpoch) { + requester.sendFailure(new SparkException(s"The request to sync of $barrierId with " + + s"barrier epoch $barrierEpoch has already finished. Maybe task $taskId is not " + + "properly killed.")) + } else { + // If this is the first sync message received for a barrier() call, start timer to ensure + // we may timeout for the sync. + if (requesters.isEmpty) { + initTimerTask() + timer.schedule(timerTask, timeoutInSecs * 1000) + } + // Add the requester to array of RPCCallContexts pending for reply. + requesters += requester + logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " + + s"$taskId, current progress: ${requesters.size}/$numTasks.") + if (maybeFinishAllRequesters(requesters, numTasks)) { + // Finished current barrier() call successfully, clean up ContextBarrierState and + // increase the barrier epoch. + logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received all updates from " + + s"tasks, finished successfully.") + barrierEpoch += 1 + requesters.clear() + cancelTimerTask() + } + } + } + + // Finish all the blocking barrier sync requests from a stage attempt successfully if we + // have received all the sync requests. + private def maybeFinishAllRequesters( + requesters: ArrayBuffer[RpcCallContext], + numTasks: Int): Boolean = { + if (requesters.size == numTasks) { + requesters.foreach(_.reply(())) + true + } else { + false + } + } + + // Cleanup the internal state of a barrier stage attempt. + def clear(): Unit = synchronized { + // The global sync fails so the stage is expected to retry another attempt, all sync + // messages come from current stage attempt shall fail. + barrierEpoch = -1 + requesters.clear() + cancelTimerTask() + } + } + + // Clean up the [[ContextBarrierState]] that correspond to a specific stage attempt. + private def cleanupBarrierStage(barrierId: ContextBarrierId): Unit = { + val barrierState = states.remove(barrierId) + if (barrierState != null) { + barrierState.clear() + } + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _) => + // Get or init the ContextBarrierState correspond to the stage attempt. + val barrierId = ContextBarrierId(stageId, stageAttemptId) + states.computeIfAbsent(barrierId, new Function[ContextBarrierId, ContextBarrierState] { + override def apply(key: ContextBarrierId): ContextBarrierState = + new ContextBarrierState(key, numTasks) + }) + val barrierState = states.get(barrierId) + + barrierState.handleRequest(context, request) + } + + private val clearStateConsumer = new Consumer[ContextBarrierState] { + override def accept(state: ContextBarrierState) = state.clear() + } +} + +private[spark] sealed trait BarrierCoordinatorMessage extends Serializable + +/** + * A global sync request message from BarrierTaskContext, by `barrier()` call. Each request is + * identified by stageId + stageAttemptId + barrierEpoch. + * + * @param numTasks The number of global sync requests the BarrierCoordinator shall receive + * @param stageId ID of current stage + * @param stageAttemptId ID of current stage attempt + * @param taskAttemptId Unique ID of current task + * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls. + */ +private[spark] case class RequestToSync( + numTasks: Int, + stageId: Int, + stageAttemptId: Int, + taskAttemptId: Long, + barrierEpoch: Int) extends BarrierCoordinatorMessage diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala index ba303680d1a0f..8e2b15599b674 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -17,12 +17,17 @@ package org.apache.spark -import java.util.Properties +import java.util.{Properties, Timer, TimerTask} + +import scala.concurrent.duration._ +import scala.language.postfixOps import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.rpc.{RpcEndpointRef, RpcTimeout} +import org.apache.spark.util.{RpcUtils, Utils} /** A [[TaskContext]] with extra info and tooling for a barrier stage. */ class BarrierTaskContext( @@ -39,6 +44,22 @@ class BarrierTaskContext( extends TaskContextImpl(stageId, stageAttemptNumber, partitionId, taskAttemptId, attemptNumber, taskMemoryManager, localProperties, metricsSystem, taskMetrics) { + // Find the driver side RPCEndpointRef of the coordinator that handles all the barrier() calls. + private val barrierCoordinator: RpcEndpointRef = { + val env = SparkEnv.get + RpcUtils.makeDriverRef("barrierSync", env.conf, env.rpcEnv) + } + + private val timer = new Timer("Barrier task timer for barrier() calls.") + + // Local barrierEpoch that identify a barrier() call from current task, it shall be identical + // with the driver side epoch. + private var barrierEpoch = 0 + + // Number of tasks of the current barrier stage, a barrier() call must collect enough requests + // from different tasks within the same barrier stage attempt to succeed. + private lazy val numTasks = getTaskInfos().size + /** * :: Experimental :: * Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to @@ -80,7 +101,44 @@ class BarrierTaskContext( @Experimental @Since("2.4.0") def barrier(): Unit = { - // TODO SPARK-24817 implement global barrier. + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " + + s"the global sync, current barrier epoch is $barrierEpoch.") + logTrace("Current callSite: " + Utils.getCallSite()) + + val startTime = System.currentTimeMillis() + val timerTask = new TimerTask { + override def run(): Unit = { + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) waiting " + + s"under the global sync since $startTime, has been waiting for " + + s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch " + + s"is $barrierEpoch.") + } + } + // Log the update of global sync every 60 seconds. + timer.schedule(timerTask, 60000, 60000) + + try { + barrierCoordinator.askSync[Unit]( + message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, + barrierEpoch), + // Set a fixed timeout for RPC here, so users shall get a SparkException thrown by + // BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework. + timeout = new RpcTimeout(31536000 /* = 3600 * 24 * 365 */ seconds, "barrierTimeout")) + barrierEpoch += 1 + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) finished " + + "global sync successfully, waited for " + + s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch is " + + s"$barrierEpoch.") + } catch { + case e: SparkException => + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) failed " + + "to perform global sync, waited for " + + s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch " + + s"is $barrierEpoch.") + throw e + } finally { + timerTask.cancel() + } } /** diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 78ba0b31fc6bb..ba13567459e1d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1930,6 +1930,12 @@ class SparkContext(config: SparkConf) extends Logging { Utils.tryLogNonFatalError { _executorAllocationManager.foreach(_.stop()) } + if (_dagScheduler != null) { + Utils.tryLogNonFatalError { + _dagScheduler.stop() + } + _dagScheduler = null + } if (_listenerBusStarted) { Utils.tryLogNonFatalError { listenerBus.stop() @@ -1939,12 +1945,6 @@ class SparkContext(config: SparkConf) extends Logging { Utils.tryLogNonFatalError { _eventLogger.foreach(_.stop()) } - if (_dagScheduler != null) { - Utils.tryLogNonFatalError { - _dagScheduler.stop() - } - _dagScheduler = null - } if (env != null && _heartbeatReceiver != null) { Utils.tryLogNonFatalError { env.rpcEnv.stop(_heartbeatReceiver) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 8fef2aa6863c5..eb08628ce1112 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -567,4 +567,14 @@ package object config { .intConf .checkValue(v => v > 0, "The value should be a positive integer.") .createWithDefault(2000) + + private[spark] val BARRIER_SYNC_TIMEOUT = + ConfigBuilder("spark.barrier.sync.timeout") + .doc("The timeout in seconds for each barrier() call from a barrier task. If the " + + "coordinator didn't receive all the sync messages from barrier tasks within the " + + "configed time, throw a SparkException to fail all the tasks. The default value is set " + + "to 31536000(3600 * 24 * 365) so the barrier() call shall wait for one year.") + .timeConf(TimeUnit.SECONDS) + .checkValue(v => v > 0, "The value should be a positive time value.") + .createWithDefaultString("365d") } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 72691389d271c..8992d7e2284a4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -30,6 +30,7 @@ import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.internal.Logging import org.apache.spark.internal.config +import org.apache.spark.rpc.RpcEndpoint import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.scheduler.TaskLocality.TaskLocality import org.apache.spark.storage.BlockManagerId @@ -138,6 +139,19 @@ private[spark] class TaskSchedulerImpl( // This is a var so that we can reset it for testing purposes. private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this) + private lazy val barrierSyncTimeout = conf.get(config.BARRIER_SYNC_TIMEOUT) + + private[scheduler] var barrierCoordinator: RpcEndpoint = null + + private def maybeInitBarrierCoordinator(): Unit = { + if (barrierCoordinator == null) { + barrierCoordinator = new BarrierCoordinator(barrierSyncTimeout, sc.listenerBus, + sc.env.rpcEnv) + sc.env.rpcEnv.setupEndpoint("barrierSync", barrierCoordinator) + logInfo("Registered BarrierCoordinator endpoint") + } + } + override def setDAGScheduler(dagScheduler: DAGScheduler) { this.dagScheduler = dagScheduler } @@ -413,6 +427,9 @@ private[spark] class TaskSchedulerImpl( s"${taskSet.numTasks} tasks got resource offers. The resource offers may have " + "been blacklisted or cannot fulfill task locality requirements.") + // materialize the barrier coordinator. + maybeInitBarrierCoordinator() + // Update the taskInfos into all the barrier task properties. val addressesStr = addressesWithDescs // Addresses ordered by partitionId @@ -566,6 +583,9 @@ private[spark] class TaskSchedulerImpl( if (taskResultGetter != null) { taskResultGetter.stop() } + if (barrierCoordinator != null) { + barrierCoordinator.stop() + } starvationTimer.cancel() } diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala new file mode 100644 index 0000000000000..5f96d6fb0cdb6 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import scala.util.Random + +import org.apache.spark._ + +class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { + + test("global sync by barrier() call") { + val conf = new SparkConf() + // Init local cluster here so each barrier task runs in a separated process, thus `barrier()` + // call is actually useful. + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { (it, context) => + // Sleep for a random time before global sync. + Thread.sleep(Random.nextInt(1000)) + context.barrier() + Seq(System.currentTimeMillis()).iterator + } + + val times = rdd2.collect() + // All the tasks shall finish global sync within a short time slot. + assert(times.max - times.min <= 1000) + } + + test("support multiple barrier() call within a single task") { + val conf = new SparkConf() + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { (it, context) => + // Sleep for a random time before global sync. + Thread.sleep(Random.nextInt(1000)) + context.barrier() + val time1 = System.currentTimeMillis() + // Sleep for a random time between two global syncs. + Thread.sleep(Random.nextInt(1000)) + context.barrier() + val time2 = System.currentTimeMillis() + Seq((time1, time2)).iterator + } + + val times = rdd2.collect() + // All the tasks shall finish the first round of global sync within a short time slot. + val times1 = times.map(_._1) + assert(times1.max - times1.min <= 1000) + + // All the tasks shall finish the second round of global sync within a short time slot. + val times2 = times.map(_._2) + assert(times2.max - times2.min <= 1000) + } + + test("throw exception on barrier() call timeout") { + val conf = new SparkConf() + .set("spark.barrier.sync.timeout", "1") + .set("spark.test.noStageRetry", "true") + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { (it, context) => + // Task 3 shall sleep 2000ms to ensure barrier() call timeout + if (context.taskAttemptId == 3) { + Thread.sleep(2000) + } + context.barrier() + it + } + + val error = intercept[SparkException] { + rdd2.collect() + }.getMessage + assert(error.contains("The coordinator didn't get all barrier sync requests")) + assert(error.contains("within 1 second(s)")) + } + + test("throw exception if barrier() call doesn't happen on every task") { + val conf = new SparkConf() + .set("spark.barrier.sync.timeout", "1") + .set("spark.test.noStageRetry", "true") + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { (it, context) => + if (context.taskAttemptId != 0) { + context.barrier() + } + it + } + + val error = intercept[SparkException] { + rdd2.collect() + }.getMessage + assert(error.contains("The coordinator didn't get all barrier sync requests")) + assert(error.contains("within 1 second(s)")) + } + + test("throw exception if the number of barrier() calls are not the same on every task") { + val conf = new SparkConf() + .set("spark.barrier.sync.timeout", "1") + .set("spark.test.noStageRetry", "true") + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { (it, context) => + try { + if (context.taskAttemptId == 0) { + // Due to some non-obvious reason, the code can trigger an Exception and skip the + // following statements within the try ... catch block, including the first barrier() + // call. + throw new SparkException("test") + } + context.barrier() + } catch { + case e: Exception => // Do nothing + } + context.barrier() + it + } + + val error = intercept[SparkException] { + rdd2.collect() + }.getMessage + assert(error.contains("The coordinator didn't get all barrier sync requests")) + assert(error.contains("within 1 second(s)")) + } +}