Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 55 additions & 9 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,25 @@ def fromProto(cls, pb: pb2.ConfigResponse) -> "ConfigResult":
)


class ForbidRecursion:
def __init__(self):
self._local = threading.local()
self._local.in_recursion = False

@property
def can_enter(self):
return self._local.in_recursion

def __enter__(self):
if self._local.in_recursion:
raise RecursionError

self._local.in_recursion = True

def __exit__(self, exc_type, exc_val, exc_tb):
self._local.in_recursion = False


class SparkConnectClient(object):
"""
Conceptually the remote spark session that communicates with the server
Expand Down Expand Up @@ -631,6 +650,8 @@ def __init__(
# be updated on the first response received.
self._server_session_id: Optional[str] = None

self._inside_error_handling: bool = False

def _retrying(self) -> "Retrying":
return Retrying(self._retry_policies)

Expand Down Expand Up @@ -1494,14 +1515,25 @@ def _handle_error(self, error: Exception) -> NoReturn:
-------
Throws the appropriate internal Python exception.
"""
if isinstance(error, grpc.RpcError):
self._handle_rpc_error(error)
elif isinstance(error, ValueError):
if "Cannot invoke RPC" in str(error) and "closed" in str(error):
raise SparkConnectException(
error_class="NO_ACTIVE_SESSION", message_parameters=dict()
) from None
raise error

if not self._inside_error_handling:
# We are already inside error handling routine,
# avoid recursive error processing (with potentially infinite recursion)
raise error

try:
self._inside_error_handling = True
Copy link
Contributor

@heyihong heyihong Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know Python has Global Interpreter Lock but is this thread-safe?


if isinstance(error, grpc.RpcError):
self._handle_rpc_error(error)
elif isinstance(error, ValueError):
if "Cannot invoke RPC" in str(error) and "closed" in str(error):
raise SparkConnectException(
error_class="NO_ACTIVE_SESSION", message_parameters=dict()
) from None
raise error
finally:
self._inside_error_handling = False

def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDetailsResponse]:
if "errorId" not in info.metadata:
Expand All @@ -1520,6 +1552,20 @@ def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDet
except grpc.RpcError:
return None

def _display_server_stack_trace(self) -> bool:
from pyspark.sql.connect.conf import RuntimeConf

conf = RuntimeConf(self)
try:
if conf.get("spark.sql.connect.serverStacktrace.enabled") == "true":
return True
return conf.get("spark.sql.pyspark.jvmStacktrace.enabled") == "true"
except Exception as e: # noqa: F841
# Falls back to true if an exception occurs during reading the config.
# Otherwise, it will recursively try to get the conf when it consistently
# fails, ending up with `RecursionError`.
return True

def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn:
"""
Error handling helper for dealing with GRPC Errors. On the server side, certain
Expand Down Expand Up @@ -1553,7 +1599,7 @@ def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn:
info,
status.message,
self._fetch_enriched_error(info),
True,
self._display_server_stack_trace(),
) from None

raise SparkConnectGrpcException(status.message) from None
Expand Down
1 change: 1 addition & 0 deletions python/pyspark/sql/tests/connect/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from collections.abc import Generator
from typing import Optional, Any

from pyspark.sql.connect.client.core import ForbidRecursion
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
from pyspark.testing.utils import eventually

Expand Down