Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions common/utils/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -2083,6 +2083,11 @@
"Operation not found."
]
},
"SESSION_CHANGED" : {
"message" : [
"The existing Spark server driver instance has restarted. Please reconnect."
]
},
"SESSION_CLOSED" : {
"message" : [
"Session was closed."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ message AnalyzePlanRequest {
// The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff`
string session_id = 1;

// (Optional)
//
// Server-side generated idempotency key from the previous responses (if any). Server
// can use this to validate that the server side session has not changed.
optional string client_observed_server_side_session_id = 17;

// (Required) User context
UserContext user_context = 2;

Expand Down Expand Up @@ -281,6 +287,12 @@ message ExecutePlanRequest {
// The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff`
string session_id = 1;

// (Optional)
//
// Server-side generated idempotency key from the previous responses (if any). Server
// can use this to validate that the server side session has not changed.
optional string client_observed_server_side_session_id = 8;

// (Required) User context
//
// user_context.user_id and session+id both identify a unique remote spark session on the
Expand Down Expand Up @@ -443,6 +455,12 @@ message ConfigRequest {
// The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff`
string session_id = 1;

// (Optional)
//
// Server-side generated idempotency key from the previous responses (if any). Server
// can use this to validate that the server side session has not changed.
optional string client_observed_server_side_session_id = 8;

// (Required) User context
UserContext user_context = 2;

Expand Down Expand Up @@ -536,6 +554,12 @@ message AddArtifactsRequest {
// User context
UserContext user_context = 2;

// (Optional)
//
// Server-side generated idempotency key from the previous responses (if any). Server
// can use this to validate that the server side session has not changed.
optional string client_observed_server_side_session_id = 7;

// Provides optional information about the client sending the request. This field
// can be used for language or version specific information and is only intended for
// logging purposes and will not be interpreted by the server.
Expand Down Expand Up @@ -630,6 +654,12 @@ message ArtifactStatusesRequest {
// The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff`
string session_id = 1;

// (Optional)
//
// Server-side generated idempotency key from the previous responses (if any). Server
// can use this to validate that the server side session has not changed.
optional string client_observed_server_side_session_id = 5;

// User context
UserContext user_context = 2;

Expand Down Expand Up @@ -673,6 +703,12 @@ message InterruptRequest {
// The id should be an UUID string of the format `00112233-4455-6677-8899-aabbccddeeff`
string session_id = 1;

// (Optional)
//
// Server-side generated idempotency key from the previous responses (if any). Server
// can use this to validate that the server side session has not changed.
optional string client_observed_server_side_session_id = 7;

// (Required) User context
UserContext user_context = 2;

Expand Down Expand Up @@ -738,6 +774,12 @@ message ReattachExecuteRequest {
// This must be an id of existing session.
string session_id = 1;

// (Optional)
//
// Server-side generated idempotency key from the previous responses (if any). Server
// can use this to validate that the server side session has not changed.
optional string client_observed_server_side_session_id = 6;

// (Required) User context
//
// user_context.user_id and session+id both identify a unique remote spark session on the
Expand Down Expand Up @@ -772,6 +814,12 @@ message ReleaseExecuteRequest {
// This must be an id of existing session.
string session_id = 1;

// (Optional)
//
// Server-side generated idempotency key from the previous responses (if any). Server
// can use this to validate that the server side session has not changed.
optional string client_observed_server_side_session_id = 7;

// (Required) User context
//
// user_context.user_id and session+id both identify a unique remote spark session on the
Expand Down Expand Up @@ -856,6 +904,12 @@ message FetchErrorDetailsRequest {
// The id should be a UUID string of the format `00112233-4455-6677-8899-aabbccddeeff`.
string session_id = 1;

// (Optional)
//
// Server-side generated idempotency key from the previous responses (if any). Server
// can use this to validate that the server side session has not changed.
optional string client_observed_server_side_session_id = 5;

// User context
UserContext user_context = 2;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,27 @@ import io.grpc.stub.StreamObserver

import org.apache.spark.internal.Logging

// This is common logic to be shared between different stub instances to validate responses as
// seen by the client.
// This is common logic to be shared between different stub instances to keep the server-side
// session id and to validate responses as seen by the client.
class ResponseValidator extends Logging {

// Server side session ID, used to detect if the server side session changed. This is set upon
// receiving the first response from the server. This value is used only for executions that
// do not use server-side streaming.
private var serverSideSessionId: Option[String] = None

// Returns the server side session ID, used to send it back to the server in the follow-up
// requests so the server can validate it session id against the previous requests.
def getServerSideSessionId: Option[String] = serverSideSessionId

/**
* Hijacks the stored server side session ID with the given suffix. Used for testing to make
* sure that server is validating the session ID.
*/
private[sql] def hijackServerSideSessionIdForTesting(suffix: String): Unit = {
serverSideSessionId = Some(serverSideSessionId.getOrElse("") + suffix)
}

def verifyResponse[RespT <: GeneratedMessageV3](fn: => RespT): RespT = {
val response = fn
val field = response.getDescriptorForType.findFieldByName("server_side_session_id")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ private[sql] class SparkConnectClient(
// a new client will create a new session ID.
private[sql] val sessionId: String = configuration.sessionId.getOrElse(UUID.randomUUID.toString)

/**
* Hijacks the stored server side session ID with the given suffix. Used for testing to make
* sure that server is validating the session ID.
*/
private[sql] def hijackServerSideSessionIdForTesting(suffix: String) = {
stubState.responseValidator.hijackServerSideSessionIdForTesting(suffix)
}

private[sql] val artifactManager: ArtifactManager = {
new ArtifactManager(configuration, sessionId, bstub, stub)
}
Expand All @@ -73,6 +81,14 @@ private[sql] class SparkConnectClient(
private[sql] def uploadAllClassFileArtifacts(): Unit =
artifactManager.uploadAllClassFileArtifacts()

/**
* Returns the server-side session id obtained from the first request, if there was a request
* already.
*/
private def serverSideSessionId: Option[String] = {
stubState.responseValidator.getServerSideSessionId
}

/**
* Dispatch the [[proto.AnalyzePlanRequest]] to the Spark Connect server.
* @return
Expand All @@ -99,11 +115,11 @@ private[sql] class SparkConnectClient(
.setSessionId(sessionId)
.setClientType(userAgent)
.addAllTags(tags.get.toSeq.asJava)
.build()
serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session))
if (configuration.useReattachableExecute) {
bstub.executePlanReattachable(request)
bstub.executePlanReattachable(request.build())
} else {
bstub.executePlan(request)
bstub.executePlan(request.build())
}
}

Expand All @@ -119,8 +135,8 @@ private[sql] class SparkConnectClient(
.setSessionId(sessionId)
.setClientType(userAgent)
.setUserContext(userContext)
.build()
bstub.config(request)
serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session))
bstub.config(request.build())
}

/**
Expand Down Expand Up @@ -207,8 +223,8 @@ private[sql] class SparkConnectClient(
.setUserContext(userContext)
.setSessionId(sessionId)
.setClientType(userAgent)
.build()
analyze(request)
serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session))
analyze(request.build())
}

private[sql] def interruptAll(): proto.InterruptResponse = {
Expand All @@ -218,8 +234,8 @@ private[sql] class SparkConnectClient(
.setSessionId(sessionId)
.setClientType(userAgent)
.setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_ALL)
.build()
bstub.interrupt(request)
serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session))
bstub.interrupt(request.build())
}

private[sql] def interruptTag(tag: String): proto.InterruptResponse = {
Expand All @@ -230,8 +246,8 @@ private[sql] class SparkConnectClient(
.setClientType(userAgent)
.setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_TAG)
.setOperationTag(tag)
.build()
bstub.interrupt(request)
serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session))
bstub.interrupt(request.build())
}

private[sql] def interruptOperation(id: String): proto.InterruptResponse = {
Expand All @@ -242,8 +258,8 @@ private[sql] class SparkConnectClient(
.setClientType(userAgent)
.setInterruptType(proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_OPERATION_ID)
.setOperationId(id)
.build()
bstub.interrupt(request)
serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session))
bstub.interrupt(request.build())
}

private[sql] def releaseSession(): proto.ReleaseSessionResponse = {
Expand All @@ -252,8 +268,7 @@ private[sql] class SparkConnectClient(
.setUserContext(userContext)
.setSessionId(sessionId)
.setClientType(userAgent)
.build()
bstub.releaseSession(request)
bstub.releaseSession(request.build())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two lines are equivalent right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is just to make it uniform to everything else.

}

private[this] val tags = new InheritableThreadLocal[mutable.Set[String]] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,14 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr

override def onNext(req: AddArtifactsRequest): Unit = try {
if (this.holder == null) {
val previousSessionId = req.hasClientObservedServerSideSessionId match {
case true => Some(req.getClientObservedServerSideSessionId)
case false => None
}
this.holder = SparkConnectService.getOrCreateIsolatedSession(
req.getUserContext.getUserId,
req.getSessionId)
req.getSessionId,
previousSessionId)
}

if (req.hasBeginChunk) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,14 @@ private[connect] class SparkConnectAnalyzeHandler(
extends Logging {

def handle(request: proto.AnalyzePlanRequest): Unit = {
val previousSessionId = request.hasClientObservedServerSideSessionId match {
case true => Some(request.getClientObservedServerSideSessionId)
case false => None
}
val sessionHolder = SparkConnectService.getOrCreateIsolatedSession(
request.getUserContext.getUserId,
request.getSessionId)
request.getSessionId,
previousSessionId)
// `withSession` ensures that session-specific artifacts (such as JARs and class files) are
// available during processing (such as deserialization).
sessionHolder.withSession { _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,28 @@ class SparkConnectArtifactStatusesHandler(
val responseObserver: StreamObserver[proto.ArtifactStatusesResponse])
extends Logging {

protected def cacheExists(userId: String, sessionId: String, hash: String): Boolean = {
protected def cacheExists(
userId: String,
sessionId: String,
previouslySeenSessionId: Option[String],
hash: String): Boolean = {
val session = SparkConnectService
.getOrCreateIsolatedSession(userId, sessionId)
.getOrCreateIsolatedSession(userId, sessionId, previouslySeenSessionId)
.session
val blockManager = session.sparkContext.env.blockManager
blockManager.getStatus(CacheId(session.sessionUUID, hash)).isDefined
}

def handle(request: proto.ArtifactStatusesRequest): Unit = {
val previousSessionId = request.hasClientObservedServerSideSessionId match {
case true => Some(request.getClientObservedServerSideSessionId)
case false => None
}
val holder = SparkConnectService
.getOrCreateIsolatedSession(request.getUserContext.getUserId, request.getSessionId)
.getOrCreateIsolatedSession(
request.getUserContext.getUserId,
request.getSessionId,
previousSessionId)

val builder = proto.ArtifactStatusesResponse.newBuilder()
builder.setSessionId(holder.sessionId)
Expand All @@ -49,6 +60,7 @@ class SparkConnectArtifactStatusesHandler(
cacheExists(
userId = request.getUserContext.getUserId,
sessionId = request.getSessionId,
previouslySeenSessionId = previousSessionId,
hash = name.stripPrefix("cache/"))
} else false
builder.putStatuses(name, status.setExists(exists).build())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,16 @@ class SparkConnectConfigHandler(responseObserver: StreamObserver[proto.ConfigRes
extends Logging {

def handle(request: proto.ConfigRequest): Unit = {
val previousSessionId = request.hasClientObservedServerSideSessionId match {
case true => Some(request.getClientObservedServerSideSessionId)
case false => None
}
val holder =
SparkConnectService
.getOrCreateIsolatedSession(request.getUserContext.getUserId, request.getSessionId)
.getOrCreateIsolatedSession(
request.getUserContext.getUserId,
request.getSessionId,
previousSessionId)
val session = holder.session

val builder = request.getOperation.getOpTypeCase match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,15 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
* Create a new ExecuteHolder and register it with this global manager and with its session.
*/
private[connect] def createExecuteHolder(request: proto.ExecutePlanRequest): ExecuteHolder = {
val previousSessionId = request.hasClientObservedServerSideSessionId match {
case true => Some(request.getClientObservedServerSideSessionId)
case false => None
}
val sessionHolder = SparkConnectService
.getOrCreateIsolatedSession(request.getUserContext.getUserId, request.getSessionId)
.getOrCreateIsolatedSession(
request.getUserContext.getUserId,
request.getSessionId,
previousSessionId)
val executeHolder = new ExecuteHolder(request, sessionHolder)
executionsLock.synchronized {
// Check if the operation already exists, both in active executions, and in the graveyard
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,13 @@ class SparkConnectFetchErrorDetailsHandler(
responseObserver: StreamObserver[proto.FetchErrorDetailsResponse]) {

def handle(v: proto.FetchErrorDetailsRequest): Unit = {
val previousSessionId = v.hasClientObservedServerSideSessionId match {
case true => Some(v.getClientObservedServerSideSessionId)
case false => None
}
val sessionHolder =
SparkConnectService
.getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getSessionId)
.getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getSessionId, previousSessionId)

val response = Option(sessionHolder.errorIdToError.getIfPresent(v.getErrorId))
.map { error =>
Expand Down
Loading