Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,20 @@ private[sql] trait CloseableIterator[E] extends Iterator[E] with AutoCloseable {
}
}

private[sql] abstract class WrappedCloseableIterator[E] extends CloseableIterator[E] {

def innerIterator: Iterator[E]

override def next(): E = innerIterator.next()

override def hasNext(): Boolean = innerIterator.hasNext

override def close(): Unit = innerIterator match {
case it: CloseableIterator[E] => it.close()
case _ => // nothing
}
}

private[sql] object CloseableIterator {

/**
Expand All @@ -35,12 +49,8 @@ private[sql] object CloseableIterator {
def apply[T](iterator: Iterator[T]): CloseableIterator[T] = iterator match {
case closeable: CloseableIterator[T] => closeable
case _ =>
new CloseableIterator[T] {
override def next(): T = iterator.next()

override def hasNext(): Boolean = iterator.hasNext

override def close() = { /* empty */ }
new WrappedCloseableIterator[T] {
override def innerIterator = iterator
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import io.grpc.ManagedChannel

import org.apache.spark.connect.proto._

private[client] class CustomSparkConnectBlockingStub(
private[connect] class CustomSparkConnectBlockingStub(
channel: ManagedChannel,
retryPolicy: GrpcRetryHandler.RetryPolicy) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.client

import java.util.UUID

import scala.collection.JavaConverters._
import scala.util.control.NonFatal

import io.grpc.{ManagedChannel, StatusRuntimeException}
Expand Down Expand Up @@ -50,7 +51,7 @@ class ExecutePlanResponseReattachableIterator(
request: proto.ExecutePlanRequest,
channel: ManagedChannel,
retryPolicy: GrpcRetryHandler.RetryPolicy)
extends CloseableIterator[proto.ExecutePlanResponse]
extends WrappedCloseableIterator[proto.ExecutePlanResponse]
with Logging {

val operationId = if (request.hasOperationId) {
Expand Down Expand Up @@ -86,14 +87,25 @@ class ExecutePlanResponseReattachableIterator(
// True after ResultComplete message was seen in the stream.
// Server will always send this message at the end of the stream, if the underlying iterator
// finishes without producing one, another iterator needs to be reattached.
private var resultComplete: Boolean = false
// Visible for testing.
private[connect] var resultComplete: Boolean = false

// Initial iterator comes from ExecutePlan request.
// Note: This is not retried, because no error would ever be thrown here, and GRPC will only
// throw error on first iter.hasNext() or iter.next()
private var iter: Option[java.util.Iterator[proto.ExecutePlanResponse]] =
// Visible for testing.
private[connect] var iter: Option[java.util.Iterator[proto.ExecutePlanResponse]] =
Some(rawBlockingStub.executePlan(initialRequest))

override def innerIterator: Iterator[proto.ExecutePlanResponse] = iter match {
case Some(it) => it.asScala
case None =>
// The iterator is only unset for short moments while retry exception is thrown.
// It should only happen in the middle of internal processing. Since this iterator is not
// thread safe, no-one should be accessing it at this moment.
throw new IllegalStateException("innerIterator unset")
}

override def next(): proto.ExecutePlanResponse = synchronized {
// hasNext will trigger reattach in case the stream completed without resultComplete
if (!hasNext()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ private[client] object GrpcExceptionConverter extends JsonUtils {
}

def convertIterator[T](iter: CloseableIterator[T]): CloseableIterator[T] = {
new CloseableIterator[T] {
new WrappedCloseableIterator[T] {

override def innerIterator: Iterator[T] = iter

override def hasNext: Boolean = {
convert {
iter.hasNext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,13 @@ private[sql] class GrpcRetryHandler(
* The type of the response.
*/
class RetryIterator[T, U](request: T, call: T => CloseableIterator[U])
extends CloseableIterator[U] {
extends WrappedCloseableIterator[U] {

private var opened = false // we only retry if it fails on first call when using the iterator
private var iter = call(request)

override def innerIterator: Iterator[U] = iter

private def retryIter[V](f: Iterator[U] => V) = {
if (!opened) {
opened = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message](

private var interrupted = false

// Time at which this sender should finish if the response stream is not finished by then.
private var deadlineTimeMillis = Long.MaxValue

// Signal to wake up when grpcCallObserver.isReady()
private val grpcCallObserverReadySignal = new Object

Expand All @@ -65,6 +68,12 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message](
executionObserver.notifyAll()
}

// For testing
private[connect] def setDeadline(deadlineMs: Long) = executionObserver.synchronized {
deadlineTimeMillis = deadlineMs
executionObserver.notifyAll()
}

def run(lastConsumedStreamIndex: Long): Unit = {
if (executeHolder.reattachable) {
// In reattachable execution we use setOnReadyHandler and grpcCallObserver.isReady to control
Expand Down Expand Up @@ -150,7 +159,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message](
var finished = false

// Time at which this sender should finish if the response stream is not finished by then.
val deadlineTimeMillis = if (!executeHolder.reattachable) {
deadlineTimeMillis = if (!executeHolder.reattachable) {
Long.MaxValue
} else {
val confSize =
Expand Down Expand Up @@ -232,8 +241,8 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message](
assert(finished == false)
} else {
// If it wasn't sent, time deadline must have been reached before stream became available,
// will exit in the enxt loop iterattion.
assert(deadlineLimitReached)
// or it was intterupted. Will exit in the next loop iterattion.
assert(deadlineLimitReached || interrupted)
}
} else if (streamFinished) {
// Stream is finished and all responses have been sent
Expand Down Expand Up @@ -301,7 +310,7 @@ private[connect] class ExecuteGrpcResponseSender[T <: Message](
val sleepStart = System.nanoTime()
var sleepEnd = 0L
// Conditions for exiting the inner loop
// 1. was detached
// 1. was interrupted
// 2. grpcCallObserver is ready to send more data
// 3. time deadline is reached
while (!interrupted &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,16 @@ private[connect] class ExecuteResponseObserver[T <: Message](val executeHolder:
/** The index of the last response produced by execution. */
private var lastProducedIndex: Long = 0 // first response will have index 1

// For testing
private[connect] var releasedUntilIndex: Long = 0

/**
* Highest response index that was consumed. Keeps track of it to decide which responses needs
* to be cached, and to assert that all responses are consumed.
*
* Visible for testing.
*/
private var highestConsumedIndex: Long = 0
private[connect] var highestConsumedIndex: Long = 0

/**
* Consumer that waits for available responses. There can be only one at a time, @see
Expand Down Expand Up @@ -284,6 +289,7 @@ private[connect] class ExecuteResponseObserver[T <: Message](val executeHolder:
responses.remove(i)
i -= 1
}
releasedUntilIndex = index
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,16 @@ private[connect] class ExecuteHolder(
}
}

// For testing.
private[connect] def setGrpcResponseSendersDeadline(deadlineMs: Long) = synchronized {
grpcResponseSenders.foreach(_.setDeadline(deadlineMs))
}

// For testing
private[connect] def interruptGrpcResponseSenders() = synchronized {
grpcResponseSenders.foreach(_.interrupt())
}

/**
* For a short period in ExecutePlan after creation and until runGrpcResponseSender is called,
* there is no attached response sender, but yet we start with lastAttachedRpcTime = None, so we
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,14 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
// The latter is to prevent double execution when a client retries execution, thinking it
// never reached the server, but in fact it did, and already got removed as abandoned.
if (executions.get(executeHolder.key).isDefined) {
if (getAbandonedTombstone(executeHolder.key).isDefined) {
throw new SparkSQLException(
errorClass = "INVALID_HANDLE.OPERATION_ABANDONED",
messageParameters = Map("handle" -> executeHolder.operationId))
} else {
throw new SparkSQLException(
errorClass = "INVALID_HANDLE.OPERATION_ALREADY_EXISTS",
messageParameters = Map("handle" -> executeHolder.operationId))
}
throw new SparkSQLException(
errorClass = "INVALID_HANDLE.OPERATION_ALREADY_EXISTS",
messageParameters = Map("handle" -> executeHolder.operationId))
}
if (getAbandonedTombstone(executeHolder.key).isDefined) {
throw new SparkSQLException(
errorClass = "INVALID_HANDLE.OPERATION_ABANDONED",
messageParameters = Map("handle" -> executeHolder.operationId))
}
sessionHolder.addExecuteHolder(executeHolder)
executions.put(executeHolder.key, executeHolder)
Expand Down Expand Up @@ -141,12 +140,17 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
abandonedTombstones.asMap.asScala.values.toBuffer.toSeq
}

private[service] def shutdown(): Unit = executionsLock.synchronized {
private[connect] def shutdown(): Unit = executionsLock.synchronized {
scheduledExecutor.foreach { executor =>
executor.shutdown()
executor.awaitTermination(1, TimeUnit.MINUTES)
}
scheduledExecutor = None
executions.clear()
abandonedTombstones.invalidateAll()
if (!lastExecutionTime.isDefined) {
lastExecutionTime = Some(System.currentTimeMillis())
}
}

/**
Expand Down Expand Up @@ -188,7 +192,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
executions.values.foreach { executeHolder =>
executeHolder.lastAttachedRpcTime match {
case Some(detached) =>
if (detached + timeout < nowMs) {
if (detached + timeout <= nowMs) {
toRemove += executeHolder
}
case _ => // execution is active
Expand All @@ -206,4 +210,18 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
}
logInfo("Finished periodic run of SparkConnectExecutionManager maintenance.")
}

// For testing.
private[connect] def setAllRPCsDeadline(deadlineMs: Long) = executionsLock.synchronized {
executions.values.foreach(_.setGrpcResponseSendersDeadline(deadlineMs))
}

// For testing.
private[connect] def interruptAllRPCs() = executionsLock.synchronized {
executions.values.foreach(_.interruptGrpcResponseSenders())
}

private[connect] def listExecuteHolders = executionsLock.synchronized {
executions.values.toBuffer.toSeq
}
}
Loading