Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
2c1d5d8
Prototype
itholic Mar 4, 2024
376fc46
Merge branch 'master' of https://github.com/apache/spark into error_c…
itholic Apr 1, 2024
174a929
Merge branch 'master' of https://github.com/apache/spark into error_c…
itholic Apr 2, 2024
8ab1edf
Support query context testing and added UTs
itholic Apr 2, 2024
5906852
Merge branch 'master' of https://github.com/apache/spark into error_c…
itholic Apr 3, 2024
f3a7bd4
resolve comments
itholic Apr 3, 2024
bbaa399
Add JIRA pointer for testing
itholic Apr 3, 2024
b9f54f1
Silence the linter
itholic Apr 3, 2024
c8d98ea
Adjusted comments
itholic Apr 3, 2024
ef7f1df
Merge branch 'master' of https://github.com/apache/spark into error_c…
itholic Apr 4, 2024
cc52aab
Update displayed string and add comment for PySparkCurrentOrigin
itholic Apr 5, 2024
9c323d4
Using queue to ensure multiple call sites can be logged in order and …
itholic Apr 5, 2024
f5ad1c4
remove unnecessary comment
itholic Apr 5, 2024
4f12dc7
Extends Origin and WithOrigin to PySpark context support
itholic Apr 8, 2024
001c71e
Reusing fn for PySpark logging
itholic Apr 9, 2024
daa08cd
Add document for extended PySpark specific logging functions
itholic Apr 9, 2024
92faffe
remove unused code
itholic Apr 9, 2024
2514afb
Merge branch 'master' of https://github.com/apache/spark into error_c…
itholic Apr 9, 2024
672c176
Adress None properly
itholic Apr 9, 2024
1304c2b
Simplifying
itholic Apr 9, 2024
ff4037b
Merge branch 'master' of https://github.com/apache/spark into error_c…
itholic Apr 10, 2024
1d8df34
Respect spark.sql.stackTracesInDataFrameContext
itholic Apr 10, 2024
95f7848
Add captureStackTrace to remove duplication
itholic Apr 10, 2024
1dd53ed
pysparkLoggingInfo -> pysparkErrorContext
itholic Apr 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions python/pyspark/errors/exceptions/captured.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,5 +379,17 @@ def fragment(self) -> str:
def callSite(self) -> str:
return str(self._q.callSite())

def pysparkFragment(self) -> Optional[str]:
if self.contextType() == QueryContextType.DataFrame:
return str(self._q.pysparkFragment())
else:
return None

def pysparkCallSite(self) -> Optional[str]:
if self.contextType() == QueryContextType.DataFrame:
return str(self._q.pysparkCallSite())
else:
return None

def summary(self) -> str:
return str(self._q.summary())
7 changes: 3 additions & 4 deletions python/pyspark/errors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def get_message_template(self, error_class: str) -> str:
return message_template


def _capture_call_site(func_name: str) -> None:
def _capture_call_site(fragment: str) -> None:
"""
Capture the call site information including file name, line number, and function name.

Expand All @@ -150,13 +150,12 @@ def _capture_call_site(func_name: str) -> None:

stack = inspect.stack()
frame_info = stack[-1]
function = func_name
filename = frame_info.filename
lineno = frame_info.lineno
call_site = f'"{function}" was called from\n{filename}:{lineno}'
call_site = f"{filename}:{lineno}"

pyspark_origin = spark._jvm.org.apache.spark.sql.catalyst.trees.PySparkCurrentOrigin
pyspark_origin.set(call_site)
pyspark_origin.set(fragment, call_site)


def with_origin(func: Callable[..., Any]) -> Callable[..., Any]:
Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/sql/tests/connect/test_parity_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def test_help_command(self):
def test_toDF_with_schema_string(self):
super().test_toDF_with_schema_string()

@unittest.skip("Spark Connect does not support DataFrameQueryContext currently.")
def test_dataframe_error_context(self):
super().test_dataframe_error_context()


if __name__ == "__main__":
import unittest
Expand Down
168 changes: 168 additions & 0 deletions python/pyspark/sql/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
AnalysisException,
IllegalArgumentException,
PySparkTypeError,
ArithmeticException,
QueryContextType,
NumberFormatException,
)
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
Expand Down Expand Up @@ -825,6 +828,171 @@ def test_duplicate_field_names(self):
self.assertEqual(df.schema, schema)
self.assertEqual(df.collect(), data)

def test_dataframe_error_context(self):
with self.sql_conf({"spark.sql.ansi.enabled": True}):
df = self.spark.range(10)

# DataFrameQueryContext with pysparkCallSite - divide
with self.assertRaises(ArithmeticException) as pe:
df.withColumn("div_zero", df.id / 0).collect()
self.check_error(
exception=pe.exception,
error_class="DIVIDE_BY_ZERO",
message_parameters={"config": '"spark.sql.ansi.enabled"'},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="divide",
)

# DataFrameQueryContext with pysparkCallSite - plus
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("plus_invalid_type", df.id + "string").collect()
self.check_error(
exception=pe.exception,
error_class="CAST_INVALID_INPUT",
message_parameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="plus",
)

# DataFrameQueryContext with pysparkCallSite - minus
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("minus_invalid_type", df.id - "string").collect()
self.check_error(
exception=pe.exception,
error_class="CAST_INVALID_INPUT",
message_parameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="minus",
)

# DataFrameQueryContext with pysparkCallSite - multiply
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("multiply_invalid_type", df.id * "string").collect()
self.check_error(
exception=pe.exception,
error_class="CAST_INVALID_INPUT",
message_parameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="multiply",
)

# DataFrameQueryContext with pysparkCallSite - chained (`divide` is problematic)
with self.assertRaises(ArithmeticException) as pe:
df.withColumn("multiply_ten", df.id * 10).withColumn(
"divide_zero", df.id / 0
).withColumn("plus_ten", df.id + 10).withColumn("minus_ten", df.id - 10).collect()
self.check_error(
exception=pe.exception,
error_class="DIVIDE_BY_ZERO",
message_parameters={"config": '"spark.sql.ansi.enabled"'},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="divide",
)

# DataFrameQueryContext with pysparkCallSite - chained (`plus` is problematic)
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("multiply_ten", df.id * 10).withColumn(
"divide_ten", df.id / 10
).withColumn("plus_string", df.id + "string").withColumn(
"minus_ten", df.id - 10
).collect()
self.check_error(
exception=pe.exception,
error_class="CAST_INVALID_INPUT",
message_parameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="plus",
)

# DataFrameQueryContext with pysparkCallSite - chained (`minus` is problematic)
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("multiply_ten", df.id * 10).withColumn(
"divide_ten", df.id / 10
).withColumn("plus_ten", df.id + 10).withColumn(
"minus_string", df.id - "string"
).collect()
self.check_error(
exception=pe.exception,
error_class="CAST_INVALID_INPUT",
message_parameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="minus",
)

# DataFrameQueryContext with pysparkCallSite - chained (`multiply` is problematic)
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("multiply_string", df.id * "string").withColumn(
"divide_ten", df.id / 10
).withColumn("plus_ten", df.id + 10).withColumn("minus_ten", df.id - 10).collect()
self.check_error(
exception=pe.exception,
error_class="CAST_INVALID_INPUT",
message_parameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="multiply",
)

# DataFrameQueryContext without pysparkCallSite
with self.assertRaises(AnalysisException) as pe:
df.select("non-existing-column")
self.check_error(
exception=pe.exception,
error_class="UNRESOLVED_COLUMN.WITH_SUGGESTION",
message_parameters={"objectName": "`non-existing-column`", "proposal": "`id`"},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="",
)

# SQLQueryContext
with self.assertRaises(ArithmeticException) as pe:
self.spark.sql("select 10/0").collect()
self.check_error(
exception=pe.exception,
error_class="DIVIDE_BY_ZERO",
message_parameters={"config": '"spark.sql.ansi.enabled"'},
query_context_type=QueryContextType.SQL,
)

# No QueryContext
with self.assertRaises(AnalysisException) as pe:
self.spark.sql("select * from non-existing-table")
self.check_error(
exception=pe.exception,
error_class="INVALID_IDENTIFIER",
message_parameters={"ident": "non-existing-table"},
query_context_type=None,
Copy link
Contributor Author

@itholic itholic Apr 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: None is default, so we don't need to specify like this when QueryContext not existing, but I made this test for explicit example.

)


class DataFrameTests(DataFrameTestsMixin, ReusedSQLTestCase):
pass
Expand Down
30 changes: 30 additions & 0 deletions python/pyspark/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@

from pyspark import SparkContext, SparkConf
from pyspark.errors import PySparkAssertionError, PySparkException
from pyspark.errors.exceptions.captured import CapturedException
from pyspark.errors.exceptions.base import QueryContextType
from pyspark.find_spark_home import _find_spark_home
from pyspark.sql.dataframe import DataFrame
from pyspark.sql import Row
Expand Down Expand Up @@ -280,7 +282,14 @@ def check_error(
exception: PySparkException,
error_class: str,
message_parameters: Optional[Dict[str, str]] = None,
query_context_type: Optional[QueryContextType] = None,
pyspark_fragment: Optional[str] = None,
):
query_context = exception.getQueryContext()
assert bool(query_context) == (query_context_type is not None), (
f"`query_context_type` is required when QueryContext exists. "
f"QueryContext: {query_context}."
)
# Test if given error is an instance of PySparkException.
self.assertIsInstance(
exception,
Expand All @@ -302,6 +311,27 @@ def check_error(
expected, actual, f"Expected message parameters was '{expected}', got '{actual}'"
)

# Test query context
if query_context:
expected = query_context_type
actual_contexts = exception.getQueryContext()
for actual_context in actual_contexts:
actual = actual_context.contextType()
self.assertEqual(
expected, actual, f"Expected QueryContext was '{expected}', got '{actual}'"
)
if actual == QueryContextType.DataFrame:
assert (
pyspark_fragment is not None
), "`pyspark_fragment` is required when QueryContextType is DataFrame."
expected = pyspark_fragment
actual = actual_context.pysparkFragment()
self.assertEqual(
expected,
actual,
f"Expected PySpark fragment was '{expected}', got '{actual}'",
)


def assertSchemaEqual(
actual: StructType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.trees

import scala.collection.mutable

import org.apache.spark.{QueryContext, QueryContextType, SparkUnsupportedOperationException}

/** The class represents error context of a SQL query. */
Expand Down Expand Up @@ -134,7 +136,7 @@ case class SQLQueryContext(
override def callSite: String = throw SparkUnsupportedOperationException()
}

case class DataFrameQueryContext(stackTrace: Seq[StackTraceElement], pysparkCallSite: String)
case class DataFrameQueryContext(stackTrace: Seq[StackTraceElement])
extends QueryContext {
override val contextType = QueryContextType.DataFrame

Expand All @@ -156,6 +158,11 @@ case class DataFrameQueryContext(stackTrace: Seq[StackTraceElement], pysparkCall

override val callSite: String = stackTrace.tail.mkString("\n")

val pysparkOriginInfo: mutable.Map[String, String] = PySparkCurrentOrigin.get()

val pysparkFragment: String = pysparkOriginInfo.getOrElse("fragment", "")
val pysparkCallSite: String = pysparkOriginInfo.getOrElse("callSite", "")

override lazy val summary: String = {
val builder = new StringBuilder
builder ++= "== DataFrame ==\n"
Expand All @@ -167,11 +174,19 @@ case class DataFrameQueryContext(stackTrace: Seq[StackTraceElement], pysparkCall
builder ++= callSite
builder += '\n'

if (pysparkCallSite.nonEmpty) {
if (pysparkOriginInfo.nonEmpty) {
builder ++= "\n== PySpark call site ==\n"
builder ++= "\""

builder ++= pysparkFragment
builder ++= "\""
builder ++= " was called from\n"
builder ++= pysparkCallSite
builder += '\n'
}

PySparkCurrentOrigin.clear()

builder.result()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.spark.sql.catalyst.trees

import scala.collection.mutable

import org.apache.spark.QueryContext
import org.apache.spark.util.ArrayImplicits._

Expand All @@ -35,8 +37,7 @@ case class Origin(
stackTrace: Option[Array[StackTraceElement]] = None) {

lazy val context: QueryContext = if (stackTrace.isDefined) {
val pysparkCallSite = PySparkCurrentOrigin.get()
DataFrameQueryContext(stackTrace.get.toImmutableArraySeq, pysparkCallSite)
DataFrameQueryContext(stackTrace.get.toImmutableArraySeq)
} else {
SQLQueryContext(
line, startPosition, startIndex, stopIndex, sqlText, objectType, objectName)
Expand Down Expand Up @@ -88,16 +89,19 @@ object CurrentOrigin {

/**
* Provides detailed call site information on PySpark.
* This information is generated in PySpark in the form of a String.
* This information is generated in PySpark and stored in the form of a Map.
*/
object PySparkCurrentOrigin {
private val pysparkCallSite = new ThreadLocal[String]() {
override def initialValue(): String = ""
private val pysparkCallSite = new ThreadLocal[mutable.Map[String, String]]() {
override def initialValue(): mutable.Map[String, String] = mutable.Map.empty
}

def set(value: String): Unit = pysparkCallSite.set(value)
def set(fragment: String, callSite: String): Unit = {
pysparkCallSite.get().put("fragment", fragment)
pysparkCallSite.get().put("callSite", callSite)
}

def get(): String = pysparkCallSite.get()
def get(): mutable.Map[String, String] = pysparkCallSite.get()

def clear(): Unit = pysparkCallSite.remove()
}