Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[SPARK-XXX][CONNECT][PYTHON] Better error handlng
  • Loading branch information
grundprinzip committed Nov 5, 2023
commit 4097b054d573c7a5be90cfca834cbb7543abed1b
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ import io.grpc.stub.StreamObserver

import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.FetchErrorDetailsResponse
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.utils.ErrorUtils
import org.apache.spark.sql.internal.SQLConf

/**
* Handles [[proto.FetchErrorDetailsRequest]]s for the [[SparkConnectService]]. The handler
Expand All @@ -46,9 +44,7 @@ class SparkConnectFetchErrorDetailsHandler(

ErrorUtils.throwableToFetchErrorDetailsResponse(
st = error,
serverStackTraceEnabled = sessionHolder.session.conf.get(
Connect.CONNECT_SERVER_STACKTRACE_ENABLED) || sessionHolder.session.conf.get(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we deprecating this config Connect.CONNECT_SERVER_STACKTRACE_ENABLED?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's still used, but only verifies the display behavior rather than the stack trace generation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also make Connect.CONNECT_SERVER_STACKTRACE_ENABLED work for Scala client in this pr?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit more weird, in contrast to Python, the server backtrace is always there in Scala, but the user can decide how to print it.

SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED))
serverStackTraceEnabled = true)
}
.getOrElse(FetchErrorDetailsResponse.newBuilder().build())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,20 @@ private[connect] object ErrorUtils extends Logging {
"classes",
JsonMethods.compact(JsonMethods.render(allClasses(st.getClass).map(_.getName))))

// Add the SQL State and Error Class to the response metadata of the ErrorInfoObject.
st match {
case e: SparkThrowable =>
val state = e.getSqlState
if (state != null && state.nonEmpty) {
errorInfo.putMetadata("sqlState", state)
}
val errorClass = e.getErrorClass
if (errorClass != null && errorClass.nonEmpty) {
errorInfo.putMetadata("errorClass", errorClass)
}
case _ =>
}

if (sessionHolderOpt.exists(_.session.conf.get(Connect.CONNECT_ENRICH_ERROR_ENABLED))) {
// Generate a new unique key for this exception.
val errorId = UUID.randomUUID().toString
Expand Down
93 changes: 71 additions & 22 deletions python/pyspark/errors/exceptions/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#
import pyspark.sql.connect.proto as pb2
import json
from typing import Dict, List, Optional, TYPE_CHECKING
from typing import Dict, List, Optional, TYPE_CHECKING, overload

from pyspark.errors.exceptions.base import (
AnalysisException as BaseAnalysisException,
Expand Down Expand Up @@ -46,55 +46,68 @@ class SparkConnectException(PySparkException):


def convert_exception(
info: "ErrorInfo", truncated_message: str, resp: Optional[pb2.FetchErrorDetailsResponse]
info: "ErrorInfo",
truncated_message: str,
resp: Optional[pb2.FetchErrorDetailsResponse],
display_stacktrace: bool = False
) -> SparkConnectException:
classes = []
sql_state = None
error_class = None

if "classes" in info.metadata:
classes = json.loads(info.metadata["classes"])

if "sqlState" in info.metadata:
sql_state = info.metadata["sqlState"]

if "errorClass" in info.metadata:
error_class = info.metadata["errorClass"]

if resp is not None and resp.HasField("root_error_idx"):
message = resp.errors[resp.root_error_idx].message
stacktrace = _extract_jvm_stacktrace(resp)
else:
message = truncated_message
stacktrace = info.metadata["stackTrace"] if "stackTrace" in info.metadata else ""

if len(stacktrace) > 0:
message += f"\n\nJVM stacktrace:\n{stacktrace}"
stacktrace = info.metadata["stackTrace"] if "stackTrace" in info.metadata else None
display_stacktrace = display_stacktrace if stacktrace is not None else False

if "org.apache.spark.sql.catalyst.parser.ParseException" in classes:
return ParseException(message)
return ParseException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace)
# Order matters. ParseException inherits AnalysisException.
elif "org.apache.spark.sql.AnalysisException" in classes:
return AnalysisException(message)
return AnalysisException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace)
elif "org.apache.spark.sql.streaming.StreamingQueryException" in classes:
return StreamingQueryException(message)
return StreamingQueryException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace)
elif "org.apache.spark.sql.execution.QueryExecutionException" in classes:
return QueryExecutionException(message)
return QueryExecutionException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace)
# Order matters. NumberFormatException inherits IllegalArgumentException.
elif "java.lang.NumberFormatException" in classes:
return NumberFormatException(message)
return NumberFormatException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace)
elif "java.lang.IllegalArgumentException" in classes:
return IllegalArgumentException(message)
return IllegalArgumentException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace)
elif "java.lang.ArithmeticException" in classes:
return ArithmeticException(message)
return ArithmeticException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace)
elif "java.lang.UnsupportedOperationException" in classes:
return UnsupportedOperationException(message)
return UnsupportedOperationException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace)
elif "java.lang.ArrayIndexOutOfBoundsException" in classes:
return ArrayIndexOutOfBoundsException(message)
return ArrayIndexOutOfBoundsException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace)
elif "java.time.DateTimeException" in classes:
return DateTimeException(message)
return DateTimeException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace)
elif "org.apache.spark.SparkRuntimeException" in classes:
return SparkRuntimeException(message)
return SparkRuntimeException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace)
elif "org.apache.spark.SparkUpgradeException" in classes:
return SparkUpgradeException(message)
return SparkUpgradeException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace)
elif "org.apache.spark.api.python.PythonException" in classes:
return PythonException(
"\n An exception was thrown from the Python worker. "
"Please see the stack trace below.\n%s" % message
)
# Make sure that the generic SparkException is handled last.
elif "org.apache.spark.SparkException" in classes:
return SparkException(message, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace)
else:
return SparkConnectGrpcException(message, reason=info.reason)
return SparkConnectGrpcException(message, reason=info.reason, error_class=error_class, sql_state=sql_state, stacktrace=stacktrace, display_stacktrace=display_stacktrace)


def _extract_jvm_stacktrace(resp: pb2.FetchErrorDetailsResponse) -> str:
Expand All @@ -106,7 +119,7 @@ def _extract_jvm_stacktrace(resp: pb2.FetchErrorDetailsResponse) -> str:
def format_stacktrace(error: pb2.FetchErrorDetailsResponse.Error) -> None:
message = f"{error.error_type_hierarchy[0]}: {error.message}"
if len(lines) == 0:
lines.append(message)
lines.append(error.error_type_hierarchy[0])
else:
lines.append(f"Caused by: {message}")
for elem in error.stack_trace:
Expand Down Expand Up @@ -135,16 +148,48 @@ def __init__(
error_class: Optional[str] = None,
message_parameters: Optional[Dict[str, str]] = None,
reason: Optional[str] = None,
sql_state: Optional[str] = None,
stacktrace: Optional[str] = None,
display_stacktrace: bool = False
) -> None:
self.message = message # type: ignore[assignment]
if reason is not None:
self.message = f"({reason}) {self.message}"

# PySparkException has the assumption that error_class and message_parameters are
# only occurring together. If only one is set, we assume the message to be fully
# parsed.
tmp_error_class = error_class
tmp_message_parameters = message_parameters
if error_class is not None and message_parameters is None:
tmp_error_class = None
elif error_class is None and message_parameters is not None:
tmp_message_parameters = None

super().__init__(
message=self.message,
error_class=error_class,
message_parameters=message_parameters,
error_class=tmp_error_class,
message_parameters=tmp_message_parameters
)
self.error_class = error_class
self._sql_state: Optional[str] = sql_state
self._stacktrace: Optional[str] = stacktrace
self._display_stacktrace: bool = display_stacktrace

def getSqlState(self) -> None:
if self._sql_state is not None:
return self._sql_state
else:
return super().getSqlState()

def getStackTrace(self) -> Optional[str]:
return self._stacktrace

def __str__(self):
desc = self.message
if self._display_stacktrace:
desc += "\n\nJVM stacktrace:\n%s" % self._stacktrace
return desc


class AnalysisException(SparkConnectGrpcException, BaseAnalysisException):
Expand Down Expand Up @@ -223,3 +268,7 @@ class SparkUpgradeException(SparkConnectGrpcException, BaseSparkUpgradeException
"""
Exception thrown because of Spark upgrade from Spark Connect.
"""

class SparkException(SparkConnectGrpcException):
"""
"""
13 changes: 12 additions & 1 deletion python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1564,6 +1564,14 @@ def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDet
except grpc.RpcError:
return None

def _display_stack_trace(self) -> bool:
from pyspark.sql.connect.conf import RuntimeConf

conf = RuntimeConf(self)
if conf.get("spark.sql.connect.serverStacktrace.enabled") == "true":
return True
return conf.get("spark.sql.pyspark.jvmStacktrace.enabled") == "true"

def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn:
"""
Error handling helper for dealing with GRPC Errors. On the server side, certain
Expand Down Expand Up @@ -1594,7 +1602,10 @@ def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn:
d.Unpack(info)

raise convert_exception(
info, status.message, self._fetch_enriched_error(info)
info,
status.message,
self._fetch_enriched_error(info),
self._display_stack_trace(),
) from None

raise SparkConnectGrpcException(status.message) from None
Expand Down
25 changes: 14 additions & 11 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3378,35 +3378,37 @@ def test_error_enrichment_jvm_stacktrace(self):
"""select from_json(
'{"d": "02-29"}', 'd date', map('dateFormat', 'MM-dd'))"""
).collect()
self.assertTrue("JVM stacktrace" in e.exception.message)
self.assertTrue("org.apache.spark.SparkUpgradeException:" in e.exception.message)
self.assertTrue("JVM stacktrace" in str(e.exception))
self.assertTrue("org.apache.spark.SparkUpgradeException" in str(e.exception))
self.assertTrue(
"at org.apache.spark.sql.errors.ExecutionErrors"
".failToParseDateTimeInNewParserError" in e.exception.message
".failToParseDateTimeInNewParserError" in str(e.exception)
)
self.assertTrue("Caused by: java.time.DateTimeException:" in e.exception.message)
self.assertTrue("Caused by: java.time.DateTimeException:" in str(e.exception))

def test_not_hitting_netty_header_limit(self):
with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": True}):
with self.assertRaises(AnalysisException):
self.spark.sql("select " + "test" * 10000).collect()
self.spark.sql("select " + "test" * 1).collect()

def test_error_stack_trace(self):
with self.sql_conf({"spark.sql.connect.enrichError.enabled": False}):
with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": True}):
with self.assertRaises(AnalysisException) as e:
self.spark.sql("select x").collect()
self.assertTrue("JVM stacktrace" in e.exception.message)
self.assertTrue("JVM stacktrace" in str(e.exception))
self.assertIsNotNone(e.exception.getStackTrace())
self.assertTrue(
"at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message
"at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in str(e.exception)
)

with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": False}):
with self.assertRaises(AnalysisException) as e:
self.spark.sql("select x").collect()
self.assertFalse("JVM stacktrace" in e.exception.message)
self.assertFalse("JVM stacktrace" in str(e.exception))
self.assertIsNone(e.exception.getStackTrace())
self.assertFalse(
"at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message
"at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in str(e.exception)
)

# Create a new session with a different stack trace size.
Expand All @@ -3421,9 +3423,10 @@ def test_error_stack_trace(self):
spark.conf.set("spark.sql.pyspark.jvmStacktrace.enabled", True)
with self.assertRaises(AnalysisException) as e:
spark.sql("select x").collect()
self.assertTrue("JVM stacktrace" in e.exception.message)
self.assertTrue("JVM stacktrace" in str(e.exception))
self.assertIsNotNone(e.exception.getStackTrace())
self.assertFalse(
"at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in e.exception.message
"at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in str(e.exception)
)
spark.stop()

Expand Down