Skip to content
Closed
Next Next commit
Test added / not working correctly
  • Loading branch information
Changgyoo Park committed Jun 17, 2024
commit 3f095101168fd5d4840e73a7e45f8ee74624c8a9
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,7 @@ object SparkSession extends Logging {
def appName(name: String): Builder = this

private def tryCreateSessionFromClient(): Option[SparkSession] = {
if (client != null) {
if (client != null && !client.hasSessionChanged) {
Option(new SparkSession(client, planIdGenerator))
} else {
None
Expand Down Expand Up @@ -1024,7 +1024,14 @@ object SparkSession extends Logging {
*/
def getOrCreate(): SparkSession = {
val session = tryCreateSessionFromClient()
.getOrElse(sessions.get(builder.configuration))
.getOrElse({
var existingSession = sessions.get(builder.configuration)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a lock here for sessions?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so as Cache is backed by ConcurrentMap which allows concurrent access to the data. Source: https://guava.dev/releases/17.0/api/docs/com/google/common/cache/CacheBuilder.html.

while (existingSession.client != null && existingSession.client.hasSessionChanged) {
sessions.refresh(builder.configuration)
existingSession = sessions.get(builder.configuration)
}
existingSession
})
setDefaultAndActiveSession(session)
applyOptions(session)
session
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,4 +382,29 @@ class SparkSessionE2ESuite extends ConnectFunSuite with RemoteSparkSession {
.create()
}
}

test("get or create") {
val remote = s"sc://localhost:$serverPort"
val session1 = SparkSession
.builder()
.remote(remote)
.getOrCreate()
assert(session1.range(3).collect().length == 3)

session1.client.hijackServerSideSessionIdForTesting("-testing")

try {
session1.range(3).collect()
fail("unreachable")
} catch {
case t: Throwable => assert(t.getMessage.contains("INVALID_HANDLE.SESSION_CHANGED"))
}

val session2 = SparkSession
.builder()
.remote(remote)
.getOrCreate()
assert(session1 ne session2)
assert(session2.range(3).collect().length == 3)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class ResponseValidator extends Logging {
// do not use server-side streaming.
private var serverSideSessionId: Option[String] = None

// Indicates whether the server side session ID has changed. This flag being true usually means
// that the session is unusable and the user should establish a new connection to the server.
private var hasServerSideSessionIDChanged: Boolean = false

// Returns the server side session ID, used to send it back to the server in the follow-up
// requests so the server can validate it session id against the previous requests.
def getServerSideSessionId: Option[String] = serverSideSessionId
Expand All @@ -42,6 +46,13 @@ class ResponseValidator extends Logging {
serverSideSessionId = Some(serverSideSessionId.getOrElse("") + suffix)
}

/**
* Returns true if the server side session ID has changed.
*/
private[sql] def hasSessionChanged: Boolean = {
hasServerSideSessionIDChanged
}

def verifyResponse[RespT <: GeneratedMessageV3](fn: => RespT): RespT = {
val response = fn
val field = response.getDescriptorForType.findFieldByName("server_side_session_id")
Expand All @@ -54,6 +65,7 @@ class ResponseValidator extends Logging {
serverSideSessionId match {
case Some(id) =>
if (value != id) {
hasServerSideSessionIDChanged = true
throw new IllegalStateException(
s"Server side session ID changed from $id to $value")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ private[sql] class SparkConnectClient(
stubState.responseValidator.hijackServerSideSessionIdForTesting(suffix)
}

/**
* Checks if the session has received an `INVALID_HANDLE.SESSION_CHANGED` error.
*/
private[sql] def hasSessionChanged: Boolean = {
stubState.responseValidator.hasSessionChanged
}

private[sql] val artifactManager: ArtifactManager = {
new ArtifactManager(configuration, sessionId, bstub, stub)
}
Expand Down