Skip to content

Commit 478b443

Browse files
committed
Update HeartbeatReceiver to use RpcEndpoint
1 parent 17b13c5 commit 478b443

File tree

4 files changed

+101
-30
lines changed

4 files changed

+101
-30
lines changed

core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@
1717

1818
package org.apache.spark
1919

20-
import scala.concurrent.duration._
21-
import scala.collection.mutable
20+
import java.util.concurrent.{ScheduledFuture, TimeUnit, Executors}
2221

23-
import akka.actor.{Actor, Cancellable}
22+
import scala.collection.mutable
2423

2524
import org.apache.spark.executor.TaskMetrics
25+
import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint}
2626
import org.apache.spark.storage.BlockManagerId
2727
import org.apache.spark.scheduler.{SlaveLost, TaskScheduler}
28-
import org.apache.spark.util.ActorLogReceive
28+
import org.apache.spark.util.Utils
2929

3030
/**
3131
* A heartbeat from executors to the driver. This is a shared message used by several internal
@@ -45,7 +45,9 @@ private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean)
4545
* Lives in the driver to receive heartbeats from executors..
4646
*/
4747
private[spark] class HeartbeatReceiver(sc: SparkContext, scheduler: TaskScheduler)
48-
extends Actor with ActorLogReceive with Logging {
48+
extends RpcEndpoint with Logging {
49+
50+
override val rpcEnv: RpcEnv = sc.env.rpcEnv
4951

5052
// executor ID -> timestamp of when the last heartbeat from this executor was received
5153
private val executorLastSeen = new mutable.HashMap[String, Long]
@@ -61,24 +63,31 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, scheduler: TaskSchedule
6163
sc.conf.getOption("spark.network.timeoutInterval").map(_.toLong * 1000).
6264
getOrElse(sc.conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000))
6365

64-
private var timeoutCheckingTask: Cancellable = null
65-
66-
override def preStart(): Unit = {
67-
import context.dispatcher
68-
timeoutCheckingTask = context.system.scheduler.schedule(0.seconds,
69-
checkTimeoutIntervalMs.milliseconds, self, ExpireDeadHosts)
70-
super.preStart()
66+
private var timeoutCheckingTask: ScheduledFuture[_] = null
67+
68+
private val messageScheduler = Executors.newSingleThreadScheduledExecutor(
69+
Utils.namedThreadFactory("heart-beat-receiver-thread"))
70+
71+
override def onStart(): Unit = {
72+
timeoutCheckingTask = messageScheduler.scheduleAtFixedRate(new Runnable {
73+
override def run(): Unit = {
74+
self.send(ExpireDeadHosts)
75+
}
76+
}, 0, checkTimeoutIntervalMs, TimeUnit.MILLISECONDS)
7177
}
72-
73-
override def receiveWithLogging: PartialFunction[Any, Unit] = {
78+
79+
override def receive: PartialFunction[Any, Unit] = {
80+
case ExpireDeadHosts =>
81+
expireDeadHosts()
82+
}
83+
84+
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
7485
case Heartbeat(executorId, taskMetrics, blockManagerId) =>
7586
val unknownExecutor = !scheduler.executorHeartbeatReceived(
7687
executorId, taskMetrics, blockManagerId)
7788
val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor)
7889
executorLastSeen(executorId) = System.currentTimeMillis()
79-
sender ! response
80-
case ExpireDeadHosts =>
81-
expireDeadHosts()
90+
context.reply(response)
8291
}
8392

8493
private def expireDeadHosts(): Unit = {
@@ -98,10 +107,9 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, scheduler: TaskSchedule
98107
}
99108
}
100109

101-
override def postStop(): Unit = {
110+
override def onStop(): Unit = {
102111
if (timeoutCheckingTask != null) {
103-
timeoutCheckingTask.cancel()
112+
timeoutCheckingTask.cancel(true)
104113
}
105-
super.postStop()
106114
}
107115
}

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -359,8 +359,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
359359
// Create and start the scheduler
360360
private[spark] var (schedulerBackend, taskScheduler) =
361361
SparkContext.createTaskScheduler(this, master)
362-
private val heartbeatReceiver = env.actorSystem.actorOf(
363-
Props(new HeartbeatReceiver(this, taskScheduler)), "HeartbeatReceiver")
362+
private val heartbeatReceiver = env.rpcEnv.setupThreadSafeEndpoint(
363+
"HeartbeatReceiver", new HeartbeatReceiver(this, taskScheduler))
364+
364365
@volatile private[spark] var dagScheduler: DAGScheduler = _
365366
try {
366367
dagScheduler = new DAGScheduler(this)
@@ -1406,7 +1407,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
14061407
dagScheduler = null
14071408
listenerBus.stop()
14081409
eventLogger.foreach(_.stop())
1409-
env.actorSystem.stop(heartbeatReceiver)
1410+
env.rpcEnv.stop(heartbeatReceiver)
14101411
progressBar.foreach(_.stop())
14111412
taskScheduler = null
14121413
// TODO: Cache.stop()?

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -391,11 +391,7 @@ private[spark] class Executor(
391391
}
392392
}
393393

394-
private val timeout = AkkaUtils.lookupTimeout(conf)
395-
private val retryAttempts = AkkaUtils.numRetries(conf)
396-
private val retryIntervalMs = AkkaUtils.retryWaitMs(conf)
397-
private val heartbeatReceiverRef =
398-
AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem)
394+
private val heartbeatReceiverRef = RpcUtils.makeDriverRef("HeartbeatReceiver", conf, env.rpcEnv)
399395

400396
/** Reports heartbeat and metrics for active tasks to the driver. */
401397
private def reportHeartBeat(): Unit = {
@@ -426,8 +422,7 @@ private[spark] class Executor(
426422

427423
val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId)
428424
try {
429-
val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef,
430-
retryAttempts, retryIntervalMs, timeout)
425+
val response = heartbeatReceiverRef.askWithReply[HeartbeatResponse](message)
431426
if (response.reregisterBlockManager) {
432427
logWarning("Told to re-register on heartbeat")
433428
env.blockManager.reregister()
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark
19+
20+
import org.apache.spark.executor.TaskMetrics
21+
import org.apache.spark.storage.BlockManagerId
22+
import org.scalatest.FunSuite
23+
import org.mockito.Mockito._
24+
import org.mockito.Matchers
25+
import org.mockito.Matchers._
26+
27+
import org.apache.spark.scheduler.TaskScheduler
28+
import org.apache.spark.util.RpcUtils
29+
30+
class HeartbeatReceiverSuite extends FunSuite with LocalSparkContext {
31+
32+
test("HeartbeatReceiver") {
33+
sc = new SparkContext("local[2]", "test")
34+
val scheduler = mock(classOf[TaskScheduler])
35+
when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true)
36+
37+
sc.env.rpcEnv.setupEndpoint("heartbeat", new HeartbeatReceiver(sc, scheduler))
38+
val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv)
39+
40+
val metrics = new TaskMetrics
41+
val blockManagerId = BlockManagerId("executor-1", "localhost", 12345)
42+
val response = receiverRef.askWithReply[HeartbeatResponse](
43+
Heartbeat("executor-1", Array(1L -> metrics), blockManagerId))
44+
45+
verify(scheduler).executorHeartbeatReceived(
46+
Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId))
47+
assert(false === response.reregisterBlockManager)
48+
}
49+
50+
test("HeartbeatReceiver re-register") {
51+
sc = new SparkContext("local[2]", "test")
52+
val scheduler = mock(classOf[TaskScheduler])
53+
when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(false)
54+
55+
sc.env.rpcEnv.setupEndpoint("heartbeat", new HeartbeatReceiver(sc, scheduler))
56+
val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv)
57+
58+
val metrics = new TaskMetrics
59+
val blockManagerId = BlockManagerId("executor-1", "localhost", 12345)
60+
val response = receiverRef.askWithReply[HeartbeatResponse](
61+
Heartbeat("executor-1", Array(1L -> metrics), blockManagerId))
62+
63+
verify(scheduler).executorHeartbeatReceived(
64+
Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId))
65+
assert(true === response.reregisterBlockManager)
66+
}
67+
}

0 commit comments

Comments
 (0)