Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,8 @@ private[spark] class Executor(
if (taskRunner.task != null) {
taskRunner.task.metrics.mergeShuffleReadMetrics()
taskRunner.task.metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime)
accumUpdates += ((taskRunner.taskId, taskRunner.task.metrics.accumulators()))
accumUpdates +=
((taskRunner.taskId, taskRunner.task.metrics.accumulators().filterNot(_.isZero)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a flag for this behavior change?

}
}

Expand Down
100 changes: 94 additions & 6 deletions core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.lang.Thread.UncaughtExceptionHandler
import java.nio.ByteBuffer
import java.util.Properties
import java.util.concurrent.{CountDownLatch, TimeUnit}
import java.util.concurrent.{ConcurrentHashMap, CountDownLatch, TimeUnit}
import java.util.concurrent.atomic.AtomicBoolean

import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.Map
import scala.concurrent.duration._
import scala.language.postfixOps
Expand All @@ -39,14 +40,14 @@ import org.scalatest.mockito.MockitoSugar
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.memory.MemoryManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.metrics.{JVMHeapMemory, JVMOffHeapMemory, MetricsSystem}
import org.apache.spark.rdd.RDD
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.scheduler.{FakeTask, ResultTask, TaskDescription}
import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcTimeout}
import org.apache.spark.scheduler.{FakeTask, ResultTask, Task, TaskDescription}
import org.apache.spark.serializer.{JavaSerializer, SerializerManager}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.UninterruptibleThread
import org.apache.spark.storage.{BlockManager, BlockManagerId}
import org.apache.spark.util.{LongAccumulator, UninterruptibleThread, Utils}

class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually {

Expand Down Expand Up @@ -252,18 +253,105 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
}
}

test("Heartbeat should drop zero metrics") {
withHeartbeatExecutor((executor, heartbeats) => {
// When no tasks are running, there should be no accumulators sent in heartbeat
invokeReportHeartbeat(executor)
assert(heartbeats.length == 1)
assert(heartbeats(0).accumUpdates.length == 0,
"No updates should be sent when no tasks are running")

// When we start a task with a nonzero accumulator, that should end up in the heartbeat
val metrics = new TaskMetrics()
val nonZeroAccumulator = new LongAccumulator()
nonZeroAccumulator.add(1)
metrics.registerAccumulator(nonZeroAccumulator)

val executorClass = classOf[Executor]
val tasksMap = {
val field =
executorClass.getDeclaredField("org$apache$spark$executor$Executor$$runningTasks")
field.setAccessible(true)
field.get(executor).asInstanceOf[ConcurrentHashMap[Long, executor.TaskRunner]]
}
val mockTaskRunner = mock[executor.TaskRunner]
val mockTask = mock[Task[Any]]
when(mockTask.metrics).thenReturn(metrics)
when(mockTaskRunner.taskId).thenReturn(6)
when(mockTaskRunner.task).thenReturn(mockTask)
when(mockTaskRunner.startGCTime).thenReturn(1)
tasksMap.put(6, mockTaskRunner)

invokeReportHeartbeat(executor)
assert(heartbeats.length == 2)
val updates = heartbeats(1).accumUpdates
assert(updates.length == 1 && updates(0)._1 == 6,
"Heartbeat should only send update for the one task running")
val accumsSent = updates(0)._2.length
assert(accumsSent > 0, "The nonzero accumulator we added should be sent")
assert(accumsSent == metrics.accumulators().count(!_.isZero),
"The number of accumulators sent should match the number of nonzero accumulators")
})
}

private def withHeartbeatExecutor(f: (Executor, ArrayBuffer[Heartbeat]) => Unit): Unit = {
val conf = new SparkConf
val serializer = new JavaSerializer(conf)
val env = createMockEnv(conf, serializer)
val executor =
new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true)
val executorClass = classOf[Executor]

// Set ExecutorMetricType.values to be a minimal set to avoid get null exceptions
val metricClass =
Utils.classForName(classOf[org.apache.spark.metrics.ExecutorMetricType].getName() + "$")
val metricTypeValues = metricClass.getDeclaredField("values")
metricTypeValues.setAccessible(true)
metricTypeValues.set(
org.apache.spark.metrics.ExecutorMetricType,
IndexedSeq(JVMHeapMemory, JVMOffHeapMemory))

// Save all heartbeats sent into an ArrayBuffer for verification
val heartbeats = ArrayBuffer[Heartbeat]()
val mockReceiver = mock[RpcEndpointRef]
when(mockReceiver.askSync(any[Heartbeat], any[RpcTimeout])(any))
.thenAnswer(new Answer[HeartbeatResponse] {
override def answer(invocation: InvocationOnMock): HeartbeatResponse = {
val args = invocation.getArguments()
val mock = invocation.getMock
heartbeats += args(0).asInstanceOf[Heartbeat]
HeartbeatResponse(false)
}
})
val receiverRef = executorClass.getDeclaredField("heartbeatReceiverRef")
receiverRef.setAccessible(true)
receiverRef.set(executor, mockReceiver)

f(executor, heartbeats)
}

private def invokeReportHeartbeat(executor: Executor): Unit = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can mixin org.scalatest.PrivateMethodTester to replace this method, such as

    val reportHeartBeat = PrivateMethod[Long]('reportHeartBeat)
    ...
     executor.invokePrivate(reportHeartBeat())

val method = classOf[Executor]
.getDeclaredMethod("org$apache$spark$executor$Executor$$reportHeartBeat")
method.setAccessible(true)
method.invoke(executor)
}

private def createMockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = {
val mockEnv = mock[SparkEnv]
val mockRpcEnv = mock[RpcEnv]
val mockMetricsSystem = mock[MetricsSystem]
val mockMemoryManager = mock[MemoryManager]
val mockBlockManager = mock[BlockManager]
when(mockEnv.conf).thenReturn(conf)
when(mockEnv.serializer).thenReturn(serializer)
when(mockEnv.serializerManager).thenReturn(mock[SerializerManager])
when(mockEnv.rpcEnv).thenReturn(mockRpcEnv)
when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem)
when(mockEnv.memoryManager).thenReturn(mockMemoryManager)
when(mockEnv.closureSerializer).thenReturn(serializer)
when(mockBlockManager.blockManagerId).thenReturn(BlockManagerId("1", "hostA", 1234))
when(mockEnv.blockManager).thenReturn(mockBlockManager)
SparkEnv.set(mockEnv)
mockEnv
}
Expand Down