Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 50 additions & 8 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,25 @@ def fromProto(cls, pb: pb2.ConfigResponse) -> "ConfigResult":
)


class ForbidRecursion:
def __init__(self):
self._local = threading.local()
self._local.in_recursion = False

@property
def can_enter(self):
return self._local.in_recursion

def __enter__(self):
if self._local.in_recursion:
raise RecursionError

self._local.in_recursion = True

def __exit__(self, exc_type, exc_val, exc_tb):
self._local.in_recursion = False


class SparkConnectClient(object):
"""
Conceptually the remote spark session that communicates with the server
Expand Down Expand Up @@ -631,6 +650,8 @@ def __init__(
# be updated on the first response received.
self._server_session_id: Optional[str] = None

self._forbid_recursive_error_handling = ForbidRecursion()

def _retrying(self) -> "Retrying":
return Retrying(self._retry_policies)

Expand Down Expand Up @@ -1494,13 +1515,20 @@ def _handle_error(self, error: Exception) -> NoReturn:
-------
Throws the appropriate internal Python exception.
"""
if isinstance(error, grpc.RpcError):
self._handle_rpc_error(error)
elif isinstance(error, ValueError):
if "Cannot invoke RPC" in str(error) and "closed" in str(error):
raise SparkConnectException(
error_class="NO_ACTIVE_SESSION", message_parameters=dict()
) from None

if not self._forbid_recursive_error_handling.can_enter:
# We are already inside error handling routine,
# avoid recursive error processing (with potentially infinite recursion)
raise error

with self._forbid_recursive_error_handling:
if isinstance(error, grpc.RpcError):
self._handle_rpc_error(error)
elif isinstance(error, ValueError):
if "Cannot invoke RPC" in str(error) and "closed" in str(error):
raise SparkConnectException(
error_class="NO_ACTIVE_SESSION", message_parameters=dict()
) from None
raise error

def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDetailsResponse]:
Expand All @@ -1520,6 +1548,20 @@ def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDet
except grpc.RpcError:
return None

def _display_server_stack_trace(self) -> bool:
from pyspark.sql.connect.conf import RuntimeConf

conf = RuntimeConf(self)
try:
if conf.get("spark.sql.connect.serverStacktrace.enabled") == "true":
return True
return conf.get("spark.sql.pyspark.jvmStacktrace.enabled") == "true"
except Exception as e: # noqa: F841
# Falls back to true if an exception occurs during reading the config.
# Otherwise, it will recursively try to get the conf when it consistently
# fails, ending up with `RecursionError`.
return True

def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn:
"""
Error handling helper for dealing with GRPC Errors. On the server side, certain
Expand Down Expand Up @@ -1553,7 +1595,7 @@ def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn:
info,
status.message,
self._fetch_enriched_error(info),
True,
self._display_server_stack_trace(),
) from None

raise SparkConnectGrpcException(status.message) from None
Expand Down
21 changes: 21 additions & 0 deletions python/pyspark/sql/tests/connect/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from collections.abc import Generator
from typing import Optional, Any

from pyspark.sql.connect.client.core import ForbidRecursion
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
from pyspark.testing.utils import eventually

Expand Down Expand Up @@ -133,6 +134,26 @@ def test_channel_builder_with_session(self):
client = SparkConnectClient(chan)
self.assertEqual(client._session_id, chan.session_id)

def test_forbid_recursion(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test does not test directly the scenario, we're talking about. Ideally you can just use the mock tests we have to fail any query and see that the recursion guard works.

Copy link
Contributor Author

@cdkrot cdkrot Dec 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually tried, but it seems hard to make the mock test for this because it needs to pass through this pieces of code:

status = rpc_status.from_call(cast(grpc.Call, rpc_error))

This seems hard to create a mock exception which would pass this without poking grpc's internals significantly. Alternatively we could introduce some testing clutches here, i.e. check if exception is from testing code, but that's not great either.

@grundprinzip

guard = ForbidRecursion()
max_depth = 0

def g(n):
nonlocal max_depth
with guard:
max_depth = n
g(n + 1)

with self.assertRaises(RecursionError):
g(1)
self.assertEqual(max_depth, 1)

# Do the same test again to check that guard resets.
max_depth = 0
with self.assertRaises(RecursionError):
g(1)
self.assertEqual(max_depth, 1)


class TestPolicy(DefaultPolicy):
def __init__(self):
Expand Down