Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
2c1d5d8
Prototype
itholic Mar 4, 2024
376fc46
Merge branch 'master' of https://github.com/apache/spark into error_c…
itholic Apr 1, 2024
174a929
Merge branch 'master' of https://github.com/apache/spark into error_c…
itholic Apr 2, 2024
8ab1edf
Support query context testing and added UTs
itholic Apr 2, 2024
5906852
Merge branch 'master' of https://github.com/apache/spark into error_c…
itholic Apr 3, 2024
f3a7bd4
resolve comments
itholic Apr 3, 2024
bbaa399
Add JIRA pointer for testing
itholic Apr 3, 2024
b9f54f1
Silence the linter
itholic Apr 3, 2024
c8d98ea
Adjusted comments
itholic Apr 3, 2024
ef7f1df
Merge branch 'master' of https://github.com/apache/spark into error_c…
itholic Apr 4, 2024
cc52aab
Update displayed string and add comment for PySparkCurrentOrigin
itholic Apr 5, 2024
9c323d4
Using queue to ensure multiple call sites can be logged in order and …
itholic Apr 5, 2024
f5ad1c4
remove unnecessary comment
itholic Apr 5, 2024
4f12dc7
Extends Origin and WithOrigin to PySpark context support
itholic Apr 8, 2024
001c71e
Reusing fn for PySpark logging
itholic Apr 9, 2024
daa08cd
Add document for extended PySpark specific logging functions
itholic Apr 9, 2024
92faffe
remove unused code
itholic Apr 9, 2024
2514afb
Merge branch 'master' of https://github.com/apache/spark into error_c…
itholic Apr 9, 2024
672c176
Adress None properly
itholic Apr 9, 2024
1304c2b
Simplifying
itholic Apr 9, 2024
ff4037b
Merge branch 'master' of https://github.com/apache/spark into error_c…
itholic Apr 10, 2024
1d8df34
Respect spark.sql.stackTracesInDataFrameContext
itholic Apr 10, 2024
95f7848
Add captureStackTrace to remove duplication
itholic Apr 10, 2024
1dd53ed
pysparkLoggingInfo -> pysparkErrorContext
itholic Apr 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/pyspark/errors/exceptions/captured.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,5 +379,13 @@ def fragment(self) -> str:
def callSite(self) -> str:
return str(self._q.callSite())

def pysparkFragment(self) -> Optional[str]: # type: ignore[return]
if self.contextType() == QueryContextType.DataFrame:
return str(self._q.pysparkFragment())

def pysparkCallSite(self) -> Optional[str]: # type: ignore[return]
if self.contextType() == QueryContextType.DataFrame:
return str(self._q.pysparkCallSite())

def summary(self) -> str:
return str(self._q.summary())
65 changes: 64 additions & 1 deletion python/pyspark/errors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@
#

import re
from typing import Dict, Match
import functools
import inspect
from typing import Any, Callable, Dict, Match, TypeVar, Type

from pyspark.errors.error_classes import ERROR_CLASSES_MAP


T = TypeVar("T")


class ErrorClassesReader:
"""
A reader to load error information from error_classes.py.
Expand Down Expand Up @@ -119,3 +124,61 @@ def get_message_template(self, error_class: str) -> str:
message_template = main_message_template + " " + sub_message_template

return message_template


def _capture_call_site(fragment: str) -> None:
"""
Capture the call site information including file name, line number, and function name.

This function updates the thread-local storage from server side (PySparkCurrentOrigin)
with the current call site information when a PySpark API function is called.

Parameters
----------
func_name : str
The name of the PySpark API function being captured.

Notes
-----
The call site information is used to enhance error messages with the exact location
in the user code that led to the error.
"""
from pyspark.sql.session import SparkSession

spark = SparkSession._getActiveSessionOrCreate()
assert spark._jvm is not None

stack = inspect.stack()
frame_info = stack[-1]
filename = frame_info.filename
lineno = frame_info.lineno
call_site = f"{filename}:{lineno}"

pyspark_origin = spark._jvm.org.apache.spark.sql.catalyst.trees.PySparkCurrentOrigin
pyspark_origin.set(fragment, call_site)


def with_origin(func: Callable[..., Any]) -> Callable[..., Any]:
"""
A decorator to capture and provide the call site information to the server side
when PySpark API functions are invoked.
"""

@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
# Update call site when the function is called
_capture_call_site(func.__name__)

return func(*args, **kwargs)

return wrapper


def with_origin_to_class(cls: Type[T]) -> Type[T]:
"""
Decorate all methods of a class with `with_origin` to capture call site information.
"""
for name, method in cls.__dict__.items():
if callable(method) and name != "__init__":
setattr(cls, name, with_origin(method))
return cls
3 changes: 3 additions & 0 deletions python/pyspark/sql/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from pyspark.context import SparkContext
from pyspark.errors import PySparkAttributeError, PySparkTypeError, PySparkValueError
from pyspark.errors.utils import with_origin_to_class
from pyspark.sql.types import DataType
from pyspark.sql.utils import get_active_spark_context

Expand Down Expand Up @@ -177,6 +178,7 @@ def _(
return Column(njc)

_.__doc__ = doc
_.__name__ = name
return _


Expand All @@ -195,6 +197,7 @@ def _(self: "Column", other: Union["LiteralType", "DecimalLiteral"]) -> "Column"
return _


@with_origin_to_class
class Column:

"""
Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/sql/tests/connect/test_parity_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def test_help_command(self):
def test_toDF_with_schema_string(self):
super().test_toDF_with_schema_string()

@unittest.skip("Spark Connect does not support DataFrameQueryContext currently.")
def test_dataframe_error_context(self):
super().test_dataframe_error_context()


if __name__ == "__main__":
import unittest
Expand Down
169 changes: 169 additions & 0 deletions python/pyspark/sql/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
AnalysisException,
IllegalArgumentException,
PySparkTypeError,
ArithmeticException,
QueryContextType,
NumberFormatException,
)
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
Expand Down Expand Up @@ -825,6 +828,172 @@ def test_duplicate_field_names(self):
self.assertEqual(df.schema, schema)
self.assertEqual(df.collect(), data)

def test_dataframe_error_context(self):
# SPARK-47274: Add more useful contexts for PySpark DataFrame API errors.
with self.sql_conf({"spark.sql.ansi.enabled": True}):
df = self.spark.range(10)

# DataFrameQueryContext with pysparkCallSite - divide
with self.assertRaises(ArithmeticException) as pe:
df.withColumn("div_zero", df.id / 0).collect()
self.check_error(
exception=pe.exception,
error_class="DIVIDE_BY_ZERO",
message_parameters={"config": '"spark.sql.ansi.enabled"'},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="divide",
)

# DataFrameQueryContext with pysparkCallSite - plus
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("plus_invalid_type", df.id + "string").collect()
self.check_error(
exception=pe.exception,
error_class="CAST_INVALID_INPUT",
message_parameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="plus",
)

# DataFrameQueryContext with pysparkCallSite - minus
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("minus_invalid_type", df.id - "string").collect()
self.check_error(
exception=pe.exception,
error_class="CAST_INVALID_INPUT",
message_parameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="minus",
)

# DataFrameQueryContext with pysparkCallSite - multiply
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("multiply_invalid_type", df.id * "string").collect()
self.check_error(
exception=pe.exception,
error_class="CAST_INVALID_INPUT",
message_parameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="multiply",
)

# DataFrameQueryContext with pysparkCallSite - chained (`divide` is problematic)
with self.assertRaises(ArithmeticException) as pe:
df.withColumn("multiply_ten", df.id * 10).withColumn(
"divide_zero", df.id / 0
).withColumn("plus_ten", df.id + 10).withColumn("minus_ten", df.id - 10).collect()
self.check_error(
exception=pe.exception,
error_class="DIVIDE_BY_ZERO",
message_parameters={"config": '"spark.sql.ansi.enabled"'},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="divide",
)

# DataFrameQueryContext with pysparkCallSite - chained (`plus` is problematic)
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("multiply_ten", df.id * 10).withColumn(
"divide_ten", df.id / 10
).withColumn("plus_string", df.id + "string").withColumn(
"minus_ten", df.id - 10
).collect()
self.check_error(
exception=pe.exception,
error_class="CAST_INVALID_INPUT",
message_parameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="plus",
)

# DataFrameQueryContext with pysparkCallSite - chained (`minus` is problematic)
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("multiply_ten", df.id * 10).withColumn(
"divide_ten", df.id / 10
).withColumn("plus_ten", df.id + 10).withColumn(
"minus_string", df.id - "string"
).collect()
self.check_error(
exception=pe.exception,
error_class="CAST_INVALID_INPUT",
message_parameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="minus",
)

# DataFrameQueryContext with pysparkCallSite - chained (`multiply` is problematic)
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("multiply_string", df.id * "string").withColumn(
"divide_ten", df.id / 10
).withColumn("plus_ten", df.id + 10).withColumn("minus_ten", df.id - 10).collect()
self.check_error(
exception=pe.exception,
error_class="CAST_INVALID_INPUT",
message_parameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="multiply",
)

# DataFrameQueryContext without pysparkCallSite
with self.assertRaises(AnalysisException) as pe:
df.select("non-existing-column")
self.check_error(
exception=pe.exception,
error_class="UNRESOLVED_COLUMN.WITH_SUGGESTION",
message_parameters={"objectName": "`non-existing-column`", "proposal": "`id`"},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="",
)

# SQLQueryContext
with self.assertRaises(ArithmeticException) as pe:
self.spark.sql("select 10/0").collect()
self.check_error(
exception=pe.exception,
error_class="DIVIDE_BY_ZERO",
message_parameters={"config": '"spark.sql.ansi.enabled"'},
query_context_type=QueryContextType.SQL,
)

# No QueryContext
with self.assertRaises(AnalysisException) as pe:
self.spark.sql("select * from non-existing-table")
self.check_error(
exception=pe.exception,
error_class="INVALID_IDENTIFIER",
message_parameters={"ident": "non-existing-table"},
query_context_type=None,
Copy link
Contributor Author

@itholic itholic Apr 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: None is default, so we don't need to specify like this when QueryContext not existing, but I made this test for explicit example.

)


class DataFrameTests(DataFrameTestsMixin, ReusedSQLTestCase):
pass
Expand Down
30 changes: 30 additions & 0 deletions python/pyspark/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@

from pyspark import SparkContext, SparkConf
from pyspark.errors import PySparkAssertionError, PySparkException
from pyspark.errors.exceptions.captured import CapturedException
from pyspark.errors.exceptions.base import QueryContextType
from pyspark.find_spark_home import _find_spark_home
from pyspark.sql.dataframe import DataFrame
from pyspark.sql import Row
Expand Down Expand Up @@ -280,7 +282,14 @@ def check_error(
exception: PySparkException,
error_class: str,
message_parameters: Optional[Dict[str, str]] = None,
query_context_type: Optional[QueryContextType] = None,
pyspark_fragment: Optional[str] = None,
):
query_context = exception.getQueryContext()
assert bool(query_context) == (query_context_type is not None), (
"`query_context_type` is required when QueryContext exists. "
f"QueryContext: {query_context}."
)
# Test if given error is an instance of PySparkException.
self.assertIsInstance(
exception,
Expand All @@ -302,6 +311,27 @@ def check_error(
expected, actual, f"Expected message parameters was '{expected}', got '{actual}'"
)

# Test query context
if query_context:
expected = query_context_type
actual_contexts = exception.getQueryContext()
for actual_context in actual_contexts:
actual = actual_context.contextType()
self.assertEqual(
expected, actual, f"Expected QueryContext was '{expected}', got '{actual}'"
)
if actual == QueryContextType.DataFrame:
assert (
pyspark_fragment is not None
), "`pyspark_fragment` is required when QueryContextType is DataFrame."
expected = pyspark_fragment
actual = actual_context.pysparkFragment()
self.assertEqual(
expected,
actual,
f"Expected PySpark fragment was '{expected}', got '{actual}'",
)


def assertSchemaEqual(
actual: StructType,
Expand Down
Loading