Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 39 additions & 10 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,12 @@ def __init__(
use_reattachable_execute: bool
Enable reattachable execution.
"""
self.thread_local = threading.local()

class ClientThreadLocals(threading.local):
tags: set = set()
inside_error_handling: bool = False

self.thread_local = ClientThreadLocals()

# Parse the connection string.
self._builder = (
Expand Down Expand Up @@ -1494,14 +1499,24 @@ def _handle_error(self, error: Exception) -> NoReturn:
-------
Throws the appropriate internal Python exception.
"""
if isinstance(error, grpc.RpcError):
self._handle_rpc_error(error)
elif isinstance(error, ValueError):
if "Cannot invoke RPC" in str(error) and "closed" in str(error):
raise SparkConnectException(
error_class="NO_ACTIVE_SESSION", message_parameters=dict()
) from None
raise error

if self.thread_local.inside_error_handling:
# We are already inside error handling routine,
# avoid recursive error processing (with potentially infinite recursion)
raise error

try:
self.thread_local.inside_error_handling = True
if isinstance(error, grpc.RpcError):
self._handle_rpc_error(error)
elif isinstance(error, ValueError):
if "Cannot invoke RPC" in str(error) and "closed" in str(error):
raise SparkConnectException(
error_class="NO_ACTIVE_SESSION", message_parameters=dict()
) from None
raise error
finally:
self.thread_local.inside_error_handling = False

def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDetailsResponse]:
if "errorId" not in info.metadata:
Expand All @@ -1520,6 +1535,20 @@ def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDet
except grpc.RpcError:
return None

def _display_server_stack_trace(self) -> bool:
from pyspark.sql.connect.conf import RuntimeConf

conf = RuntimeConf(self)
try:
if conf.get("spark.sql.connect.serverStacktrace.enabled") == "true":
return True
return conf.get("spark.sql.pyspark.jvmStacktrace.enabled") == "true"
except Exception as e: # noqa: F841
# Falls back to true if an exception occurs during reading the config.
# Otherwise, it will recursively try to get the conf when it consistently
# fails, ending up with `RecursionError`.
return True

def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn:
"""
Error handling helper for dealing with GRPC Errors. On the server side, certain
Expand Down Expand Up @@ -1553,7 +1582,7 @@ def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn:
info,
status.message,
self._fetch_enriched_error(info),
True,
self._display_server_stack_trace(),
) from None

raise SparkConnectGrpcException(status.message) from None
Expand Down