Skip to content

Commit 22e6ca1

Browse files
committed
[SPARK-23288][SS] Fix output metrics with parquet sink
1 parent 3ee3b2a commit 22e6ca1

File tree

4 files changed

+79
-16
lines changed

4 files changed

+79
-16
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.command
1919

2020
import org.apache.hadoop.conf.Configuration
2121

22-
import org.apache.spark.SparkContext
2322
import org.apache.spark.sql.{Row, SparkSession}
2423
import org.apache.spark.sql.catalyst.expressions.Attribute
2524
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
@@ -45,15 +44,7 @@ trait DataWritingCommand extends Command {
4544
// Output columns of the analyzed input query plan
4645
def outputColumns: Seq[Attribute]
4746

48-
lazy val metrics: Map[String, SQLMetric] = {
49-
val sparkContext = SparkContext.getActive.get
50-
Map(
51-
"numFiles" -> SQLMetrics.createMetric(sparkContext, "number of written files"),
52-
"numOutputBytes" -> SQLMetrics.createMetric(sparkContext, "bytes of written output"),
53-
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
54-
"numParts" -> SQLMetrics.createMetric(sparkContext, "number of dynamic part")
55-
)
56-
}
47+
lazy val metrics: Map[String, SQLMetric] = BasicWriteJobStatsTracker.metrics
5748

5849
def basicWriteJobStatsTracker(hadoopConf: Configuration): BasicWriteJobStatsTracker = {
5950
val serializableHadoopConf = new SerializableConfiguration(hadoopConf)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,29 @@ class BasicWriteJobStatsTracker(
153153
totalNumOutput += summary.numRows
154154
}
155155

156-
metrics("numFiles").add(numFiles)
157-
metrics("numOutputBytes").add(totalNumBytes)
158-
metrics("numOutputRows").add(totalNumOutput)
159-
metrics("numParts").add(numPartitions)
156+
metrics(BasicWriteJobStatsTracker.NUM_FILES_KEY).add(numFiles)
157+
metrics(BasicWriteJobStatsTracker.NUM_OUTPUT_BYTES_KEY).add(totalNumBytes)
158+
metrics(BasicWriteJobStatsTracker.NUM_OUTPUT_ROWS_KEY).add(totalNumOutput)
159+
metrics(BasicWriteJobStatsTracker.NUM_PARTS_KEY).add(numPartitions)
160160

161161
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
162162
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toList)
163163
}
164164
}
165+
166+
object BasicWriteJobStatsTracker {
167+
private val NUM_FILES_KEY = "numFiles"
168+
private val NUM_OUTPUT_BYTES_KEY = "numOutputBytes"
169+
private val NUM_OUTPUT_ROWS_KEY = "numOutputRows"
170+
private val NUM_PARTS_KEY = "numParts"
171+
172+
def metrics: Map[String, SQLMetric] = {
173+
val sparkContext = SparkContext.getActive.get
174+
Map(
175+
NUM_FILES_KEY -> SQLMetrics.createMetric(sparkContext, "number of written files"),
176+
NUM_OUTPUT_BYTES_KEY -> SQLMetrics.createMetric(sparkContext, "bytes of written output"),
177+
NUM_OUTPUT_ROWS_KEY -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
178+
NUM_PARTS_KEY -> SQLMetrics.createMetric(sparkContext, "number of dynamic part")
179+
)
180+
}
181+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ import org.apache.spark.internal.Logging
2626
import org.apache.spark.internal.io.FileCommitProtocol
2727
import org.apache.spark.sql.{DataFrame, SparkSession}
2828
import org.apache.spark.sql.catalyst.expressions._
29-
import org.apache.spark.sql.execution.datasources.{FileFormat, FileFormatWriter}
29+
import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, FileFormat, FileFormatWriter}
30+
import org.apache.spark.util.SerializableConfiguration
3031

3132
object FileStreamSink extends Logging {
3233
// The name of the subdirectory that is used to store metadata about which files are valid.
@@ -95,6 +96,11 @@ class FileStreamSink(
9596
new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, logPath.toUri.toString)
9697
private val hadoopConf = sparkSession.sessionState.newHadoopConf()
9798

99+
private def basicWriteJobStatsTracker: BasicWriteJobStatsTracker = {
100+
val serializableHadoopConf = new SerializableConfiguration(hadoopConf)
101+
new BasicWriteJobStatsTracker(serializableHadoopConf, BasicWriteJobStatsTracker.metrics)
102+
}
103+
98104
override def addBatch(batchId: Long, data: DataFrame): Unit = {
99105
if (batchId <= fileLog.getLatest().map(_._1).getOrElse(-1L)) {
100106
logInfo(s"Skipping already committed batch $batchId")
@@ -129,7 +135,7 @@ class FileStreamSink(
129135
hadoopConf = hadoopConf,
130136
partitionColumns = partitionColumns,
131137
bucketSpec = None,
132-
statsTrackers = Nil,
138+
statsTrackers = Seq(basicWriteJobStatsTracker),
133139
options = options)
134140
}
135141
}

sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.util.Locale
2121

2222
import org.apache.hadoop.fs.Path
2323

24+
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
2425
import org.apache.spark.sql.{AnalysisException, DataFrame}
2526
import org.apache.spark.sql.execution.DataSourceScanExec
2627
import org.apache.spark.sql.execution.datasources._
@@ -405,4 +406,52 @@ class FileStreamSinkSuite extends StreamTest {
405406
}
406407
}
407408
}
409+
410+
test("SPARK-23288 writing and checking output metrics") {
411+
Seq("parquet", "orc", "text", "json").foreach { format =>
412+
val inputData = MemoryStream[String]
413+
val df = inputData.toDF()
414+
415+
val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath
416+
val checkpointDir = Utils.createTempDir(namePrefix = "stream.checkpoint").getCanonicalPath
417+
418+
var query: StreamingQuery = null
419+
420+
var numTasks = 0
421+
var recordsWritten: Long = 0L
422+
var bytesWritten: Long = 0L
423+
try {
424+
spark.sparkContext.addSparkListener(new SparkListener() {
425+
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
426+
val outputMetrics = taskEnd.taskMetrics.outputMetrics
427+
recordsWritten += outputMetrics.recordsWritten
428+
bytesWritten += outputMetrics.bytesWritten
429+
numTasks += 1
430+
}
431+
})
432+
433+
query =
434+
df.writeStream
435+
.option("checkpointLocation", checkpointDir)
436+
.format(format)
437+
.start(outputDir)
438+
439+
inputData.addData("1", "2", "3")
440+
inputData.addData("4", "5")
441+
442+
failAfter(streamingTimeout) {
443+
query.processAllAvailable()
444+
}
445+
446+
assert(numTasks === 2)
447+
assert(recordsWritten === 5)
448+
// This is heavily file type/version specific but should be filled
449+
assert(bytesWritten > 0)
450+
} finally {
451+
if (query != null) {
452+
query.stop()
453+
}
454+
}
455+
}
456+
}
408457
}

0 commit comments

Comments
 (0)