Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,98 +19,20 @@ package org.apache.spark.sql.execution.metric

import java.io.File

import scala.collection.mutable.HashMap
import scala.util.Random

import org.apache.spark.SparkFunSuite
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.ui.SparkPlanGraph
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.{AccumulatorContext, JsonProtocol}

class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with SharedSQLContext {
import testImplicits._


/**
* Call `df.collect()` and collect necessary metrics from execution data.
*
* @param df `DataFrame` to run
* @param expectedNumOfJobs number of jobs that will run
* @param expectedNodeIds the node ids of the metrics to collect from execution data.
*/
private def getSparkPlanMetrics(
df: DataFrame,
expectedNumOfJobs: Int,
expectedNodeIds: Set[Long],
enableWholeStage: Boolean = false): Option[Map[Long, (String, Map[String, Any])]] = {
val previousExecutionIds = spark.sharedState.listener.executionIdToData.keySet
withSQLConf("spark.sql.codegen.wholeStage" -> enableWholeStage.toString) {
df.collect()
}
sparkContext.listenerBus.waitUntilEmpty(10000)
val executionIds =
spark.sharedState.listener.executionIdToData.keySet.diff(previousExecutionIds)
assert(executionIds.size === 1)
val executionId = executionIds.head
val jobs = spark.sharedState.listener.getExecution(executionId).get.jobs
// Use "<=" because there is a race condition that we may miss some jobs
// TODO Change it to "=" once we fix the race condition that missing the JobStarted event.
assert(jobs.size <= expectedNumOfJobs)
if (jobs.size == expectedNumOfJobs) {
// If we can track all jobs, check the metric values
val metricValues = spark.sharedState.listener.getExecutionMetrics(executionId)
val metrics = SparkPlanGraph(SparkPlanInfo.fromSparkPlan(
df.queryExecution.executedPlan)).allNodes.filter { node =>
expectedNodeIds.contains(node.id)
}.map { node =>
val nodeMetrics = node.metrics.map { metric =>
val metricValue = metricValues(metric.accumulatorId)
(metric.name, metricValue)
}.toMap
(node.id, node.name -> nodeMetrics)
}.toMap
Some(metrics)
} else {
// TODO Remove this "else" once we fix the race condition that missing the JobStarted event.
// Since we cannot track all jobs, the metric values could be wrong and we should not check
// them.
logWarning("Due to a race condition, we miss some jobs and cannot verify the metric values")
None
}
}

/**
* Call `df.collect()` and verify if the collected metrics are same as "expectedMetrics".
*
* @param df `DataFrame` to run
* @param expectedNumOfJobs number of jobs that will run
* @param expectedMetrics the expected metrics. The format is
* `nodeId -> (operatorName, metric name -> metric value)`.
*/
private def testSparkPlanMetrics(
df: DataFrame,
expectedNumOfJobs: Int,
expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = {
val optActualMetrics = getSparkPlanMetrics(df, expectedNumOfJobs, expectedMetrics.keySet)
optActualMetrics.map { actualMetrics =>
assert(expectedMetrics.keySet === actualMetrics.keySet)
for (nodeId <- expectedMetrics.keySet) {
val (expectedNodeName, expectedMetricsMap) = expectedMetrics(nodeId)
val (actualNodeName, actualMetricsMap) = actualMetrics(nodeId)
assert(expectedNodeName === actualNodeName)
for (metricName <- expectedMetricsMap.keySet) {
assert(expectedMetricsMap(metricName).toString === actualMetricsMap(metricName))
}
}
}
}

/**
* Generates a `DataFrame` by filling randomly generated bytes for hash collision.
*/
Expand Down Expand Up @@ -570,75 +492,12 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
assert(res3 === (10L, 0L, 10L) :: (30L, 0L, 30L) :: (0L, 30L, 300L) :: (0L, 300L, 0L) :: Nil)
}
}
}

object InputOutputMetricsHelper {
private class InputOutputMetricsListener extends SparkListener {
private case class MetricsResult(
var recordsRead: Long = 0L,
var shuffleRecordsRead: Long = 0L,
var sumMaxOutputRows: Long = 0L)

private[this] val stageIdToMetricsResult = HashMap.empty[Int, MetricsResult]

def reset(): Unit = {
stageIdToMetricsResult.clear()
}

/**
* Return a list of recorded metrics aggregated per stage.
*
* The list is sorted in the ascending order on the stageId.
* For each recorded stage, the following tuple is returned:
* - sum of inputMetrics.recordsRead for all the tasks in the stage
* - sum of shuffleReadMetrics.recordsRead for all the tasks in the stage
* - sum of the highest values of "number of output rows" metric for all the tasks in the stage
*/
def getResults(): List[(Long, Long, Long)] = {
stageIdToMetricsResult.keySet.toList.sorted.map { stageId =>
val res = stageIdToMetricsResult(stageId)
(res.recordsRead, res.shuffleRecordsRead, res.sumMaxOutputRows)
}
}

override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
val res = stageIdToMetricsResult.getOrElseUpdate(taskEnd.stageId, MetricsResult())

res.recordsRead += taskEnd.taskMetrics.inputMetrics.recordsRead
res.shuffleRecordsRead += taskEnd.taskMetrics.shuffleReadMetrics.recordsRead

var maxOutputRows = 0L
for (accum <- taskEnd.taskMetrics.externalAccums) {
val info = accum.toInfo(Some(accum.value), None)
if (info.name.toString.contains("number of output rows")) {
info.update match {
case Some(n: Number) =>
if (n.longValue() > maxOutputRows) {
maxOutputRows = n.longValue()
}
case _ => // Ignore.
}
}
}
res.sumMaxOutputRows += maxOutputRows
}
test("writing data out metrics: parquet") {
testMetricsNonDynamicPartition("parquet", "t1")
}

// Run df.collect() and return aggregated metrics for each stage.
def run(df: DataFrame): List[(Long, Long, Long)] = {
val spark = df.sparkSession
val sparkContext = spark.sparkContext
val listener = new InputOutputMetricsListener()
sparkContext.addSparkListener(listener)

try {
sparkContext.listenerBus.waitUntilEmpty(5000)
listener.reset()
df.collect()
sparkContext.listenerBus.waitUntilEmpty(5000)
} finally {
sparkContext.removeSparkListener(listener)
}
listener.getResults()
test("writing data out metrics with dynamic partition: parquet") {
testMetricsDynamicPartition("parquet", "parquet", "t1")
}
}
Loading