Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.command

import org.apache.hadoop.conf.Configuration

import org.apache.spark.SparkContext
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
Expand All @@ -45,15 +44,7 @@ trait DataWritingCommand extends Command {
// Output columns of the analyzed input query plan
def outputColumns: Seq[Attribute]

lazy val metrics: Map[String, SQLMetric] = {
val sparkContext = SparkContext.getActive.get
Map(
"numFiles" -> SQLMetrics.createMetric(sparkContext, "number of written files"),
"numOutputBytes" -> SQLMetrics.createMetric(sparkContext, "bytes of written output"),
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"numParts" -> SQLMetrics.createMetric(sparkContext, "number of dynamic part")
)
}
lazy val metrics: Map[String, SQLMetric] = BasicWriteJobStatsTracker.metrics

def basicWriteJobStatsTracker(hadoopConf: Configuration): BasicWriteJobStatsTracker = {
val serializableHadoopConf = new SerializableConfiguration(hadoopConf)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,29 @@ class BasicWriteJobStatsTracker(
totalNumOutput += summary.numRows
}

metrics("numFiles").add(numFiles)
metrics("numOutputBytes").add(totalNumBytes)
metrics("numOutputRows").add(totalNumOutput)
metrics("numParts").add(numPartitions)
metrics(BasicWriteJobStatsTracker.NUM_FILES_KEY).add(numFiles)
metrics(BasicWriteJobStatsTracker.NUM_OUTPUT_BYTES_KEY).add(totalNumBytes)
metrics(BasicWriteJobStatsTracker.NUM_OUTPUT_ROWS_KEY).add(totalNumOutput)
metrics(BasicWriteJobStatsTracker.NUM_PARTS_KEY).add(numPartitions)

val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toList)
}
}

object BasicWriteJobStatsTracker {
private val NUM_FILES_KEY = "numFiles"
private val NUM_OUTPUT_BYTES_KEY = "numOutputBytes"
private val NUM_OUTPUT_ROWS_KEY = "numOutputRows"
private val NUM_PARTS_KEY = "numParts"

def metrics: Map[String, SQLMetric] = {
val sparkContext = SparkContext.getActive.get
Map(
NUM_FILES_KEY -> SQLMetrics.createMetric(sparkContext, "number of written files"),
NUM_OUTPUT_BYTES_KEY -> SQLMetrics.createMetric(sparkContext, "bytes of written output"),
NUM_OUTPUT_ROWS_KEY -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
NUM_PARTS_KEY -> SQLMetrics.createMetric(sparkContext, "number of dynamic part")
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.datasources.{FileFormat, FileFormatWriter}
import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, FileFormat, FileFormatWriter}
import org.apache.spark.util.SerializableConfiguration

object FileStreamSink extends Logging {
// The name of the subdirectory that is used to store metadata about which files are valid.
Expand Down Expand Up @@ -97,6 +98,11 @@ class FileStreamSink(
new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, logPath.toUri.toString)
private val hadoopConf = sparkSession.sessionState.newHadoopConf()

private def basicWriteJobStatsTracker: BasicWriteJobStatsTracker = {
val serializableHadoopConf = new SerializableConfiguration(hadoopConf)
new BasicWriteJobStatsTracker(serializableHadoopConf, BasicWriteJobStatsTracker.metrics)
}

override def addBatch(batchId: Long, data: DataFrame): Unit = {
if (batchId <= fileLog.getLatest().map(_._1).getOrElse(-1L)) {
logInfo(s"Skipping already committed batch $batchId")
Expand Down Expand Up @@ -131,7 +137,7 @@ class FileStreamSink(
hadoopConf = hadoopConf,
partitionColumns = partitionColumns,
bucketSpec = None,
statsTrackers = Nil,
statsTrackers = Seq(basicWriteJobStatsTracker),
options = options)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.util.Locale

import org.apache.hadoop.fs.Path

import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.sql.{AnalysisException, DataFrame}
import org.apache.spark.sql.execution.DataSourceScanExec
import org.apache.spark.sql.execution.datasources._
Expand Down Expand Up @@ -405,4 +406,53 @@ class FileStreamSinkSuite extends StreamTest {
}
}
}

test("SPARK-23288 writing and checking output metrics") {
Seq("parquet", "orc", "text", "json").foreach { format =>
val inputData = MemoryStream[String]
val df = inputData.toDF()

val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should use withTempDir to clean up the temp directory at the end

val checkpointDir = Utils.createTempDir(namePrefix = "stream.checkpoint").getCanonicalPath

var query: StreamingQuery = null

var numTasks = 0
var recordsWritten: Long = 0L
var bytesWritten: Long = 0L
try {
spark.sparkContext.addSparkListener(new SparkListener() {
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
val outputMetrics = taskEnd.taskMetrics.outputMetrics
recordsWritten += outputMetrics.recordsWritten
bytesWritten += outputMetrics.bytesWritten
numTasks += 1
}
})

query =
df.writeStream
.option("checkpointLocation", checkpointDir)
.format(format)
.start(outputDir)

inputData.addData("1", "2", "3")
inputData.addData("4", "5")

failAfter(streamingTimeout) {
query.processAllAvailable()
}
spark.sparkContext.listenerBus.waitUntilEmpty(streamingTimeout.toMillis)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it's better to add the below statement here to avoid flakiness.

spark.sparkContext.listenerBus.waitUntilEmpty(streamingTimeout.toMillis)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

assert(numTasks > 0)
assert(recordsWritten === 5)
// This is heavily file type/version specific but should be filled
assert(bytesWritten > 0)
} finally {
if (query != null) {
query.stop()
}
}
}
}
}