Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1763,6 +1763,9 @@ def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn:
info = error_details_pb2.ErrorInfo()
d.Unpack(info)

if info.metadata["errorClass"] == "INVALID_HANDLE.SESSION_CHANGED":
self._closed = True

raise convert_exception(
info,
status.message,
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,9 @@ def create(self) -> "SparkSession":
def getOrCreate(self) -> "SparkSession":
with SparkSession._lock:
session = SparkSession.getActiveSession()
if session is None:
if session is None or session.is_stopped:
session = SparkSession._default_session
if session is None:
if session is None or session.is_stopped:
session = self.create()
self._apply_options(session)
return session
Expand Down
14 changes: 14 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,20 @@ def toChannel(self):
session = RemoteSparkSession.builder.channelBuilder(CustomChannelBuilder()).create()
session.sql("select 1 + 1")

def test_reset_when_server_session_changes(self):
session = RemoteSparkSession.builder.remote("sc://localhost").getOrCreate()
# run a simple query so the session id is synchronized.
session.range(3).collect()

# trigger a mismatch between client session id and server session id.
session._client._session_id = str(uuid.uuid4())
with self.assertRaises(SparkConnectException):
session.range(3).collect()

# assert that getOrCreate() generates a new session
session = RemoteSparkSession.builder.remote("sc://localhost").getOrCreate()
session.range(3).collect()


class SparkConnectSessionWithOptionsTest(unittest.TestCase):
def setUp(self) -> None:
Expand Down