Skip to content

Commit 7adf30e

Browse files
committed
[SPARK-45022][SQL] Provide context for dataset API errors
1 parent 8e3e600 commit 7adf30e

File tree

54 files changed

+748
-296
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+748
-296
lines changed

common/utils/src/main/java/org/apache/spark/QueryContext.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
*/
2828
@Evolving
2929
public interface QueryContext {
30+
// The type of this query context.
31+
QueryContextType contextType();
32+
3033
// The object type of the query which throws the exception.
3134
// If the exception is directly from the main query, it should be an empty string.
3235
// Otherwise, it should be the exact object type in upper case. For example, a "VIEW".
@@ -45,4 +48,13 @@ public interface QueryContext {
4548

4649
// The corresponding fragment of the query which throws the exception.
4750
String fragment();
51+
52+
// The Spark code (API) that caused throwing the exception.
53+
String code();
54+
55+
// The user code (call site of the API) that caused throwing the exception.
56+
String callSite();
57+
58+
// Summary of the exception cause.
59+
String summary();
4860
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark;
19+
20+
import org.apache.spark.annotation.Evolving;
21+
22+
/**
23+
* The type of {@link QueryContext}.
24+
*
25+
* @since 3.5.0
26+
*/
27+
@Evolving
28+
public enum QueryContextType {
29+
SQL,
30+
Dataset
31+
}

common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,19 @@ private[spark] object SparkThrowableHelper {
104104
g.writeArrayFieldStart("queryContext")
105105
e.getQueryContext.foreach { c =>
106106
g.writeStartObject()
107-
g.writeStringField("objectType", c.objectType())
108-
g.writeStringField("objectName", c.objectName())
109-
val startIndex = c.startIndex() + 1
110-
if (startIndex > 0) g.writeNumberField("startIndex", startIndex)
111-
val stopIndex = c.stopIndex() + 1
112-
if (stopIndex > 0) g.writeNumberField("stopIndex", stopIndex)
113-
g.writeStringField("fragment", c.fragment())
107+
c.contextType() match {
108+
case QueryContextType.SQL =>
109+
g.writeStringField("objectType", c.objectType())
110+
g.writeStringField("objectName", c.objectName())
111+
val startIndex = c.startIndex() + 1
112+
if (startIndex > 0) g.writeNumberField("startIndex", startIndex)
113+
val stopIndex = c.stopIndex() + 1
114+
if (stopIndex > 0) g.writeNumberField("stopIndex", stopIndex)
115+
g.writeStringField("fragment", c.fragment())
116+
case QueryContextType.Dataset =>
117+
g.writeStringField("code", c.code())
118+
g.writeStringField("callSite", c.callSite())
119+
}
114120
g.writeEndObject()
115121
}
116122
g.writeEndArray()

core/src/test/scala/org/apache/spark/SparkFunSuite.scala

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ abstract class SparkFunSuite
342342
sqlState: Option[String] = None,
343343
parameters: Map[String, String] = Map.empty,
344344
matchPVals: Boolean = false,
345-
queryContext: Array[QueryContext] = Array.empty): Unit = {
345+
queryContext: Array[ExpectedContext] = Array.empty): Unit = {
346346
assert(exception.getErrorClass === errorClass)
347347
sqlState.foreach(state => assert(exception.getSqlState === state))
348348
val expectedParameters = exception.getMessageParameters.asScala
@@ -364,16 +364,25 @@ abstract class SparkFunSuite
364364
val actualQueryContext = exception.getQueryContext()
365365
assert(actualQueryContext.length === queryContext.length, "Invalid length of the query context")
366366
actualQueryContext.zip(queryContext).foreach { case (actual, expected) =>
367-
assert(actual.objectType() === expected.objectType(),
368-
"Invalid objectType of a query context Actual:" + actual.toString)
369-
assert(actual.objectName() === expected.objectName(),
370-
"Invalid objectName of a query context. Actual:" + actual.toString)
371-
assert(actual.startIndex() === expected.startIndex(),
372-
"Invalid startIndex of a query context. Actual:" + actual.toString)
373-
assert(actual.stopIndex() === expected.stopIndex(),
374-
"Invalid stopIndex of a query context. Actual:" + actual.toString)
375-
assert(actual.fragment() === expected.fragment(),
376-
"Invalid fragment of a query context. Actual:" + actual.toString)
367+
assert(actual.contextType() === expected.contextType,
368+
"Invalid contextType of a query context Actual:" + actual.toString)
369+
if (actual.contextType() == QueryContextType.SQL) {
370+
assert(actual.objectType() === expected.objectType,
371+
"Invalid objectType of a query context Actual:" + actual.toString)
372+
assert(actual.objectName() === expected.objectName,
373+
"Invalid objectName of a query context. Actual:" + actual.toString)
374+
assert(actual.startIndex() === expected.startIndex,
375+
"Invalid startIndex of a query context. Actual:" + actual.toString)
376+
assert(actual.stopIndex() === expected.stopIndex,
377+
"Invalid stopIndex of a query context. Actual:" + actual.toString)
378+
assert(actual.fragment() === expected.fragment,
379+
"Invalid fragment of a query context. Actual:" + actual.toString)
380+
} else if (actual.contextType() == QueryContextType.Dataset) {
381+
assert(actual.code() === expected.code,
382+
"Invalid code of a query context. Actual:" + actual.toString)
383+
assert(actual.callSite().matches(expected.callSitePattern),
384+
"Invalid callSite of a query context. Actual:" + actual.toString)
385+
}
377386
}
378387
}
379388

@@ -389,29 +398,29 @@ abstract class SparkFunSuite
389398
errorClass: String,
390399
sqlState: String,
391400
parameters: Map[String, String],
392-
context: QueryContext): Unit =
401+
context: ExpectedContext): Unit =
393402
checkError(exception, errorClass, Some(sqlState), parameters, false, Array(context))
394403

395404
protected def checkError(
396405
exception: SparkThrowable,
397406
errorClass: String,
398407
parameters: Map[String, String],
399-
context: QueryContext): Unit =
408+
context: ExpectedContext): Unit =
400409
checkError(exception, errorClass, None, parameters, false, Array(context))
401410

402411
protected def checkError(
403412
exception: SparkThrowable,
404413
errorClass: String,
405414
sqlState: String,
406-
context: QueryContext): Unit =
415+
context: ExpectedContext): Unit =
407416
checkError(exception, errorClass, None, Map.empty, false, Array(context))
408417

409418
protected def checkError(
410419
exception: SparkThrowable,
411420
errorClass: String,
412421
sqlState: Option[String],
413422
parameters: Map[String, String],
414-
context: QueryContext): Unit =
423+
context: ExpectedContext): Unit =
415424
checkError(exception, errorClass, sqlState, parameters,
416425
false, Array(context))
417426

@@ -426,7 +435,7 @@ abstract class SparkFunSuite
426435
errorClass: String,
427436
sqlState: Option[String],
428437
parameters: Map[String, String],
429-
context: QueryContext): Unit =
438+
context: ExpectedContext): Unit =
430439
checkError(exception, errorClass, sqlState, parameters,
431440
matchPVals = true, Array(context))
432441

@@ -453,16 +462,34 @@ abstract class SparkFunSuite
453462
parameters = Map("relationName" -> tableName))
454463

455464
case class ExpectedContext(
465+
contextType: QueryContextType,
456466
objectType: String,
457467
objectName: String,
458468
startIndex: Int,
459469
stopIndex: Int,
460-
fragment: String) extends QueryContext
470+
fragment: String,
471+
code: String,
472+
callSitePattern: String
473+
)
461474

462475
object ExpectedContext {
463476
def apply(fragment: String, start: Int, stop: Int): ExpectedContext = {
464477
ExpectedContext("", "", start, stop, fragment)
465478
}
479+
480+
def apply(
481+
objectType: String,
482+
objectName: String,
483+
startIndex: Int,
484+
stopIndex: Int,
485+
fragment: String): ExpectedContext = {
486+
new ExpectedContext(QueryContextType.SQL, objectType, objectName, startIndex, stopIndex,
487+
fragment, "", "")
488+
}
489+
490+
def apply(code: String, callSitePattern: String): ExpectedContext = {
491+
new ExpectedContext(QueryContextType.Dataset, "", "", -1, -1, "", code, callSitePattern)
492+
}
466493
}
467494

468495
class LogAppender(msg: String = "", maxEvents: Int = 1000)

core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,11 +460,15 @@ class SparkThrowableSuite extends SparkFunSuite {
460460
test("Get message in the specified format") {
461461
import ErrorMessageFormat._
462462
class TestQueryContext extends QueryContext {
463+
override val contextType = QueryContextType.SQL
463464
override val objectName = "v1"
464465
override val objectType = "VIEW"
465466
override val startIndex = 2
466467
override val stopIndex = -1
467468
override val fragment = "1 / 0"
469+
override def code: String = throw new UnsupportedOperationException
470+
override def callSite: String = throw new UnsupportedOperationException
471+
override val summary = ""
468472
}
469473
val e = new SparkArithmeticException(
470474
errorClass = "DIVIDE_BY_ZERO",
@@ -532,6 +536,55 @@ class SparkThrowableSuite extends SparkFunSuite {
532536
| "message" : "Test message"
533537
| }
534538
|}""".stripMargin)
539+
540+
class TestQueryContext2 extends QueryContext {
541+
override val contextType = QueryContextType.Dataset
542+
override def objectName: String = throw new UnsupportedOperationException
543+
override def objectType: String = throw new UnsupportedOperationException
544+
override def startIndex: Int = throw new UnsupportedOperationException
545+
override def stopIndex: Int = throw new UnsupportedOperationException
546+
override def fragment: String = throw new UnsupportedOperationException
547+
override val code: String = "div"
548+
override val callSite: String = "SimpleApp$.main(SimpleApp.scala:9)"
549+
override val summary = ""
550+
}
551+
val e4 = new SparkArithmeticException(
552+
errorClass = "DIVIDE_BY_ZERO",
553+
messageParameters = Map("config" -> "CONFIG"),
554+
context = Array(new TestQueryContext2),
555+
summary = "Query summary")
556+
557+
assert(SparkThrowableHelper.getMessage(e4, PRETTY) ===
558+
"[DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 " +
559+
"and return NULL instead. If necessary set CONFIG to \"false\" to bypass this error." +
560+
"\nQuery summary")
561+
// scalastyle:off line.size.limit
562+
assert(SparkThrowableHelper.getMessage(e4, MINIMAL) ===
563+
"""{
564+
| "errorClass" : "DIVIDE_BY_ZERO",
565+
| "sqlState" : "22012",
566+
| "messageParameters" : {
567+
| "config" : "CONFIG"
568+
| },
569+
| "queryContext" : [ {
570+
| "code" : "div",
571+
| "callSite" : "SimpleApp$.main(SimpleApp.scala:9)"
572+
| } ]
573+
|}""".stripMargin)
574+
assert(SparkThrowableHelper.getMessage(e4, STANDARD) ===
575+
"""{
576+
| "errorClass" : "DIVIDE_BY_ZERO",
577+
| "messageTemplate" : "Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set <config> to \"false\" to bypass this error.",
578+
| "sqlState" : "22012",
579+
| "messageParameters" : {
580+
| "config" : "CONFIG"
581+
| },
582+
| "queryContext" : [ {
583+
| "code" : "div",
584+
| "callSite" : "SimpleApp$.main(SimpleApp.scala:9)"
585+
| } ]
586+
|}""".stripMargin)
587+
// scalastyle:on line.size.limit
535588
}
536589

537590
test("overwrite error classes") {

sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.antlr.v4.runtime.tree.TerminalNodeImpl
2626
import org.apache.spark.{QueryContext, SparkThrowable, SparkThrowableHelper}
2727
import org.apache.spark.internal.Logging
2828
import org.apache.spark.sql.AnalysisException
29-
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, WithOrigin}
29+
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, SQLQueryContext, WithOrigin}
3030
import org.apache.spark.sql.catalyst.util.SparkParserUtils
3131
import org.apache.spark.sql.errors.QueryParsingErrors
3232
import org.apache.spark.sql.internal.SqlApiConf
@@ -229,7 +229,7 @@ class ParseException(
229229
val builder = new StringBuilder
230230
builder ++= "\n" ++= message
231231
start match {
232-
case Origin(Some(l), Some(p), _, _, _, _, _) =>
232+
case Origin(Some(l), Some(p), _, _, _, _, _, _) =>
233233
builder ++= s"(line $l, pos $p)\n"
234234
command.foreach { cmd =>
235235
val (above, below) = cmd.split("\n").splitAt(l)
@@ -262,8 +262,7 @@ class ParseException(
262262

263263
object ParseException {
264264
def getQueryContext(): Array[QueryContext] = {
265-
val context = CurrentOrigin.get.context
266-
if (context.isValid) Array(context) else Array.empty
265+
Some(CurrentOrigin.get.context).collect { case b: SQLQueryContext if b.isValid => b }.toArray
267266
}
268267
}
269268

sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/SQLQueryContext.scala renamed to sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.trees
1919

20-
import org.apache.spark.QueryContext
20+
import org.apache.spark.{QueryContext, QueryContextType}
2121

2222
/** The class represents error context of a SQL query. */
2323
case class SQLQueryContext(
@@ -28,19 +28,20 @@ case class SQLQueryContext(
2828
sqlText: Option[String],
2929
originObjectType: Option[String],
3030
originObjectName: Option[String]) extends QueryContext {
31+
override val contextType = QueryContextType.SQL
3132

32-
override val objectType = originObjectType.getOrElse("")
33-
override val objectName = originObjectName.getOrElse("")
34-
override val startIndex = originStartIndex.getOrElse(-1)
35-
override val stopIndex = originStopIndex.getOrElse(-1)
33+
val objectType = originObjectType.getOrElse("")
34+
val objectName = originObjectName.getOrElse("")
35+
val startIndex = originStartIndex.getOrElse(-1)
36+
val stopIndex = originStopIndex.getOrElse(-1)
3637

3738
/**
3839
* The SQL query context of current node. For example:
3940
* == SQL of VIEW v1(line 1, position 25) ==
4041
* SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i
4142
* ^^^^^^^^^^^^^^^
4243
*/
43-
lazy val summary: String = {
44+
override lazy val summary: String = {
4445
// If the query context is missing or incorrect, simply return an empty string.
4546
if (!isValid) {
4647
""
@@ -116,7 +117,7 @@ case class SQLQueryContext(
116117
}
117118

118119
/** Gets the textual fragment of a SQL query. */
119-
override lazy val fragment: String = {
120+
lazy val fragment: String = {
120121
if (!isValid) {
121122
""
122123
} else {
@@ -128,6 +129,47 @@ case class SQLQueryContext(
128129
sqlText.isDefined && originStartIndex.isDefined && originStopIndex.isDefined &&
129130
originStartIndex.get >= 0 && originStopIndex.get < sqlText.get.length &&
130131
originStartIndex.get <= originStopIndex.get
132+
}
133+
134+
override def code: String = throw new UnsupportedOperationException
135+
override def callSite: String = throw new UnsupportedOperationException
136+
}
137+
138+
case class DatasetQueryContext(
139+
override val code: String,
140+
override val callSite: String) extends QueryContext {
141+
override val contextType = QueryContextType.Dataset
142+
143+
override def objectType: String = throw new UnsupportedOperationException
144+
override def objectName: String = throw new UnsupportedOperationException
145+
override def startIndex: Int = throw new UnsupportedOperationException
146+
override def stopIndex: Int = throw new UnsupportedOperationException
147+
override def fragment: String = throw new UnsupportedOperationException
148+
149+
override lazy val summary: String = {
150+
val builder = new StringBuilder
151+
builder ++= "== Dataset ==\n"
152+
builder ++= "\""
153+
154+
builder ++= code
155+
builder ++= "\""
156+
builder ++= " was called from "
157+
builder ++= callSite
158+
builder += '\n'
159+
builder.result()
160+
}
161+
}
162+
163+
object DatasetQueryContext {
164+
def apply(elements: Array[StackTraceElement]): DatasetQueryContext = {
165+
val methodName = elements(0).getMethodName
166+
val code = if (methodName.length > 1 && methodName(0) == '$') {
167+
methodName.substring(1)
168+
} else {
169+
methodName
170+
}
171+
val callSite = elements(1).toString
131172

173+
DatasetQueryContext(code, callSite)
132174
}
133175
}

0 commit comments

Comments
 (0)