From 4097b054d573c7a5be90cfca834cbb7543abed1b Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Mon, 6 Nov 2023 00:50:42 +0100 Subject: [PATCH 1/7] [SPARK-XXX][CONNECT][PYTHON] Better error handlng --- ...SparkConnectFetchErrorDetailsHandler.scala | 6 +- .../spark/sql/connect/utils/ErrorUtils.scala | 14 +++ python/pyspark/errors/exceptions/connect.py | 93 ++++++++++++++----- python/pyspark/sql/connect/client/core.py | 13 ++- .../sql/tests/connect/test_connect_basic.py | 25 ++--- 5 files changed, 112 insertions(+), 39 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectFetchErrorDetailsHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectFetchErrorDetailsHandler.scala index 17a6e9e434f3..b5a3c986d169 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectFetchErrorDetailsHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectFetchErrorDetailsHandler.scala @@ -20,9 +20,7 @@ import io.grpc.stub.StreamObserver import org.apache.spark.connect.proto import org.apache.spark.connect.proto.FetchErrorDetailsResponse -import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.utils.ErrorUtils -import org.apache.spark.sql.internal.SQLConf /** * Handles [[proto.FetchErrorDetailsRequest]]s for the [[SparkConnectService]]. The handler @@ -46,9 +44,7 @@ class SparkConnectFetchErrorDetailsHandler( ErrorUtils.throwableToFetchErrorDetailsResponse( st = error, - serverStackTraceEnabled = sessionHolder.session.conf.get( - Connect.CONNECT_SERVER_STACKTRACE_ENABLED) || sessionHolder.session.conf.get( - SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED)) + serverStackTraceEnabled = true) } .getOrElse(FetchErrorDetailsResponse.newBuilder().build()) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala index 744fa3c8aa1a..7cb555ca47ec 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala @@ -164,6 +164,20 @@ private[connect] object ErrorUtils extends Logging { "classes", JsonMethods.compact(JsonMethods.render(allClasses(st.getClass).map(_.getName)))) + // Add the SQL State and Error Class to the response metadata of the ErrorInfoObject. + st match { + case e: SparkThrowable => + val state = e.getSqlState + if (state != null && state.nonEmpty) { + errorInfo.putMetadata("sqlState", state) + } + val errorClass = e.getErrorClass + if (errorClass != null && errorClass.nonEmpty) { + errorInfo.putMetadata("errorClass", errorClass) + } + case _ => + } + if (sessionHolderOpt.exists(_.session.conf.get(Connect.CONNECT_ENRICH_ERROR_ENABLED))) { // Generate a new unique key for this exception. val errorId = UUID.randomUUID().toString diff --git a/python/pyspark/errors/exceptions/connect.py b/python/pyspark/errors/exceptions/connect.py index 423fb2c6f0ac..29426b190ac9 100644 --- a/python/pyspark/errors/exceptions/connect.py +++ b/python/pyspark/errors/exceptions/connect.py @@ -16,7 +16,7 @@ # import pyspark.sql.connect.proto as pb2 import json -from typing import Dict, List, Optional, TYPE_CHECKING +from typing import Dict, List, Optional, TYPE_CHECKING, overload from pyspark.errors.exceptions.base import ( AnalysisException as BaseAnalysisException, @@ -46,55 +46,68 @@ class SparkConnectException(PySparkException): def convert_exception( - info: "ErrorInfo", truncated_message: str, resp: Optional[pb2.FetchErrorDetailsResponse] + info: "ErrorInfo", + truncated_message: str, + resp: Optional[pb2.FetchErrorDetailsResponse], + display_stacktrace: bool = False ) -> SparkConnectException: classes = [] + sql_state = None + error_class = None + if "classes" in info.metadata: classes = json.loads(info.metadata["classes"]) + if "sqlState" in info.metadata: + sql_state = info.metadata["sqlState"] + + if "errorClass" in info.metadata: + error_class = info.metadata["errorClass"] + if resp is not None and resp.HasField("root_error_idx"): message = resp.errors[resp.root_error_idx].message stacktrace = _extract_jvm_stacktrace(resp) else: message = truncated_message - stacktrace = info.metadata["stackTrace"] if "stackTrace" in info.metadata else "" - - if len(stacktrace) > 0: - message += f"\n\nJVM stacktrace:\n{stacktrace}" + stacktrace = info.metadata["stackTrace"] if "stackTrace" in info.metadata else None + display_stacktrace = display_stacktrace if stacktrace is not None else False if "org.apache.spark.sql.catalyst.parser.ParseException" in classes: - return ParseException(message) + return ParseException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) # Order matters. ParseException inherits AnalysisException. elif "org.apache.spark.sql.AnalysisException" in classes: - return AnalysisException(message) + return AnalysisException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) elif "org.apache.spark.sql.streaming.StreamingQueryException" in classes: - return StreamingQueryException(message) + return StreamingQueryException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) elif "org.apache.spark.sql.execution.QueryExecutionException" in classes: - return QueryExecutionException(message) + return QueryExecutionException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) # Order matters. NumberFormatException inherits IllegalArgumentException. elif "java.lang.NumberFormatException" in classes: - return NumberFormatException(message) + return NumberFormatException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) elif "java.lang.IllegalArgumentException" in classes: - return IllegalArgumentException(message) + return IllegalArgumentException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) elif "java.lang.ArithmeticException" in classes: - return ArithmeticException(message) + return ArithmeticException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) elif "java.lang.UnsupportedOperationException" in classes: - return UnsupportedOperationException(message) + return UnsupportedOperationException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) elif "java.lang.ArrayIndexOutOfBoundsException" in classes: - return ArrayIndexOutOfBoundsException(message) + return ArrayIndexOutOfBoundsException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) elif "java.time.DateTimeException" in classes: - return DateTimeException(message) + return DateTimeException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) elif "org.apache.spark.SparkRuntimeException" in classes: - return SparkRuntimeException(message) + return SparkRuntimeException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) elif "org.apache.spark.SparkUpgradeException" in classes: - return SparkUpgradeException(message) + return SparkUpgradeException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) elif "org.apache.spark.api.python.PythonException" in classes: return PythonException( "\n An exception was thrown from the Python worker. " "Please see the stack trace below.\n%s" % message ) + # Make sure that the generic SparkException is handled last. + elif "org.apache.spark.SparkException" in classes: + return SparkException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) else: - return SparkConnectGrpcException(message, reason=info.reason) + return SparkConnectGrpcException(message, reason=info.reason, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) def _extract_jvm_stacktrace(resp: pb2.FetchErrorDetailsResponse) -> str: @@ -106,7 +119,7 @@ def _extract_jvm_stacktrace(resp: pb2.FetchErrorDetailsResponse) -> str: def format_stacktrace(error: pb2.FetchErrorDetailsResponse.Error) -> None: message = f"{error.error_type_hierarchy[0]}: {error.message}" if len(lines) == 0: - lines.append(message) + lines.append(error.error_type_hierarchy[0]) else: lines.append(f"Caused by: {message}") for elem in error.stack_trace: @@ -135,16 +148,48 @@ def __init__( error_class: Optional[str] = None, message_parameters: Optional[Dict[str, str]] = None, reason: Optional[str] = None, + sql_state: Optional[str] = None, + stacktrace: Optional[str] = None, + display_stacktrace: bool = False ) -> None: self.message = message # type: ignore[assignment] if reason is not None: self.message = f"({reason}) {self.message}" + # PySparkException has the assumption that error_class and message_parameters are + # only occurring together. If only one is set, we assume the message to be fully + # parsed. + tmp_error_class = error_class + tmp_message_parameters = message_parameters + if error_class is not None and message_parameters is None: + tmp_error_class = None + elif error_class is None and message_parameters is not None: + tmp_message_parameters = None + super().__init__( message=self.message, - error_class=error_class, - message_parameters=message_parameters, + error_class=tmp_error_class, + message_parameters=tmp_message_parameters ) + self.error_class = error_class + self._sql_state: Optional[str] = sql_state + self._stacktrace: Optional[str] = stacktrace + self._display_stacktrace: bool = display_stacktrace + + def getSqlState(self) -> None: + if self._sql_state is not None: + return self._sql_state + else: + return super().getSqlState() + + def getStackTrace(self) -> Optional[str]: + return self._stacktrace + + def __str__(self): + desc = self.message + if self._display_stacktrace: + desc += "\n\nJVM stacktrace:\n%s" % self._stacktrace + return desc class AnalysisException(SparkConnectGrpcException, BaseAnalysisException): @@ -223,3 +268,7 @@ class SparkUpgradeException(SparkConnectGrpcException, BaseSparkUpgradeException """ Exception thrown because of Spark upgrade from Spark Connect. """ + +class SparkException(SparkConnectGrpcException): + """ + """ \ No newline at end of file diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 11a1112ad1fe..69afef992c34 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -1564,6 +1564,14 @@ def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDet except grpc.RpcError: return None + def _display_stack_trace(self) -> bool: + from pyspark.sql.connect.conf import RuntimeConf + + conf = RuntimeConf(self) + if conf.get("spark.sql.connect.serverStacktrace.enabled") == "true": + return True + return conf.get("spark.sql.pyspark.jvmStacktrace.enabled") == "true" + def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn: """ Error handling helper for dealing with GRPC Errors. On the server side, certain @@ -1594,7 +1602,10 @@ def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn: d.Unpack(info) raise convert_exception( - info, status.message, self._fetch_enriched_error(info) + info, + status.message, + self._fetch_enriched_error(info), + self._display_stack_trace(), ) from None raise SparkConnectGrpcException(status.message) from None diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index f024a03c2686..daf6772e52bf 100755 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -3378,35 +3378,37 @@ def test_error_enrichment_jvm_stacktrace(self): """select from_json( '{"d": "02-29"}', 'd date', map('dateFormat', 'MM-dd'))""" ).collect() - self.assertTrue("JVM stacktrace" in e.exception.message) - self.assertTrue("org.apache.spark.SparkUpgradeException:" in e.exception.message) + self.assertTrue("JVM stacktrace" in str(e.exception)) + self.assertTrue("org.apache.spark.SparkUpgradeException" in str(e.exception)) self.assertTrue( "at org.apache.spark.sql.errors.ExecutionErrors" - ".failToParseDateTimeInNewParserError" in e.exception.message + ".failToParseDateTimeInNewParserError" in str(e.exception) ) - self.assertTrue("Caused by: java.time.DateTimeException:" in e.exception.message) + self.assertTrue("Caused by: java.time.DateTimeException:" in str(e.exception)) def test_not_hitting_netty_header_limit(self): with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": True}): with self.assertRaises(AnalysisException): - self.spark.sql("select " + "test" * 10000).collect() + self.spark.sql("select " + "test" * 1).collect() def test_error_stack_trace(self): with self.sql_conf({"spark.sql.connect.enrichError.enabled": False}): with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": True}): with self.assertRaises(AnalysisException) as e: self.spark.sql("select x").collect() - self.assertTrue("JVM stacktrace" in e.exception.message) + self.assertTrue("JVM stacktrace" in str(e.exception)) + self.assertIsNotNone(e.exception.getStackTrace()) self.assertTrue( - "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message + "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in str(e.exception) ) with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": False}): with self.assertRaises(AnalysisException) as e: self.spark.sql("select x").collect() - self.assertFalse("JVM stacktrace" in e.exception.message) + self.assertFalse("JVM stacktrace" in str(e.exception)) + self.assertIsNone(e.exception.getStackTrace()) self.assertFalse( - "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message + "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in str(e.exception) ) # Create a new session with a different stack trace size. @@ -3421,9 +3423,10 @@ def test_error_stack_trace(self): spark.conf.set("spark.sql.pyspark.jvmStacktrace.enabled", True) with self.assertRaises(AnalysisException) as e: spark.sql("select x").collect() - self.assertTrue("JVM stacktrace" in e.exception.message) + self.assertTrue("JVM stacktrace" in str(e.exception)) + self.assertIsNotNone(e.exception.getStackTrace()) self.assertFalse( - "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message + "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in str(e.exception) ) spark.stop() From ae1514917d9dd61ccb6b89ba903f11bd459e5ad9 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Mon, 6 Nov 2023 01:04:01 +0100 Subject: [PATCH 2/7] empty From d9dcd81e7847978d186a04c084e40cba4700ef01 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Mon, 6 Nov 2023 06:02:18 +0100 Subject: [PATCH 3/7] fix UT --- .../service/FetchErrorDetailsHandlerSuite.scala | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/FetchErrorDetailsHandlerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/FetchErrorDetailsHandlerSuite.scala index 40439a217230..ebcd1de60057 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/FetchErrorDetailsHandlerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/FetchErrorDetailsHandlerSuite.scala @@ -103,15 +103,11 @@ class FetchErrorDetailsHandlerSuite extends SharedSparkSession with ResourceHelp assert(response.getErrors(1).getErrorTypeHierarchy(1) == classOf[Throwable].getName) assert(response.getErrors(1).getErrorTypeHierarchy(2) == classOf[Object].getName) assert(!response.getErrors(1).hasCauseIdx) - if (serverStacktraceEnabled) { - assert(response.getErrors(0).getStackTraceCount == testError.getStackTrace.length) - assert( - response.getErrors(1).getStackTraceCount == - testError.getCause.getStackTrace.length) - } else { - assert(response.getErrors(0).getStackTraceCount == 0) - assert(response.getErrors(1).getStackTraceCount == 0) - } + assert(response.getErrors(0).getStackTraceCount == testError.getStackTrace.length) + assert( + response.getErrors(1).getStackTraceCount == + testError.getCause.getStackTrace.length) + } finally { sessionHolder.session.conf.unset(Connect.CONNECT_SERVER_STACKTRACE_ENABLED.key) sessionHolder.session.conf.unset(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED.key) From 624ceeff4875a8cc4a2fb4a4fed6fdce18807856 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Mon, 6 Nov 2023 10:54:24 +0100 Subject: [PATCH 4/7] fix UT --- python/pyspark/errors/exceptions/connect.py | 127 ++++++++++++++++---- 1 file changed, 106 insertions(+), 21 deletions(-) diff --git a/python/pyspark/errors/exceptions/connect.py b/python/pyspark/errors/exceptions/connect.py index 29426b190ac9..625ea01889f6 100644 --- a/python/pyspark/errors/exceptions/connect.py +++ b/python/pyspark/errors/exceptions/connect.py @@ -47,9 +47,9 @@ class SparkConnectException(PySparkException): def convert_exception( info: "ErrorInfo", - truncated_message: str, - resp: Optional[pb2.FetchErrorDetailsResponse], - display_stacktrace: bool = False + truncated_message: str, + resp: Optional[pb2.FetchErrorDetailsResponse], + display_stacktrace: bool = False, ) -> SparkConnectException: classes = [] sql_state = None @@ -73,31 +73,103 @@ def convert_exception( display_stacktrace = display_stacktrace if stacktrace is not None else False if "org.apache.spark.sql.catalyst.parser.ParseException" in classes: - return ParseException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) + return ParseException( + message, + error_class=error_class, + sql_state=sql_state, + stacktrace=stacktrace, + display_stacktrace=display_stacktrace, + ) # Order matters. ParseException inherits AnalysisException. elif "org.apache.spark.sql.AnalysisException" in classes: - return AnalysisException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) + return AnalysisException( + message, + error_class=error_class, + sql_state=sql_state, + stacktrace=stacktrace, + display_stacktrace=display_stacktrace, + ) elif "org.apache.spark.sql.streaming.StreamingQueryException" in classes: - return StreamingQueryException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) + return StreamingQueryException( + message, + error_class=error_class, + sql_state=sql_state, + stacktrace=stacktrace, + display_stacktrace=display_stacktrace, + ) elif "org.apache.spark.sql.execution.QueryExecutionException" in classes: - return QueryExecutionException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) + return QueryExecutionException( + message, + error_class=error_class, + sql_state=sql_state, + stacktrace=stacktrace, + display_stacktrace=display_stacktrace, + ) # Order matters. NumberFormatException inherits IllegalArgumentException. elif "java.lang.NumberFormatException" in classes: - return NumberFormatException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) + return NumberFormatException( + message, + error_class=error_class, + sql_state=sql_state, + stacktrace=stacktrace, + display_stacktrace=display_stacktrace, + ) elif "java.lang.IllegalArgumentException" in classes: - return IllegalArgumentException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) + return IllegalArgumentException( + message, + error_class=error_class, + sql_state=sql_state, + stacktrace=stacktrace, + display_stacktrace=display_stacktrace, + ) elif "java.lang.ArithmeticException" in classes: - return ArithmeticException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) + return ArithmeticException( + message, + error_class=error_class, + sql_state=sql_state, + stacktrace=stacktrace, + display_stacktrace=display_stacktrace, + ) elif "java.lang.UnsupportedOperationException" in classes: - return UnsupportedOperationException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) + return UnsupportedOperationException( + message, + error_class=error_class, + sql_state=sql_state, + stacktrace=stacktrace, + display_stacktrace=display_stacktrace, + ) elif "java.lang.ArrayIndexOutOfBoundsException" in classes: - return ArrayIndexOutOfBoundsException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) + return ArrayIndexOutOfBoundsException( + message, + error_class=error_class, + sql_state=sql_state, + stacktrace=stacktrace, + display_stacktrace=display_stacktrace, + ) elif "java.time.DateTimeException" in classes: - return DateTimeException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) + return DateTimeException( + message, + error_class=error_class, + sql_state=sql_state, + stacktrace=stacktrace, + display_stacktrace=display_stacktrace, + ) elif "org.apache.spark.SparkRuntimeException" in classes: - return SparkRuntimeException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) + return SparkRuntimeException( + message, + error_class=error_class, + sql_state=sql_state, + stacktrace=stacktrace, + display_stacktrace=display_stacktrace, + ) elif "org.apache.spark.SparkUpgradeException" in classes: - return SparkUpgradeException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) + return SparkUpgradeException( + message, + error_class=error_class, + sql_state=sql_state, + stacktrace=stacktrace, + display_stacktrace=display_stacktrace, + ) elif "org.apache.spark.api.python.PythonException" in classes: return PythonException( "\n An exception was thrown from the Python worker. " @@ -105,9 +177,22 @@ def convert_exception( ) # Make sure that the generic SparkException is handled last. elif "org.apache.spark.SparkException" in classes: - return SparkException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) + return SparkException( + message, + error_class=error_class, + sql_state=sql_state, + stacktrace=stacktrace, + display_stacktrace=display_stacktrace, + ) else: - return SparkConnectGrpcException(message, reason=info.reason, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace) + return SparkConnectGrpcException( + message, + reason=info.reason, + error_class=error_class, + sql_state=sql_state, + stacktrace=stacktrace, + display_stacktrace=display_stacktrace, + ) def _extract_jvm_stacktrace(resp: pb2.FetchErrorDetailsResponse) -> str: @@ -150,7 +235,7 @@ def __init__( reason: Optional[str] = None, sql_state: Optional[str] = None, stacktrace: Optional[str] = None, - display_stacktrace: bool = False + display_stacktrace: bool = False, ) -> None: self.message = message # type: ignore[assignment] if reason is not None: @@ -169,7 +254,7 @@ def __init__( super().__init__( message=self.message, error_class=tmp_error_class, - message_parameters=tmp_message_parameters + message_parameters=tmp_message_parameters, ) self.error_class = error_class self._sql_state: Optional[str] = sql_state @@ -269,6 +354,6 @@ class SparkUpgradeException(SparkConnectGrpcException, BaseSparkUpgradeException Exception thrown because of Spark upgrade from Spark Connect. """ + class SparkException(SparkConnectGrpcException): - """ - """ \ No newline at end of file + """ """ From b2ee6acf95daa4082b303d2bcdac0b5b643d1db2 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Mon, 6 Nov 2023 20:55:45 +0100 Subject: [PATCH 5/7] fix and review --- .../SparkConnectSessionHolderSuite.scala | 102 +++++++++--------- python/pyspark/errors/exceptions/connect.py | 70 ++++++------ python/pyspark/sql/connect/client/core.py | 4 +- 3 files changed, 89 insertions(+), 87 deletions(-) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala index 910c2a2650c6..9845cee31037 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala @@ -169,6 +169,56 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession { accumulator = null) } + test("python listener process: process terminates after listener is removed") { + // scalastyle:off assume + assume(IntegratedUDFTestUtils.shouldTestPandasUDFs) + // scalastyle:on assume + + val sessionHolder = SessionHolder.forTesting(spark) + try { + SparkConnectService.start(spark.sparkContext) + + val pythonFn = dummyPythonFunction(sessionHolder)(streamingQueryListenerFunction) + + val id1 = "listener_removeListener_test_1" + val id2 = "listener_removeListener_test_2" + val listener1 = new PythonStreamingQueryListener(pythonFn, sessionHolder) + val listener2 = new PythonStreamingQueryListener(pythonFn, sessionHolder) + + sessionHolder.cacheListenerById(id1, listener1) + spark.streams.addListener(listener1) + sessionHolder.cacheListenerById(id2, listener2) + spark.streams.addListener(listener2) + + val (runner1, runner2) = (listener1.runner, listener2.runner) + + // assert both python processes are running + assert(!runner1.isWorkerStopped().get) + assert(!runner2.isWorkerStopped().get) + + // remove listener1 + spark.streams.removeListener(listener1) + sessionHolder.removeCachedListener(id1) + // assert listener1's python process is not running + eventually(timeout(30.seconds)) { + assert(runner1.isWorkerStopped().get) + assert(!runner2.isWorkerStopped().get) + } + + // remove listener2 + spark.streams.removeListener(listener2) + sessionHolder.removeCachedListener(id2) + eventually(timeout(30.seconds)) { + // assert listener2's python process is not running + assert(runner2.isWorkerStopped().get) + // all listeners are removed + assert(spark.streams.listListeners().isEmpty) + } + } finally { + SparkConnectService.stop() + } + } + test("python foreachBatch process: process terminates after query is stopped") { // scalastyle:off assume assume(IntegratedUDFTestUtils.shouldTestPandasUDFs) @@ -232,58 +282,10 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession { assert(spark.streams.listListeners().length == 1) // only process termination listener } finally { SparkConnectService.stop() + // Wait for things to calm down. + Thread.sleep(4.seconds.toMillis) // remove process termination listener spark.streams.listListeners().foreach(spark.streams.removeListener) } } - - test("python listener process: process terminates after listener is removed") { - // scalastyle:off assume - assume(IntegratedUDFTestUtils.shouldTestPandasUDFs) - // scalastyle:on assume - - val sessionHolder = SessionHolder.forTesting(spark) - try { - SparkConnectService.start(spark.sparkContext) - - val pythonFn = dummyPythonFunction(sessionHolder)(streamingQueryListenerFunction) - - val id1 = "listener_removeListener_test_1" - val id2 = "listener_removeListener_test_2" - val listener1 = new PythonStreamingQueryListener(pythonFn, sessionHolder) - val listener2 = new PythonStreamingQueryListener(pythonFn, sessionHolder) - - sessionHolder.cacheListenerById(id1, listener1) - spark.streams.addListener(listener1) - sessionHolder.cacheListenerById(id2, listener2) - spark.streams.addListener(listener2) - - val (runner1, runner2) = (listener1.runner, listener2.runner) - - // assert both python processes are running - assert(!runner1.isWorkerStopped().get) - assert(!runner2.isWorkerStopped().get) - - // remove listener1 - spark.streams.removeListener(listener1) - sessionHolder.removeCachedListener(id1) - // assert listener1's python process is not running - eventually(timeout(30.seconds)) { - assert(runner1.isWorkerStopped().get) - assert(!runner2.isWorkerStopped().get) - } - - // remove listener2 - spark.streams.removeListener(listener2) - sessionHolder.removeCachedListener(id2) - eventually(timeout(30.seconds)) { - // assert listener2's python process is not running - assert(runner2.isWorkerStopped().get) - // all listeners are removed - assert(spark.streams.listListeners().isEmpty) - } - } finally { - SparkConnectService.stop() - } - } } diff --git a/python/pyspark/errors/exceptions/connect.py b/python/pyspark/errors/exceptions/connect.py index 625ea01889f6..9bca69c19c90 100644 --- a/python/pyspark/errors/exceptions/connect.py +++ b/python/pyspark/errors/exceptions/connect.py @@ -16,7 +16,7 @@ # import pyspark.sql.connect.proto as pb2 import json -from typing import Dict, List, Optional, TYPE_CHECKING, overload +from typing import Dict, List, Optional, TYPE_CHECKING from pyspark.errors.exceptions.base import ( AnalysisException as BaseAnalysisException, @@ -49,7 +49,7 @@ def convert_exception( info: "ErrorInfo", truncated_message: str, resp: Optional[pb2.FetchErrorDetailsResponse], - display_stacktrace: bool = False, + display_server_stacktrace: bool = False, ) -> SparkConnectException: classes = [] sql_state = None @@ -70,15 +70,15 @@ def convert_exception( else: message = truncated_message stacktrace = info.metadata["stackTrace"] if "stackTrace" in info.metadata else None - display_stacktrace = display_stacktrace if stacktrace is not None else False + display_server_stacktrace = display_server_stacktrace if stacktrace is not None else False if "org.apache.spark.sql.catalyst.parser.ParseException" in classes: return ParseException( message, error_class=error_class, sql_state=sql_state, - stacktrace=stacktrace, - display_stacktrace=display_stacktrace, + server_stacktrace=stacktrace, + display_server_stacktrace=display_server_stacktrace, ) # Order matters. ParseException inherits AnalysisException. elif "org.apache.spark.sql.AnalysisException" in classes: @@ -86,24 +86,24 @@ def convert_exception( message, error_class=error_class, sql_state=sql_state, - stacktrace=stacktrace, - display_stacktrace=display_stacktrace, + server_stacktrace=stacktrace, + display_server_stacktrace=display_server_stacktrace, ) elif "org.apache.spark.sql.streaming.StreamingQueryException" in classes: return StreamingQueryException( message, error_class=error_class, sql_state=sql_state, - stacktrace=stacktrace, - display_stacktrace=display_stacktrace, + server_stacktrace=stacktrace, + display_server_stacktrace=display_server_stacktrace, ) elif "org.apache.spark.sql.execution.QueryExecutionException" in classes: return QueryExecutionException( message, error_class=error_class, sql_state=sql_state, - stacktrace=stacktrace, - display_stacktrace=display_stacktrace, + server_stacktrace=stacktrace, + display_server_stacktrace=display_server_stacktrace, ) # Order matters. NumberFormatException inherits IllegalArgumentException. elif "java.lang.NumberFormatException" in classes: @@ -111,64 +111,64 @@ def convert_exception( message, error_class=error_class, sql_state=sql_state, - stacktrace=stacktrace, - display_stacktrace=display_stacktrace, + server_stacktrace=stacktrace, + display_server_stacktrace=display_server_stacktrace, ) elif "java.lang.IllegalArgumentException" in classes: return IllegalArgumentException( message, error_class=error_class, sql_state=sql_state, - stacktrace=stacktrace, - display_stacktrace=display_stacktrace, + server_stacktrace=stacktrace, + display_server_stacktrace=display_server_stacktrace, ) elif "java.lang.ArithmeticException" in classes: return ArithmeticException( message, error_class=error_class, sql_state=sql_state, - stacktrace=stacktrace, - display_stacktrace=display_stacktrace, + server_stacktrace=stacktrace, + display_server_stacktrace=display_server_stacktrace, ) elif "java.lang.UnsupportedOperationException" in classes: return UnsupportedOperationException( message, error_class=error_class, sql_state=sql_state, - stacktrace=stacktrace, - display_stacktrace=display_stacktrace, + server_stacktrace=stacktrace, + display_server_stacktrace=display_server_stacktrace, ) elif "java.lang.ArrayIndexOutOfBoundsException" in classes: return ArrayIndexOutOfBoundsException( message, error_class=error_class, sql_state=sql_state, - stacktrace=stacktrace, - display_stacktrace=display_stacktrace, + server_stacktrace=stacktrace, + display_server_stacktrace=display_server_stacktrace, ) elif "java.time.DateTimeException" in classes: return DateTimeException( message, error_class=error_class, sql_state=sql_state, - stacktrace=stacktrace, - display_stacktrace=display_stacktrace, + server_stacktrace=stacktrace, + display_server_stacktrace=display_server_stacktrace, ) elif "org.apache.spark.SparkRuntimeException" in classes: return SparkRuntimeException( message, error_class=error_class, sql_state=sql_state, - stacktrace=stacktrace, - display_stacktrace=display_stacktrace, + server_stacktrace=stacktrace, + display_server_stacktrace=display_server_stacktrace, ) elif "org.apache.spark.SparkUpgradeException" in classes: return SparkUpgradeException( message, error_class=error_class, sql_state=sql_state, - stacktrace=stacktrace, - display_stacktrace=display_stacktrace, + server_stacktrace=stacktrace, + display_server_stacktrace=display_server_stacktrace, ) elif "org.apache.spark.api.python.PythonException" in classes: return PythonException( @@ -181,8 +181,8 @@ def convert_exception( message, error_class=error_class, sql_state=sql_state, - stacktrace=stacktrace, - display_stacktrace=display_stacktrace, + server_stacktrace=stacktrace, + display_server_stacktrace=display_server_stacktrace, ) else: return SparkConnectGrpcException( @@ -190,8 +190,8 @@ def convert_exception( reason=info.reason, error_class=error_class, sql_state=sql_state, - stacktrace=stacktrace, - display_stacktrace=display_stacktrace, + server_stacktrace=stacktrace, + display_server_stacktrace=display_server_stacktrace, ) @@ -234,8 +234,8 @@ def __init__( message_parameters: Optional[Dict[str, str]] = None, reason: Optional[str] = None, sql_state: Optional[str] = None, - stacktrace: Optional[str] = None, - display_stacktrace: bool = False, + server_stacktrace: Optional[str] = None, + display_server_stacktrace: bool = False, ) -> None: self.message = message # type: ignore[assignment] if reason is not None: @@ -258,8 +258,8 @@ def __init__( ) self.error_class = error_class self._sql_state: Optional[str] = sql_state - self._stacktrace: Optional[str] = stacktrace - self._display_stacktrace: bool = display_stacktrace + self._stacktrace: Optional[str] = server_stacktrace + self._display_stacktrace: bool = display_server_stacktrace def getSqlState(self) -> None: if self._sql_state is not None: diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 69afef992c34..cef0ea4f305d 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -1564,7 +1564,7 @@ def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDet except grpc.RpcError: return None - def _display_stack_trace(self) -> bool: + def _display_server_stack_trace(self) -> bool: from pyspark.sql.connect.conf import RuntimeConf conf = RuntimeConf(self) @@ -1605,7 +1605,7 @@ def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn: info, status.message, self._fetch_enriched_error(info), - self._display_stack_trace(), + self._display_server_stack_trace(), ) from None raise SparkConnectGrpcException(status.message) from None From 4b13fa70bb7c3536fe913a5b8f5c7fd9b055cbd3 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Tue, 7 Nov 2023 07:09:18 +0100 Subject: [PATCH 6/7] fix and review --- python/pyspark/errors/exceptions/base.py | 2 +- python/pyspark/errors/exceptions/captured.py | 2 +- python/pyspark/errors/exceptions/connect.py | 6 ++++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/python/pyspark/errors/exceptions/base.py b/python/pyspark/errors/exceptions/base.py index 1d09a68dffbf..518a2d99ce88 100644 --- a/python/pyspark/errors/exceptions/base.py +++ b/python/pyspark/errors/exceptions/base.py @@ -75,7 +75,7 @@ def getMessageParameters(self) -> Optional[Dict[str, str]]: """ return self.message_parameters - def getSqlState(self) -> None: + def getSqlState(self) -> Optional[str]: """ Returns an SQLSTATE as a string. diff --git a/python/pyspark/errors/exceptions/captured.py b/python/pyspark/errors/exceptions/captured.py index d62b7d24347e..55ed7ab3a6d5 100644 --- a/python/pyspark/errors/exceptions/captured.py +++ b/python/pyspark/errors/exceptions/captured.py @@ -107,7 +107,7 @@ def getMessageParameters(self) -> Optional[Dict[str, str]]: else: return None - def getSqlState(self) -> Optional[str]: # type: ignore[override] + def getSqlState(self) -> Optional[str]: assert SparkContext._gateway is not None gw = SparkContext._gateway if self._origin is not None and is_instance_of( diff --git a/python/pyspark/errors/exceptions/connect.py b/python/pyspark/errors/exceptions/connect.py index 9bca69c19c90..2558c425469a 100644 --- a/python/pyspark/errors/exceptions/connect.py +++ b/python/pyspark/errors/exceptions/connect.py @@ -55,6 +55,8 @@ def convert_exception( sql_state = None error_class = None + stacktrace: Optional[str] = None + if "classes" in info.metadata: classes = json.loads(info.metadata["classes"]) @@ -261,7 +263,7 @@ def __init__( self._stacktrace: Optional[str] = server_stacktrace self._display_stacktrace: bool = display_server_stacktrace - def getSqlState(self) -> None: + def getSqlState(self) -> Optional[str]: if self._sql_state is not None: return self._sql_state else: @@ -270,7 +272,7 @@ def getSqlState(self) -> None: def getStackTrace(self) -> Optional[str]: return self._stacktrace - def __str__(self): + def __str__(self) -> str: desc = self.message if self._display_stacktrace: desc += "\n\nJVM stacktrace:\n%s" % self._stacktrace From 3fafafc10ea12804c2e553351def7628ff639632 Mon Sep 17 00:00:00 2001 From: Martin Grund Date: Tue, 7 Nov 2023 12:07:03 +0100 Subject: [PATCH 7/7] fix --- .../test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index d9a77f2830b9..3423075ab4b4 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -136,8 +136,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM assert( ex.getStackTrace .find(_.getClassName.contains("org.apache.spark.sql.catalyst.analysis.CheckAnalysis")) - .isDefined - == isServerStackTraceEnabled) + .isDefined) } } }