Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[SPARK-29562][sql] Speed up and slim down metric aggregation in SQL l…
…istener.

First, a bit of background on the code being changed. The current code tracks
metric updates for each task, recording which metrics the task is monitoring
and the last update value.

Once a SQL execution finishes, then the metrics for all the stages are
aggregated, by building a list with all (metric ID, value) pairs collected
for all tasks in the stages related to the execution, then grouping by metric
ID, and then calculating the values shown in the UI.

That is full of inefficiencies:

- in normal operation, all tasks will be tracking and updating the same
  metrics. So recording the metric IDs per task is wasteful.
- tracking by task means we might be double-counting values if you have
  speculative tasks (as a comment in the code mentions).
- creating a list of (metric ID, value) is extremely inefficient, because now
  you have a huge map in memory storing boxed versions of the metric IDs and
  values.
- same thing for the aggregation part, where now a Seq is built with the values
  for each metric ID.

The end result is that for large queries, this code can become both really
slow, thus affecting the processing of events, and memory hungry.

The updated code changes the approach to the following:

- stages track metrics by their ID; this means the stage tracking code
  naturally groups values, making aggregation later simpler.
- each metric ID being tracked uses a long array matching the number of
  partitions of the stage; this means that it's cheap to update the value of
  the metric once a task ends.
- when aggregating, custom code just concatenates the arrays corresponding to
  the matching metric IDs; this is cheaper than the previous, boxing-heavy
  approach.

The end result is that the listener uses about half as much memory as before
for tracking metrics, since it doesn't need to track metric IDs per task.

I captured heap dumps with the old and the new code during metric aggregation
in the listener, for an execution with 3 stages, 100k tasks per stage, 50
metrics updated per task. The dumps contained just reachable memory - so data
kept by the listener plus the variables in the aggregateMetrics() method.

With the old code, the thread doing aggregation references >1G of memory - and
that does not include temporary data created by the "groupBy" transformation
(for which the intermediate state is not referenced in the aggregation method).
The same thread with the new code references ~250M of memory. The old code uses
about ~250M to track all the metric values for that execution, while the new
code uses about ~130M. (Note the per-thread numbers include the amount used to
track the metrics - so, e.g., in the old case, aggregation was referencing
about ~750M of temporary data.)

I'm also including a small benchmark (based on the Benchmark class) so that we
can measure how much changes to this code affect performance. The benchmark
contains some extra code to measure things the normal Benchmark class does not,
given that the code under test does not really map that well to the
expectations of that class.

Running with the old code (I removed results that don't make much
sense for this benchmark):

```
[info] Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Linux 4.15.0-66-generic
[info] Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz
[info] metrics aggregation (50 metrics, 100k tasks per stage):  Best Time(ms)   Avg Time(ms)
[info] --------------------------------------------------------------------------------------
[info] 1 stage(s)                                                  2113           2118
[info] 2 stage(s)                                                  4172           4392
[info] 3 stage(s)                                                  7755           8460
[info]
[info] Stage Count    Stage Proc. Time    Aggreg. Time
[info]      1              614                1187
[info]      2              620                2480
[info]      3              718                5069
```

With the new code:

```
[info] Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Linux 4.15.0-66-generic
[info] Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz
[info] metrics aggregation (50 metrics, 100k tasks per stage):  Best Time(ms)   Avg Time(ms)
[info] --------------------------------------------------------------------------------------
[info] 1 stage(s)                                                   727            886
[info] 2 stage(s)                                                  1722           1983
[info] 3 stage(s)                                                  2752           3013
[info]
[info] Stage Count    Stage Proc. Time    Aggreg. Time
[info]      1              408                177
[info]      2              389                423
[info]      3              372                660

```

So the new code is faster than the old when processing task events, and about
an order of maginute faster when aggregating metrics.

Note this still leaves room for improvement; for example, using the above
measurements, 600ms is still a huge amount of time to spend in an event
handler. But I'll leave further enhancements for a separate change.

Tested with benchmarking code + existing unit tests.
  • Loading branch information
Marcelo Vanzin committed Oct 22, 2019
commit 067d63ecca9318cafe2003d2980c9223cf0c6dab
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.metric

import java.text.NumberFormat
import java.util.Locale
import java.util.{Arrays, Locale}

import scala.concurrent.duration._

Expand Down Expand Up @@ -150,7 +150,7 @@ object SQLMetrics {
* A function that defines how we aggregate the final accumulator results among all tasks,
* and represent it in string for a SQL physical operator.
*/
def stringValue(metricsType: String, values: Seq[Long]): String = {
def stringValue(metricsType: String, values: Array[Long]): String = {
if (metricsType == SUM_METRIC) {
val numberFormat = NumberFormat.getIntegerInstance(Locale.US)
numberFormat.format(values.sum)
Expand All @@ -162,8 +162,9 @@ object SQLMetrics {
val metric = if (validValues.isEmpty) {
Seq.fill(3)(0L)
} else {
val sorted = validValues.sorted
Seq(sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1))
Arrays.sort(validValues)
Seq(validValues(0), validValues(validValues.length / 2),
validValues(validValues.length - 1))
}
metric.map(v => numberFormat.format(v.toDouble / baseForAvgMetric))
}
Expand All @@ -184,8 +185,9 @@ object SQLMetrics {
val metric = if (validValues.isEmpty) {
Seq.fill(4)(0L)
} else {
val sorted = validValues.sorted
Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1))
Arrays.sort(validValues)
Seq(validValues.sum, validValues(0), validValues(validValues.length / 2),
validValues(validValues.length - 1))
}
metric.map(strFormat)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
*/
package org.apache.spark.sql.execution.ui

import java.util.{Date, NoSuchElementException}
import java.util.{Arrays, Date, NoSuchElementException}
import java.util.concurrent.ConcurrentHashMap

import scala.collection.JavaConverters._
import scala.collection.mutable

import org.apache.spark.{JobExecutionStatus, SparkConf}
import org.apache.spark.internal.Logging
Expand All @@ -29,6 +30,7 @@ import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.metric._
import org.apache.spark.sql.internal.StaticSQLConf._
import org.apache.spark.status.{ElementTrackingStore, KVUtils, LiveEntity}
import org.apache.spark.util.collection.OpenHashMap

class SQLAppStatusListener(
conf: SparkConf,
Expand Down Expand Up @@ -103,8 +105,10 @@ class SQLAppStatusListener(
// Record the accumulator IDs for the stages of this job, so that the code that keeps
// track of the metrics knows which accumulators to look at.
val accumIds = exec.metrics.map(_.accumulatorId).toSet
event.stageIds.foreach { id =>
stageMetrics.put(id, new LiveStageMetrics(id, 0, accumIds, new ConcurrentHashMap()))
if (accumIds.nonEmpty) {
event.stageInfos.foreach { stage =>
stageMetrics.put(stage.stageId, new LiveStageMetrics(0, stage.numTasks, accumIds))
}
}

exec.jobs = exec.jobs + (jobId -> JobExecutionStatus.RUNNING)
Expand All @@ -118,9 +122,11 @@ class SQLAppStatusListener(
}

// Reset the metrics tracking object for the new attempt.
Option(stageMetrics.get(event.stageInfo.stageId)).foreach { metrics =>
metrics.taskMetrics.clear()
metrics.attemptId = event.stageInfo.attemptNumber
Option(stageMetrics.get(event.stageInfo.stageId)).foreach { stage =>
if (stage.attemptId != event.stageInfo.attemptNumber) {
stageMetrics.put(event.stageInfo.stageId,
new LiveStageMetrics(event.stageInfo.attemptNumber, stage.numTasks, stage.accumulatorIds))
}
}
}

Expand All @@ -140,7 +146,15 @@ class SQLAppStatusListener(

override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = {
event.accumUpdates.foreach { case (taskId, stageId, attemptId, accumUpdates) =>
updateStageMetrics(stageId, attemptId, taskId, accumUpdates, false)
updateStageMetrics(stageId, attemptId, taskId, -1, accumUpdates, false)
}
}

override def onTaskStart(event: SparkListenerTaskStart): Unit = {
Option(stageMetrics.get(event.stageId)).foreach { stage =>
if (stage.attemptId == event.stageAttemptId) {
stage.registerTask(event.taskInfo.taskId, event.taskInfo.index)
}
}
}

Expand All @@ -165,7 +179,7 @@ class SQLAppStatusListener(
} else {
info.accumulables
}
updateStageMetrics(event.stageId, event.stageAttemptId, info.taskId, accums,
updateStageMetrics(event.stageId, event.stageAttemptId, info.taskId, info.index, accums,
info.successful)
}

Expand All @@ -181,17 +195,40 @@ class SQLAppStatusListener(

private def aggregateMetrics(exec: LiveExecutionData): Map[Long, String] = {
val metricTypes = exec.metrics.map { m => (m.accumulatorId, m.metricType) }.toMap
val metrics = exec.stages.toSeq

val taskMetrics = exec.stages.toSeq
.flatMap { stageId => Option(stageMetrics.get(stageId)) }
.flatMap(_.taskMetrics.values().asScala)
.flatMap { metrics => metrics.ids.zip(metrics.values) }

val aggregatedMetrics = (metrics ++ exec.driverAccumUpdates.toSeq)
.filter { case (id, _) => metricTypes.contains(id) }
.groupBy(_._1)
.map { case (id, values) =>
id -> SQLMetrics.stringValue(metricTypes(id), values.map(_._2))
.flatMap(_.metricValues())

val allMetrics = new mutable.HashMap[Long, Array[Long]]()

taskMetrics.foreach { case (id, values) =>
val prev = allMetrics.getOrElse(id, null)
val updated = if (prev != null) {
prev ++ values
} else {
values
}
allMetrics(id) = updated
}

exec.driverAccumUpdates.foreach { case (id, value) =>
if (metricTypes.contains(id)) {
val prev = allMetrics.getOrElse(id, null)
val updated = if (prev != null) {
val _copy = Arrays.copyOf(prev, prev.length + 1)
_copy(prev.length) = value
_copy
} else {
Array(value)
}
allMetrics(id) = updated
}
}

val aggregatedMetrics = allMetrics.map { case (id, values) =>
id -> SQLMetrics.stringValue(metricTypes(id), values)
}.toMap

// Check the execution again for whether the aggregated metrics data has been calculated.
// This can happen if the UI is requesting this data, and the onExecutionEnd handler is
Expand All @@ -208,43 +245,13 @@ class SQLAppStatusListener(
stageId: Int,
attemptId: Int,
taskId: Long,
partIdx: Int, // -1 if unknown from the event data.
accumUpdates: Seq[AccumulableInfo],
succeeded: Boolean): Unit = {
Option(stageMetrics.get(stageId)).foreach { metrics =>
if (metrics.attemptId != attemptId || metrics.accumulatorIds.isEmpty) {
return
}

val oldTaskMetrics = metrics.taskMetrics.get(taskId)
if (oldTaskMetrics != null && oldTaskMetrics.succeeded) {
return
if (metrics.attemptId == attemptId) {
metrics.updateTaskMetrics(taskId, partIdx, succeeded, accumUpdates)
}

val updates = accumUpdates
.filter { acc => acc.update.isDefined && metrics.accumulatorIds.contains(acc.id) }
.sortBy(_.id)

if (updates.isEmpty) {
return
}

val ids = new Array[Long](updates.size)
val values = new Array[Long](updates.size)
updates.zipWithIndex.foreach { case (acc, idx) =>
ids(idx) = acc.id
// In a live application, accumulators have Long values, but when reading from event
// logs, they have String values. For now, assume all accumulators are Long and covert
// accordingly.
values(idx) = acc.update.get match {
case s: String => s.toLong
case l: Long => l
case o => throw new IllegalArgumentException(s"Unexpected: $o")
}
}

// TODO: storing metrics by task ID can cause metrics for the same task index to be
// counted multiple times, for example due to speculation or re-attempts.
metrics.taskMetrics.put(taskId, new LiveTaskMetrics(ids, values, succeeded))
}
}

Expand Down Expand Up @@ -425,12 +432,72 @@ private class LiveExecutionData(val executionId: Long) extends LiveEntity {
}

private class LiveStageMetrics(
val stageId: Int,
var attemptId: Int,
val accumulatorIds: Set[Long],
val taskMetrics: ConcurrentHashMap[Long, LiveTaskMetrics])

private class LiveTaskMetrics(
val ids: Array[Long],
val values: Array[Long],
val succeeded: Boolean)
val attemptId: Int,
val numTasks: Int,
val accumulatorIds: Set[Long]) {

/**
* Mapping of task IDs to the partition index they're computing. Note this may contain more
* elements than the stage's number of tasks, if speculative execution is on.
*/
private val taskIndices = new OpenHashMap[Long, Int]()

/** Bit set tracking which partition indices have been successfully computed. */
private val completedParts = new mutable.BitSet()

/**
* Task metrics values for the stage. Maps the metric ID to the metric values for each
* partition. For each metric ID, there will be the same number of values as the number
* of partitions. This relies on `SQLMetrics.stringValue` treating 0 as a neutral value,
* independent of the actual metric type.
*/
private val taskMetrics = new ConcurrentHashMap[Long, Array[Long]]()

def registerTask(taskId: Long, partIdx: Int): Unit = {
taskIndices.update(taskId, partIdx)
}

def updateTaskMetrics(
taskId: Long,
eventPartIdx: Int,
finished: Boolean,
accumUpdates: Seq[AccumulableInfo]): Unit = {
val partIdx = if (eventPartIdx == -1) {
if (!taskIndices.contains(taskId)) {
// We probably missed the start event for the task, just ignore it.
return
}
taskIndices(taskId)
} else {
// Here we can recover from a missing task start event. Just register the task again.
registerTask(taskId, eventPartIdx)
eventPartIdx
}

if (completedParts.contains(partIdx)) {
return
}

accumUpdates
.filter { acc => acc.update.isDefined && accumulatorIds.contains(acc.id) }
.foreach { acc =>
// In a live application, accumulators have Long values, but when reading from event
// logs, they have String values. For now, assume all accumulators are Long and covert
// accordingly.
val value = acc.update.get match {
case s: String => s.toLong
case l: Long => l
case o => throw new IllegalArgumentException(s"Unexpected: $o")
}

val metricValues = taskMetrics.computeIfAbsent(acc.id, _ => new Array(numTasks))
metricValues(partIdx) = value
}

if (finished) {
completedParts += partIdx
}
}

def metricValues(): Seq[(Long, Array[Long])] = taskMetrics.asScala.toSeq
}
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ trait SQLMetricsTestUtils extends SQLTestUtils {
val (actualNodeName, actualMetricsMap) = actualMetrics(nodeId)
assert(expectedNodeName === actualNodeName)
for ((metricName, metricPredicate) <- expectedMetricsPredicatesMap) {
assert(metricPredicate(actualMetricsMap(metricName)))
assert(metricPredicate(actualMetricsMap(metricName)),
s"$nodeId / '$metricName' (= ${actualMetricsMap(metricName)}) did not match predicate.")
}
}
}
Expand Down
Loading