Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Initial draft
  • Loading branch information
grundprinzip committed Feb 17, 2024
commit 084d257d215cf3c415a65d5274ce51e76f645cab
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,16 @@ class SparkSessionE2ESuite extends RemoteSparkSession {
assert(interrupted.length == 2, s"Interrupted operations: $interrupted.")
}

test("progress is available for the spark result") {
val result = spark
.range(10000)
.repartition(1000)
.collectResult()
assert(result.length == 10000)
assert(result.progress.totalTasks > 100)
assert(result.progress.completedTasks > 100)
}

test("interrupt operation") {
val session = spark
import session.implicits._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ message ExecutePlanRequest {

// The response of a query, can be one or more for each request. Responses belonging to the
// same input query, carry the same `session_id`.
// Next ID: 16
// Next ID: 17
message ExecutePlanResponse {
string session_id = 1;
// Server-side generated idempotency key that the client can use to assert that the server side
Expand Down Expand Up @@ -360,6 +360,9 @@ message ExecutePlanResponse {
// Response type informing if the stream is complete in reattachable execution.
ResultComplete result_complete = 14;

// (Optional) Intermediate query progress reports.
ExecutionProgress execution_progress = 16;

// Support arbitrary result objects.
google.protobuf.Any extension = 999;
}
Expand Down Expand Up @@ -420,6 +423,15 @@ message ExecutePlanResponse {
// the execution is complete. If the server sends onComplete without sending a ResultComplete,
// it means that there is more, and the client should use ReattachExecute RPC to continue.
}

// This message is used to communicate progress about the query progress during the execution.
message ExecutionProgress {
int64 num_tasks = 1;
int64 num_completed_tasks = 2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this for the current running stage or all stages?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Across all stages. It can always be extended later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering how can this be accurate. With AQE we never know what is the number of partitions for the next stage, as re-optimization can happen.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The goal of the progress metrics is not to be accurate into the future but only represent the snapshot of the current state. This means that the number of tasks can be updated when new stages are added or AQE kicks in.

The point is that the number of remaining tasks will converge over time and become stable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

progres

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Just my 2c: I think having any progress bar is much better than none. The standard Spark progress bar has some ups and some downs, definitely having new progress bars appear isn't the most intuitive either. I think it's probably net better than one progress bar that gets longer, but I would much prefer having some progress bar now that we can extend later, perhaps as we get a better sense of how to incorporate AQE and future stages into the UX.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After a second thought, it's better to hide Spark internals (stages) to end users, and eventually we should only have one progress bar for the query. So the current PR is a good starting point.

However, this server-client protocol needs to be stable and we don't want to change the client frequently to improve the progress reporting. Can we define a minimum set of information we need to send to the client side to display the progress bar? I feel it's better to calculate the percentage at the server side.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I refactored the code to avoid closing any doors. I did not change the way the progress bar is displayed. However, I extended the progress message to capture the stage-wise information so other clients can decide independently how to present the information to the end user.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 @cloud-fan what do you think about that? Capture stage-level info in the proto, but keep the display simple for now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea this is more flexible. The proto message contains all the information and clients can do whatever they want.

int64 num_stages = 3;
int64 num_completed_stages = 4;
int64 input_bytes_read = 5;
}
}

// The key-value pair for the config request and response.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ private[sql] class SparkResult[T](
timeZoneId: String)
extends AutoCloseable { self =>

/**
* Progress of the query execution. This information can be accessed from the iterator.
*/
case class Progress (
totalTasks: Long = 0,
completedTasks: Long = 0,
totalStages: Long = 0,
completedStages: Long = 0,
inputBytesRead: Long = 0)

var progress: Progress = new Progress()
private[this] var opId: String = _
private[this] var numRecords: Int = 0
private[this] var structType: StructType = _
Expand Down Expand Up @@ -97,6 +108,17 @@ private[sql] class SparkResult[T](
}
stop |= stopOnOperationId

// Update the execution status. This information can now be accessed directly from
// the iterator.
if (response.hasExecutionProgress) {
progress = Progress(
response.getExecutionProgress.getNumTasks,
response.getExecutionProgress.getNumCompletedTasks,
response.getExecutionProgress.getNumStages,
response.getExecutionProgress.getNumCompletedStages,
response.getExecutionProgress.getInputBytesRead)
}

if (response.hasSchema) {
// The original schema should arrive before ArrowBatches.
structType =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.execution

import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd, SparkListenerJobStart, SparkListenerStageCompleted, SparkListenerTaskEnd}

/**
* A listener that tracks the execution of jobs and stages for a given set of tags.
* This is used to track the progress of a job that is being executed through the connect API.
*
* The listener is instantiated once for the SparkConnectService and then used to track all the
* current query executions.
*/
private[connect] class ConnectProgressExecutionListener extends SparkListener with Logging {
/**
* A tracker for a given tag. This is used to track the progress of an operation is being executed
* through the connect API.
*/
class ExecutionTracker(var tag: String) {
private[ConnectProgressExecutionListener] var jobs: Set[Int] = Set()
private[ConnectProgressExecutionListener] var stages: Set[Int] = Set()
private[ConnectProgressExecutionListener] var totalTasks = 0
private[ConnectProgressExecutionListener] var completedTasks = 0
private[ConnectProgressExecutionListener] var completedStages = 0
private[ConnectProgressExecutionListener] var inputBytesRead = 0L
// The tracker is marked as dirty if it has new progress to report. This variable does
// not need to be protected by a mutex even if multiple threads would read the same dirty
// state the output is expected to be identical.
@volatile private[ConnectProgressExecutionListener] var dirty = false

/**
* Yield the current state of the tracker if it is dirty. A consumer of the tracker can provide
* a callback that will be called with the current state of the tracker if the tracker has new
* progress to report.
*
* If the tracker was marked as dirty, the state is reset after.
*/
def yieldWhenDirty(thunk: (Int, Int, Int, Int, Long) => Unit): Unit = {
if (dirty) {
thunk(totalTasks, completedTasks, stages.size, totalTasks, inputBytesRead)
dirty = false
}
}

/**
* Add a job to the tracker. This will add the job to the list of jobs that are being tracked
*/
def addJob(job: SparkListenerJobStart): Unit = {
jobs = jobs + job.jobId
stages = stages ++ job.stageIds
totalTasks += job.stageInfos.map(_.numTasks).sum
}
}

val trackedTags = collection.mutable.Map[String, ExecutionTracker]()

override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
val tags = jobStart.properties.getProperty("spark.job.tags")
if (tags != null) {
val thisJobTags = tags.split(",").map(_.trim).toSet
thisJobTags.foreach { tag =>
if (trackedTags.contains(tag)) {
trackedTags(tag).addJob(jobStart)
}
}
}
}

override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
// Check if the task belongs to a job that we are tracking.
trackedTags.foreach({ case (tag, tracker) =>
if (tracker.stages.contains(taskEnd.stageId)) {
tracker.completedTasks += 1
tracker.inputBytesRead += taskEnd.taskMetrics.inputMetrics.bytesRead
tracker.dirty = true
}
})
}

override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
trackedTags.foreach({ case (tag, tracker) =>
if (tracker.stages.contains(stageCompleted.stageInfo.stageId)) {
tracker.completedStages += 1
}
})
}

override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
trackedTags.foreach({ case (tag, tracker) =>
if (tracker.jobs.contains(jobEnd.jobId)) {
tracker.jobs -= jobEnd.jobId
}
})
}

def registerJobTag(tag: String): Unit = {
trackedTags += tag -> new ExecutionTracker(tag)
}

def removeJobTag(tag: String): Unit = {
trackedTags -= tag
}

def clearJobTags(): Unit = {
trackedTags.clear()
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import com.google.protobuf.Message
import io.grpc.stub.{ServerCallStreamObserver, StreamObserver}

import org.apache.spark.{SparkEnv, SparkSQLException}
import org.apache.spark.connect.proto.ExecutePlanResponse
import org.apache.spark.internal.Logging
import org.apache.spark.sql.connect.common.ProtoUtils
import org.apache.spark.sql.connect.config.Connect.{CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_DURATION, CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_SIZE}
Expand Down Expand Up @@ -131,6 +132,38 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message](
}
}

/**
* This method is called repeatedly during the query execution to enqueue a new message to be send
* to the client about the current query progress. The message is not directly send to the client,
* but rather enqueued to in the response observer.
*/
private def enqueueProgressMessage(): Unit = {
SparkConnectService.executionListener.foreach { listener =>
if (listener.trackedTags.contains(executeHolder.jobTag)) {
val tracker = listener.trackedTags(executeHolder.jobTag)
// Only send progress message if there is something new to report.
tracker.yieldWhenDirty { (tasks, tasksCompleted, stages, stagesCompleted, inputBytesRead) =>
val response = ExecutePlanResponse
.newBuilder()
.setExecutionProgress(
ExecutePlanResponse.ExecutionProgress
.newBuilder()
.setInputBytesRead(inputBytesRead)
.setNumTasks(tasks)
.setNumCompletedTasks(tasksCompleted)
.setNumCompletedStages(stagesCompleted)
.setNumStages(stages)
)
.build()
// There is a special case when the response observer has alreaady determined
// that the final message is send (and the stream will be closed) but we might want
// to send the progress message. In this case we ignore the result of the `onNext` call.
executeHolder.responseObserver.tryOnNext(response)
}
}
}
}

/**
* Attach to the executionObserver, consume responses from it, and send them to grpcObserver.
*
Expand Down Expand Up @@ -173,6 +206,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message](
var sentResponsesSize: Long = 0

while (!finished) {
enqueueProgressMessage()
var response: Option[CachedStreamResponse[T]] = None

// Conditions for exiting the inner loop (and helpers to compute them):
Expand Down Expand Up @@ -201,9 +235,11 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message](
// The state of interrupted, response and lastIndex are changed under executionObserver
// monitor, and will notify upon state change.
if (response.isEmpty) {
val timeout = Math.max(1, deadlineTimeMillis - System.currentTimeMillis())
// Wake up more frequently to send the progress updates.
val timeout = 2000
logTrace(s"Wait for response to become available with timeout=$timeout ms.")
executionObserver.responseLock.wait(timeout)
enqueueProgressMessage()
logTrace(s"Reacquired executionObserver lock after waiting.")
sleepEnd = System.nanoTime()
}
Expand All @@ -228,6 +264,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message](
s"waitingForResults=${consumeSleep}ns waitingForSend=${sendSleep}ns")
throw new SparkSQLException(errorClass = "INVALID_CURSOR.DISCONNECTED", Map.empty)
} else if (gotResponse) {
enqueueProgressMessage()
// There is a response available to be sent.
val sent = sendResponse(response.get, deadlineTimeMillis)
if (sent) {
Expand All @@ -240,6 +277,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message](
assert(deadlineLimitReached || interrupted)
}
} else if (streamFinished) {
enqueueProgressMessage()
// Stream is finished and all responses have been sent
logInfo(
s"Stream finished for opId=${executeHolder.operationId}, " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ private[connect] class ExecuteResponseObserver[T <: Message](val executeHolder:
0
}

def onNext(r: T): Unit = responseLock.synchronized {
def tryOnNext(r: T): Boolean = responseLock.synchronized {
if (finalProducedIndex.nonEmpty) {
throw new IllegalStateException("Stream onNext can't be called after stream completed")
return false
}
lastProducedIndex += 1
val processedResponse = setCommonResponseFields(r)
Expand All @@ -127,6 +127,13 @@ private[connect] class ExecuteResponseObserver[T <: Message](val executeHolder:
s"Execution opId=${executeHolder.operationId} produced response " +
s"responseId=${responseId} idx=$lastProducedIndex")
responseLock.notifyAll()
true
}

def onNext(r: T): Unit = {
if (!tryOnNext(r)) {
throw new IllegalStateException("Stream onNext can't be called after stream completed")
}
}

def onError(t: Throwable): Unit = responseLock.synchronized {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.sql.connect.common.ProtoUtils
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteSessionTag}
import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteSessionTag, SparkConnectService}
import org.apache.spark.sql.connect.utils.ErrorUtils
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -123,6 +123,7 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends
}
} finally {
executeHolder.sessionHolder.session.sparkContext.removeJobTag(executeHolder.jobTag)
SparkConnectService.executionListener.foreach(_.removeJobTag(executeHolder.jobTag))
executeHolder.sparkSessionTags.foreach { tag =>
executeHolder.sessionHolder.session.sparkContext.removeJobTag(
ExecuteSessionTag(
Expand Down Expand Up @@ -158,6 +159,8 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends

// Set tag for query cancellation
session.sparkContext.addJobTag(executeHolder.jobTag)
// Register the job for progress reports.
SparkConnectService.executionListener.foreach(_.registerJobTag(executeHolder.jobTag))
// Also set all user defined tags as Spark Job tags.
executeHolder.sparkSessionTags.foreach { tag =>
session.sparkContext.addJobTag(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import org.apache.spark.connect.proto.SparkConnectServiceGrpc.AsyncService
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.UI.UI_ENABLED
import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_ADDRESS, CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE}
import org.apache.spark.sql.connect.execution.ConnectProgressExecutionListener
import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore, SparkConnectServerListener, SparkConnectServerTab}
import org.apache.spark.sql.connect.utils.ErrorUtils
import org.apache.spark.status.ElementTrackingStore
Expand Down Expand Up @@ -284,6 +285,7 @@ object SparkConnectService extends Logging {

private[connect] var uiTab: Option[SparkConnectServerTab] = None
private[connect] var listener: SparkConnectServerListener = _
private[connect] var executionListener: Option[ConnectProgressExecutionListener] = None

// For testing purpose, it's package level private.
private[connect] def localPort: Int = {
Expand Down Expand Up @@ -325,6 +327,9 @@ object SparkConnectService extends Logging {
} else {
None
}
// Add the execution listener needed for query progress.
executionListener = Some(new ConnectProgressExecutionListener)
sc.addSparkListener(executionListener.get)
}

/**
Expand Down
Loading