diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 19c5a3f14c64..80336fb1eaea 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -829,10 +829,16 @@ object SparkSession extends Logging { /** * Set the (global) default [[SparkSession]], and (thread-local) active [[SparkSession]] when - * they are not set yet. + * they are not set yet or the associated [[SparkConnectClient]] is unusable. */ private def setDefaultAndActiveSession(session: SparkSession): Unit = { - defaultSession.compareAndSet(null, session) + val currentDefault = defaultSession.getAcquire + if (currentDefault == null || !currentDefault.client.isSessionValid) { + // Update `defaultSession` if it is null or the contained session is not valid. There is a + // chance that the following `compareAndSet` fails if a new default session has just been set, + // but that does not matter since that event has happened after this method was invoked. + defaultSession.compareAndSet(currentDefault, session) + } if (getActiveSession.isEmpty) { setActiveSession(session) } @@ -972,7 +978,7 @@ object SparkSession extends Logging { def appName(name: String): Builder = this private def tryCreateSessionFromClient(): Option[SparkSession] = { - if (client != null) { + if (client != null && client.isSessionValid) { Option(new SparkSession(client, planIdGenerator)) } else { None @@ -1024,7 +1030,16 @@ object SparkSession extends Logging { */ def getOrCreate(): SparkSession = { val session = tryCreateSessionFromClient() - .getOrElse(sessions.get(builder.configuration)) + .getOrElse({ + var existingSession = sessions.get(builder.configuration) + if (!existingSession.client.isSessionValid) { + // If the cached session has become invalid, e.g., due to a server restart, the cache + // entry is invalidated. + sessions.invalidate(builder.configuration) + existingSession = sessions.get(builder.configuration) + } + existingSession + }) setDefaultAndActiveSession(session) applyOptions(session) session @@ -1032,11 +1047,13 @@ object SparkSession extends Logging { } /** - * Returns the default SparkSession. + * Returns the default SparkSession. If the previously set default SparkSession becomes + * unusable, returns None. * * @since 3.5.0 */ - def getDefaultSession: Option[SparkSession] = Option(defaultSession.get()) + def getDefaultSession: Option[SparkSession] = + Option(defaultSession.get()).filter(_.client.isSessionValid) /** * Sets the default SparkSession. @@ -1057,11 +1074,13 @@ object SparkSession extends Logging { } /** - * Returns the active SparkSession for the current thread. + * Returns the active SparkSession for the current thread. If the previously set active + * SparkSession becomes unusable, returns None. * * @since 3.5.0 */ - def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get()) + def getActiveSession: Option[SparkSession] = + Option(activeThreadSession.get()).filter(_.client.isSessionValid) /** * Changes the SparkSession that will be returned in this thread and its children when diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala index 203b1295005a..b28aa905c7a2 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala @@ -382,4 +382,43 @@ class SparkSessionE2ESuite extends ConnectFunSuite with RemoteSparkSession { .create() } } + + test("SPARK-47986: get or create after session changed") { + val remote = s"sc://localhost:$serverPort" + + SparkSession.clearDefaultSession() + SparkSession.clearActiveSession() + + val session1 = SparkSession + .builder() + .remote(remote) + .getOrCreate() + + assert(session1 eq SparkSession.getActiveSession.get) + assert(session1 eq SparkSession.getDefaultSession.get) + assert(session1.range(3).collect().length == 3) + + session1.client.hijackServerSideSessionIdForTesting("-testing") + + val e = intercept[SparkException] { + session1.range(3).analyze + } + + assert(e.getMessage.contains("[INVALID_HANDLE.SESSION_CHANGED]")) + assert(!session1.client.isSessionValid) + assert(SparkSession.getActiveSession.isEmpty) + assert(SparkSession.getDefaultSession.isEmpty) + + val session2 = SparkSession + .builder() + .remote(remote) + .getOrCreate() + + assert(session1 ne session2) + assert(session2.client.isSessionValid) + assert(session2 eq SparkSession.getActiveSession.get) + assert(session2 eq SparkSession.getDefaultSession.get) + assert(session2.range(3).collect().length == 3) + } + } diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala index 29272c96132b..42c3387335be 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala @@ -16,7 +16,10 @@ */ package org.apache.spark.sql.connect.client +import java.util.concurrent.atomic.AtomicBoolean + import com.google.protobuf.GeneratedMessageV3 +import io.grpc.{Status, StatusRuntimeException} import io.grpc.stub.StreamObserver import org.apache.spark.internal.Logging @@ -30,6 +33,12 @@ class ResponseValidator extends Logging { // do not use server-side streaming. private var serverSideSessionId: Option[String] = None + // Indicates whether the client and the client information on the server correspond to each other + // This flag being false means that the server has restarted and lost the client information, or + // there is a logic error in the code; both cases, the user should establish a new connection to + // the server. Access to the value has to be synchronized since it can be shared. + private val isSessionActive: AtomicBoolean = new AtomicBoolean(true) + // Returns the server side session ID, used to send it back to the server in the follow-up // requests so the server can validate it session id against the previous requests. def getServerSideSessionId: Option[String] = serverSideSessionId @@ -42,8 +51,25 @@ class ResponseValidator extends Logging { serverSideSessionId = Some(serverSideSessionId.getOrElse("") + suffix) } + /** + * Returns true if the session is valid on both the client and the server. + */ + private[sql] def isSessionValid: Boolean = { + // An active session is considered valid. + isSessionActive.getAcquire + } + def verifyResponse[RespT <: GeneratedMessageV3](fn: => RespT): RespT = { - val response = fn + val response = + try { + fn + } catch { + case e: StatusRuntimeException + if e.getStatus.getCode == Status.Code.INTERNAL && + e.getMessage.contains("[INVALID_HANDLE.SESSION_CHANGED]") => + isSessionActive.setRelease(false) + throw e + } val field = response.getDescriptorForType.findFieldByName("server_side_session_id") // If the field does not exist, we ignore it. New / Old message might not contain it and this // behavior allows us to be compatible. @@ -54,6 +80,7 @@ class ResponseValidator extends Logging { serverSideSessionId match { case Some(id) => if (value != id) { + isSessionActive.setRelease(false) throw new IllegalStateException( s"Server side session ID changed from $id to $value") } diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index b5eda024bfb3..7c3108fdb1b0 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -71,6 +71,17 @@ private[sql] class SparkConnectClient( stubState.responseValidator.hijackServerSideSessionIdForTesting(suffix) } + /** + * Returns true if the session is valid on both the client and the server. A session becomes + * invalid if the server side information about the client, e.g., session ID, does not + * correspond to the actual client state. + */ + private[sql] def isSessionValid: Boolean = { + // The last known state of the session is store in `responseValidator`, because it is where the + // client gets responses from the server. + stubState.responseValidator.isSessionValid + } + private[sql] val artifactManager: ArtifactManager = { new ArtifactManager(configuration, sessionId, bstub, stub) }