diff --git a/common/utils/src/main/java/org/apache/spark/QueryContext.java b/common/utils/src/main/java/org/apache/spark/QueryContext.java index de5b29d02951..de79c80ffb83 100644 --- a/common/utils/src/main/java/org/apache/spark/QueryContext.java +++ b/common/utils/src/main/java/org/apache/spark/QueryContext.java @@ -27,6 +27,9 @@ */ @Evolving public interface QueryContext { + // The type of this query context. + QueryContextType contextType(); + // The object type of the query which throws the exception. // If the exception is directly from the main query, it should be an empty string. // Otherwise, it should be the exact object type in upper case. For example, a "VIEW". @@ -45,4 +48,13 @@ public interface QueryContext { // The corresponding fragment of the query which throws the exception. String fragment(); + + // The Spark code (API) that caused throwing the exception. + String code(); + + // The user code (call site of the API) that caused throwing the exception. + String callSite(); + + // Summary of the exception cause. + String summary(); } diff --git a/common/utils/src/main/java/org/apache/spark/QueryContextType.java b/common/utils/src/main/java/org/apache/spark/QueryContextType.java new file mode 100644 index 000000000000..d7a28e63b79b --- /dev/null +++ b/common/utils/src/main/java/org/apache/spark/QueryContextType.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +import org.apache.spark.annotation.Evolving; + +/** + * The type of {@link QueryContext}. + * + * @since 3.5.0 + */ +@Evolving +public enum QueryContextType { + SQL, + Dataset +} diff --git a/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala b/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala index 0f329b5655b3..cb508be6db47 100644 --- a/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala +++ b/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala @@ -104,13 +104,19 @@ private[spark] object SparkThrowableHelper { g.writeArrayFieldStart("queryContext") e.getQueryContext.foreach { c => g.writeStartObject() - g.writeStringField("objectType", c.objectType()) - g.writeStringField("objectName", c.objectName()) - val startIndex = c.startIndex() + 1 - if (startIndex > 0) g.writeNumberField("startIndex", startIndex) - val stopIndex = c.stopIndex() + 1 - if (stopIndex > 0) g.writeNumberField("stopIndex", stopIndex) - g.writeStringField("fragment", c.fragment()) + c.contextType() match { + case QueryContextType.SQL => + g.writeStringField("objectType", c.objectType()) + g.writeStringField("objectName", c.objectName()) + val startIndex = c.startIndex() + 1 + if (startIndex > 0) g.writeNumberField("startIndex", startIndex) + val stopIndex = c.stopIndex() + 1 + if (stopIndex > 0) g.writeNumberField("stopIndex", stopIndex) + g.writeStringField("fragment", c.fragment()) + case QueryContextType.Dataset => + g.writeStringField("code", c.code()) + g.writeStringField("callSite", c.callSite()) + } g.writeEndObject() } g.writeEndArray() diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 1163088c82aa..b3af9d85ce5a 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -342,7 +342,7 @@ abstract class SparkFunSuite sqlState: Option[String] = None, parameters: Map[String, String] = Map.empty, matchPVals: Boolean = false, - queryContext: Array[QueryContext] = Array.empty): Unit = { + queryContext: Array[ExpectedContext] = Array.empty): Unit = { assert(exception.getErrorClass === errorClass) sqlState.foreach(state => assert(exception.getSqlState === state)) val expectedParameters = exception.getMessageParameters.asScala @@ -364,16 +364,25 @@ abstract class SparkFunSuite val actualQueryContext = exception.getQueryContext() assert(actualQueryContext.length === queryContext.length, "Invalid length of the query context") actualQueryContext.zip(queryContext).foreach { case (actual, expected) => - assert(actual.objectType() === expected.objectType(), - "Invalid objectType of a query context Actual:" + actual.toString) - assert(actual.objectName() === expected.objectName(), - "Invalid objectName of a query context. Actual:" + actual.toString) - assert(actual.startIndex() === expected.startIndex(), - "Invalid startIndex of a query context. Actual:" + actual.toString) - assert(actual.stopIndex() === expected.stopIndex(), - "Invalid stopIndex of a query context. Actual:" + actual.toString) - assert(actual.fragment() === expected.fragment(), - "Invalid fragment of a query context. Actual:" + actual.toString) + assert(actual.contextType() === expected.contextType, + "Invalid contextType of a query context Actual:" + actual.toString) + if (actual.contextType() == QueryContextType.SQL) { + assert(actual.objectType() === expected.objectType, + "Invalid objectType of a query context Actual:" + actual.toString) + assert(actual.objectName() === expected.objectName, + "Invalid objectName of a query context. Actual:" + actual.toString) + assert(actual.startIndex() === expected.startIndex, + "Invalid startIndex of a query context. Actual:" + actual.toString) + assert(actual.stopIndex() === expected.stopIndex, + "Invalid stopIndex of a query context. Actual:" + actual.toString) + assert(actual.fragment() === expected.fragment, + "Invalid fragment of a query context. Actual:" + actual.toString) + } else if (actual.contextType() == QueryContextType.Dataset) { + assert(actual.code() === expected.code, + "Invalid code of a query context. Actual:" + actual.toString) + assert(actual.callSite().matches(expected.callSitePattern), + "Invalid callSite of a query context. Actual:" + actual.toString) + } } } @@ -389,21 +398,21 @@ abstract class SparkFunSuite errorClass: String, sqlState: String, parameters: Map[String, String], - context: QueryContext): Unit = + context: ExpectedContext): Unit = checkError(exception, errorClass, Some(sqlState), parameters, false, Array(context)) protected def checkError( exception: SparkThrowable, errorClass: String, parameters: Map[String, String], - context: QueryContext): Unit = + context: ExpectedContext): Unit = checkError(exception, errorClass, None, parameters, false, Array(context)) protected def checkError( exception: SparkThrowable, errorClass: String, sqlState: String, - context: QueryContext): Unit = + context: ExpectedContext): Unit = checkError(exception, errorClass, None, Map.empty, false, Array(context)) protected def checkError( @@ -411,7 +420,7 @@ abstract class SparkFunSuite errorClass: String, sqlState: Option[String], parameters: Map[String, String], - context: QueryContext): Unit = + context: ExpectedContext): Unit = checkError(exception, errorClass, sqlState, parameters, false, Array(context)) @@ -426,7 +435,7 @@ abstract class SparkFunSuite errorClass: String, sqlState: Option[String], parameters: Map[String, String], - context: QueryContext): Unit = + context: ExpectedContext): Unit = checkError(exception, errorClass, sqlState, parameters, matchPVals = true, Array(context)) @@ -453,16 +462,34 @@ abstract class SparkFunSuite parameters = Map("relationName" -> tableName)) case class ExpectedContext( + contextType: QueryContextType, objectType: String, objectName: String, startIndex: Int, stopIndex: Int, - fragment: String) extends QueryContext + fragment: String, + code: String, + callSitePattern: String + ) object ExpectedContext { def apply(fragment: String, start: Int, stop: Int): ExpectedContext = { ExpectedContext("", "", start, stop, fragment) } + + def apply( + objectType: String, + objectName: String, + startIndex: Int, + stopIndex: Int, + fragment: String): ExpectedContext = { + new ExpectedContext(QueryContextType.SQL, objectType, objectName, startIndex, stopIndex, + fragment, "", "") + } + + def apply(code: String, callSitePattern: String): ExpectedContext = { + new ExpectedContext(QueryContextType.Dataset, "", "", -1, -1, "", code, callSitePattern) + } } class LogAppender(msg: String = "", maxEvents: Int = 1000) diff --git a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala index 5c9009bf8fa4..bc8098c985ac 100644 --- a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala @@ -491,11 +491,15 @@ class SparkThrowableSuite extends SparkFunSuite { test("Get message in the specified format") { import ErrorMessageFormat._ class TestQueryContext extends QueryContext { + override val contextType = QueryContextType.SQL override val objectName = "v1" override val objectType = "VIEW" override val startIndex = 2 override val stopIndex = -1 override val fragment = "1 / 0" + override def code: String = throw new UnsupportedOperationException + override def callSite: String = throw new UnsupportedOperationException + override val summary = "" } val e = new SparkArithmeticException( errorClass = "DIVIDE_BY_ZERO", @@ -563,6 +567,55 @@ class SparkThrowableSuite extends SparkFunSuite { | "message" : "Test message" | } |}""".stripMargin) + + class TestQueryContext2 extends QueryContext { + override val contextType = QueryContextType.Dataset + override def objectName: String = throw new UnsupportedOperationException + override def objectType: String = throw new UnsupportedOperationException + override def startIndex: Int = throw new UnsupportedOperationException + override def stopIndex: Int = throw new UnsupportedOperationException + override def fragment: String = throw new UnsupportedOperationException + override val code: String = "div" + override val callSite: String = "SimpleApp$.main(SimpleApp.scala:9)" + override val summary = "" + } + val e4 = new SparkArithmeticException( + errorClass = "DIVIDE_BY_ZERO", + messageParameters = Map("config" -> "CONFIG"), + context = Array(new TestQueryContext2), + summary = "Query summary") + + assert(SparkThrowableHelper.getMessage(e4, PRETTY) === + "[DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 " + + "and return NULL instead. If necessary set CONFIG to \"false\" to bypass this error." + + "\nQuery summary") + // scalastyle:off line.size.limit + assert(SparkThrowableHelper.getMessage(e4, MINIMAL) === + """{ + | "errorClass" : "DIVIDE_BY_ZERO", + | "sqlState" : "22012", + | "messageParameters" : { + | "config" : "CONFIG" + | }, + | "queryContext" : [ { + | "code" : "div", + | "callSite" : "SimpleApp$.main(SimpleApp.scala:9)" + | } ] + |}""".stripMargin) + assert(SparkThrowableHelper.getMessage(e4, STANDARD) === + """{ + | "errorClass" : "DIVIDE_BY_ZERO", + | "messageTemplate" : "Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set to \"false\" to bypass this error.", + | "sqlState" : "22012", + | "messageParameters" : { + | "config" : "CONFIG" + | }, + | "queryContext" : [ { + | "code" : "div", + | "callSite" : "SimpleApp$.main(SimpleApp.scala:9)" + | } ] + |}""".stripMargin) + // scalastyle:on line.size.limit } test("overwrite error classes") { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 52440ca7d17b..007c2547cabf 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -41,7 +41,14 @@ object MimaExcludes { // [SPARK-44705][PYTHON] Make PythonRunner single-threaded ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.api.python.BasePythonRunner#ReaderIterator.this"), // [SPARK-44198][CORE] Support propagation of the log level to the executors - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages$SparkAppConfig$") + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages$SparkAppConfig$"), + + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.QueryContext.contextType"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.QueryContext.code"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.QueryContext.callSite"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.QueryContext.summary"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.types.Decimal.fromStringANSI$default$3"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.types.Decimal.fromStringANSI") ) // Default exclude rules diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala index c3a051be89bc..dca111e55c28 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala @@ -26,7 +26,7 @@ import org.antlr.v4.runtime.tree.TerminalNodeImpl import org.apache.spark.{QueryContext, SparkThrowable, SparkThrowableHelper} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, WithOrigin} +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, SQLQueryContext, WithOrigin} import org.apache.spark.sql.catalyst.util.SparkParserUtils import org.apache.spark.sql.errors.QueryParsingErrors import org.apache.spark.sql.internal.SqlApiConf @@ -229,7 +229,7 @@ class ParseException( val builder = new StringBuilder builder ++= "\n" ++= message start match { - case Origin(Some(l), Some(p), _, _, _, _, _) => + case Origin(Some(l), Some(p), _, _, _, _, _, _) => builder ++= s"(line $l, pos $p)\n" command.foreach { cmd => val (above, below) = cmd.split("\n").splitAt(l) @@ -262,8 +262,7 @@ class ParseException( object ParseException { def getQueryContext(): Array[QueryContext] = { - val context = CurrentOrigin.get.context - if (context.isValid) Array(context) else Array.empty + Some(CurrentOrigin.get.context).collect { case b: SQLQueryContext if b.isValid => b }.toArray } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/SQLQueryContext.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala similarity index 73% rename from sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/SQLQueryContext.scala rename to sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala index 99889cf7dae9..4d12341fd32b 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/SQLQueryContext.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.trees -import org.apache.spark.QueryContext +import org.apache.spark.{QueryContext, QueryContextType} /** The class represents error context of a SQL query. */ case class SQLQueryContext( @@ -28,11 +28,12 @@ case class SQLQueryContext( sqlText: Option[String], originObjectType: Option[String], originObjectName: Option[String]) extends QueryContext { + override val contextType = QueryContextType.SQL - override val objectType = originObjectType.getOrElse("") - override val objectName = originObjectName.getOrElse("") - override val startIndex = originStartIndex.getOrElse(-1) - override val stopIndex = originStopIndex.getOrElse(-1) + val objectType = originObjectType.getOrElse("") + val objectName = originObjectName.getOrElse("") + val startIndex = originStartIndex.getOrElse(-1) + val stopIndex = originStopIndex.getOrElse(-1) /** * The SQL query context of current node. For example: @@ -40,7 +41,7 @@ case class SQLQueryContext( * SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i * ^^^^^^^^^^^^^^^ */ - lazy val summary: String = { + override lazy val summary: String = { // If the query context is missing or incorrect, simply return an empty string. if (!isValid) { "" @@ -116,7 +117,7 @@ case class SQLQueryContext( } /** Gets the textual fragment of a SQL query. */ - override lazy val fragment: String = { + lazy val fragment: String = { if (!isValid) { "" } else { @@ -128,6 +129,47 @@ case class SQLQueryContext( sqlText.isDefined && originStartIndex.isDefined && originStopIndex.isDefined && originStartIndex.get >= 0 && originStopIndex.get < sqlText.get.length && originStartIndex.get <= originStopIndex.get + } + + override def code: String = throw new UnsupportedOperationException + override def callSite: String = throw new UnsupportedOperationException +} + +case class DatasetQueryContext( + override val code: String, + override val callSite: String) extends QueryContext { + override val contextType = QueryContextType.Dataset + + override def objectType: String = throw new UnsupportedOperationException + override def objectName: String = throw new UnsupportedOperationException + override def startIndex: Int = throw new UnsupportedOperationException + override def stopIndex: Int = throw new UnsupportedOperationException + override def fragment: String = throw new UnsupportedOperationException + + override lazy val summary: String = { + val builder = new StringBuilder + builder ++= "== Dataset ==\n" + builder ++= "\"" + + builder ++= code + builder ++= "\"" + builder ++= " was called from " + builder ++= callSite + builder += '\n' + builder.result() + } +} + +object DatasetQueryContext { + def apply(elements: Array[StackTraceElement]): DatasetQueryContext = { + val methodName = elements(0).getMethodName + val code = if (methodName.length > 1 && methodName(0) == '$') { + methodName.substring(1) + } else { + methodName + } + val callSite = elements(1).toString + DatasetQueryContext(code, callSite) } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala index ec3e627ac958..7b1f49ded0dd 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala @@ -30,15 +30,21 @@ case class Origin( stopIndex: Option[Int] = None, sqlText: Option[String] = None, objectType: Option[String] = None, - objectName: Option[String] = None) { + objectName: Option[String] = None, + stackTrace: Option[Array[StackTraceElement]] = None) { - lazy val context: SQLQueryContext = SQLQueryContext( - line, startPosition, startIndex, stopIndex, sqlText, objectType, objectName) - - def getQueryContext: Array[QueryContext] = if (context.isValid) { - Array(context) + lazy val context: QueryContext = if (stackTrace.isDefined) { + DatasetQueryContext(stackTrace.get) } else { - Array.empty + SQLQueryContext( + line, startPosition, startIndex, stopIndex, sqlText, objectType, objectName) + } + + def getQueryContext: Array[QueryContext] = { + Some(context).filter { + case s: SQLQueryContext => s.isValid + case _ => true + }.toArray } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala index 7c1b37e9e581..99caef978bb4 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.QueryContext import org.apache.spark.sql.errors.ExecutionErrors /** @@ -27,37 +27,37 @@ object MathUtils { def addExact(a: Int, b: Int): Int = withOverflow(Math.addExact(a, b)) - def addExact(a: Int, b: Int, context: SQLQueryContext): Int = { + def addExact(a: Int, b: Int, context: QueryContext): Int = { withOverflow(Math.addExact(a, b), hint = "try_add", context) } def addExact(a: Long, b: Long): Long = withOverflow(Math.addExact(a, b)) - def addExact(a: Long, b: Long, context: SQLQueryContext): Long = { + def addExact(a: Long, b: Long, context: QueryContext): Long = { withOverflow(Math.addExact(a, b), hint = "try_add", context) } def subtractExact(a: Int, b: Int): Int = withOverflow(Math.subtractExact(a, b)) - def subtractExact(a: Int, b: Int, context: SQLQueryContext): Int = { + def subtractExact(a: Int, b: Int, context: QueryContext): Int = { withOverflow(Math.subtractExact(a, b), hint = "try_subtract", context) } def subtractExact(a: Long, b: Long): Long = withOverflow(Math.subtractExact(a, b)) - def subtractExact(a: Long, b: Long, context: SQLQueryContext): Long = { + def subtractExact(a: Long, b: Long, context: QueryContext): Long = { withOverflow(Math.subtractExact(a, b), hint = "try_subtract", context) } def multiplyExact(a: Int, b: Int): Int = withOverflow(Math.multiplyExact(a, b)) - def multiplyExact(a: Int, b: Int, context: SQLQueryContext): Int = { + def multiplyExact(a: Int, b: Int, context: QueryContext): Int = { withOverflow(Math.multiplyExact(a, b), hint = "try_multiply", context) } def multiplyExact(a: Long, b: Long): Long = withOverflow(Math.multiplyExact(a, b)) - def multiplyExact(a: Long, b: Long, context: SQLQueryContext): Long = { + def multiplyExact(a: Long, b: Long, context: QueryContext): Long = { withOverflow(Math.multiplyExact(a, b), hint = "try_multiply", context) } @@ -78,7 +78,7 @@ object MathUtils { def withOverflow[A]( f: => A, hint: String = "", - context: SQLQueryContext = null): A = { + context: QueryContext = null): A = { try { f } catch { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala index 698e7b37a9ef..f8a9274a5646 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala @@ -25,7 +25,7 @@ import scala.util.control.NonFatal import sun.util.calendar.ZoneInfo -import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.RebaseDateTime.{rebaseGregorianToJulianDays, rebaseGregorianToJulianMicros, rebaseJulianToGregorianDays, rebaseJulianToGregorianMicros} import org.apache.spark.sql.errors.ExecutionErrors @@ -355,7 +355,7 @@ trait SparkDateTimeUtils { def stringToDateAnsi( s: UTF8String, - context: SQLQueryContext = null): Int = { + context: QueryContext = null): Int = { stringToDate(s).getOrElse { throw ExecutionErrors.invalidInputInCastToDatetimeError(s, DateType, context) } @@ -567,7 +567,7 @@ trait SparkDateTimeUtils { def stringToTimestampAnsi( s: UTF8String, timeZoneId: ZoneId, - context: SQLQueryContext = null): Long = { + context: QueryContext = null): Long = { stringToTimestamp(s, timeZoneId).getOrElse { throw ExecutionErrors.invalidInputInCastToDatetimeError(s, TimestampType, context) } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala index 5e52e283338d..b30f7b7a00e9 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala @@ -16,9 +16,9 @@ */ package org.apache.spark.sql.errors -import org.apache.spark.{SparkArithmeticException, SparkException, SparkIllegalArgumentException, SparkNumberFormatException, SparkRuntimeException, SparkUnsupportedOperationException} +import org.apache.spark.{QueryContext, SparkArithmeticException, SparkException, SparkIllegalArgumentException, SparkNumberFormatException, SparkRuntimeException, SparkUnsupportedOperationException} import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext} +import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.catalyst.util.QuotingUtils import org.apache.spark.sql.catalyst.util.QuotingUtils.toSQLSchema import org.apache.spark.sql.types.{DataType, Decimal, StringType} @@ -191,7 +191,7 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { value: Decimal, decimalPrecision: Int, decimalScale: Int, - context: SQLQueryContext = null): ArithmeticException = { + context: QueryContext = null): ArithmeticException = { numericValueOutOfRange(value, decimalPrecision, decimalScale, context) } @@ -199,7 +199,7 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { value: Decimal, decimalPrecision: Int, decimalScale: Int, - context: SQLQueryContext = null): ArithmeticException = { + context: QueryContext = null): ArithmeticException = { numericValueOutOfRange(value, decimalPrecision, decimalScale, context) } @@ -207,7 +207,7 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { value: Decimal, decimalPrecision: Int, decimalScale: Int, - context: SQLQueryContext): ArithmeticException = { + context: QueryContext): ArithmeticException = { new SparkArithmeticException( errorClass = "NUMERIC_VALUE_OUT_OF_RANGE", messageParameters = Map( @@ -222,7 +222,7 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { def invalidInputInCastToNumberError( to: DataType, s: UTF8String, - context: SQLQueryContext): SparkNumberFormatException = { + context: QueryContext): SparkNumberFormatException = { val convertedValueStr = "'" + s.toString.replace("\\", "\\\\").replace("'", "\\'") + "'" new SparkNumberFormatException( errorClass = "CAST_INVALID_INPUT", diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrorsBase.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrorsBase.scala index aed3c681365d..7e039cec980c 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrorsBase.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrorsBase.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.errors import java.util.Locale import org.apache.spark.QueryContext -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.util.{AttributeNameParser, QuotingUtils} import org.apache.spark.sql.types.{AbstractDataType, DataType, TypeCollection} import org.apache.spark.unsafe.types.UTF8String @@ -89,11 +88,11 @@ private[sql] trait DataTypeErrorsBase { "\"" + elem + "\"" } - def getSummary(sqlContext: SQLQueryContext): String = { - if (sqlContext == null) "" else sqlContext.summary + def getSummary(context: QueryContext): String = { + if (context == null) "" else context.summary } - def getQueryContext(sqlContext: SQLQueryContext): Array[QueryContext] = { - if (sqlContext == null) Array.empty else Array(sqlContext.asInstanceOf[QueryContext]) + def getQueryContext(context: QueryContext): Array[QueryContext] = { + if (context == null) Array.empty else Array(context) } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala index c8321e81027b..394e56062071 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala @@ -21,9 +21,8 @@ import java.time.temporal.ChronoField import org.apache.arrow.vector.types.pojo.ArrowType -import org.apache.spark.{SparkArithmeticException, SparkBuildInfo, SparkDateTimeException, SparkException, SparkRuntimeException, SparkUnsupportedOperationException, SparkUpgradeException} +import org.apache.spark.{QueryContext, SparkArithmeticException, SparkBuildInfo, SparkDateTimeException, SparkException, SparkRuntimeException, SparkUnsupportedOperationException, SparkUpgradeException} import org.apache.spark.sql.catalyst.WalkedTypePath -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types.{DataType, DoubleType, StringType, UserDefinedType} import org.apache.spark.unsafe.types.UTF8String @@ -83,14 +82,14 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { def invalidInputInCastToDatetimeError( value: UTF8String, to: DataType, - context: SQLQueryContext): SparkDateTimeException = { + context: QueryContext): SparkDateTimeException = { invalidInputInCastToDatetimeErrorInternal(toSQLValue(value), StringType, to, context) } def invalidInputInCastToDatetimeError( value: Double, to: DataType, - context: SQLQueryContext): SparkDateTimeException = { + context: QueryContext): SparkDateTimeException = { invalidInputInCastToDatetimeErrorInternal(toSQLValue(value), DoubleType, to, context) } @@ -98,7 +97,7 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { sqlValue: String, from: DataType, to: DataType, - context: SQLQueryContext): SparkDateTimeException = { + context: QueryContext): SparkDateTimeException = { new SparkDateTimeException( errorClass = "CAST_INVALID_INPUT", messageParameters = Map( @@ -113,7 +112,7 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { def arithmeticOverflowError( message: String, hint: String = "", - context: SQLQueryContext = null): ArithmeticException = { + context: QueryContext = null): ArithmeticException = { val alternative = if (hint.nonEmpty) { s" Use '$hint' to tolerate overflow and return NULL instead." } else "" diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala index afe73635a682..c1661038025c 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -21,8 +21,8 @@ import java.math.{BigDecimal => JavaBigDecimal, BigInteger, MathContext, Roundin import scala.util.Try +import org.apache.spark.QueryContext import org.apache.spark.annotation.Unstable -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.errors.DataTypeErrors import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.unsafe.types.UTF8String @@ -341,7 +341,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { scale: Int, roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP, nullOnOverflow: Boolean = true, - context: SQLQueryContext = null): Decimal = { + context: QueryContext = null): Decimal = { val copy = clone() if (copy.changePrecision(precision, scale, roundMode)) { copy @@ -617,7 +617,7 @@ object Decimal { def fromStringANSI( str: UTF8String, to: DecimalType = DecimalType.USER_DEFAULT, - context: SQLQueryContext = null): Decimal = { + context: QueryContext = null): Decimal = { try { val bigDecimal = stringToJavaBigDecimal(str) // We fast fail because constructing a very large JavaBigDecimal to Decimal is very slow. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index b975dc3c7a59..4925b87afdc4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -21,13 +21,13 @@ import java.time.{ZoneId, ZoneOffset} import java.util.Locale import java.util.concurrent.TimeUnit._ -import org.apache.spark.SparkArithmeticException +import org.apache.spark.{QueryContext, SparkArithmeticException} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.{SQLQueryContext, TreeNodeTag} +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.types.{PhysicalFractionalType, PhysicalIntegralType, PhysicalNumericType} import org.apache.spark.sql.catalyst.util._ @@ -524,7 +524,7 @@ case class Cast( } } - override def initQueryContext(): Option[SQLQueryContext] = if (ansiEnabled) { + override def initQueryContext(): Option[QueryContext] = if (ansiEnabled) { Some(origin.context) } else { None @@ -942,7 +942,7 @@ case class Cast( private[this] def toPrecision( value: Decimal, decimalType: DecimalType, - context: SQLQueryContext): Decimal = + context: QueryContext): Decimal = value.toPrecision( decimalType.precision, decimalType.scale, Decimal.ROUND_HALF_UP, !ansiEnabled, context) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index bd7369e57b05..3870e6d39a34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -19,14 +19,14 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale -import org.apache.spark.SparkException +import org.apache.spark.{QueryContext, SparkException} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.{BinaryLike, CurrentOrigin, LeafLike, QuaternaryLike, SQLQueryContext, TernaryLike, TreeNode, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{BinaryLike, CurrentOrigin, LeafLike, QuaternaryLike, TernaryLike, TreeNode, UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE, TreePattern} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.truncatedString @@ -613,11 +613,11 @@ abstract class UnaryExpression extends Expression with UnaryLike[Expression] { * to executors. It will also be kept after rule transforms. */ trait SupportQueryContext extends Expression with Serializable { - protected var queryContext: Option[SQLQueryContext] = initQueryContext() + protected var queryContext: Option[QueryContext] = initQueryContext() - def initQueryContext(): Option[SQLQueryContext] + def initQueryContext(): Option[QueryContext] - def getContextOrNull(): SQLQueryContext = queryContext.orNull + def getContextOrNull(): QueryContext = queryContext.orNull def getContextOrNullCode(ctx: CodegenContext, withErrorContext: Boolean = true): String = { if (withErrorContext && queryContext.isDefined) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index fd6131f18560..fe30e2ea6f3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.trees.{SQLQueryContext, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{AVERAGE, TreePattern} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.errors.QueryCompilationErrors @@ -134,7 +135,7 @@ case class Average( override protected def withNewChildInternal(newChild: Expression): Average = copy(child = newChild) - override def initQueryContext(): Option[SQLQueryContext] = if (evalMode == EvalMode.ANSI) { + override def initQueryContext(): Option[QueryContext] = if (evalMode == EvalMode.ANSI) { Some(origin.context) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index e3881520e490..dfd41ad12a28 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{EvalMode, _} -import org.apache.spark.sql.catalyst.trees.{SQLQueryContext, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{SUM, TreePattern} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.errors.QueryCompilationErrors @@ -186,7 +187,7 @@ case class Sum( // The flag `evalMode` won't be shown in the `toString` or `toAggString` methods override def flatArguments: Iterator[Any] = Iterator(child) - override def initQueryContext(): Option[SQLQueryContext] = if (evalMode == EvalMode.ANSI) { + override def initQueryContext(): Option[QueryContext] = if (evalMode == EvalMode.ANSI) { Some(origin.context) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index a556ac9f1294..e3c5184c5acc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.catalyst.expressions import scala.math.{max, min} +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLId, toSQLType} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_ARITHMETIC, TreePattern, UNARY_POSITIVE} import org.apache.spark.sql.catalyst.types.{PhysicalDecimalType, PhysicalFractionalType, PhysicalIntegerType, PhysicalIntegralType, PhysicalLongType} import org.apache.spark.sql.catalyst.util.{IntervalMathUtils, IntervalUtils, MathUtils, TypeUtils} @@ -266,7 +266,7 @@ abstract class BinaryArithmetic extends BinaryOperator final override val nodePatterns: Seq[TreePattern] = Seq(BINARY_ARITHMETIC) - override def initQueryContext(): Option[SQLQueryContext] = { + override def initQueryContext(): Option[QueryContext] = { if (failOnError) { Some(origin.context) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 4a3c7bbc2beb..44e89a1b4a61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -22,13 +22,14 @@ import java.util.Comparator import scala.collection.mutable import scala.reflect.ClassTag +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedSeed} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.{BinaryLike, SQLQueryContext, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{ARRAYS_ZIP, CONCAT, TreePattern} import org.apache.spark.sql.catalyst.types.{DataTypeUtils, PhysicalDataType, PhysicalIntegralType} import org.apache.spark.sql.catalyst.util._ @@ -2525,7 +2526,7 @@ case class ElementAt( override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): ElementAt = copy(left = newLeft, right = newRight) - override def initQueryContext(): Option[SQLQueryContext] = { + override def initQueryContext(): Option[QueryContext] = { if (failOnError && left.resolved && left.dataType.isInstanceOf[ArrayType]) { Some(origin.context) } else { @@ -5045,7 +5046,7 @@ case class ArrayInsert( newSrcArrayExpr: Expression, newPosExpr: Expression, newItemExpr: Expression): ArrayInsert = copy(srcArrayExpr = newSrcArrayExpr, posExpr = newPosExpr, itemExpr = newItemExpr) - override def initQueryContext(): Option[SQLQueryContext] = Some(origin.context) + override def initQueryContext(): Option[QueryContext] = Some(origin.context) } @ExpressionDescription( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index e22af21daaad..edd824b2d111 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.trees.TreePattern.{EXTRACT_VALUE, TreePattern} import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -316,7 +316,7 @@ case class GetArrayItem( newLeft: Expression, newRight: Expression): GetArrayItem = copy(child = newLeft, ordinal = newRight) - override def initQueryContext(): Option[SQLQueryContext] = if (failOnError) { + override def initQueryContext(): Option[QueryContext] = if (failOnError) { Some(origin.context) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 378920856eb1..5f13d397d1bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.types.PhysicalDecimalType import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.errors.QueryExecutionErrors @@ -146,7 +146,7 @@ case class CheckOverflow( override protected def withNewChildInternal(newChild: Expression): CheckOverflow = copy(child = newChild) - override def initQueryContext(): Option[SQLQueryContext] = if (!nullOnOverflow) { + override def initQueryContext(): Option[QueryContext] = if (!nullOnOverflow) { Some(origin.context) } else { None @@ -158,7 +158,7 @@ case class CheckOverflowInSum( child: Expression, dataType: DecimalType, nullOnOverflow: Boolean, - context: SQLQueryContext) extends UnaryExpression with SupportQueryContext { + context: QueryContext) extends UnaryExpression with SupportQueryContext { override def nullable: Boolean = true @@ -210,7 +210,7 @@ case class CheckOverflowInSum( override protected def withNewChildInternal(newChild: Expression): CheckOverflowInSum = copy(child = newChild) - override def initQueryContext(): Option[SQLQueryContext] = Option(context) + override def initQueryContext(): Option[QueryContext] = Option(context) } /** @@ -256,12 +256,12 @@ case class DecimalDivideWithOverflowCheck( left: Expression, right: Expression, override val dataType: DecimalType, - context: SQLQueryContext, + context: QueryContext, nullOnOverflow: Boolean) extends BinaryExpression with ExpectsInputTypes with SupportQueryContext { override def nullable: Boolean = nullOnOverflow override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, DecimalType) - override def initQueryContext(): Option[SQLQueryContext] = Option(context) + override def initQueryContext(): Option[QueryContext] = Option(context) def decimalMethod: String = "$div" override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index fec1df108bcc..6870ab5bfcee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, Un import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.trees.{BinaryLike, QuaternaryLike, TernaryLike} +import org.apache.spark.sql.catalyst.trees.{BinaryLike, CurrentOrigin, QuaternaryLike, TernaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util._ @@ -200,9 +200,11 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes { */ final def bind( f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): HigherOrderFunction = { - val res = bindInternal(f) - res.copyTagsFrom(this) - res + CurrentOrigin.withOrigin(origin) { + val res = bindInternal(f) + res.copyTagsFrom(this) + res + } } protected def bindInternal( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 5378639e6838..13676733a9ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -22,8 +22,8 @@ import java.util.Locale import com.google.common.math.{DoubleMath, IntMath, LongMath} +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.util.DateTimeConstants.MONTHS_PER_YEAR import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.catalyst.util.IntervalUtils._ @@ -604,7 +604,7 @@ trait IntervalDivide { minValue: Any, num: Expression, numValue: Any, - context: SQLQueryContext): Unit = { + context: QueryContext): Unit = { if (value == minValue && num.dataType.isInstanceOf[IntegralType]) { if (numValue.asInstanceOf[Number].longValue() == -1) { throw QueryExecutionErrors.intervalArithmeticOverflowError( @@ -616,7 +616,7 @@ trait IntervalDivide { def divideByZeroCheck( dataType: DataType, num: Any, - context: SQLQueryContext): Unit = dataType match { + context: QueryContext): Unit = dataType match { case _: DecimalType => if (num.asInstanceOf[Decimal].isZero) { throw QueryExecutionErrors.intervalDividedByZeroError(context) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index e87d0bc41412..8f6a615ebe49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.catalyst.expressions import java.{lang => jl} import java.util.Locale +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.util.{MathUtils, NumberConverter, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf @@ -480,7 +480,7 @@ case class Conv( newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(numExpr = newFirst, fromBaseExpr = newSecond, toBaseExpr = newThird) - override def initQueryContext(): Option[SQLQueryContext] = if (ansiEnabled) { + override def initQueryContext(): Option[QueryContext] = if (ansiEnabled) { Some(origin.context) } else { None @@ -1523,7 +1523,7 @@ abstract class RoundBase(child: Expression, scale: Expression, private lazy val scaleV: Any = scale.eval(EmptyRow) protected lazy val _scale: Int = scaleV.asInstanceOf[Int] - override def initQueryContext(): Option[SQLQueryContext] = { + override def initQueryContext(): Option[QueryContext] = { if (ansiEnabled) { Some(origin.context) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 44ec403bf19a..e750aa9283ff 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -23,6 +23,7 @@ import java.util.{HashMap, Locale, Map => JMap} import scala.collection.mutable.ArrayBuffer +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch @@ -30,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke -import org.apache.spark.sql.catalyst.trees.{BinaryLike, SQLQueryContext} +import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LOWER} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -411,7 +412,7 @@ case class Elt( override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Elt = copy(children = newChildren) - override def initQueryContext(): Option[SQLQueryContext] = if (failOnError) { + override def initQueryContext(): Option[QueryContext] = if (failOnError) { Some(origin.context) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 23bbc91c16d5..8fabb4487620 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -24,7 +24,7 @@ import java.util.concurrent.TimeUnit._ import scala.util.control.NonFatal -import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.{Decimal, DoubleExactNumeric, TimestampNTZType, TimestampType} @@ -70,7 +70,7 @@ object DateTimeUtils extends SparkDateTimeUtils { // the "GMT" string. For example, it returns 2000-01-01T00:00+01:00 for 2000-01-01T00:00GMT+01:00. def cleanLegacyTimestampStr(s: UTF8String): UTF8String = s.replace(gmtUtf8, UTF8String.EMPTY_UTF8) - def doubleToTimestampAnsi(d: Double, context: SQLQueryContext): Long = { + def doubleToTimestampAnsi(d: Double, context: QueryContext): Long = { if (d.isNaN || d.isInfinite) { throw QueryExecutionErrors.invalidInputInCastToDatetimeError(d, TimestampType, context) } else { @@ -91,7 +91,7 @@ object DateTimeUtils extends SparkDateTimeUtils { def stringToTimestampWithoutTimeZoneAnsi( s: UTF8String, - context: SQLQueryContext): Long = { + context: QueryContext): Long = { stringToTimestampWithoutTimeZone(s, true).getOrElse { throw QueryExecutionErrors.invalidInputInCastToDatetimeError(s, TimestampNTZType, context) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala index 59765cde1f92..2730ab8f4b89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.QueryContext import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.unsafe.types.UTF8String @@ -54,7 +54,7 @@ object NumberConverter { fromPos: Int, value: Array[Byte], ansiEnabled: Boolean, - context: SQLQueryContext): Long = { + context: QueryContext): Long = { var v: Long = 0L // bound will always be positive since radix >= 2 // Note that: -1 is equivalent to 11111111...1111 which is the largest unsigned long value @@ -134,7 +134,7 @@ object NumberConverter { fromBase: Int, toBase: Int, ansiEnabled: Boolean, - context: SQLQueryContext): UTF8String = { + context: QueryContext): UTF8String = { if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX || Math.abs(toBase) < Character.MIN_RADIX || Math.abs(toBase) > Character.MAX_RADIX) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala index f7800469c352..1c3a5075dab2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.QueryContext import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.{ByteType, DataType, IntegerType, LongType, ShortType} import org.apache.spark.unsafe.types.UTF8String @@ -27,21 +27,21 @@ import org.apache.spark.unsafe.types.UTF8String */ object UTF8StringUtils { - def toLongExact(s: UTF8String, context: SQLQueryContext): Long = + def toLongExact(s: UTF8String, context: QueryContext): Long = withException(s.toLongExact, context, LongType, s) - def toIntExact(s: UTF8String, context: SQLQueryContext): Int = + def toIntExact(s: UTF8String, context: QueryContext): Int = withException(s.toIntExact, context, IntegerType, s) - def toShortExact(s: UTF8String, context: SQLQueryContext): Short = + def toShortExact(s: UTF8String, context: QueryContext): Short = withException(s.toShortExact, context, ShortType, s) - def toByteExact(s: UTF8String, context: SQLQueryContext): Byte = + def toByteExact(s: UTF8String, context: QueryContext): Byte = withException(s.toByteExact, context, ByteType, s) private def withException[A]( f: => A, - context: SQLQueryContext, + context: QueryContext, to: DataType, s: UTF8String): A = { try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 84472490128b..fb0ddd22f4b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.ValueInterval -import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext, TreeNode} +import org.apache.spark.sql.catalyst.trees.{Origin, TreeNode} import org.apache.spark.sql.catalyst.util.{sideBySide, BadRecordException, DateTimeUtils, FailFastMode} import org.apache.spark.sql.connector.catalog.{CatalogNotFoundException, Table, TableProvider} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -103,7 +103,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE value: Decimal, decimalPrecision: Int, decimalScale: Int, - context: SQLQueryContext = null): ArithmeticException = { + context: QueryContext = null): ArithmeticException = { new SparkArithmeticException( errorClass = "NUMERIC_VALUE_OUT_OF_RANGE", messageParameters = Map( @@ -117,7 +117,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def invalidInputSyntaxForBooleanError( s: UTF8String, - context: SQLQueryContext): SparkRuntimeException = { + context: QueryContext): SparkRuntimeException = { new SparkRuntimeException( errorClass = "CAST_INVALID_INPUT", messageParameters = Map( @@ -132,7 +132,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def invalidInputInCastToNumberError( to: DataType, s: UTF8String, - context: SQLQueryContext): SparkNumberFormatException = { + context: QueryContext): SparkNumberFormatException = { new SparkNumberFormatException( errorClass = "CAST_INVALID_INPUT", messageParameters = Map( @@ -193,15 +193,15 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE cause = e) } - def divideByZeroError(context: SQLQueryContext): ArithmeticException = { + def divideByZeroError(context: QueryContext): ArithmeticException = { new SparkArithmeticException( errorClass = "DIVIDE_BY_ZERO", messageParameters = Map("config" -> toSQLConf(SQLConf.ANSI_ENABLED.key)), - context = getQueryContext(context), + context = Array(context), summary = getSummary(context)) } - def intervalDividedByZeroError(context: SQLQueryContext): ArithmeticException = { + def intervalDividedByZeroError(context: QueryContext): ArithmeticException = { new SparkArithmeticException( errorClass = "INTERVAL_DIVIDED_BY_ZERO", messageParameters = Map.empty, @@ -212,7 +212,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def invalidArrayIndexError( index: Int, numElements: Int, - context: SQLQueryContext): ArrayIndexOutOfBoundsException = { + context: QueryContext): ArrayIndexOutOfBoundsException = { new SparkArrayIndexOutOfBoundsException( errorClass = "INVALID_ARRAY_INDEX", messageParameters = Map( @@ -226,7 +226,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def invalidElementAtIndexError( index: Int, numElements: Int, - context: SQLQueryContext): ArrayIndexOutOfBoundsException = { + context: QueryContext): ArrayIndexOutOfBoundsException = { new SparkArrayIndexOutOfBoundsException( errorClass = "INVALID_ARRAY_INDEX_IN_ELEMENT_AT", messageParameters = Map( @@ -291,15 +291,15 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE ansiIllegalArgumentError(e.getMessage) } - def overflowInSumOfDecimalError(context: SQLQueryContext): ArithmeticException = { + def overflowInSumOfDecimalError(context: QueryContext): ArithmeticException = { arithmeticOverflowError("Overflow in sum of decimals", context = context) } - def overflowInIntegralDivideError(context: SQLQueryContext): ArithmeticException = { + def overflowInIntegralDivideError(context: QueryContext): ArithmeticException = { arithmeticOverflowError("Overflow in integral divide", "try_divide", context) } - def overflowInConvError(context: SQLQueryContext): ArithmeticException = { + def overflowInConvError(context: QueryContext): ArithmeticException = { arithmeticOverflowError("Overflow in function conv()", context = context) } @@ -624,7 +624,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def intervalArithmeticOverflowError( message: String, hint: String = "", - context: SQLQueryContext): ArithmeticException = { + context: QueryContext): ArithmeticException = { val alternative = if (hint.nonEmpty) { s" Use '$hint' to tolerate overflow and return NULL instead." } else "" @@ -1390,7 +1390,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE "prettyName" -> prettyName)) } - def invalidIndexOfZeroError(context: SQLQueryContext): RuntimeException = { + def invalidIndexOfZeroError(context: QueryContext): RuntimeException = { new SparkRuntimeException( errorClass = "INVALID_INDEX_OF_ZERO", cause = null, @@ -2555,7 +2555,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE cause = null) } - def multipleRowScalarSubqueryError(context: SQLQueryContext): Throwable = { + def multipleRowScalarSubqueryError(context: QueryContext): Throwable = { new SparkException( errorClass = "SCALAR_SUBQUERY_TOO_MANY_ROWS", messageParameters = Map.empty, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 997308c6ef44..ba4e7b279f51 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -35,8 +35,6 @@ import org.apache.spark.sql.types.StructType trait AnalysisTest extends PlanTest { - import org.apache.spark.QueryContext - protected def extendedAnalysisRules: Seq[Rule[LogicalPlan]] = Nil protected def createTempView( @@ -177,7 +175,7 @@ trait AnalysisTest extends PlanTest { inputPlan: LogicalPlan, expectedErrorClass: String, expectedMessageParameters: Map[String, String], - queryContext: Array[QueryContext] = Array.empty, + queryContext: Array[ExpectedContext] = Array.empty, caseSensitive: Boolean = true): Unit = { withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { val analyzer = getAnalyzer diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala index d91a080d8fe8..3fd0c1ee5de4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale -import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cast, CreateNamedStruct, GetStructField, If, IsNull, LessThanOrEqual, Literal} @@ -159,7 +158,7 @@ abstract class V2ANSIWriteAnalysisSuiteBase extends V2WriteAnalysisSuiteBase { inputPlan: LogicalPlan, expectedErrorClass: String, expectedMessageParameters: Map[String, String], - queryContext: Array[QueryContext] = Array.empty, + queryContext: Array[ExpectedContext] = Array.empty, caseSensitive: Boolean = true): Unit = { withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> StoreAssignmentPolicy.ANSI.toString) { super.assertAnalysisErrorClass( @@ -196,7 +195,7 @@ abstract class V2StrictWriteAnalysisSuiteBase extends V2WriteAnalysisSuiteBase { inputPlan: LogicalPlan, expectedErrorClass: String, expectedMessageParameters: Map[String, String], - queryContext: Array[QueryContext] = Array.empty, + queryContext: Array[ExpectedContext] = Array.empty, caseSensitive: Boolean = true): Unit = { withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> StoreAssignmentPolicy.STRICT.toString) { super.assertAnalysisErrorClass( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index bb326119ab49..a28d9213bc89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -70,8 +70,10 @@ private[sql] object Column { name: String, isDistinct: Boolean, ignoreNulls: Boolean, - inputs: Column*): Column = Column { - UnresolvedFunction(Seq(name), inputs.map(_.expr), isDistinct, ignoreNulls = ignoreNulls) + inputs: Column*): Column = withOrigin(1) { + Column { + UnresolvedFunction(Seq(name), inputs.map(_.expr), isDistinct, ignoreNulls = ignoreNulls) + } } } @@ -148,12 +150,14 @@ class TypedColumn[-T, U]( @Stable class Column(val expr: Expression) extends Logging { - def this(name: String) = this(name match { - case "*" => UnresolvedStar(None) - case _ if name.endsWith(".*") => - val parts = UnresolvedAttribute.parseAttributeName(name.substring(0, name.length - 2)) - UnresolvedStar(Some(parts)) - case _ => UnresolvedAttribute.quotedString(name) + def this(name: String) = this(withOrigin() { + name match { + case "*" => UnresolvedStar(None) + case _ if name.endsWith(".*") => + val parts = UnresolvedAttribute.parseAttributeName(name.substring(0, name.length - 2)) + UnresolvedStar(Some(parts)) + case _ => UnresolvedAttribute.quotedString(name) + } }) private def fn(name: String): Column = { @@ -180,7 +184,9 @@ class Column(val expr: Expression) extends Logging { } /** Creates a column based on the given expression. */ - private def withExpr(newExpr: Expression): Column = new Column(newExpr) + private def withExpr(newExpr: => Expression): Column = withOrigin(1) { + new Column(newExpr) + } /** * Returns the expression for this column either with an existing or auto assigned name. @@ -1370,7 +1376,9 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def over(window: expressions.WindowSpec): Column = window.withAggregate(this) + def over(window: expressions.WindowSpec): Column = withOrigin() { + window.withAggregate(this) + } /** * Defines an empty analytic clause. In this case the analytic function is applied diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 7511c21fa76d..540f90163cd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -71,7 +71,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { def approxQuantile( col: String, probabilities: Array[Double], - relativeError: Double): Array[Double] = { + relativeError: Double): Array[Double] = withOrigin() { approxQuantile(Array(col), probabilities, relativeError).head } @@ -96,7 +96,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { def approxQuantile( cols: Array[String], probabilities: Array[Double], - relativeError: Double): Array[Array[Double]] = { + relativeError: Double): Array[Array[Double]] = withOrigin() { StatFunctions.multipleApproxQuantiles( df.select(cols.map(col): _*), cols, @@ -131,7 +131,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 1.4.0 */ - def cov(col1: String, col2: String): Double = { + def cov(col1: String, col2: String): Double = withOrigin() { StatFunctions.calculateCov(df, Seq(col1, col2)) } @@ -153,7 +153,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 1.4.0 */ - def corr(col1: String, col2: String, method: String): Double = { + def corr(col1: String, col2: String, method: String): Double = withOrigin() { require(method == "pearson", "Currently only the calculation of the Pearson Correlation " + "coefficient is supported.") StatFunctions.pearsonCorrelation(df, Seq(col1, col2)) @@ -209,7 +209,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 1.4.0 */ - def crosstab(col1: String, col2: String): DataFrame = { + def crosstab(col1: String, col2: String): DataFrame = withOrigin() { StatFunctions.crossTabulate(df, col1, col2) } @@ -256,7 +256,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 1.4.0 */ - def freqItems(cols: Array[String], support: Double): DataFrame = { + def freqItems(cols: Array[String], support: Double): DataFrame = withOrigin() { FrequentItems.singlePassFreqItems(df, cols, support) } @@ -275,7 +275,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 1.4.0 */ - def freqItems(cols: Array[String]): DataFrame = { + def freqItems(cols: Array[String]): DataFrame = withOrigin () { FrequentItems.singlePassFreqItems(df, cols, 0.01) } @@ -319,7 +319,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 1.4.0 */ - def freqItems(cols: Seq[String], support: Double): DataFrame = { + def freqItems(cols: Seq[String], support: Double): DataFrame = withOrigin() { FrequentItems.singlePassFreqItems(df, cols, support) } @@ -338,7 +338,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 1.4.0 */ - def freqItems(cols: Seq[String]): DataFrame = { + def freqItems(cols: Seq[String]): DataFrame = withOrigin() { FrequentItems.singlePassFreqItems(df, cols, 0.01) } @@ -414,7 +414,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 3.0.0 */ - def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DataFrame = { + def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DataFrame = withOrigin() { require(fractions.values.forall(p => p >= 0.0 && p <= 1.0), s"Fractions must be in [0, 1], but got $fractions.") import org.apache.spark.sql.functions.{rand, udf} @@ -498,7 +498,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { countMinSketch(col, CountMinSketch.create(eps, confidence, seed)) } - private def countMinSketch(col: Column, zero: CountMinSketch): CountMinSketch = { + private def countMinSketch(col: Column, zero: CountMinSketch): CountMinSketch = withOrigin() { val singleCol = df.select(col) val colType = singleCol.schema.head.dataType @@ -570,7 +570,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @param numBits expected number of bits of the filter. * @since 2.0.0 */ - def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = { + def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = withOrigin() { buildBloomFilter(col, expectedNumItems, numBits, Double.NaN) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f07496e64304..c12edeefe335 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -508,9 +508,11 @@ class Dataset[T] private[sql]( * @group basic * @since 3.4.0 */ - def to(schema: StructType): DataFrame = withPlan { - val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] - Project.matchSchema(logicalPlan, replaced, sparkSession.sessionState.conf) + def to(schema: StructType): DataFrame = withOrigin() { + withPlan { + val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] + Project.matchSchema(logicalPlan, replaced, sparkSession.sessionState.conf) + } } /** @@ -770,12 +772,14 @@ class Dataset[T] private[sql]( */ // We only accept an existing column name, not a derived column here as a watermark that is // defined on a derived column cannot referenced elsewhere in the plan. - def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = withTypedPlan { - val parsedDelay = IntervalUtils.fromIntervalString(delayThreshold) - require(!IntervalUtils.isNegative(parsedDelay), - s"delay threshold ($delayThreshold) should not be negative.") - EliminateEventTimeWatermark( - EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan)) + def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = withOrigin() { + withTypedPlan { + val parsedDelay = IntervalUtils.fromIntervalString(delayThreshold) + require(!IntervalUtils.isNegative(parsedDelay), + s"delay threshold ($delayThreshold) should not be negative.") + EliminateEventTimeWatermark( + EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan)) + } } /** @@ -947,8 +951,10 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.0.0 */ - def join(right: Dataset[_]): DataFrame = withPlan { - Join(logicalPlan, right.logicalPlan, joinType = Inner, None, JoinHint.NONE) + def join(right: Dataset[_]): DataFrame = withOrigin() { + withPlan { + Join(logicalPlan, right.logicalPlan, joinType = Inner, None, JoinHint.NONE) + } } /** @@ -1081,22 +1087,23 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.0.0 */ - def join(right: Dataset[_], usingColumns: Seq[String], joinType: String): DataFrame = { - // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right - // by creating a new instance for one of the branch. - val joined = sparkSession.sessionState.executePlan( - Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None, JoinHint.NONE)) - .analyzed.asInstanceOf[Join] - - withPlan { - Join( - joined.left, - joined.right, - UsingJoin(JoinType(joinType), usingColumns.toIndexedSeq), - None, - JoinHint.NONE) + def join(right: Dataset[_], usingColumns: Seq[String], joinType: String): DataFrame = + withOrigin() { + // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right + // by creating a new instance for one of the branch. + val joined = sparkSession.sessionState.executePlan( + Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None, JoinHint.NONE)) + .analyzed.asInstanceOf[Join] + + withPlan { + Join( + joined.left, + joined.right, + UsingJoin(JoinType(joinType), usingColumns.toIndexedSeq), + None, + JoinHint.NONE) + } } - } /** * Inner join with another `DataFrame`, using the given join expression. @@ -1177,7 +1184,7 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.0.0 */ - def join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame = { + def join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame = withOrigin() { withPlan { resolveSelfJoinCondition(right, Some(joinExprs), joinType) } @@ -1193,8 +1200,10 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.1.0 */ - def crossJoin(right: Dataset[_]): DataFrame = withPlan { - Join(logicalPlan, right.logicalPlan, joinType = Cross, None, JoinHint.NONE) + def crossJoin(right: Dataset[_]): DataFrame = withOrigin() { + withPlan { + Join(logicalPlan, right.logicalPlan, joinType = Cross, None, JoinHint.NONE) + } } /** @@ -1218,27 +1227,28 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { - // Creates a Join node and resolve it first, to get join condition resolved, self-join resolved, - // etc. - val joined = sparkSession.sessionState.executePlan( - Join( - this.logicalPlan, - other.logicalPlan, - JoinType(joinType), - Some(condition.expr), - JoinHint.NONE)).analyzed.asInstanceOf[Join] - - implicit val tuple2Encoder: Encoder[(T, U)] = - ExpressionEncoder.tuple(this.exprEnc, other.exprEnc) - - withTypedPlan(JoinWith.typedJoinWith( - joined, - sqlContext.conf.dataFrameSelfJoinAutoResolveAmbiguity, - sparkSession.sessionState.analyzer.resolver, - this.exprEnc.isSerializedAsStructForTopLevel, - other.exprEnc.isSerializedAsStructForTopLevel)) - } + def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = + withOrigin() { + // Creates a Join node and resolve it first, to get join condition resolved, self-join + // resolved, etc. + val joined = sparkSession.sessionState.executePlan( + Join( + this.logicalPlan, + other.logicalPlan, + JoinType(joinType), + Some(condition.expr), + JoinHint.NONE)).analyzed.asInstanceOf[Join] + + implicit val tuple2Encoder: Encoder[(T, U)] = + ExpressionEncoder.tuple(this.exprEnc, other.exprEnc) + + withTypedPlan(JoinWith.typedJoinWith( + joined, + sqlContext.conf.dataFrameSelfJoinAutoResolveAmbiguity, + sparkSession.sessionState.analyzer.resolver, + this.exprEnc.isSerializedAsStructForTopLevel, + other.exprEnc.isSerializedAsStructForTopLevel)) + } /** * Using inner equi-join to join this Dataset returning a `Tuple2` for each pair @@ -1421,14 +1431,16 @@ class Dataset[T] private[sql]( * @since 2.2.0 */ @scala.annotation.varargs - def hint(name: String, parameters: Any*): Dataset[T] = withTypedPlan { - val exprs = parameters.map { - case c: Column => c.expr - case s: Symbol => Column(s.name).expr - case e: Expression => e - case literal => Literal(literal) - }.toSeq - UnresolvedHint(name, exprs, logicalPlan) + def hint(name: String, parameters: Any*): Dataset[T] = withOrigin() { + withTypedPlan { + val exprs = parameters.map { + case c: Column => c.expr + case s: Symbol => Column(s.name).expr + case e: Expression => e + case literal => Literal(literal) + }.toSeq + UnresolvedHint(name, exprs, logicalPlan) + } } /** @@ -1511,8 +1523,10 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def as(alias: String): Dataset[T] = withTypedPlan { - SubqueryAlias(alias, logicalPlan) + def as(alias: String): Dataset[T] = withOrigin() { + withTypedPlan { + SubqueryAlias(alias, logicalPlan) + } } /** @@ -1549,25 +1563,28 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def select(cols: Column*): DataFrame = withPlan { - val untypedCols = cols.map { - case typedCol: TypedColumn[_, _] => - // Checks if a `TypedColumn` has been inserted with - // specific input type and schema by `withInputType`. - val needInputType = typedCol.expr.exists { - case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty => true - case _ => false - } + def select(cols: Column*): DataFrame = withOrigin() { + withPlan { + val untypedCols = cols.map { + case typedCol: TypedColumn[_, _] => + // Checks if a `TypedColumn` has been inserted with + // specific input type and schema by `withInputType`. + val needInputType = typedCol.expr.exists { + case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty => true + case _ => false + } - if (!needInputType) { - typedCol - } else { - throw QueryCompilationErrors.cannotPassTypedColumnInUntypedSelectError(typedCol.toString) - } + if (!needInputType) { + typedCol + } else { + throw + QueryCompilationErrors.cannotPassTypedColumnInUntypedSelectError(typedCol.toString) + } - case other => other + case other => other + } + Project(untypedCols.map(_.named), logicalPlan) } - Project(untypedCols.map(_.named), logicalPlan) } /** @@ -1584,7 +1601,9 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)) : _*) + def select(col: String, cols: String*): DataFrame = withOrigin() { + select((col +: cols).map(Column(_)) : _*) + } /** * Selects a set of SQL expressions. This is a variant of `select` that accepts @@ -1600,10 +1619,12 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def selectExpr(exprs: String*): DataFrame = sparkSession.withActive { - select(exprs.map { expr => - Column(sparkSession.sessionState.sqlParser.parseExpression(expr)) - }: _*) + def selectExpr(exprs: String*): DataFrame = withOrigin() { + sparkSession.withActive { + select(exprs.map { expr => + Column(sparkSession.sessionState.sqlParser.parseExpression(expr)) + }: _*) + } } /** @@ -1617,7 +1638,7 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { + def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = withOrigin() { implicit val encoder = c1.encoder val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan) @@ -1701,8 +1722,10 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def filter(condition: Column): Dataset[T] = withTypedPlan { - Filter(condition.expr, logicalPlan) + def filter(condition: Column): Dataset[T] = withOrigin() { + withTypedPlan { + Filter(condition.expr, logicalPlan) + } } /** @@ -2061,15 +2084,17 @@ class Dataset[T] private[sql]( ids: Array[Column], values: Array[Column], variableColumnName: String, - valueColumnName: String): DataFrame = withPlan { - Unpivot( - Some(ids.map(_.named)), - Some(values.map(v => Seq(v.named))), - None, - variableColumnName, - Seq(valueColumnName), - logicalPlan - ) + valueColumnName: String): DataFrame = withOrigin() { + withPlan { + Unpivot( + Some(ids.map(_.named)), + Some(values.map(v => Seq(v.named))), + None, + variableColumnName, + Seq(valueColumnName), + logicalPlan + ) + } } /** @@ -2092,15 +2117,17 @@ class Dataset[T] private[sql]( def unpivot( ids: Array[Column], variableColumnName: String, - valueColumnName: String): DataFrame = withPlan { - Unpivot( - Some(ids.map(_.named)), - None, - None, - variableColumnName, - Seq(valueColumnName), - logicalPlan - ) + valueColumnName: String): DataFrame = withOrigin() { + withPlan { + Unpivot( + Some(ids.map(_.named)), + None, + None, + variableColumnName, + Seq(valueColumnName), + logicalPlan + ) + } } /** @@ -2217,8 +2244,10 @@ class Dataset[T] private[sql]( * @since 3.0.0 */ @varargs - def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = withTypedPlan { - CollectMetrics(name, (expr +: exprs).map(_.named), logicalPlan, id) + def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = withOrigin() { + withTypedPlan { + CollectMetrics(name, (expr +: exprs).map(_.named), logicalPlan, id) + } } /** @@ -2255,8 +2284,10 @@ class Dataset[T] private[sql]( * @group typedrel * @since 2.0.0 */ - def limit(n: Int): Dataset[T] = withTypedPlan { - Limit(Literal(n), logicalPlan) + def limit(n: Int): Dataset[T] = withOrigin() { + withTypedPlan { + Limit(Literal(n), logicalPlan) + } } /** @@ -2265,8 +2296,10 @@ class Dataset[T] private[sql]( * @group typedrel * @since 3.4.0 */ - def offset(n: Int): Dataset[T] = withTypedPlan { - Offset(Literal(n), logicalPlan) + def offset(n: Int): Dataset[T] = withOrigin() { + withTypedPlan { + Offset(Literal(n), logicalPlan) + } } // This breaks caching, but it's usually ok because it addresses a very specific use case: @@ -2676,20 +2709,21 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @deprecated("use flatMap() or select() with functions.explode() instead", "2.0.0") - def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = { - val elementSchema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] + def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = + withOrigin() { + val elementSchema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] - val convert = CatalystTypeConverters.createToCatalystConverter(elementSchema) + val convert = CatalystTypeConverters.createToCatalystConverter(elementSchema) - val rowFunction = - f.andThen(_.map(convert(_).asInstanceOf[InternalRow])) - val generator = UserDefinedGenerator(elementSchema, rowFunction, input.map(_.expr)) + val rowFunction = + f.andThen(_.map(convert(_).asInstanceOf[InternalRow])) + val generator = UserDefinedGenerator(elementSchema, rowFunction, input.map(_.expr)) - withPlan { - Generate(generator, unrequiredChildIndex = Nil, outer = false, - qualifier = None, generatorOutput = Nil, logicalPlan) + withPlan { + Generate(generator, unrequiredChildIndex = Nil, outer = false, + qualifier = None, generatorOutput = Nil, logicalPlan) + } } - } /** * (Scala-specific) Returns a new Dataset where a single column has been expanded to zero @@ -2714,7 +2748,7 @@ class Dataset[T] private[sql]( */ @deprecated("use flatMap() or select() with functions.explode() instead", "2.0.0") def explode[A, B : TypeTag](inputColumn: String, outputColumn: String)(f: A => TraversableOnce[B]) - : DataFrame = { + : DataFrame = withOrigin() { val dataType = ScalaReflection.schemaFor[B].dataType val attributes = AttributeReference(outputColumn, dataType)() :: Nil // TODO handle the metadata? @@ -2871,7 +2905,7 @@ class Dataset[T] private[sql]( * @since 3.4.0 */ @throws[AnalysisException] - def withColumnsRenamed(colsMap: Map[String, String]): DataFrame = { + def withColumnsRenamed(colsMap: Map[String, String]): DataFrame = withOrigin() { val resolver = sparkSession.sessionState.analyzer.resolver val output: Seq[NamedExpression] = queryExecution.analyzed.output @@ -3085,9 +3119,11 @@ class Dataset[T] private[sql]( * @group typedrel * @since 2.0.0 */ - def dropDuplicates(colNames: Seq[String]): Dataset[T] = withTypedPlan { - val groupCols = groupColsFromDropDuplicates(colNames) - Deduplicate(groupCols, logicalPlan) + def dropDuplicates(colNames: Seq[String]): Dataset[T] = withOrigin() { + withTypedPlan { + val groupCols = groupColsFromDropDuplicates(colNames) + Deduplicate(groupCols, logicalPlan) + } } /** @@ -3163,10 +3199,12 @@ class Dataset[T] private[sql]( * @group typedrel * @since 3.5.0 */ - def dropDuplicatesWithinWatermark(colNames: Seq[String]): Dataset[T] = withTypedPlan { - val groupCols = groupColsFromDropDuplicates(colNames) - // UnsupportedOperationChecker will fail the query if this is called with batch Dataset. - DeduplicateWithinWatermark(groupCols, logicalPlan) + def dropDuplicatesWithinWatermark(colNames: Seq[String]): Dataset[T] = withOrigin() { + withTypedPlan { + val groupCols = groupColsFromDropDuplicates(colNames) + // UnsupportedOperationChecker will fail the query if this is called with batch Dataset. + DeduplicateWithinWatermark(groupCols, logicalPlan) + } } /** @@ -3390,7 +3428,7 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def filter(func: T => Boolean): Dataset[T] = { + def filter(func: T => Boolean): Dataset[T] = withOrigin() { withTypedPlan(TypedFilter(func, logicalPlan)) } @@ -3401,7 +3439,7 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def filter(func: FilterFunction[T]): Dataset[T] = { + def filter(func: FilterFunction[T]): Dataset[T] = withOrigin() { withTypedPlan(TypedFilter(func, logicalPlan)) } @@ -3412,8 +3450,10 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan { - MapElements[T, U](func, logicalPlan) + def map[U : Encoder](func: T => U): Dataset[U] = withOrigin() { + withTypedPlan { + MapElements[T, U](func, logicalPlan) + } } /** @@ -3423,7 +3463,7 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { + def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = withOrigin() { implicit val uEnc = encoder withTypedPlan(MapElements[T, U](func, logicalPlan)) } @@ -3586,8 +3626,9 @@ class Dataset[T] private[sql]( * @group action * @since 3.0.0 */ - def tail(n: Int): Array[T] = withAction( - "tail", withTypedPlan(Tail(Literal(n), logicalPlan)).queryExecution)(collectFromPlan) + def tail(n: Int): Array[T] = withOrigin() { + withAction("tail", withTypedPlan(Tail(Literal(n), logicalPlan)).queryExecution)(collectFromPlan) + } /** * Returns the first `n` rows in the Dataset as a list. @@ -3651,8 +3692,10 @@ class Dataset[T] private[sql]( * @group action * @since 1.6.0 */ - def count(): Long = withAction("count", groupBy().count().queryExecution) { plan => - plan.executeCollect().head.getLong(0) + def count(): Long = withOrigin() { + withAction("count", groupBy().count().queryExecution) { plan => + plan.executeCollect().head.getLong(0) + } } /** @@ -3661,13 +3704,15 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def repartition(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = true, logicalPlan) + def repartition(numPartitions: Int): Dataset[T] = withOrigin() { + withTypedPlan { + Repartition(numPartitions, shuffle = true, logicalPlan) + } } private def repartitionByExpression( numPartitions: Option[Int], - partitionExprs: Seq[Column]): Dataset[T] = { + partitionExprs: Seq[Column]): Dataset[T] = withOrigin(1) { // The underlying `LogicalPlan` operator special-cases all-`SortOrder` arguments. // However, we don't want to complicate the semantics of this API method. // Instead, let's give users a friendly error message, pointing them to the new method. @@ -3712,7 +3757,7 @@ class Dataset[T] private[sql]( private def repartitionByRange( numPartitions: Option[Int], - partitionExprs: Seq[Column]): Dataset[T] = { + partitionExprs: Seq[Column]): Dataset[T] = withOrigin(1) { require(partitionExprs.nonEmpty, "At least one partition-by expression must be specified.") val sortOrder: Seq[SortOrder] = partitionExprs.map(_.expr match { case expr: SortOrder => expr @@ -3784,8 +3829,10 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def coalesce(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = false, logicalPlan) + def coalesce(numPartitions: Int): Dataset[T] = withOrigin() { + withTypedPlan { + Repartition(numPartitions, shuffle = false, logicalPlan) + } } /** @@ -3932,8 +3979,10 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @throws[AnalysisException] - def createTempView(viewName: String): Unit = withPlan { - createTempViewCommand(viewName, replace = false, global = false) + def createTempView(viewName: String): Unit = withOrigin() { + withPlan { + createTempViewCommand(viewName, replace = false, global = false) + } } @@ -3945,8 +3994,10 @@ class Dataset[T] private[sql]( * @group basic * @since 2.0.0 */ - def createOrReplaceTempView(viewName: String): Unit = withPlan { - createTempViewCommand(viewName, replace = true, global = false) + def createOrReplaceTempView(viewName: String): Unit = withOrigin() { + withPlan { + createTempViewCommand(viewName, replace = true, global = false) + } } /** @@ -3964,8 +4015,10 @@ class Dataset[T] private[sql]( * @since 2.1.0 */ @throws[AnalysisException] - def createGlobalTempView(viewName: String): Unit = withPlan { - createTempViewCommand(viewName, replace = false, global = true) + def createGlobalTempView(viewName: String): Unit = withOrigin() { + withPlan { + createTempViewCommand(viewName, replace = false, global = true) + } } /** @@ -4373,7 +4426,7 @@ class Dataset[T] private[sql]( plan.executeCollect().map(fromRow) } - private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = { + private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = withOrigin() { val sortOrder: Seq[SortOrder] = sortExprs.map { col => col.expr match { case expr: SortOrder => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 41230c7792c5..b2a7f96b3b41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.execution +import org.apache.spark.QueryContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, ExprId, InSet, ListQuery, Literal, PlanExpression, Predicate, SupportQueryContext} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.{LeafLike, SQLQueryContext, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{LeafLike, UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf @@ -68,7 +69,7 @@ case class ScalarSubquery( override def nullable: Boolean = true override def toString: String = plan.simpleString(SQLConf.get.maxToStringFields) override def withNewPlan(query: BaseSubqueryExec): ScalarSubquery = copy(plan = query) - def initQueryContext(): Option[SQLQueryContext] = Some(origin.context) + def initQueryContext(): Option[QueryContext] = Some(origin.context) override lazy val canonicalized: Expression = { ScalarSubquery(plan.canonicalized.asInstanceOf[BaseSubqueryExec], ExprId(0)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 2a7ed263c748..637be95da73b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -82,11 +82,13 @@ import org.apache.spark.util.Utils object functions { // scalastyle:on - private def withExpr(expr: Expression): Column = Column(expr) + private def withExpr(expr: => Expression): Column = withOrigin(1) { + Column(expr) + } private def withAggregateFunction( - func: AggregateFunction, - isDistinct: Boolean = false): Column = { + func: => AggregateFunction, + isDistinct: Boolean = false): Column = withOrigin(1) { Column(func.toAggregateExpression(isDistinct)) } @@ -116,16 +118,18 @@ object functions { * @group normal_funcs * @since 1.3.0 */ - def lit(literal: Any): Column = literal match { - case c: Column => c - case s: Symbol => new ColumnName(s.name) - case _ => - // This is different from `typedlit`. `typedlit` calls `Literal.create` to use - // `ScalaReflection` to get the type of `literal`. However, since we use `Any` in this method, - // `typedLit[Any](literal)` will always fail and fallback to `Literal.apply`. Hence, we can - // just manually call `Literal.apply` to skip the expensive `ScalaReflection` code. This is - // significantly better when there are many threads calling `lit` concurrently. + def lit(literal: Any): Column = withOrigin() { + literal match { + case c: Column => c + case s: Symbol => new ColumnName(s.name) + case _ => + // This is different from `typedlit`. `typedlit` calls `Literal.create` to use + // `ScalaReflection` to get the type of `literal`. However, since we use `Any` in this + // method, `typedLit[Any](literal)` will always fail and fallback to `Literal.apply`. Hence, + // we can just manually call `Literal.apply` to skip the expensive `ScalaReflection` code. + // This is significantly better when there are many threads calling `lit` concurrently. Column(Literal(literal)) + } } /** @@ -136,7 +140,9 @@ object functions { * @group normal_funcs * @since 2.2.0 */ - def typedLit[T : TypeTag](literal: T): Column = typedlit(literal) + def typedLit[T : TypeTag](literal: T): Column = withOrigin() { + typedlit(literal) + } /** * Creates a [[Column]] of literal value. @@ -153,10 +159,12 @@ object functions { * @group normal_funcs * @since 3.2.0 */ - def typedlit[T : TypeTag](literal: T): Column = literal match { - case c: Column => c - case s: Symbol => new ColumnName(s.name) - case _ => Column(Literal.create(literal)) + def typedlit[T : TypeTag](literal: T): Column = withOrigin() { + literal match { + case c: Column => c + case s: Symbol => new ColumnName(s.name) + case _ => Column(Literal.create(literal)) + } } ////////////////////////////////////////////////////////////////////////////////////////////// @@ -5946,25 +5954,31 @@ object functions { def array_except(col1: Column, col2: Column): Column = Column.fn("array_except", col1, col2) - private def createLambda(f: Column => Column) = Column { - val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) - val function = f(Column(x)).expr - LambdaFunction(function, Seq(x)) + private def createLambda(f: Column => Column) = withOrigin(1) { + Column { + val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) + val function = f(Column(x)).expr + LambdaFunction(function, Seq(x)) + } } - private def createLambda(f: (Column, Column) => Column) = Column { - val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) - val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y"))) - val function = f(Column(x), Column(y)).expr - LambdaFunction(function, Seq(x, y)) + private def createLambda(f: (Column, Column) => Column) = withOrigin(1) { + Column { + val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) + val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y"))) + val function = f(Column(x), Column(y)).expr + LambdaFunction(function, Seq(x, y)) + } } - private def createLambda(f: (Column, Column, Column) => Column) = Column { - val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) - val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y"))) - val z = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("z"))) - val function = f(Column(x), Column(y), Column(z)).expr - LambdaFunction(function, Seq(x, y, z)) + private def createLambda(f: (Column, Column, Column) => Column) = withOrigin(1) { + Column { + val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) + val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y"))) + val z = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("z"))) + val function = f(Column(x), Column(y), Column(z)).expr + LambdaFunction(function, Seq(x, y, z)) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 1794ac513749..906fe2ab1916 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -17,7 +17,10 @@ package org.apache.spark +import java.util.regex.Pattern + import org.apache.spark.annotation.{DeveloperApi, Unstable} +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} import org.apache.spark.sql.execution.SparkStrategy /** @@ -73,4 +76,43 @@ package object sql { * with rebasing. */ private[sql] val SPARK_LEGACY_INT96_METADATA_KEY = "org.apache.spark.legacyINT96" + + /** + * This helper function captures the Spark API and its call site in the user code from the current + * stacktrace. + * + * As adding `withOrigin` explicitly to all Spark API definition would be a huge change, + * `withOrigin` is used only at certain places where all API implementation surely pass through + * and the current stacktrace is filtered to the point where first Spark API code is invoked from + * the user code. + * + * As there might be multiple nested `withOrigin` calls (e.g. any Spark API implementations can + * invoke other APIs) only the first `withOrigin` is captured because that is closer to the user + * code. + * + * @param framesToDrop the number of stack frames we can surely drop before searching for the user + * code + * @param f the function that can use the origin + * @return the result of `f` + */ + private[sql] def withOrigin[T](framesToDrop: Int = 0)(f: => T): T = { + if (CurrentOrigin.get.stackTrace.isDefined) { + f + } else { + val st = Thread.currentThread().getStackTrace + var i = framesToDrop + 3 + while (sparkCode(st(i))) i += 1 + val origin = + Origin(stackTrace = Some(Thread.currentThread().getStackTrace.slice(i - 1, i + 1))) + CurrentOrigin.withOrigin(origin)(f) + } + } + + private val sparkCodePattern = Pattern.compile("org\\.apache\\.spark\\.sql\\." + + "(?:functions|Column|ColumnName|SQLImplicits|Dataset|DataFrameStatFunctions)" + + "(?:|\\..*|\\$.*)") + + private def sparkCode(ste: StackTraceElement): Boolean = { + sparkCodePattern.matcher(ste.getClassName).matches() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 0baded3323c6..ceddc89849f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -458,7 +458,8 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { parameters = Map( "functionName" -> "`in`", "dataType" -> "[\"INT\", \"ARRAY\"]", - "sqlExpr" -> "\"(a IN (b))\"") + "sqlExpr" -> "\"(a IN (b))\""), + context = ExpectedContext(code = "isin", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -525,7 +526,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { parameters = Map( "functionName" -> "`in`", "dataType" -> "[\"INT\", \"ARRAY\"]", - "sqlExpr" -> "\"(a IN (b))\"") + "sqlExpr" -> "\"(a IN (b))\""), + context = ExpectedContext( + code = "isInCollection", + callSitePattern = getCurrentClassCallSitePattern) ) } } @@ -1056,7 +1060,9 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"key\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"STRUCT\"") + "requiredType" -> "\"STRUCT\""), + context = + ExpectedContext(code = "withField", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1101,7 +1107,9 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"a.b\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"STRUCT\"") + "requiredType" -> "\"STRUCT\""), + context = + ExpectedContext(code = "withField", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1849,7 +1857,9 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"key\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"STRUCT\"") + "requiredType" -> "\"STRUCT\""), + context = + ExpectedContext(code = "dropFields", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1886,7 +1896,9 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"a.b\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"STRUCT\"") + "requiredType" -> "\"STRUCT\""), + context = + ExpectedContext(code = "dropFields", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1952,7 +1964,9 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { structLevel1.withColumn("a", $"a".dropFields("a", "b", "c")) }, errorClass = "DATATYPE_MISMATCH.CANNOT_DROP_ALL_FIELDS", - parameters = Map("sqlExpr" -> "\"update_fields(a, dropfield(), dropfield(), dropfield())\"") + parameters = Map("sqlExpr" -> "\"update_fields(a, dropfield(), dropfield(), dropfield())\""), + context = + ExpectedContext(code = "dropFields", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -2224,7 +2238,9 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { .select($"struct_col".dropFields("a", "b")) }, errorClass = "DATATYPE_MISMATCH.CANNOT_DROP_ALL_FIELDS", - parameters = Map("sqlExpr" -> "\"update_fields(struct_col, dropfield(), dropfield())\"") + parameters = Map("sqlExpr" -> "\"update_fields(struct_col, dropfield(), dropfield())\""), + context = + ExpectedContext(code = "dropFields", callSitePattern = getCurrentClassCallSitePattern) ) checkAnswer( @@ -2398,7 +2414,9 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { structLevel1.select($"a".withField("d", lit(4)).withField("e", $"a.d" + 1).as("a")) }, errorClass = "FIELD_NOT_FOUND", - parameters = Map("fieldName" -> "`d`", "fields" -> "`a`, `b`, `c`")) + parameters = Map("fieldName" -> "`d`", "fields" -> "`a`, `b`, `c`"), + context = + ExpectedContext(code = "$", callSitePattern = getCurrentClassCallSitePattern)) checkAnswer( structLevel1 @@ -2451,7 +2469,9 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { .select($"a".withField("z", $"a.c")).as("a") }, errorClass = "FIELD_NOT_FOUND", - parameters = Map("fieldName" -> "`c`", "fields" -> "`a`, `b`")) + parameters = Map("fieldName" -> "`c`", "fields" -> "`a`, `b`"), + context = + ExpectedContext(code = "$", callSitePattern = getCurrentClassCallSitePattern)) } test("nestedDf should generate nested DataFrames") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index 77b9b3808526..4f600c816a0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -52,7 +52,8 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "inputSchema" -> "\"ARRAY\"", "dataType" -> "\"ARRAY\"" - ) + ), + context = ExpectedContext(code = "from_csv", getCurrentClassCallSitePattern) ) checkError( @@ -395,7 +396,8 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { .select(from_csv($"csv", $"schema", options)).collect() }, errorClass = "INVALID_SCHEMA.NON_STRING_LITERAL", - parameters = Map("inputSchema" -> "\"schema\"") + parameters = Map("inputSchema" -> "\"schema\""), + context = ExpectedContext(code = "from_csv", getCurrentClassCallSitePattern) ) checkError( @@ -403,7 +405,8 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { Seq("1").toDF("csv").select(from_csv($"csv", lit(1), options)).collect() }, errorClass = "INVALID_SCHEMA.NON_STRING_LITERAL", - parameters = Map("inputSchema" -> "\"1\"") + parameters = Map("inputSchema" -> "\"1\""), + context = ExpectedContext(code = "from_csv", getCurrentClassCallSitePattern) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 80862eec41e0..30f317b15e32 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -633,7 +633,9 @@ class DataFrameAggregateSuite extends QueryTest "functionName" -> "`collect_set`", "dataType" -> "\"MAP\"", "sqlExpr" -> "\"collect_set(b)\"" - ) + ), + context = + ExpectedContext(code = "collect_set", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -706,7 +708,8 @@ class DataFrameAggregateSuite extends QueryTest testData.groupBy(sum($"key")).count() }, errorClass = "GROUP_BY_AGGREGATE", - parameters = Map("sqlExpr" -> "sum(key)") + parameters = Map("sqlExpr" -> "sum(key)"), + context = ExpectedContext(code = "sum", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1302,7 +1305,8 @@ class DataFrameAggregateSuite extends QueryTest "paramIndex" -> "2", "inputSql" -> "\"a\"", "inputType" -> "\"STRING\"", - "requiredType" -> "\"INTEGRAL\"")) + "requiredType" -> "\"INTEGRAL\""), + context = ExpectedContext(code = "$", callSitePattern = getCurrentClassCallSitePattern)) } test("SPARK-34716: Support ANSI SQL intervals by the aggregate function `sum`") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 4020688bc319..ef9353bedf1a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -171,7 +171,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "requiredType" -> "\"ARRAY\"", "inputSql" -> "\"k\"", "inputType" -> "\"INT\"" - ) + ), + queryContext = Array( + ExpectedContext(code = "map_from_arrays", callSitePattern = getCurrentClassCallSitePattern)) ) val df5 = Seq((Seq("a", null), Seq(1, 2))).toDF("k", "v") @@ -758,7 +760,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { test("The given function only supports array input") { val df = Seq(1, 2, 3).toDF("a") - checkErrorMatchPVals( + checkError( exception = intercept[AnalysisException] { df.select(array_sort(col("a"), (x, y) => x - y)) }, @@ -769,7 +771,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "requiredType" -> "\"ARRAY\"", "inputSql" -> "\"a\"", "inputType" -> "\"INT\"" - )) + ), + matchPVals = true, + queryContext = Array( + ExpectedContext(code = "array_sort", callSitePattern = getCurrentClassCallSitePattern)) + ) } test("sort_array/array_sort functions") { @@ -1305,7 +1311,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"map_concat(map1, map2)\"", "dataType" -> "(\"MAP, INT>\" or \"MAP\")", - "functionName" -> "`map_concat`") + "functionName" -> "`map_concat`"), + context = + ExpectedContext(code = "map_concat", callSitePattern = getCurrentClassCallSitePattern) ) checkError( @@ -1333,7 +1341,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"map_concat(map1, 12)\"", "dataType" -> "[\"MAP, INT>\", \"INT\"]", - "functionName" -> "`map_concat`") + "functionName" -> "`map_concat`"), + context = + ExpectedContext(code = "map_concat", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1402,7 +1412,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"a\"", "inputType" -> "\"INT\"", "requiredType" -> "\"ARRAY\" of pair \"STRUCT\"" - ) + ), + context = + ExpectedContext(code = "map_from_entries", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1439,7 +1451,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"array_contains(a, NULL)\"", "functionName" -> "`array_contains`" - ) + ), + context = + ExpectedContext(code = "array_contains", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -2348,7 +2362,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "functionName" -> "`array_union`", "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", - "rightType" -> "\"ARRAY\"")) + "rightType" -> "\"ARRAY\""), + context = + ExpectedContext(code = "array_union", callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = intercept[AnalysisException] { @@ -2379,7 +2395,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "functionName" -> "`array_union`", "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", - "rightType" -> "\"VOID\"") + "rightType" -> "\"VOID\""), + context = + ExpectedContext(code = "array_union", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -2410,7 +2428,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "functionName" -> "`array_union`", "arrayType" -> "\"ARRAY\"", "leftType" -> "\"ARRAY>\"", - "rightType" -> "\"ARRAY\"") + "rightType" -> "\"ARRAY\""), + queryContext = Array( + ExpectedContext(code = "array_union", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -2647,7 +2667,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"arr\"", "inputType" -> "\"ARRAY\"", "requiredType" -> "\"ARRAY\" of \"ARRAY\"" - ) + ), + context = + ExpectedContext(code = "flatten", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -2660,7 +2682,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", "requiredType" -> "\"ARRAY\" of \"ARRAY\"" - ) + ), + queryContext = Array( + ExpectedContext(code = "flatten", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -2673,7 +2697,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"s\"", "inputType" -> "\"STRING\"", "requiredType" -> "\"ARRAY\" of \"ARRAY\"" - ) + ), + queryContext = Array( + ExpectedContext(code = "flatten", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -2782,7 +2808,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"b\"", "inputType" -> "\"STRING\"", "requiredType" -> "\"INT\"" - ) + ), + context = + ExpectedContext(code = "array_repeat", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -2795,7 +2823,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"1\"", "inputType" -> "\"STRING\"", "requiredType" -> "\"INT\"" - ) + ), + queryContext = Array( + ExpectedContext(code = "array_repeat", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3123,7 +3153,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", "rightType" -> "\"VOID\"" - ) + ), + context = + ExpectedContext(code = "array_except", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -3151,7 +3183,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"ARRAY\"", "rightType" -> "\"ARRAY\"" - ) + ), + queryContext = Array( + ExpectedContext(code = "array_except", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3179,7 +3213,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"ARRAY\"", "rightType" -> "\"VOID\"" - ) + ), + queryContext = Array( + ExpectedContext(code = "array_except", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3207,7 +3243,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", "rightType" -> "\"ARRAY\"" - ) + ), + queryContext = Array( + ExpectedContext(code = "array_except", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3276,7 +3314,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", "rightType" -> "\"VOID\"" - ) + ), + context = + ExpectedContext(code = "array_intersect", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -3305,7 +3345,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"ARRAY\"", "rightType" -> "\"ARRAY\"" - ) + ), + queryContext = Array( + ExpectedContext(code = "array_intersect", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3334,7 +3376,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", "rightType" -> "\"ARRAY\"" - ) + ), + queryContext = Array( + ExpectedContext(code = "array_intersect", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3750,7 +3794,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"MAP\"")) + "requiredType" -> "\"MAP\""), + queryContext = Array( + ExpectedContext(code = "map_filter", callSitePattern = getCurrentClassCallSitePattern))) checkError( exception = @@ -3933,7 +3979,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"ARRAY\"")) + "requiredType" -> "\"ARRAY\""), + queryContext = Array( + ExpectedContext(code = "filter", callSitePattern = getCurrentClassCallSitePattern))) checkError( exception = intercept[AnalysisException] { @@ -3945,7 +3993,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "2", "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", - "requiredType" -> "\"BOOLEAN\"")) + "requiredType" -> "\"BOOLEAN\""), + ExpectedContext( + fragment = "filter(s, x -> x)", + start = 0, + stop = 16)) checkError( exception = intercept[AnalysisException] { @@ -3957,7 +4009,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "2", "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", - "requiredType" -> "\"BOOLEAN\"")) + "requiredType" -> "\"BOOLEAN\""), + context = ExpectedContext(code = "filter", callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = @@ -4112,7 +4165,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"ARRAY\"")) + "requiredType" -> "\"ARRAY\""), + queryContext = Array( + ExpectedContext(code = "exists", callSitePattern = getCurrentClassCallSitePattern))) checkError( exception = intercept[AnalysisException] { @@ -4124,7 +4179,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "2", "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", - "requiredType" -> "\"BOOLEAN\"")) + "requiredType" -> "\"BOOLEAN\""), + context = ExpectedContext( + fragment = "exists(s, x -> x)", + start = 0, + stop = 16) + ) checkError( exception = intercept[AnalysisException] { @@ -4136,7 +4196,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "2", "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", - "requiredType" -> "\"BOOLEAN\"")) + "requiredType" -> "\"BOOLEAN\""), + context = + ExpectedContext(code = "exists", callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = intercept[AnalysisException](df.selectExpr("exists(a, x -> x)")), @@ -4304,7 +4366,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"ARRAY\"")) + "requiredType" -> "\"ARRAY\""), + queryContext = Array( + ExpectedContext(code = "forall", callSitePattern = getCurrentClassCallSitePattern))) checkError( exception = intercept[AnalysisException] { @@ -4316,7 +4380,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "2", "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", - "requiredType" -> "\"BOOLEAN\"")) + "requiredType" -> "\"BOOLEAN\""), + context = ExpectedContext( + fragment = "forall(s, x -> x)", + start = 0, + stop = 16)) checkError( exception = intercept[AnalysisException] { @@ -4328,7 +4396,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "2", "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", - "requiredType" -> "\"BOOLEAN\"")) + "requiredType" -> "\"BOOLEAN\""), + context = + ExpectedContext(code = "forall", callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = intercept[AnalysisException](df.selectExpr("forall(a, x -> x)")), @@ -4343,7 +4413,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkError( exception = intercept[AnalysisException](df.select(forall(col("a"), x => x))), errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", - parameters = Map("objectName" -> "`a`", "proposal" -> "`i`, `s`")) + parameters = Map("objectName" -> "`a`", "proposal" -> "`i`, `s`"), + queryContext = Array( + ExpectedContext(code = "col", callSitePattern = getCurrentClassCallSitePattern))) } test("aggregate function - array for primitive type not containing null") { @@ -4581,7 +4653,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"ARRAY\"")) + "requiredType" -> "\"ARRAY\""), + queryContext = Array( + ExpectedContext(code = "aggregate", callSitePattern = getCurrentClassCallSitePattern))) // scalastyle:on line.size.limit // scalastyle:off line.size.limit @@ -4597,7 +4671,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", "requiredType" -> "\"INT\"" - )) + ), + context = ExpectedContext( + fragment = s"$agg(s, 0, (acc, x) -> x)", + start = 0, + stop = agg.length + 20)) } // scalastyle:on line.size.limit @@ -4613,7 +4691,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", "requiredType" -> "\"INT\"" - )) + ), + context = + ExpectedContext(code = "aggregate", callSitePattern = getCurrentClassCallSitePattern)) // scalastyle:on line.size.limit Seq("aggregate", "reduce").foreach { agg => @@ -4719,7 +4799,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "sqlExpr" -> """"map_zip_with\(mis, mmi, lambdafunction\(concat\(x_\d+, y_\d+, z_\d+\), x_\d+, y_\d+, z_\d+\)\)"""", "functionName" -> "`map_zip_with`", "leftType" -> "\"INT\"", - "rightType" -> "\"MAP\"")) + "rightType" -> "\"MAP\""), + queryContext = Array( + ExpectedContext(code = "map_zip_with", callSitePattern = getCurrentClassCallSitePattern))) // scalastyle:on line.size.limit checkError( @@ -4749,7 +4831,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "sqlExpr" -> """"map_zip_with\(i, mis, lambdafunction\(concat\(x_\d+, y_\d+, z_\d+\), x_\d+, y_\d+, z_\d+\)\)"""", "paramIndex" -> "1", "inputSql" -> "\"i\"", - "inputType" -> "\"INT\"", "requiredType" -> "\"MAP\"")) + "inputType" -> "\"INT\"", "requiredType" -> "\"MAP\""), + queryContext = Array( + ExpectedContext(code = "map_zip_with", callSitePattern = getCurrentClassCallSitePattern))) // scalastyle:on line.size.limit checkError( @@ -4779,7 +4863,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "sqlExpr" -> """"map_zip_with\(mis, i, lambdafunction\(concat\(x_\d+, y_\d+, z_\d+\), x_\d+, y_\d+, z_\d+\)\)"""", "paramIndex" -> "2", "inputSql" -> "\"i\"", - "inputType" -> "\"INT\"", "requiredType" -> "\"MAP\"")) + "inputType" -> "\"INT\"", "requiredType" -> "\"MAP\""), + queryContext = Array( + ExpectedContext(code = "map_zip_with", callSitePattern = getCurrentClassCallSitePattern))) // scalastyle:on line.size.limit checkError( @@ -5235,7 +5321,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"x\"", "inputType" -> "\"ARRAY\"", - "requiredType" -> "\"MAP\"")) + "requiredType" -> "\"MAP\""), + queryContext = Array( + ExpectedContext( + code = "transform_values", + callSitePattern = getCurrentClassCallSitePattern))) } testInvalidLambdaFunctions() @@ -5375,7 +5465,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"ARRAY\"")) + "requiredType" -> "\"ARRAY\""), + queryContext = Array( + ExpectedContext(code = "zip_with", callSitePattern = getCurrentClassCallSitePattern))) checkError( exception = @@ -5631,7 +5723,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"map(m, 1)\"", "keyType" -> "\"MAP\"" - ) + ), + context = + ExpectedContext(code = "map", callSitePattern = getCurrentClassCallSitePattern) ) checkAnswer( df.select(map(map_entries($"m"), lit(1))), @@ -5753,7 +5847,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "requiredType" -> "\"ARRAY\"", "inputSql" -> "\"a\"", "inputType" -> "\"INT\"" - )) + ), + context = + ExpectedContext(code = "array_compact", callSitePattern = getCurrentClassCallSitePattern)) } test("array_append -> Unit Test cases for the function ") { @@ -5772,7 +5868,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "dataType" -> "\"ARRAY\"", "leftType" -> "\"ARRAY\"", "rightType" -> "\"INT\"", - "sqlExpr" -> "\"array_append(a, b)\"") + "sqlExpr" -> "\"array_append(a, b)\""), + context = + ExpectedContext(code = "array_append", callSitePattern = getCurrentClassCallSitePattern) ) checkAnswer(df1.selectExpr("array_append(a, 3)"), Seq(Row(Seq(3, 2, 5, 1, 2, 3)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index eafd454439ca..c8c242a99d71 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -310,7 +310,8 @@ class DataFramePivotSuite extends QueryTest with SharedSparkSession { .agg(sum($"sales.earnings")) }, errorClass = "GROUP_BY_AGGREGATE", - parameters = Map("sqlExpr" -> "min(training)") + parameters = Map("sqlExpr" -> "min(training)"), + context = ExpectedContext(code = "min", callSitePattern = getCurrentClassCallSitePattern) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala index b8300390eddf..59dfbd94a6e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala @@ -484,7 +484,8 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { checkError(ex, errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map("objectName" -> "`df1`.`timeStr`", - "proposal" -> "`df3`.`timeStr`, `df1`.`tsStr`")) + "proposal" -> "`df3`.`timeStr`, `df1`.`tsStr`"), + context = ExpectedContext(code = "$", getCurrentClassCallSitePattern)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index 747f43fa2a74..c641e7f05ab3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -373,7 +373,8 @@ class DataFrameSetOperationsSuite extends QueryTest errorClass = "UNSUPPORTED_FEATURE.SET_OPERATION_ON_MAP_TYPE", parameters = Map( "colName" -> "`m`", - "dataType" -> "\"MAP\"") + "dataType" -> "\"MAP\""), + context = ExpectedContext(code = "distinct", callSitePattern = getCurrentClassCallSitePattern) ) withTempView("v") { df.createOrReplaceTempView("v") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 47ff942e5ca1..7ae290f7c7b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -146,7 +146,9 @@ class DataFrameStatSuite extends QueryTest with SharedSparkSession { parameters = Map( "name" -> "`num`", "referenceNames" -> "[`table1`.`num`, `table2`.`num`]" - ) + ), + context = + ExpectedContext(code = "freqItems", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -156,7 +158,9 @@ class DataFrameStatSuite extends QueryTest with SharedSparkSession { parameters = Map( "name" -> "`num`", "referenceNames" -> "[`table1`.`num`, `table2`.`num`]" - ) + ), + context = + ExpectedContext(code = "approxQuantile", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index c72bc9167759..10118011593b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -354,7 +354,8 @@ class DataFrameSuite extends QueryTest "paramIndex" -> "1", "inputSql"-> "\"csv\"", "inputType" -> "\"STRING\"", - "requiredType" -> "(\"ARRAY\" or \"MAP\")") + "requiredType" -> "(\"ARRAY\" or \"MAP\")"), + context = ExpectedContext(code = "explode", getCurrentClassCallSitePattern) ) val df2 = Seq(Array("1", "2"), Array("4"), Array("7", "8", "9")).toDF("csv") @@ -2947,7 +2948,8 @@ class DataFrameSuite extends QueryTest df.groupBy($"d", $"b").as[GroupByKey, Row] }, errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", - parameters = Map("objectName" -> "`d`", "proposal" -> "`a`, `b`, `c`")) + parameters = Map("objectName" -> "`d`", "proposal" -> "`a`, `b`, `c`"), + context = ExpectedContext(code = "$", callSitePattern = getCurrentClassCallSitePattern)) } test("SPARK-40601: flatMapCoGroupsInPandas should fail with different number of keys") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala index 2a81f7e7c2f3..d3764e7de5f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala @@ -184,7 +184,9 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { "sqlExpr" -> (""""\(ORDER BY key ASC NULLS FIRST, value ASC NULLS FIRST RANGE """ + """BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING\)"""") ), - matchPVals = true + matchPVals = true, + queryContext = + Array(ExpectedContext(code = "over", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( @@ -198,7 +200,9 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { "sqlExpr" -> (""""\(ORDER BY key ASC NULLS FIRST, value ASC NULLS FIRST RANGE """ + """BETWEEN -1 FOLLOWING AND UNBOUNDED FOLLOWING\)"""") ), - matchPVals = true + matchPVals = true, + queryContext = + Array(ExpectedContext(code = "over", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( @@ -212,7 +216,9 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { "sqlExpr" -> (""""\(ORDER BY key ASC NULLS FIRST, value ASC NULLS FIRST RANGE """ + """BETWEEN -1 FOLLOWING AND 1 FOLLOWING\)"""") ), - matchPVals = true + matchPVals = true, + queryContext = + Array(ExpectedContext(code = "over", callSitePattern = getCurrentClassCallSitePattern)) ) } @@ -240,7 +246,8 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { "expectedType" -> ("(\"NUMERIC\" or \"INTERVAL DAY TO SECOND\" or \"INTERVAL YEAR " + "TO MONTH\" or \"INTERVAL\")"), "sqlExpr" -> "\"RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING\"" - ) + ), + context = ExpectedContext(code = "over", callSitePattern = getCurrentClassCallSitePattern) ) checkError( @@ -255,7 +262,8 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { "expectedType" -> ("(\"NUMERIC\" or \"INTERVAL DAY TO SECOND\" or \"INTERVAL YEAR " + "TO MONTH\" or \"INTERVAL\")"), "sqlExpr" -> "\"RANGE BETWEEN -1 FOLLOWING AND UNBOUNDED FOLLOWING\"" - ) + ), + context = ExpectedContext(code = "over", callSitePattern = getCurrentClassCallSitePattern) ) checkError( @@ -270,7 +278,8 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { "expectedType" -> ("(\"NUMERIC\" or \"INTERVAL DAY TO SECOND\" or \"INTERVAL YEAR " + "TO MONTH\" or \"INTERVAL\")"), "sqlExpr" -> "\"RANGE BETWEEN -1 FOLLOWING AND 1 FOLLOWING\"" - ) + ), + context = ExpectedContext(code = "over", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -462,7 +471,8 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { "upper" -> "\"2\"", "lowerType" -> "\"INTERVAL\"", "upperType" -> "\"BIGINT\"" - ) + ), + context = ExpectedContext(code = "over", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -481,7 +491,8 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"RANGE BETWEEN nonfoldableliteral() FOLLOWING AND 2 FOLLOWING\"", "location" -> "lower", - "expression" -> "\"nonfoldableliteral()\"") + "expression" -> "\"nonfoldableliteral()\""), + context = ExpectedContext(code = "over", callSitePattern = getCurrentClassCallSitePattern) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index df3f3eaf7efe..10ec8a7da4da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -412,7 +412,8 @@ class DataFrameWindowFunctionsSuite extends QueryTest errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`invalid`", - "proposal" -> "`value`, `key`")) + "proposal" -> "`value`, `key`"), + context = ExpectedContext(code = "count", callSitePattern = getCurrentClassCallSitePattern)) } test("numerical aggregate functions on string column") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 324695349787..d026a0d41e00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -605,7 +605,8 @@ class DatasetSuite extends QueryTest } }, errorClass = "INVALID_USAGE_OF_STAR_OR_REGEX", - parameters = Map("elem" -> "'*'", "prettyName" -> "MapGroups")) + parameters = Map("elem" -> "'*'", "prettyName" -> "MapGroups"), + context = ExpectedContext(code = "$", getCurrentClassCallSitePattern)) } test("groupBy function, flatMapSorted") { @@ -633,7 +634,8 @@ class DatasetSuite extends QueryTest } }, errorClass = "INVALID_USAGE_OF_STAR_OR_REGEX", - parameters = Map("elem" -> "'*'", "prettyName" -> "MapGroups")) + parameters = Map("elem" -> "'*'", "prettyName" -> "MapGroups"), + context = ExpectedContext(code = "$", getCurrentClassCallSitePattern)) } test("groupBy, flatMapSorted desc") { @@ -2261,7 +2263,8 @@ class DatasetSuite extends QueryTest sqlState = None, parameters = Map( "objectName" -> s"`${colName.replace(".", "`.`")}`", - "proposal" -> "`field.1`, `field 2`")) + "proposal" -> "`field.1`, `field 2`"), + context = ExpectedContext(code = "select", getCurrentClassCallSitePattern)) } } } @@ -2275,7 +2278,8 @@ class DatasetSuite extends QueryTest sqlState = None, parameters = Map( "objectName" -> "`the`.`id`", - "proposal" -> "`the.id`")) + "proposal" -> "`the.id`"), + context = ExpectedContext(code = "select", getCurrentClassCallSitePattern)) } test("SPARK-39783: backticks in error message for map candidate key with dots") { @@ -2289,7 +2293,8 @@ class DatasetSuite extends QueryTest sqlState = None, parameters = Map( "objectName" -> "`nonexisting`", - "proposal" -> "`map`, `other.column`")) + "proposal" -> "`map`, `other.column`"), + context = ExpectedContext(code = "$", getCurrentClassCallSitePattern)) } test("groupBy.as") { @@ -2597,6 +2602,23 @@ class DatasetSuite extends QueryTest parameters = Map("cls" -> classOf[Array[Int]].getName)) } } + + test("SPARK-45022: exact DatasetQueryContext call site") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + val df = Seq(1).toDS + var callSitePattern: String = null + checkError( + exception = intercept[AnalysisException] { + callSitePattern = getNextLineCallSitePattern() + val c = col("a") + df.select(c) + }, + errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + sqlState = "42703", + parameters = Map("objectName" -> "`a`", "proposal" -> "`value`"), + context = ExpectedContext(code = "col", callSitePattern = callSitePattern)) + } + } } class DatasetLargeResultCollectingSuite extends QueryTest diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala index 4117ea63bdd8..d794dd3d0d3d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala @@ -373,7 +373,8 @@ class DatasetUnpivotSuite extends QueryTest errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`1`", - "proposal" -> "`id`, `int1`, `str1`, `long1`, `str2`")) + "proposal" -> "`id`, `int1`, `str1`, `long1`, `str2`"), + context = ExpectedContext(code = "$", callSitePattern = getCurrentClassCallSitePattern)) // unpivoting where value column does not exist val e2 = intercept[AnalysisException] { @@ -389,7 +390,8 @@ class DatasetUnpivotSuite extends QueryTest errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`does`", - "proposal" -> "`id`, `int1`, `long1`, `str1`, `str2`")) + "proposal" -> "`id`, `int1`, `long1`, `str1`, `str2`"), + context = ExpectedContext(code = "$", callSitePattern = getCurrentClassCallSitePattern)) // unpivoting without values where potential value columns are of incompatible types val e3 = intercept[AnalysisException] { @@ -506,7 +508,8 @@ class DatasetUnpivotSuite extends QueryTest errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`an`.`id`", - "proposal" -> "`an.id`, `int1`, `long1`, `str.one`, `str.two`")) + "proposal" -> "`an.id`, `int1`, `long1`, `str.one`, `str.two`"), + context = ExpectedContext(code = "$", callSitePattern = getCurrentClassCallSitePattern)) } test("unpivot with struct fields") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index 68f63feb5c51..baaf967ed5fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -293,7 +293,8 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"array()\"", "inputType" -> "\"ARRAY\"", - "requiredType" -> "\"ARRAY\"") + "requiredType" -> "\"ARRAY\""), + context = ExpectedContext(code = "inline", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -331,7 +332,8 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"array(struct(a), struct(b))\"", "functionName" -> "`array`", - "dataType" -> "(\"STRUCT\" or \"STRUCT\")")) + "dataType" -> "(\"STRUCT\" or \"STRUCT\")"), + context = ExpectedContext(code = "array", callSitePattern = getCurrentClassCallSitePattern)) checkAnswer( df.select(inline(array(struct('a), struct('b.alias("a"))))), @@ -346,7 +348,8 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"array(struct(a), struct(2))\"", "functionName" -> "`array`", - "dataType" -> "(\"STRUCT\" or \"STRUCT\")")) + "dataType" -> "(\"STRUCT\" or \"STRUCT\")"), + context = ExpectedContext(code = "array", callSitePattern = getCurrentClassCallSitePattern)) checkAnswer( df.select(inline(array(struct('a), struct(lit(2).alias("a"))))), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index a76e102fe913..a2fd28bdc6c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -194,7 +194,9 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"json_tuple(a, 1)\"", "funcName" -> "`json_tuple`" - ) + ), + context = + ExpectedContext(code = "json_tuple", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -647,7 +649,9 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { errorClass = "DATATYPE_MISMATCH.INVALID_JSON_MAP_KEY_TYPE", parameters = Map( "schema" -> "\"MAP, STRING>\"", - "sqlExpr" -> "\"entries\"")) + "sqlExpr" -> "\"entries\""), + context = + ExpectedContext(code = "from_json", callSitePattern = getCurrentClassCallSitePattern)) } test("SPARK-24709: infers schemas of json strings and pass them to from_json") { @@ -957,7 +961,8 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { .select(from_json($"json", $"schema", options)).collect() }, errorClass = "INVALID_SCHEMA.NON_STRING_LITERAL", - parameters = Map("inputSchema" -> "\"schema\"") + parameters = Map("inputSchema" -> "\"schema\""), + context = ExpectedContext(code = "from_json", getCurrentClassCallSitePattern) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala index 2a24f0cc3996..f84b188546f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala @@ -600,6 +600,9 @@ class ParametersSuite extends QueryTest with SharedSparkSession { array(str_to_map(Column(Literal("a:1,b:2,c:3"))))))) }, errorClass = "INVALID_SQL_ARG", - parameters = Map("name" -> "m")) + parameters = Map("name" -> "m"), + context = + ExpectedContext(code = "map_from_arrays", callSitePattern = getCurrentClassCallSitePattern) + ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index c2c333a998b4..348514b6ac92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.util.TimeZone +import java.util.regex.Pattern import scala.collection.JavaConverters._ @@ -229,6 +230,17 @@ abstract class QueryTest extends PlanTest { assert(query.queryExecution.executedPlan.missingInput.isEmpty, s"The physical plan has missing inputs:\n${query.queryExecution.executedPlan}") } + + protected def getCurrentClassCallSitePattern: String = { + val cs = Thread.currentThread().getStackTrace()(2) + s"${cs.getClassName}\\..*\\(${cs.getFileName}:\\d+\\)" + } + + protected def getNextLineCallSitePattern(lines: Int = 1): String = { + val cs = Thread.currentThread().getStackTrace()(2) + Pattern.quote( + s"${cs.getClassName}.${cs.getMethodName}(${cs.getFileName}:${cs.getLineNumber + lines})") + } } object QueryTest extends Assertions { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 720b7953a505..59d454799b34 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1920,7 +1920,9 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark dfNoCols.select($"b.*") }, errorClass = "CANNOT_RESOLVE_STAR_EXPAND", - parameters = Map("targetString" -> "`b`", "columns" -> "")) + parameters = Map("targetString" -> "`b`", "columns" -> ""), + context = + ExpectedContext(code = "$", callSitePattern = getCurrentClassCallSitePattern)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 8e9be5dcdced..5223099c024b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -879,7 +879,9 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "funcName" -> s"`$funcName`", "paramName" -> "`format`", - "paramType" -> "\"STRING\"")) + "paramType" -> "\"STRING\""), + context = + ExpectedContext(code = funcName, callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = intercept[AnalysisException] { df2.select(func(col("input"), lit("invalid_format"))).collect() @@ -888,7 +890,9 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "parameter" -> "`format`", "functionName" -> s"`$funcName`", - "invalidFormat" -> "'invalid_format'")) + "invalidFormat" -> "'invalid_format'"), + context = + ExpectedContext(code = funcName, callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = intercept[AnalysisException] { sql(s"select $funcName('a', 'b', 'c')") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala index 3f665d637748..f28aa81e3dbb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala @@ -696,7 +696,9 @@ class QueryCompilationErrorsSuite Seq("""{"a":1}""").toDF("a").select(from_json($"a", IntegerType)).collect() }, errorClass = "DATATYPE_MISMATCH.INVALID_JSON_SCHEMA", - parameters = Map("schema" -> "\"INT\"", "sqlExpr" -> "\"from_json(a)\"")) + parameters = Map("schema" -> "\"INT\"", "sqlExpr" -> "\"from_json(a)\""), + context = + ExpectedContext(code = "from_json", callSitePattern = getCurrentClassCallSitePattern)) } test("WRONG_NUM_ARGS.WITHOUT_SUGGESTION: wrong args of CAST(parameter types contains DataType)") { @@ -767,7 +769,8 @@ class QueryCompilationErrorsSuite }, errorClass = "AMBIGUOUS_REFERENCE_TO_FIELDS", sqlState = "42000", - parameters = Map("field" -> "`firstname`", "count" -> "2") + parameters = Map("field" -> "`firstname`", "count" -> "2"), + context = ExpectedContext(code = "$", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -780,7 +783,9 @@ class QueryCompilationErrorsSuite }, errorClass = "INVALID_EXTRACT_BASE_FIELD_TYPE", sqlState = "42000", - parameters = Map("base" -> "\"firstname\"", "other" -> "\"STRING\"")) + parameters = Map("base" -> "\"firstname\"", "other" -> "\"STRING\""), + context = ExpectedContext(code = "$", callSitePattern = getCurrentClassCallSitePattern) + ) } test("INVALID_EXTRACT_FIELD_TYPE: extract not string literal field") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala index ee28a90aed9a..68b9ec49ced8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.errors import org.apache.spark._ import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, CheckOverflowInTableInsert, ExpressionProxy, Literal, SubExprEvaluationRuntime} +import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.ByteType @@ -53,6 +55,24 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest sqlState = "22012", parameters = Map("config" -> ansiConf), context = ExpectedContext(fragment = "6/0", start = 7, stop = 9)) + + checkError( + exception = intercept[SparkArithmeticException] { + OneRowRelation().select(lit(5) / lit(0)).collect() + }, + errorClass = "DIVIDE_BY_ZERO", + sqlState = "22012", + parameters = Map("config" -> ansiConf), + context = ExpectedContext(code = "div", callSitePattern = getCurrentClassCallSitePattern)) + + checkError( + exception = intercept[SparkArithmeticException] { + OneRowRelation().select(lit(5).divide(lit(0))).collect() + }, + errorClass = "DIVIDE_BY_ZERO", + sqlState = "22012", + parameters = Map("config" -> ansiConf), + context = ExpectedContext(code = "divide", callSitePattern = getCurrentClassCallSitePattern)) } test("INTERVAL_DIVIDED_BY_ZERO: interval divided by zero") { @@ -92,6 +112,19 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest fragment = "CAST('66666666666666.666' AS DECIMAL(8, 1))", start = 7, stop = 49)) + + checkError( + exception = intercept[SparkArithmeticException] { + OneRowRelation().select(lit("66666666666666.666").cast("DECIMAL(8, 1)")).collect() + }, + errorClass = "NUMERIC_VALUE_OUT_OF_RANGE", + sqlState = "22003", + parameters = Map( + "value" -> "66666666666666.666", + "precision" -> "8", + "scale" -> "1", + "config" -> ansiConf), + context = ExpectedContext(code = "cast", callSitePattern = getCurrentClassCallSitePattern)) } test("INVALID_ARRAY_INDEX: get element from array") { @@ -102,6 +135,14 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest errorClass = "INVALID_ARRAY_INDEX", parameters = Map("indexValue" -> "8", "arraySize" -> "5", "ansiConfig" -> ansiConf), context = ExpectedContext(fragment = "array(1, 2, 3, 4, 5)[8]", start = 7, stop = 29)) + + checkError( + exception = intercept[SparkArrayIndexOutOfBoundsException] { + OneRowRelation().select(lit(Array(1, 2, 3, 4, 5))(8)).collect() + }, + errorClass = "INVALID_ARRAY_INDEX", + parameters = Map("indexValue" -> "8", "arraySize" -> "5", "ansiConfig" -> ansiConf), + context = ExpectedContext(code = "apply", callSitePattern = getCurrentClassCallSitePattern)) } test("INVALID_ARRAY_INDEX_IN_ELEMENT_AT: element_at from array") { @@ -115,6 +156,15 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest fragment = "element_at(array(1, 2, 3, 4, 5), 8)", start = 7, stop = 41)) + + checkError( + exception = intercept[SparkArrayIndexOutOfBoundsException] { + OneRowRelation().select(element_at(lit(Array(1, 2, 3, 4, 5)), 8)).collect() + }, + errorClass = "INVALID_ARRAY_INDEX_IN_ELEMENT_AT", + parameters = Map("indexValue" -> "8", "arraySize" -> "5", "ansiConfig" -> ansiConf), + context = + ExpectedContext(code = "element_at", callSitePattern = getCurrentClassCallSitePattern)) } test("INVALID_INDEX_OF_ZERO: element_at from array by index zero") { @@ -129,6 +179,15 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest start = 7, stop = 41) ) + + checkError( + exception = intercept[SparkRuntimeException]( + OneRowRelation().select(element_at(lit(Array(1, 2, 3, 4, 5)), 0)).collect() + ), + errorClass = "INVALID_INDEX_OF_ZERO", + parameters = Map.empty, + context = + ExpectedContext(code = "element_at", callSitePattern = getCurrentClassCallSitePattern)) } test("CAST_INVALID_INPUT: cast string to double") { @@ -146,6 +205,18 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest fragment = "CAST('111111111111xe23' AS DOUBLE)", start = 7, stop = 40)) + + checkError( + exception = intercept[SparkNumberFormatException] { + OneRowRelation().select(lit("111111111111xe23").cast("DOUBLE")).collect() + }, + errorClass = "CAST_INVALID_INPUT", + parameters = Map( + "expression" -> "'111111111111xe23'", + "sourceType" -> "\"STRING\"", + "targetType" -> "\"DOUBLE\"", + "ansiConfig" -> ansiConf), + context = ExpectedContext(code = "cast", callSitePattern = getCurrentClassCallSitePattern)) } test("CANNOT_PARSE_TIMESTAMP: parse string to timestamp") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 75e5d4d452e1..a7cab381c7f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -1206,7 +1206,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { }, errorClass = "DIVIDE_BY_ZERO", parameters = Map("config" -> "\"spark.sql.ansi.enabled\""), - context = new ExpectedContext( + context = ExpectedContext( objectType = "VIEW", objectName = s"$SESSION_CATALOG_NAME.default.v5", fragment = "1/0", @@ -1225,7 +1225,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { }, errorClass = "DIVIDE_BY_ZERO", parameters = Map("config" -> "\"spark.sql.ansi.enabled\""), - context = new ExpectedContext( + context = ExpectedContext( objectType = "VIEW", objectName = s"$SESSION_CATALOG_NAME.default.v1", fragment = "1/0", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala index c782104f4f9b..33968a42bc20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala @@ -244,7 +244,9 @@ class FileMetadataStructSuite extends QueryTest with SharedSparkSession { df.select("name", METADATA_FILE_NAME).collect() }, errorClass = "FIELD_NOT_FOUND", - parameters = Map("fieldName" -> "`file_name`", "fields" -> "`id`, `university`")) + parameters = Map("fieldName" -> "`file_name`", "fields" -> "`id`, `university`"), + context = + ExpectedContext(code = "select", callSitePattern = getCurrentClassCallSitePattern)) } metadataColumnsTest("SPARK-42683: df metadataColumn - schema conflict", @@ -522,14 +524,18 @@ class FileMetadataStructSuite extends QueryTest with SharedSparkSession { df.select("name", "_metadata.file_name").collect() }, errorClass = "FIELD_NOT_FOUND", - parameters = Map("fieldName" -> "`file_name`", "fields" -> "`id`, `university`")) + parameters = Map("fieldName" -> "`file_name`", "fields" -> "`id`, `university`"), + context = + ExpectedContext(code = "select", callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = intercept[AnalysisException] { df.select("name", "_METADATA.file_NAME").collect() }, errorClass = "FIELD_NOT_FOUND", - parameters = Map("fieldName" -> "`file_NAME`", "fields" -> "`id`, `university`")) + parameters = Map("fieldName" -> "`file_NAME`", "fields" -> "`id`, `university`"), + context = + ExpectedContext(code = "select", callSitePattern = getCurrentClassCallSitePattern)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 3bd45ca0dcdb..be47fc1f66d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -2687,7 +2687,9 @@ abstract class CSVSuite readback.filter($"AAA" === 2 && $"bbb" === 3).collect() }, errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", - parameters = Map("objectName" -> "`AAA`", "proposal" -> "`BBB`, `aaa`")) + parameters = Map("objectName" -> "`AAA`", "proposal" -> "`BBB`, `aaa`"), + context = + ExpectedContext(code = "$", callSitePattern = getCurrentClassCallSitePattern)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 11779286ec25..b6d25f6a8188 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -3061,7 +3061,9 @@ abstract class JsonSuite readback.filter($"AAA" === 0 && $"bbb" === 1).collect() }, errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", - parameters = Map("objectName" -> "`AAA`", "proposal" -> "`BBB`, `aaa`")) + parameters = Map("objectName" -> "`AAA`", "proposal" -> "`BBB`, `aaa`"), + context = + ExpectedContext(code = "$", callSitePattern = getCurrentClassCallSitePattern)) // Schema inferring val readback2 = spark.read.json(path.getCanonicalPath) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileMetadataStructRowIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileMetadataStructRowIndexSuite.scala index c10e1799702d..6f26e5eaa4cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileMetadataStructRowIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileMetadataStructRowIndexSuite.scala @@ -133,7 +133,8 @@ class ParquetFileMetadataStructRowIndexSuite extends QueryTest with SharedSparkS parameters = Map( "fieldName" -> "`row_index`", "fields" -> ("`file_path`, `file_name`, `file_size`, " + - "`file_block_start`, `file_block_length`, `file_modification_time`"))) + "`file_block_start`, `file_block_length`, `file_modification_time`")), + context = ExpectedContext(code = "select", getCurrentClassCallSitePattern)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index cf1f4d4d4f28..ac44357501df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -1917,7 +1917,9 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { Seq("xyz").toDF.select("value", "default").write.insertInto("t") }, errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", - parameters = Map("objectName" -> "`default`", "proposal" -> "`value`")) + parameters = Map("objectName" -> "`default`", "proposal" -> "`value`"), + context = + ExpectedContext(code = "select", callSitePattern = getCurrentClassCallSitePattern)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index c97979a57a55..c2e6846b0edc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -711,7 +711,9 @@ class StreamSuite extends StreamTest { "columnName" -> "`rn_col`", "windowSpec" -> ("(PARTITION BY COL1 ORDER BY COL2 ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING " + - "AND CURRENT ROW)"))) + "AND CURRENT ROW)")), + queryContext = Array( + ExpectedContext(code = "withColumn", callSitePattern = getCurrentClassCallSitePattern))) }