Skip to content

Commit 7556ffa

Browse files
committed
[SPARK-45022][SQL] Provide context for dataset API errors
1 parent 6c885a7 commit 7556ffa

File tree

46 files changed

+1300
-657
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+1300
-657
lines changed

common/utils/src/main/java/org/apache/spark/QueryContext.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
*/
2828
@Evolving
2929
public interface QueryContext {
30+
// The type of this query context.
31+
QueryContextType contextType();
32+
3033
// The object type of the query which throws the exception.
3134
// If the exception is directly from the main query, it should be an empty string.
3235
// Otherwise, it should be the exact object type in upper case. For example, a "VIEW".
@@ -45,4 +48,13 @@ public interface QueryContext {
4548

4649
// The corresponding fragment of the query which throws the exception.
4750
String fragment();
51+
52+
// The Spark code (API) that caused throwing the exception.
53+
String code();
54+
55+
// The user code (call site of the API) that caused throwing the exception.
56+
String callSite();
57+
58+
// Summary of the exception cause.
59+
String summary();
4860
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark;
19+
20+
import org.apache.spark.annotation.Evolving;
21+
22+
/**
23+
* The type of {@link QueryContext}.
24+
*
25+
* @since 3.5.0
26+
*/
27+
@Evolving
28+
public enum QueryContextType {
29+
SQL,
30+
Dataset
31+
}

common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,19 @@ private[spark] object SparkThrowableHelper {
104104
g.writeArrayFieldStart("queryContext")
105105
e.getQueryContext.foreach { c =>
106106
g.writeStartObject()
107-
g.writeStringField("objectType", c.objectType())
108-
g.writeStringField("objectName", c.objectName())
109-
val startIndex = c.startIndex() + 1
110-
if (startIndex > 0) g.writeNumberField("startIndex", startIndex)
111-
val stopIndex = c.stopIndex() + 1
112-
if (stopIndex > 0) g.writeNumberField("stopIndex", stopIndex)
113-
g.writeStringField("fragment", c.fragment())
107+
c.contextType() match {
108+
case QueryContextType.SQL =>
109+
g.writeStringField("objectType", c.objectType())
110+
g.writeStringField("objectName", c.objectName())
111+
val startIndex = c.startIndex() + 1
112+
if (startIndex > 0) g.writeNumberField("startIndex", startIndex)
113+
val stopIndex = c.stopIndex() + 1
114+
if (stopIndex > 0) g.writeNumberField("stopIndex", stopIndex)
115+
g.writeStringField("fragment", c.fragment())
116+
case QueryContextType.Dataset =>
117+
g.writeStringField("code", c.code())
118+
g.writeStringField("callSite", c.callSite())
119+
}
114120
g.writeEndObject()
115121
}
116122
g.writeEndArray()

connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ object CheckConnectJvmClientCompatibility {
209209
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.unwrap_udt"),
210210
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.udaf"),
211211

212+
ProblemFilters.exclude[Problem]("org.apache.spark.sql.TypedColumn.withExprTyped"),
213+
212214
// KeyValueGroupedDataset
213215
ProblemFilters.exclude[Problem](
214216
"org.apache.spark.sql.KeyValueGroupedDataset.queryExecution"),

core/src/test/scala/org/apache/spark/SparkFunSuite.scala

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ abstract class SparkFunSuite
318318
sqlState: Option[String] = None,
319319
parameters: Map[String, String] = Map.empty,
320320
matchPVals: Boolean = false,
321-
queryContext: Array[QueryContext] = Array.empty): Unit = {
321+
queryContext: Array[ExpectedContext] = Array.empty): Unit = {
322322
assert(exception.getErrorClass === errorClass)
323323
sqlState.foreach(state => assert(exception.getSqlState === state))
324324
val expectedParameters = exception.getMessageParameters.asScala
@@ -340,16 +340,23 @@ abstract class SparkFunSuite
340340
val actualQueryContext = exception.getQueryContext()
341341
assert(actualQueryContext.length === queryContext.length, "Invalid length of the query context")
342342
actualQueryContext.zip(queryContext).foreach { case (actual, expected) =>
343-
assert(actual.objectType() === expected.objectType(),
344-
"Invalid objectType of a query context Actual:" + actual.toString)
345-
assert(actual.objectName() === expected.objectName(),
346-
"Invalid objectName of a query context. Actual:" + actual.toString)
347-
assert(actual.startIndex() === expected.startIndex(),
348-
"Invalid startIndex of a query context. Actual:" + actual.toString)
349-
assert(actual.stopIndex() === expected.stopIndex(),
350-
"Invalid stopIndex of a query context. Actual:" + actual.toString)
351-
assert(actual.fragment() === expected.fragment(),
352-
"Invalid fragment of a query context. Actual:" + actual.toString)
343+
if (actual.contextType() == QueryContextType.SQL) {
344+
assert(actual.objectType() === expected.objectType,
345+
"Invalid objectType of a query context Actual:" + actual.toString)
346+
assert(actual.objectName() === expected.objectName,
347+
"Invalid objectName of a query context. Actual:" + actual.toString)
348+
assert(actual.startIndex() === expected.startIndex,
349+
"Invalid startIndex of a query context. Actual:" + actual.toString)
350+
assert(actual.stopIndex() === expected.stopIndex,
351+
"Invalid stopIndex of a query context. Actual:" + actual.toString)
352+
assert(actual.fragment() === expected.fragment,
353+
"Invalid fragment of a query context. Actual:" + actual.toString)
354+
} else if (actual.contextType() == QueryContextType.Dataset) {
355+
assert(actual.code() === expected.code,
356+
"Invalid code of a query context. Actual:" + actual.toString)
357+
assert(actual.callSite().matches(expected.callSitePattern),
358+
"Invalid callSite of a query context. Actual:" + actual.toString)
359+
}
353360
}
354361
}
355362

@@ -365,29 +372,29 @@ abstract class SparkFunSuite
365372
errorClass: String,
366373
sqlState: String,
367374
parameters: Map[String, String],
368-
context: QueryContext): Unit =
375+
context: ExpectedContext): Unit =
369376
checkError(exception, errorClass, Some(sqlState), parameters, false, Array(context))
370377

371378
protected def checkError(
372379
exception: SparkThrowable,
373380
errorClass: String,
374381
parameters: Map[String, String],
375-
context: QueryContext): Unit =
382+
context: ExpectedContext): Unit =
376383
checkError(exception, errorClass, None, parameters, false, Array(context))
377384

378385
protected def checkError(
379386
exception: SparkThrowable,
380387
errorClass: String,
381388
sqlState: String,
382-
context: QueryContext): Unit =
389+
context: ExpectedContext): Unit =
383390
checkError(exception, errorClass, None, Map.empty, false, Array(context))
384391

385392
protected def checkError(
386393
exception: SparkThrowable,
387394
errorClass: String,
388395
sqlState: Option[String],
389396
parameters: Map[String, String],
390-
context: QueryContext): Unit =
397+
context: ExpectedContext): Unit =
391398
checkError(exception, errorClass, sqlState, parameters,
392399
false, Array(context))
393400

@@ -402,7 +409,7 @@ abstract class SparkFunSuite
402409
errorClass: String,
403410
sqlState: Option[String],
404411
parameters: Map[String, String],
405-
context: QueryContext): Unit =
412+
context: ExpectedContext): Unit =
406413
checkError(exception, errorClass, sqlState, parameters,
407414
matchPVals = true, Array(context))
408415

@@ -433,12 +440,28 @@ abstract class SparkFunSuite
433440
objectName: String,
434441
startIndex: Int,
435442
stopIndex: Int,
436-
fragment: String) extends QueryContext
443+
fragment: String,
444+
code: String,
445+
callSitePattern: String
446+
)
437447

438448
object ExpectedContext {
439449
def apply(fragment: String, start: Int, stop: Int): ExpectedContext = {
440450
ExpectedContext("", "", start, stop, fragment)
441451
}
452+
453+
def apply(
454+
objectType: String,
455+
objectName: String,
456+
startIndex: Int,
457+
stopIndex: Int,
458+
fragment: String): ExpectedContext = {
459+
new ExpectedContext(objectType, objectName, startIndex, stopIndex, fragment, "", "")
460+
}
461+
462+
def apply(code: String, callSitePattern: String): ExpectedContext = {
463+
new ExpectedContext("", "", -1, -1, "", code, callSitePattern)
464+
}
442465
}
443466

444467
class LogAppender(msg: String = "", maxEvents: Int = 1000)

core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,11 +460,15 @@ class SparkThrowableSuite extends SparkFunSuite {
460460
test("Get message in the specified format") {
461461
import ErrorMessageFormat._
462462
class TestQueryContext extends QueryContext {
463+
override val contextType = QueryContextType.SQL
463464
override val objectName = "v1"
464465
override val objectType = "VIEW"
465466
override val startIndex = 2
466467
override val stopIndex = -1
467468
override val fragment = "1 / 0"
469+
override def code: String = throw new UnsupportedOperationException
470+
override def callSite: String = throw new UnsupportedOperationException
471+
override val summary = ""
468472
}
469473
val e = new SparkArithmeticException(
470474
errorClass = "DIVIDE_BY_ZERO",
@@ -532,6 +536,55 @@ class SparkThrowableSuite extends SparkFunSuite {
532536
| "message" : "Test message"
533537
| }
534538
|}""".stripMargin)
539+
540+
class TestQueryContext2 extends QueryContext {
541+
override val contextType = QueryContextType.Dataset
542+
override def objectName: String = throw new UnsupportedOperationException
543+
override def objectType: String = throw new UnsupportedOperationException
544+
override def startIndex: Int = throw new UnsupportedOperationException
545+
override def stopIndex: Int = throw new UnsupportedOperationException
546+
override def fragment: String = throw new UnsupportedOperationException
547+
override val code: String = "div"
548+
override val callSite: String = "SimpleApp$.main(SimpleApp.scala:9)"
549+
override val summary = ""
550+
}
551+
val e4 = new SparkArithmeticException(
552+
errorClass = "DIVIDE_BY_ZERO",
553+
messageParameters = Map("config" -> "CONFIG"),
554+
context = Array(new TestQueryContext2),
555+
summary = "Query summary")
556+
557+
assert(SparkThrowableHelper.getMessage(e4, PRETTY) ===
558+
"[DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 " +
559+
"and return NULL instead. If necessary set CONFIG to \"false\" to bypass this error." +
560+
"\nQuery summary")
561+
// scalastyle:off line.size.limit
562+
assert(SparkThrowableHelper.getMessage(e4, MINIMAL) ===
563+
"""{
564+
| "errorClass" : "DIVIDE_BY_ZERO",
565+
| "sqlState" : "22012",
566+
| "messageParameters" : {
567+
| "config" : "CONFIG"
568+
| },
569+
| "queryContext" : [ {
570+
| "code" : "div",
571+
| "callSite" : "SimpleApp$.main(SimpleApp.scala:9)"
572+
| } ]
573+
|}""".stripMargin)
574+
assert(SparkThrowableHelper.getMessage(e4, STANDARD) ===
575+
"""{
576+
| "errorClass" : "DIVIDE_BY_ZERO",
577+
| "messageTemplate" : "Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set <config> to \"false\" to bypass this error.",
578+
| "sqlState" : "22012",
579+
| "messageParameters" : {
580+
| "config" : "CONFIG"
581+
| },
582+
| "queryContext" : [ {
583+
| "code" : "div",
584+
| "callSite" : "SimpleApp$.main(SimpleApp.scala:9)"
585+
| } ]
586+
|}""".stripMargin)
587+
// scalastyle:on line.size.limit
535588
}
536589

537590
test("overwrite error classes") {

sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.antlr.v4.runtime.tree.TerminalNodeImpl
2626
import org.apache.spark.{QueryContext, SparkThrowable, SparkThrowableHelper}
2727
import org.apache.spark.internal.Logging
2828
import org.apache.spark.sql.AnalysisException
29-
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, WithOrigin}
29+
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, SQLQueryContext, WithOrigin}
3030
import org.apache.spark.sql.catalyst.util.SparkParserUtils
3131
import org.apache.spark.sql.errors.QueryParsingErrors
3232
import org.apache.spark.sql.internal.SqlApiConf
@@ -229,7 +229,7 @@ class ParseException(
229229
val builder = new StringBuilder
230230
builder ++= "\n" ++= message
231231
start match {
232-
case Origin(Some(l), Some(p), _, _, _, _, _) =>
232+
case Origin(Some(l), Some(p), _, _, _, _, _, _) =>
233233
builder ++= s"(line $l, pos $p)\n"
234234
command.foreach { cmd =>
235235
val (above, below) = cmd.split("\n").splitAt(l)
@@ -262,8 +262,7 @@ class ParseException(
262262

263263
object ParseException {
264264
def getQueryContext(): Array[QueryContext] = {
265-
val context = CurrentOrigin.get.context
266-
if (context.isValid) Array(context) else Array.empty
265+
Some(CurrentOrigin.get.context).collect { case b: SQLQueryContext if b.isValid => b }.toArray
267266
}
268267
}
269268

sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/SQLQueryContext.scala renamed to sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.trees
1919

20-
import org.apache.spark.QueryContext
20+
import org.apache.spark.{QueryContext, QueryContextType}
2121

2222
/** The class represents error context of a SQL query. */
2323
case class SQLQueryContext(
@@ -28,19 +28,20 @@ case class SQLQueryContext(
2828
sqlText: Option[String],
2929
originObjectType: Option[String],
3030
originObjectName: Option[String]) extends QueryContext {
31+
override val contextType = QueryContextType.SQL
3132

32-
override val objectType = originObjectType.getOrElse("")
33-
override val objectName = originObjectName.getOrElse("")
34-
override val startIndex = originStartIndex.getOrElse(-1)
35-
override val stopIndex = originStopIndex.getOrElse(-1)
33+
val objectType = originObjectType.getOrElse("")
34+
val objectName = originObjectName.getOrElse("")
35+
val startIndex = originStartIndex.getOrElse(-1)
36+
val stopIndex = originStopIndex.getOrElse(-1)
3637

3738
/**
3839
* The SQL query context of current node. For example:
3940
* == SQL of VIEW v1(line 1, position 25) ==
4041
* SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i
4142
* ^^^^^^^^^^^^^^^
4243
*/
43-
lazy val summary: String = {
44+
override lazy val summary: String = {
4445
// If the query context is missing or incorrect, simply return an empty string.
4546
if (!isValid) {
4647
""
@@ -116,7 +117,7 @@ case class SQLQueryContext(
116117
}
117118

118119
/** Gets the textual fragment of a SQL query. */
119-
override lazy val fragment: String = {
120+
lazy val fragment: String = {
120121
if (!isValid) {
121122
""
122123
} else {
@@ -128,6 +129,43 @@ case class SQLQueryContext(
128129
sqlText.isDefined && originStartIndex.isDefined && originStopIndex.isDefined &&
129130
originStartIndex.get >= 0 && originStopIndex.get < sqlText.get.length &&
130131
originStartIndex.get <= originStopIndex.get
132+
}
133+
134+
override def code: String = throw new UnsupportedOperationException
135+
override def callSite: String = throw new UnsupportedOperationException
136+
}
137+
138+
case class DatasetQueryContext(
139+
override val code: String,
140+
override val callSite: String) extends QueryContext {
141+
override val contextType = QueryContextType.Dataset
142+
143+
override def objectType: String = throw new UnsupportedOperationException
144+
override def objectName: String = throw new UnsupportedOperationException
145+
override def startIndex: Int = throw new UnsupportedOperationException
146+
override def stopIndex: Int = throw new UnsupportedOperationException
147+
override def fragment: String = throw new UnsupportedOperationException
148+
149+
override lazy val summary: String = {
150+
val builder = new StringBuilder
151+
builder ++= "== Dataset ==\n"
152+
builder ++= "\""
153+
154+
builder ++= code
155+
builder ++= "\""
156+
builder ++= " was called from "
157+
builder ++= callSite
158+
builder += '\n'
159+
builder.result()
160+
}
161+
}
162+
163+
object DatasetQueryContext {
164+
def apply(elements: Array[StackTraceElement]): DatasetQueryContext = {
165+
val methodName = elements(0).getMethodName
166+
val code = (if (methodName.startsWith("$")) methodName.substring(1) else methodName)
167+
val callSite = elements(1).toString
131168

169+
DatasetQueryContext(code, callSite)
132170
}
133171
}

0 commit comments

Comments
 (0)