Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
2c1d5d8
Prototype
itholic Mar 4, 2024
376fc46
Merge branch 'master' of https://github.com/apache/spark into error_c…
itholic Apr 1, 2024
174a929
Merge branch 'master' of https://github.com/apache/spark into error_c…
itholic Apr 2, 2024
8ab1edf
Support query context testing and added UTs
itholic Apr 2, 2024
5906852
Merge branch 'master' of https://github.com/apache/spark into error_c…
itholic Apr 3, 2024
f3a7bd4
resolve comments
itholic Apr 3, 2024
bbaa399
Add JIRA pointer for testing
itholic Apr 3, 2024
b9f54f1
Silence the linter
itholic Apr 3, 2024
c8d98ea
Adjusted comments
itholic Apr 3, 2024
ef7f1df
Merge branch 'master' of https://github.com/apache/spark into error_c…
itholic Apr 4, 2024
cc52aab
Update displayed string and add comment for PySparkCurrentOrigin
itholic Apr 5, 2024
9c323d4
Using queue to ensure multiple call sites can be logged in order and …
itholic Apr 5, 2024
f5ad1c4
remove unnecessary comment
itholic Apr 5, 2024
4f12dc7
Extends Origin and WithOrigin to PySpark context support
itholic Apr 8, 2024
001c71e
Reusing fn for PySpark logging
itholic Apr 9, 2024
daa08cd
Add document for extended PySpark specific logging functions
itholic Apr 9, 2024
92faffe
remove unused code
itholic Apr 9, 2024
2514afb
Merge branch 'master' of https://github.com/apache/spark into error_c…
itholic Apr 9, 2024
672c176
Adress None properly
itholic Apr 9, 2024
1304c2b
Simplifying
itholic Apr 9, 2024
ff4037b
Merge branch 'master' of https://github.com/apache/spark into error_c…
itholic Apr 10, 2024
1d8df34
Respect spark.sql.stackTracesInDataFrameContext
itholic Apr 10, 2024
95f7848
Add captureStackTrace to remove duplication
itholic Apr 10, 2024
1dd53ed
pysparkLoggingInfo -> pysparkErrorContext
itholic Apr 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Reusing fn for PySpark logging
  • Loading branch information
itholic committed Apr 9, 2024
commit 001c71e3a2a6f9519fcecdc18961e6226acbf4b1
23 changes: 21 additions & 2 deletions python/pyspark/sql/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,25 @@ def _bin_op(
["Column", Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"]], "Column"
]:
"""Create a method for given binary operator"""
binary_operator_map = {
"plus": "+",
"minus": "-",
"divide": "/",
"multiply": "*",
"mod": "%",
"equalTo": "=",
"lt": "<",
"leq": "<=",
"geq": ">=",
"gt": ">",
"eqNullSafe": "<=>",
"bitwiseOR": "|",
"bitwiseAND": "&",
"bitwiseXOR": "^",
# Just following JVM rule even if the names of source and target are the same.
"and": "and",
"or": "or",
}

def _(
self: "Column",
Expand All @@ -178,7 +197,7 @@ def _(

logging_info = {}
spark = SparkSession.getActiveSession()
if name in SUPPORTED_WITH_PYSPARK_LOGGING_INFO_FUNCTIONS and spark is not None:
if name in binary_operator_map and spark is not None:
assert spark._jvm is not None

stack = inspect.stack()
Expand All @@ -190,7 +209,7 @@ def _(

jc = other._jc if isinstance(other, Column) else other
if logging_info:
njc = getattr(self._jc, f"{name}WithPySparkLoggingInfo")(jc, logging_info)
njc = getattr(self._jc, "fn")(binary_operator_map[name], jc, logging_info)
else:
njc = getattr(self._jc, name)(jc)
return Column(njc)
Expand Down
208 changes: 199 additions & 9 deletions python/pyspark/sql/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,7 @@ def test_dataframe_error_context(self):
with self.sql_conf({"spark.sql.ansi.enabled": True}):
df = self.spark.range(10)

# DataFrameQueryContext with pysparkCallSite - divide
# DataFrameQueryContext with pysparkLoggingInfo - divide
with self.assertRaises(ArithmeticException) as pe:
df.withColumn("div_zero", df.id / 0).collect()
self.check_error(
Expand All @@ -844,7 +844,7 @@ def test_dataframe_error_context(self):
pyspark_fragment="divide",
)

# DataFrameQueryContext with pysparkCallSite - plus
# DataFrameQueryContext with pysparkLoggingInfo - plus
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("plus_invalid_type", df.id + "string").collect()
self.check_error(
Expand All @@ -860,7 +860,7 @@ def test_dataframe_error_context(self):
pyspark_fragment="plus",
)

# DataFrameQueryContext with pysparkCallSite - minus
# DataFrameQueryContext with pysparkLoggingInfo - minus
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("minus_invalid_type", df.id - "string").collect()
self.check_error(
Expand All @@ -876,7 +876,7 @@ def test_dataframe_error_context(self):
pyspark_fragment="minus",
)

# DataFrameQueryContext with pysparkCallSite - multiply
# DataFrameQueryContext with pysparkLoggingInfo - multiply
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("multiply_invalid_type", df.id * "string").collect()
self.check_error(
Expand All @@ -892,7 +892,197 @@ def test_dataframe_error_context(self):
pyspark_fragment="multiply",
)

# DataFrameQueryContext with pysparkCallSite - chained (`divide` is problematic)
# DataFrameQueryContext with pysparkLoggingInfo - mod
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("mod_invalid_type", df.id % "string").collect()
self.check_error(
exception=pe.exception,
error_class="CAST_INVALID_INPUT",
message_parameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="mod",
)

# DataFrameQueryContext with pysparkLoggingInfo - equalTo
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("equalTo_invalid_type", df.id == "string").collect()
self.check_error(
exception=pe.exception,
error_class="CAST_INVALID_INPUT",
message_parameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="equalTo",
)

# DataFrameQueryContext with pysparkLoggingInfo - lt
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("lt_invalid_type", df.id < "string").collect()
self.check_error(
exception=pe.exception,
error_class="CAST_INVALID_INPUT",
message_parameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="lt",
)

# DataFrameQueryContext with pysparkLoggingInfo - leq
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("leq_invalid_type", df.id <= "string").collect()
self.check_error(
exception=pe.exception,
error_class="CAST_INVALID_INPUT",
message_parameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="leq",
)

# DataFrameQueryContext with pysparkLoggingInfo - geq
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("geq_invalid_type", df.id >= "string").collect()
self.check_error(
exception=pe.exception,
error_class="CAST_INVALID_INPUT",
message_parameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="geq",
)

# DataFrameQueryContext with pysparkLoggingInfo - gt
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("gt_invalid_type", df.id > "string").collect()
self.check_error(
exception=pe.exception,
error_class="CAST_INVALID_INPUT",
message_parameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="gt",
)

# DataFrameQueryContext with pysparkLoggingInfo - eqNullSafe
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("eqNullSafe_invalid_type", df.id.eqNullSafe("string")).collect()
self.check_error(
exception=pe.exception,
error_class="CAST_INVALID_INPUT",
message_parameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="eqNullSafe",
)

# DataFrameQueryContext with pysparkLoggingInfo - and
with self.assertRaises(AnalysisException) as pe:
df.withColumn("and_invalid_type", df.id & "string").collect()
self.check_error(
exception=pe.exception,
error_class="DATATYPE_MISMATCH.BINARY_OP_WRONG_TYPE",
message_parameters={
"inputType": '"BOOLEAN"',
"actualDataType": '"BIGINT"',
"sqlExpr": '"(id AND string)"',
},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="and",
)

# DataFrameQueryContext with pysparkLoggingInfo - or
with self.assertRaises(AnalysisException) as pe:
df.withColumn("or_invalid_type", df.id | "string").collect()
self.check_error(
exception=pe.exception,
error_class="DATATYPE_MISMATCH.BINARY_OP_WRONG_TYPE",
message_parameters={
"inputType": '"BOOLEAN"',
"actualDataType": '"BIGINT"',
"sqlExpr": '"(id OR string)"',
},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="or",
)

# DataFrameQueryContext with pysparkLoggingInfo - bitwiseOR
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("bitwiseOR_invalid_type", df.id.bitwiseOR("string")).collect()
self.check_error(
exception=pe.exception,
error_class="CAST_INVALID_INPUT",
message_parameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="bitwiseOR",
)

# DataFrameQueryContext with pysparkLoggingInfo - bitwiseAND
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("bitwiseAND_invalid_type", df.id.bitwiseAND("string")).collect()
self.check_error(
exception=pe.exception,
error_class="CAST_INVALID_INPUT",
message_parameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="bitwiseAND",
)

# DataFrameQueryContext with pysparkLoggingInfo - bitwiseXOR
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("bitwiseXOR_invalid_type", df.id.bitwiseXOR("string")).collect()
self.check_error(
exception=pe.exception,
error_class="CAST_INVALID_INPUT",
message_parameters={
"expression": "'string'",
"sourceType": '"STRING"',
"targetType": '"BIGINT"',
"ansiConfig": '"spark.sql.ansi.enabled"',
},
query_context_type=QueryContextType.DataFrame,
pyspark_fragment="bitwiseXOR",
)

# DataFrameQueryContext with pysparkLoggingInfo - chained (`divide` is problematic)
with self.assertRaises(ArithmeticException) as pe:
df.withColumn("multiply_ten", df.id * 10).withColumn(
"divide_zero", df.id / 0
Expand All @@ -905,7 +1095,7 @@ def test_dataframe_error_context(self):
pyspark_fragment="divide",
)

# DataFrameQueryContext with pysparkCallSite - chained (`plus` is problematic)
# DataFrameQueryContext with pysparkLoggingInfo - chained (`plus` is problematic)
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("multiply_ten", df.id * 10).withColumn(
"divide_ten", df.id / 10
Expand All @@ -925,7 +1115,7 @@ def test_dataframe_error_context(self):
pyspark_fragment="plus",
)

# DataFrameQueryContext with pysparkCallSite - chained (`minus` is problematic)
# DataFrameQueryContext with pysparkLoggingInfo - chained (`minus` is problematic)
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("multiply_ten", df.id * 10).withColumn(
"divide_ten", df.id / 10
Expand All @@ -945,7 +1135,7 @@ def test_dataframe_error_context(self):
pyspark_fragment="minus",
)

# DataFrameQueryContext with pysparkCallSite - chained (`multiply` is problematic)
# DataFrameQueryContext with pysparkLoggingInfo - chained (`multiply` is problematic)
with self.assertRaises(NumberFormatException) as pe:
df.withColumn("multiply_string", df.id * "string").withColumn(
"divide_ten", df.id / 10
Expand Down Expand Up @@ -1089,7 +1279,7 @@ def test_dataframe_error_context(self):
pyspark_fragment="multiply",
)

# DataFrameQueryContext without pysparkCallSite
# DataFrameQueryContext without pysparkLoggingInfo
with self.assertRaises(AnalysisException) as pe:
df.select("non-existing-column")
self.check_error(
Expand Down
48 changes: 20 additions & 28 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,26 @@ class Column(val expr: Expression) extends Logging {
Column.fn(name, this, lit(other))
}

// For PySpark logging
private def fn(
name: String, pysparkLoggingInfo: java.util.Map[String, String]): Column = {
withOrigin(Some(pysparkLoggingInfo)) {
Column.fn(name, this)
}
}
private def fn(
name: String, other: Column, pysparkLoggingInfo: java.util.Map[String, String]): Column = {
withOrigin(Some(pysparkLoggingInfo)) {
Column.fn(name, this, other)
}
}
private def fn(
Copy link
Contributor

@cloud-fan cloud-fan Apr 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HyukjinKwon This probably can't cover all the cases, and we may need to add more overloads for certain functions that require non-expression parameters, but it shouldn't be many.

I think it's better than using ThreadLocal which can be quite fragile to pass values between Python and JVM.

name: String, other: Any, pysparkLoggingInfo: java.util.Map[String, String]): Column = {
withOrigin(Some(pysparkLoggingInfo)) {
Column.fn(name, this, lit(other))
}
}

override def toString: String = toPrettySQL(expr)

override def equals(that: Any): Boolean = that match {
Expand Down Expand Up @@ -699,13 +719,6 @@ class Column(val expr: Expression) extends Logging {
*/
def plus(other: Any): Column = this + other

def plusWithPySparkLoggingInfo(
other: Any, loggingInfo: java.util.Map[String, String]): Column = {
withOrigin(Some(loggingInfo)) {
this + other
}
}

/**
* Subtraction. Subtract the other expression from this expression.
* {{{
Expand Down Expand Up @@ -736,13 +749,6 @@ class Column(val expr: Expression) extends Logging {
*/
def minus(other: Any): Column = this - other

def minusWithPySparkLoggingInfo(
other: Any, loggingInfo: java.util.Map[String, String]): Column = {
withOrigin(Some(loggingInfo)) {
this - other
}
}

/**
* Multiplication of this expression and another expression.
* {{{
Expand Down Expand Up @@ -773,13 +779,6 @@ class Column(val expr: Expression) extends Logging {
*/
def multiply(other: Any): Column = this * other

def multiplyWithPySparkLoggingInfo(
other: Any, loggingInfo: java.util.Map[String, String]): Column = {
withOrigin(Some(loggingInfo)) {
this * other
}
}

/**
* Division this expression by another expression.
* {{{
Expand Down Expand Up @@ -810,13 +809,6 @@ class Column(val expr: Expression) extends Logging {
*/
def divide(other: Any): Column = this / other

def divideWithPySparkLoggingInfo(
other: Any, loggingInfo: java.util.Map[String, String]): Column = {
withOrigin(Some(loggingInfo)) {
this / other
}
}

/**
* Modulo (a.k.a. remainder) expression.
*
Expand Down