diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index badd9a33397e..5e3462c2d0c1 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -1829,6 +1829,7 @@ def _verify_response_integrity( response.server_side_session_id and response.server_side_session_id != self._server_session_id ): + self._closed = True raise PySparkAssertionError( "Received incorrect server side session identifier for request. " "Please create a new Spark Session to reconnect. (" diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index b688ca022c8c..bec3c5b579a0 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -237,9 +237,9 @@ def create(self) -> "SparkSession": def getOrCreate(self) -> "SparkSession": with SparkSession._lock: session = SparkSession.getActiveSession() - if session is None or session.is_stopped: - session = SparkSession._default_session - if session is None or session.is_stopped: + if session is None: + session = SparkSession._get_default_session() + if session is None: session = self.create() self._apply_options(session) return session @@ -285,9 +285,19 @@ def _set_default_and_active_session(cls, session: "SparkSession") -> None: if getattr(cls._active_session, "session", None) is None: cls._active_session.session = session + @classmethod + def _get_default_session(cls) -> Optional["SparkSession"]: + s = cls._default_session + if s is not None and not s.is_stopped: + return s + return None + @classmethod def getActiveSession(cls) -> Optional["SparkSession"]: - return getattr(cls._active_session, "session", None) + s = getattr(cls._active_session, "session", None) + if s is not None and not s.is_stopped: + return s + return None @classmethod def _getActiveSessionIfMatches(cls, session_id: str) -> "SparkSession": @@ -315,7 +325,7 @@ def _getActiveSessionIfMatches(cls, session_id: str) -> "SparkSession": def active(cls) -> "SparkSession": session = cls.getActiveSession() if session is None: - session = cls._default_session + session = cls._get_default_session() if session is None: raise PySparkRuntimeError( error_class="NO_ACTIVE_OR_DEFAULT_SESSION", diff --git a/python/pyspark/sql/tests/connect/test_connect_session.py b/python/pyspark/sql/tests/connect/test_connect_session.py index c5ce697a9561..1dd5cde0dff5 100644 --- a/python/pyspark/sql/tests/connect/test_connect_session.py +++ b/python/pyspark/sql/tests/connect/test_connect_session.py @@ -242,7 +242,7 @@ def toChannel(self): session = RemoteSparkSession.builder.channelBuilder(CustomChannelBuilder()).create() session.sql("select 1 + 1") - def test_reset_when_server_session_changes(self): + def test_reset_when_server_and_client_sessionids_mismatch(self): session = RemoteSparkSession.builder.remote("sc://localhost").getOrCreate() # run a simple query so the session id is synchronized. session.range(3).collect() @@ -256,6 +256,20 @@ def test_reset_when_server_session_changes(self): session = RemoteSparkSession.builder.remote("sc://localhost").getOrCreate() session.range(3).collect() + def test_reset_when_server_session_id_mismatch(self): + session = RemoteSparkSession.builder.remote("sc://localhost").getOrCreate() + # run a simple query so the session id is synchronized. + session.range(3).collect() + + # trigger a mismatch + session._client._server_session_id = str(uuid.uuid4()) + with self.assertRaises(SparkConnectException): + session.range(3).collect() + + # assert that getOrCreate() generates a new session + session = RemoteSparkSession.builder.remote("sc://localhost").getOrCreate() + session.range(3).collect() + class SparkConnectSessionWithOptionsTest(unittest.TestCase): def setUp(self) -> None: diff --git a/python/pyspark/sql/tests/connect/test_session.py b/python/pyspark/sql/tests/connect/test_session.py index 5184b9f06171..820f54b83327 100644 --- a/python/pyspark/sql/tests/connect/test_session.py +++ b/python/pyspark/sql/tests/connect/test_session.py @@ -77,6 +77,34 @@ def test_session_create_sets_active_session(self): self.assertIs(session, session2) session.stop() + def test_active_session_expires_when_client_closes(self): + s1 = RemoteSparkSession.builder.remote("sc://other").getOrCreate() + s2 = RemoteSparkSession.getActiveSession() + + self.assertIs(s1, s2) + + # We don't call close() to avoid executing ExecutePlanResponseReattachableIterator + s1._client._closed = True + + self.assertIsNone(RemoteSparkSession.getActiveSession()) + s3 = RemoteSparkSession.builder.remote("sc://other").getOrCreate() + + self.assertIsNot(s1, s3) + + def test_default_session_expires_when_client_closes(self): + s1 = RemoteSparkSession.builder.remote("sc://other").getOrCreate() + s2 = RemoteSparkSession.getDefaultSession() + + self.assertIs(s1, s2) + + # We don't call close() to avoid executing ExecutePlanResponseReattachableIterator + s1._client._closed = True + + self.assertIsNone(RemoteSparkSession.getDefaultSession()) + s3 = RemoteSparkSession.builder.remote("sc://other").getOrCreate() + + self.assertIsNot(s1, s3) + class JobCancellationTests(ReusedConnectTestCase): def test_tags(self):