Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
test
  • Loading branch information
cdkrot committed Dec 4, 2023
commit 285b85ccae02517b520c221e02068aa88ef5687e
37 changes: 30 additions & 7 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,25 @@ def fromProto(cls, pb: pb2.ConfigResponse) -> "ConfigResult":
)


class ForbidRecursion:
def __init__(self):
self._local = threading.local()
self._local.in_recursion = False

@property
def can_enter(self):
return self._local.in_recursion

def __enter__(self):
if self._local.in_recursion:
raise RecursionError

self._local.in_recursion = True

def __exit__(self, exc_type, exc_val, exc_tb):
self._local.in_recursion = False


class SparkConnectClient(object):
"""
Conceptually the remote spark session that communicates with the server
Expand Down Expand Up @@ -631,6 +650,8 @@ def __init__(
# be updated on the first response received.
self._server_session_id: Optional[str] = None

self._forbid_recursive_error_handling = ForbidRecursion()

def _retrying(self) -> "Retrying":
return Retrying(self._retry_policies)

Expand Down Expand Up @@ -1494,13 +1515,15 @@ def _handle_error(self, error: Exception) -> NoReturn:
-------
Throws the appropriate internal Python exception.
"""
if isinstance(error, grpc.RpcError):
self._handle_rpc_error(error)
elif isinstance(error, ValueError):
if "Cannot invoke RPC" in str(error) and "closed" in str(error):
raise SparkConnectException(
error_class="NO_ACTIVE_SESSION", message_parameters=dict()
) from None
if True or self._forbid_recursive_error_handling.can_enter:
with self._forbid_recursive_error_handling:
if isinstance(error, grpc.RpcError):
self._handle_rpc_error(error)
elif isinstance(error, ValueError):
if "Cannot invoke RPC" in str(error) and "closed" in str(error):
raise SparkConnectException(
error_class="NO_ACTIVE_SESSION", message_parameters=dict()
) from None
raise error

def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDetailsResponse]:
Expand Down