Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,16 @@ class SparkSessionE2ESuite extends RemoteSparkSession {
assert(interrupted.length == 2, s"Interrupted operations: $interrupted.")
}

test("progress is available for the spark result") {
val result = spark
.range(10000)
.repartition(1000)
.collectResult()
assert(result.length == 10000)
assert(result.progress.stages.map(_.numTasks).sum > 100)
assert(result.progress.stages.map(_.completedTasks).sum > 100)
}

test("interrupt operation") {
val session = spark
import session.implicits._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ message ExecutePlanRequest {

// The response of a query, can be one or more for each request. Responses belonging to the
// same input query, carry the same `session_id`.
// Next ID: 16
// Next ID: 17
message ExecutePlanResponse {
string session_id = 1;
// Server-side generated idempotency key that the client can use to assert that the server side
Expand Down Expand Up @@ -378,6 +378,9 @@ message ExecutePlanResponse {
// Response for command that creates ResourceProfile.
CreateResourceProfileCommandResult create_resource_profile_command_result = 17;

// (Optional) Intermediate query progress reports.
ExecutionProgress execution_progress = 18;

// Support arbitrary result objects.
google.protobuf.Any extension = 999;
}
Expand Down Expand Up @@ -438,6 +441,23 @@ message ExecutePlanResponse {
// the execution is complete. If the server sends onComplete without sending a ResultComplete,
// it means that there is more, and the client should use ReattachExecute RPC to continue.
}

// This message is used to communicate progress about the query progress during the execution.
message ExecutionProgress {
// Captures the progress of each individual stage.
repeated StageInfo stages = 1;

// Captures the currently in progress tasks.
int64 num_inflight_tasks = 2;

message StageInfo {
int64 stage_id = 1;
int64 num_tasks = 2;
int64 num_completed_tasks = 3;
int64 input_bytes_read = 4;
bool done = 5;
}
}
}

// The key-value pair for the config request and response.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import java.lang.ref.Cleaner
import java.util.Objects

import scala.collection.mutable
import scala.jdk.CollectionConverters._

import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.vector.ipc.message.{ArrowMessage, ArrowRecordBatch}
Expand All @@ -40,6 +41,38 @@ private[sql] class SparkResult[T](
timeZoneId: String)
extends AutoCloseable { self =>

case class StageInfo(
stageId: Long,
numTasks: Long,
completedTasks: Long = 0,
inputBytesRead: Long = 0,
completed: Boolean = false)

object StageInfo {
def apply(stageInfo: proto.ExecutePlanResponse.ExecutionProgress.StageInfo): StageInfo = {
StageInfo(
stageInfo.getStageId,
stageInfo.getNumTasks,
stageInfo.getNumCompletedTasks,
stageInfo.getInputBytesRead,
stageInfo.getDone)
}
}

object Progress {
def apply(progress: proto.ExecutePlanResponse.ExecutionProgress): Progress = {
Progress(
progress.getStagesList.asScala.map(StageInfo(_)).toSeq,
progress.getNumInflightTasks)
}
}

/**
* Progress of the query execution. This information can be accessed from the iterator.
*/
case class Progress(stages: Seq[StageInfo], inflight: Long)

var progress: Progress = new Progress(Seq.empty, 0)
private[this] var opId: String = _
private[this] var numRecords: Int = 0
private[this] var structType: StructType = _
Expand Down Expand Up @@ -97,6 +130,12 @@ private[sql] class SparkResult[T](
}
stop |= stopOnOperationId

// Update the execution status. This information can now be accessed directly from
// the iterator.
if (response.hasExecutionProgress) {
progress = Progress(response.getExecutionProgress)
}

if (response.hasSchema) {
// The original schema should arrive before ArrowBatches.
structType =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,4 +265,12 @@ object Connect {
.version("4.0.0")
.bytesConf(ByteUnit.BYTE)
.createWithDefault(1024)

val CONNECT_PROGRESS_REPORT_INTERVAL =
buildConf("spark.connect.progress.reportInterval")
.doc("The interval at which the progress of a query is reported to the client." +
" If the value is set to a negative value the progress reports will be disabled.")
.version("4.0.0")
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("2s")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.execution

import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger}

import org.apache.spark.connect.proto.ExecutePlanResponse
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd, SparkListenerJobStart, SparkListenerStageCompleted, SparkListenerTaskEnd, SparkListenerTaskStart}

/**
* A listener that tracks the execution of jobs and stages for a given set of tags. This is used
* to track the progress of a job that is being executed through the connect API.
*
* The listener is instantiated once for the SparkConnectService and then used to track all the
* current query executions.
*/
private[connect] class ConnectProgressExecutionListener extends SparkListener with Logging {

/**
* A tracker for a given tag. This is used to track the progress of an operation is being
* executed through the connect API.
*/
class ExecutionTracker(val tag: String) {

class StageInfo(
val stageId: Int,
var numTasks: Int,
var completedTasks: Int = 0,
var inputBytesRead: Long = 0,
var completed: Boolean = false) {

val lock = new Object
def update(i: StageInfo => Unit): Unit = {
lock.synchronized {
i(this)
}
}

def toProto(): ExecutePlanResponse.ExecutionProgress.StageInfo = {
ExecutePlanResponse.ExecutionProgress.StageInfo
.newBuilder()
.setStageId(stageId)
.setNumTasks(numTasks)
.setNumCompletedTasks(completedTasks)
.setInputBytesRead(inputBytesRead)
.setDone(completed)
.build()
}
}

// The set of jobs that are being tracked by this tracker. We always only add to this list
// but never remove. This is to avoid concurrency issues.
private[ConnectProgressExecutionListener] var jobs: Set[Int] = Set()
// The set of stages that are being tracked by this tracker. We always only add to this list
// but never remove. This is to avoid concurrency issues.
private[ConnectProgressExecutionListener] var stages: Map[Int, StageInfo] = Map.empty
// The tracker is marked as dirty if it has new progress to report.
private[ConnectProgressExecutionListener] val dirty = new AtomicBoolean(false)
// Tracks all currently running tasks for a particular tracker.
private[ConnectProgressExecutionListener] val inFlightTasks = new AtomicInteger(0)

/**
* Yield the current state of the tracker if it is dirty. A consumer of the tracker can
* provide a callback that will be called with the current state of the tracker if the tracker
* has new progress to report.
*
* If the tracker was marked as dirty, the state is reset after.
*/
def yieldWhenDirty(thunk: (Seq[StageInfo], Long) => Unit): Unit = {
if (dirty.get()) {
thunk(stages.values.toSeq, inFlightTasks.get())
dirty.set(false)
}
}

/**
* Add a job to the tracker. This will add the job to the list of jobs that are being tracked
*/
def addJob(job: SparkListenerJobStart): Unit = synchronized {
jobs = jobs + job.jobId
job.stageInfos.foreach { stage =>
stages = stages + (stage.stageId -> new StageInfo(stage.stageId, stage.numTasks))
}
dirty.set(true)
}

def jobCount(): Int = {
jobs.size
}

def stageCount(): Int = {
stages.size
}
}

val trackedTags = collection.concurrent.TrieMap[String, ExecutionTracker]()

override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
val tags = jobStart.properties.getProperty("spark.job.tags")
if (tags != null) {
val thisJobTags = tags.split(",").map(_.trim).toSet
thisJobTags.foreach { tag =>
trackedTags.get(tag).foreach { tracker =>
tracker.addJob(jobStart)
}
}
}
}

override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
// Check if the task belongs to a job that we are tracking.
trackedTags.foreach({ case (_, tracker) =>
if (tracker.stages.contains(taskStart.stageId)) {
tracker.inFlightTasks.incrementAndGet()
tracker.dirty.set(true)
}
})
}

override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
// Check if the task belongs to a job that we are tracking.
trackedTags.foreach({ case (_, tracker) =>
if (tracker.stages.contains(taskEnd.stageId)) {
tracker.stages.get(taskEnd.stageId).foreach { stage =>
stage.update { i =>
i.completedTasks += 1
i.inputBytesRead += taskEnd.taskMetrics.inputMetrics.bytesRead
}
}
// This should never become negative, simply reset to zero if it does.
tracker.inFlightTasks.decrementAndGet()
if (tracker.inFlightTasks.get() < 0) {
tracker.inFlightTasks.set(0)
}
tracker.dirty.set(true)
}
})
}

override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
trackedTags.foreach({ case (_, tracker) =>
if (tracker.stages.contains(stageCompleted.stageInfo.stageId)) {
tracker.stages(stageCompleted.stageInfo.stageId).update { stage =>
stage.completed = true
}
tracker.dirty.set(true)
}
})
}

override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
trackedTags.foreach({ case (_, tracker) =>
if (tracker.jobs.contains(jobEnd.jobId)) {
tracker.dirty.set(true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we set the dirty flag when nothing is updated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly to make sure that all progress is reported and an update is sent to the client. If you're tracking time between progress messages, every message itself is progress.

}
})
}

def tryGetTracker(tag: String): Option[ExecutionTracker] = {
trackedTags.get(tag)
}

def registerJobTag(tag: String): Unit = {
trackedTags += tag -> new ExecutionTracker(tag)
}

def removeJobTag(tag: String): Unit = {
trackedTags -= tag
}

def clearJobTags(): Unit = {
trackedTags.clear()
}

}
Loading