Skip to content

Commit 8ad4ddd

Browse files
Lingkai KongHyukjinKwon
authored andcommitted
[SPARK-44776][CONNECT] Add ProducedRowCount to SparkListenerConnectOperationFinished
### What changes were proposed in this pull request? Add ProducedRowCount field to SparkListenerConnectOperationFinished ### Why are the changes needed? Needed for showing number of rows getting produced ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added Unit test Closes #42454 from gjxdxh/SPARK-44776. Authored-by: Lingkai Kong <lingkai.kong@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org> (cherry picked from commit 4646991) Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 83556c4 commit 8ad4ddd

File tree

5 files changed

+238
-101
lines changed

5 files changed

+238
-101
lines changed

connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
110110
errorOnDuplicatedFieldNames = false)
111111

112112
var numSent = 0
113+
var totalNumRows: Long = 0
113114
def sendBatch(bytes: Array[Byte], count: Long): Unit = {
114115
val response = proto.ExecutePlanResponse.newBuilder().setSessionId(sessionId)
115116
val batch = proto.ExecutePlanResponse.ArrowBatch
@@ -120,14 +121,15 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
120121
response.setArrowBatch(batch)
121122
responseObserver.onNext(response.build())
122123
numSent += 1
124+
totalNumRows += count
123125
}
124126

125127
dataframe.queryExecution.executedPlan match {
126128
case LocalTableScanExec(_, rows) =>
127-
executePlan.eventsManager.postFinished()
128129
converter(rows.iterator).foreach { case (bytes, count) =>
129130
sendBatch(bytes, count)
130131
}
132+
executePlan.eventsManager.postFinished(Some(totalNumRows))
131133
case _ =>
132134
SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
133135
val rows = dataframe.queryExecution.executedPlan.execute()
@@ -162,8 +164,7 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
162164
resultFunc = () => ())
163165
// Collect errors and propagate them to the main thread.
164166
.andThen {
165-
case Success(_) =>
166-
executePlan.eventsManager.postFinished()
167+
case Success(_) => // do nothing
167168
case Failure(throwable) =>
168169
signal.synchronized {
169170
error = Some(throwable)
@@ -200,8 +201,9 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
200201
currentPartitionId += 1
201202
}
202203
ThreadUtils.awaitReady(future, Duration.Inf)
204+
executePlan.eventsManager.postFinished(Some(totalNumRows))
203205
} else {
204-
executePlan.eventsManager.postFinished()
206+
executePlan.eventsManager.postFinished(Some(totalNumRows))
205207
}
206208
}
207209
}

connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2513,7 +2513,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
25132513
.putAllArgs(getSqlCommand.getArgsMap)
25142514
.addAllPosArgs(getSqlCommand.getPosArgsList)))
25152515
}
2516-
executeHolder.eventsManager.postFinished()
2516+
executeHolder.eventsManager.postFinished(Some(rows.size))
25172517
// Exactly one SQL Command Result Batch
25182518
responseObserver.onNext(
25192519
ExecutePlanResponse

connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteEventsManager.scala

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ case class ExecuteEventsManager(executeHolder: ExecuteHolder, clock: Clock) {
7575

7676
private var canceled = Option.empty[Boolean]
7777

78+
private var producedRowCount = Option.empty[Long]
79+
7880
/**
7981
* @return
8082
* Last event posted by the Connect request
@@ -95,6 +97,13 @@ case class ExecuteEventsManager(executeHolder: ExecuteHolder, clock: Clock) {
9597
*/
9698
private[connect] def hasError: Option[Boolean] = error
9799

100+
/**
101+
* @return
102+
* How many rows the Connect request has produced @link
103+
* org.apache.spark.sql.connect.service.SparkListenerConnectOperationFinished
104+
*/
105+
private[connect] def getProducedRowCount: Option[Long] = producedRowCount
106+
98107
/**
99108
* Post @link org.apache.spark.sql.connect.service.SparkListenerConnectOperationStarted.
100109
*/
@@ -192,13 +201,23 @@ case class ExecuteEventsManager(executeHolder: ExecuteHolder, clock: Clock) {
192201

193202
/**
194203
* Post @link org.apache.spark.sql.connect.service.SparkListenerConnectOperationFinished.
204+
* @param producedRowsCountOpt
205+
* Number of rows that are returned to the user. None is expected when the operation does not
206+
* return any rows.
195207
*/
196-
def postFinished(): Unit = {
208+
def postFinished(producedRowsCountOpt: Option[Long] = None): Unit = {
197209
assertStatus(
198210
List(ExecuteStatus.Started, ExecuteStatus.ReadyForExecution),
199211
ExecuteStatus.Finished)
212+
producedRowCount = producedRowsCountOpt
213+
200214
listenerBus
201-
.post(SparkListenerConnectOperationFinished(jobTag, operationId, clock.getTimeMillis()))
215+
.post(
216+
SparkListenerConnectOperationFinished(
217+
jobTag,
218+
operationId,
219+
clock.getTimeMillis(),
220+
producedRowCount))
202221
}
203222

204223
/**
@@ -395,13 +414,17 @@ case class SparkListenerConnectOperationFailed(
395414
* 36 characters UUID assigned by Connect during a request.
396415
* @param eventTime:
397416
* The time in ms when the event was generated.
417+
* @param producedRowCount:
418+
* Number of rows that are returned to the user. None is expected when the operation does not
419+
* return any rows.
398420
* @param extraTags:
399421
* Additional metadata during the request.
400422
*/
401423
case class SparkListenerConnectOperationFinished(
402424
jobTag: String,
403425
operationId: String,
404426
eventTime: Long,
427+
producedRowCount: Option[Long] = None,
405428
extraTags: Map[String, String] = Map.empty)
406429
extends SparkListenerEvent
407430

0 commit comments

Comments
 (0)