Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -829,10 +829,16 @@ object SparkSession extends Logging {

/**
* Set the (global) default [[SparkSession]], and (thread-local) active [[SparkSession]] when
* they are not set yet.
* they are not set yet or the associated [[SparkConnectClient]] is unusable.
*/
private def setDefaultAndActiveSession(session: SparkSession): Unit = {
defaultSession.compareAndSet(null, session)
val currentDefault = defaultSession.getAcquire
if (currentDefault == null || !currentDefault.client.isSessionValid) {
// Update `defaultSession` if it is null or the contained session is not valid. There is a
// chance that the following `compareAndSet` fails if a new default session has just been set,
// but that does not matter since that event has happened after this method was invoked.
defaultSession.compareAndSet(currentDefault, session)
}
if (getActiveSession.isEmpty) {
setActiveSession(session)
}
Expand Down Expand Up @@ -972,7 +978,7 @@ object SparkSession extends Logging {
def appName(name: String): Builder = this

private def tryCreateSessionFromClient(): Option[SparkSession] = {
if (client != null) {
if (client != null && client.isSessionValid) {
Option(new SparkSession(client, planIdGenerator))
} else {
None
Expand Down Expand Up @@ -1024,19 +1030,30 @@ object SparkSession extends Logging {
*/
def getOrCreate(): SparkSession = {
val session = tryCreateSessionFromClient()
.getOrElse(sessions.get(builder.configuration))
.getOrElse({
var existingSession = sessions.get(builder.configuration)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a lock here for sessions?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so as Cache is backed by ConcurrentMap which allows concurrent access to the data. Source: https://guava.dev/releases/17.0/api/docs/com/google/common/cache/CacheBuilder.html.

if (!existingSession.client.isSessionValid) {
// If the cached session has become invalid, e.g., due to a server restart, the cache
// entry is invalidated.
sessions.invalidate(builder.configuration)
existingSession = sessions.get(builder.configuration)
}
existingSession
})
setDefaultAndActiveSession(session)
applyOptions(session)
session
}
}

/**
* Returns the default SparkSession.
* Returns the default SparkSession. If the previously set default SparkSession becomes
* unusable, returns None.
*
* @since 3.5.0
*/
def getDefaultSession: Option[SparkSession] = Option(defaultSession.get())
def getDefaultSession: Option[SparkSession] =
Option(defaultSession.get()).filter(_.client.isSessionValid)

/**
* Sets the default SparkSession.
Expand All @@ -1057,11 +1074,13 @@ object SparkSession extends Logging {
}

/**
* Returns the active SparkSession for the current thread.
* Returns the active SparkSession for the current thread. If the previously set active
* SparkSession becomes unusable, returns None.
*
* @since 3.5.0
*/
def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get())
def getActiveSession: Option[SparkSession] =
Option(activeThreadSession.get()).filter(_.client.isSessionValid)

/**
* Changes the SparkSession that will be returned in this thread and its children when
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,4 +382,43 @@ class SparkSessionE2ESuite extends ConnectFunSuite with RemoteSparkSession {
.create()
}
}

test("SPARK-47986: get or create after session changed") {
val remote = s"sc://localhost:$serverPort"

SparkSession.clearDefaultSession()
SparkSession.clearActiveSession()

val session1 = SparkSession
.builder()
.remote(remote)
.getOrCreate()

assert(session1 eq SparkSession.getActiveSession.get)
assert(session1 eq SparkSession.getDefaultSession.get)
assert(session1.range(3).collect().length == 3)

session1.client.hijackServerSideSessionIdForTesting("-testing")

val e = intercept[SparkException] {
session1.range(3).analyze
}

assert(e.getMessage.contains("[INVALID_HANDLE.SESSION_CHANGED]"))
assert(!session1.client.isSessionValid)
assert(SparkSession.getActiveSession.isEmpty)
assert(SparkSession.getDefaultSession.isEmpty)

val session2 = SparkSession
.builder()
.remote(remote)
.getOrCreate()

assert(session1 ne session2)
assert(session2.client.isSessionValid)
assert(session2 eq SparkSession.getActiveSession.get)
assert(session2 eq SparkSession.getDefaultSession.get)
assert(session2.range(3).collect().length == 3)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
*/
package org.apache.spark.sql.connect.client

import java.util.concurrent.atomic.AtomicBoolean

import com.google.protobuf.GeneratedMessageV3
import io.grpc.{Status, StatusRuntimeException}
import io.grpc.stub.StreamObserver

import org.apache.spark.internal.Logging
Expand All @@ -30,6 +33,12 @@ class ResponseValidator extends Logging {
// do not use server-side streaming.
private var serverSideSessionId: Option[String] = None

// Indicates whether the client and the client information on the server correspond to each other
// This flag being false means that the server has restarted and lost the client information, or
// there is a logic error in the code; both cases, the user should establish a new connection to
// the server. Access to the value has to be synchronized since it can be shared.
private val isSessionActive: AtomicBoolean = new AtomicBoolean(true)

// Returns the server side session ID, used to send it back to the server in the follow-up
// requests so the server can validate it session id against the previous requests.
def getServerSideSessionId: Option[String] = serverSideSessionId
Expand All @@ -42,8 +51,25 @@ class ResponseValidator extends Logging {
serverSideSessionId = Some(serverSideSessionId.getOrElse("") + suffix)
}

/**
* Returns true if the session is valid on both the client and the server.
*/
private[sql] def isSessionValid: Boolean = {
// An active session is considered valid.
isSessionActive.getAcquire
}

def verifyResponse[RespT <: GeneratedMessageV3](fn: => RespT): RespT = {
val response = fn
val response =
try {
fn
} catch {
case e: StatusRuntimeException
if e.getStatus.getCode == Status.Code.INTERNAL &&
e.getMessage.contains("[INVALID_HANDLE.SESSION_CHANGED]") =>
isSessionActive.setRelease(false)
throw e
}
val field = response.getDescriptorForType.findFieldByName("server_side_session_id")
// If the field does not exist, we ignore it. New / Old message might not contain it and this
// behavior allows us to be compatible.
Expand All @@ -54,6 +80,7 @@ class ResponseValidator extends Logging {
serverSideSessionId match {
case Some(id) =>
if (value != id) {
isSessionActive.setRelease(false)
throw new IllegalStateException(
s"Server side session ID changed from $id to $value")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,17 @@ private[sql] class SparkConnectClient(
stubState.responseValidator.hijackServerSideSessionIdForTesting(suffix)
}

/**
* Returns true if the session is valid on both the client and the server. A session becomes
* invalid if the server side information about the client, e.g., session ID, does not
* correspond to the actual client state.
*/
private[sql] def isSessionValid: Boolean = {
// The last known state of the session is store in `responseValidator`, because it is where the
// client gets responses from the server.
stubState.responseValidator.isSessionValid
}

private[sql] val artifactManager: ArtifactManager = {
new ArtifactManager(configuration, sessionId, bstub, stub)
}
Expand Down