diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 0b502494f781..58ec00641c2c 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -587,7 +587,12 @@ def __init__( use_reattachable_execute: bool Enable reattachable execution. """ - self.thread_local = threading.local() + + class ClientThreadLocals(threading.local): + tags: set = set() + inside_error_handling: bool = False + + self.thread_local = ClientThreadLocals() # Parse the connection string. self._builder = ( @@ -1494,14 +1499,24 @@ def _handle_error(self, error: Exception) -> NoReturn: ------- Throws the appropriate internal Python exception. """ - if isinstance(error, grpc.RpcError): - self._handle_rpc_error(error) - elif isinstance(error, ValueError): - if "Cannot invoke RPC" in str(error) and "closed" in str(error): - raise SparkConnectException( - error_class="NO_ACTIVE_SESSION", message_parameters=dict() - ) from None - raise error + + if self.thread_local.inside_error_handling: + # We are already inside error handling routine, + # avoid recursive error processing (with potentially infinite recursion) + raise error + + try: + self.thread_local.inside_error_handling = True + if isinstance(error, grpc.RpcError): + self._handle_rpc_error(error) + elif isinstance(error, ValueError): + if "Cannot invoke RPC" in str(error) and "closed" in str(error): + raise SparkConnectException( + error_class="NO_ACTIVE_SESSION", message_parameters=dict() + ) from None + raise error + finally: + self.thread_local.inside_error_handling = False def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDetailsResponse]: if "errorId" not in info.metadata: @@ -1520,6 +1535,20 @@ def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDet except grpc.RpcError: return None + def _display_server_stack_trace(self) -> bool: + from pyspark.sql.connect.conf import RuntimeConf + + conf = RuntimeConf(self) + try: + if conf.get("spark.sql.connect.serverStacktrace.enabled") == "true": + return True + return conf.get("spark.sql.pyspark.jvmStacktrace.enabled") == "true" + except Exception as e: # noqa: F841 + # Falls back to true if an exception occurs during reading the config. + # Otherwise, it will recursively try to get the conf when it consistently + # fails, ending up with `RecursionError`. + return True + def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn: """ Error handling helper for dealing with GRPC Errors. On the server side, certain @@ -1553,7 +1582,7 @@ def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn: info, status.message, self._fetch_enriched_error(info), - True, + self._display_server_stack_trace(), ) from None raise SparkConnectGrpcException(status.message) from None