diff --git a/common/utils/src/main/java/org/apache/spark/QueryContext.java b/common/utils/src/main/java/org/apache/spark/QueryContext.java index de5b29d02951..45e38c8cfe0f 100644 --- a/common/utils/src/main/java/org/apache/spark/QueryContext.java +++ b/common/utils/src/main/java/org/apache/spark/QueryContext.java @@ -27,6 +27,9 @@ */ @Evolving public interface QueryContext { + // The type of this query context. + QueryContextType contextType(); + // The object type of the query which throws the exception. // If the exception is directly from the main query, it should be an empty string. // Otherwise, it should be the exact object type in upper case. For example, a "VIEW". @@ -45,4 +48,10 @@ public interface QueryContext { // The corresponding fragment of the query which throws the exception. String fragment(); + + // The user code (call site of the API) that caused throwing the exception. + String callSite(); + + // Summary of the exception cause. + String summary(); } diff --git a/common/utils/src/main/java/org/apache/spark/QueryContextType.java b/common/utils/src/main/java/org/apache/spark/QueryContextType.java new file mode 100644 index 000000000000..171833162baf --- /dev/null +++ b/common/utils/src/main/java/org/apache/spark/QueryContextType.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +import org.apache.spark.annotation.Evolving; + +/** + * The type of {@link QueryContext}. + * + * @since 4.0.0 + */ +@Evolving +public enum QueryContextType { + SQL, + DataFrame +} diff --git a/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala b/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala index b312a1a7e227..a44d36ff85b5 100644 --- a/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala +++ b/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala @@ -114,13 +114,19 @@ private[spark] object SparkThrowableHelper { g.writeArrayFieldStart("queryContext") e.getQueryContext.foreach { c => g.writeStartObject() - g.writeStringField("objectType", c.objectType()) - g.writeStringField("objectName", c.objectName()) - val startIndex = c.startIndex() + 1 - if (startIndex > 0) g.writeNumberField("startIndex", startIndex) - val stopIndex = c.stopIndex() + 1 - if (stopIndex > 0) g.writeNumberField("stopIndex", stopIndex) - g.writeStringField("fragment", c.fragment()) + c.contextType() match { + case QueryContextType.SQL => + g.writeStringField("objectType", c.objectType()) + g.writeStringField("objectName", c.objectName()) + val startIndex = c.startIndex() + 1 + if (startIndex > 0) g.writeNumberField("startIndex", startIndex) + val stopIndex = c.stopIndex() + 1 + if (stopIndex > 0) g.writeNumberField("stopIndex", stopIndex) + g.writeStringField("fragment", c.fragment()) + case QueryContextType.DataFrame => + g.writeStringField("fragment", c.fragment()) + g.writeStringField("callSite", c.callSite()) + } g.writeEndObject() } g.writeEndArray() diff --git a/connector/connect/common/src/main/protobuf/spark/connect/base.proto b/connector/connect/common/src/main/protobuf/spark/connect/base.proto index 5b94c6d663cc..27f51551ba92 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto @@ -823,6 +823,13 @@ message FetchErrorDetailsResponse { // QueryContext defines the schema for the query context of a SparkThrowable. // It helps users understand where the error occurs while executing queries. message QueryContext { + // The type of this query context. + enum ContextType { + SQL = 0; + DATAFRAME = 1; + } + ContextType context_type = 10; + // The object type of the query which throws the exception. // If the exception is directly from the main query, it should be an empty string. // Otherwise, it should be the exact object type in upper case. For example, a "VIEW". @@ -841,6 +848,12 @@ message FetchErrorDetailsResponse { // The corresponding fragment of the query which throws the exception. string fragment = 5; + + // The user code (call site of the API) that caused throwing the exception. + string callSite = 6; + + // Summary of the exception cause. + string summary = 7; } // SparkThrowable defines the schema for SparkThrowable exceptions. diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala index b2782442f4a5..3e53722caeb0 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala @@ -28,7 +28,7 @@ import io.grpc.protobuf.StatusProto import org.json4s.DefaultFormats import org.json4s.jackson.JsonMethods -import org.apache.spark.{QueryContext, SparkArithmeticException, SparkArrayIndexOutOfBoundsException, SparkDateTimeException, SparkException, SparkIllegalArgumentException, SparkNumberFormatException, SparkRuntimeException, SparkUnsupportedOperationException, SparkUpgradeException} +import org.apache.spark.{QueryContext, QueryContextType, SparkArithmeticException, SparkArrayIndexOutOfBoundsException, SparkDateTimeException, SparkException, SparkIllegalArgumentException, SparkNumberFormatException, SparkRuntimeException, SparkUnsupportedOperationException, SparkUpgradeException} import org.apache.spark.connect.proto.{FetchErrorDetailsRequest, FetchErrorDetailsResponse, UserContext} import org.apache.spark.connect.proto.SparkConnectServiceGrpc.SparkConnectServiceBlockingStub import org.apache.spark.internal.Logging @@ -324,15 +324,18 @@ private[client] object GrpcExceptionConverter { val queryContext = error.getSparkThrowable.getQueryContextsList.asScala.map { queryCtx => new QueryContext { + override def contextType(): QueryContextType = queryCtx.getContextType match { + case FetchErrorDetailsResponse.QueryContext.ContextType.DATAFRAME => + QueryContextType.DataFrame + case _ => QueryContextType.SQL + } override def objectType(): String = queryCtx.getObjectType - override def objectName(): String = queryCtx.getObjectName - override def startIndex(): Int = queryCtx.getStartIndex - override def stopIndex(): Int = queryCtx.getStopIndex - override def fragment(): String = queryCtx.getFragment + override def callSite(): String = queryCtx.getCallSite + override def summary(): String = queryCtx.getSummary } }.toArray diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index e3792eb0d237..518c0592488f 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -342,7 +342,7 @@ abstract class SparkFunSuite sqlState: Option[String] = None, parameters: Map[String, String] = Map.empty, matchPVals: Boolean = false, - queryContext: Array[QueryContext] = Array.empty): Unit = { + queryContext: Array[ExpectedContext] = Array.empty): Unit = { assert(exception.getErrorClass === errorClass) sqlState.foreach(state => assert(exception.getSqlState === state)) val expectedParameters = exception.getMessageParameters.asScala @@ -364,16 +364,25 @@ abstract class SparkFunSuite val actualQueryContext = exception.getQueryContext() assert(actualQueryContext.length === queryContext.length, "Invalid length of the query context") actualQueryContext.zip(queryContext).foreach { case (actual, expected) => - assert(actual.objectType() === expected.objectType(), - "Invalid objectType of a query context Actual:" + actual.toString) - assert(actual.objectName() === expected.objectName(), - "Invalid objectName of a query context. Actual:" + actual.toString) - assert(actual.startIndex() === expected.startIndex(), - "Invalid startIndex of a query context. Actual:" + actual.toString) - assert(actual.stopIndex() === expected.stopIndex(), - "Invalid stopIndex of a query context. Actual:" + actual.toString) - assert(actual.fragment() === expected.fragment(), - "Invalid fragment of a query context. Actual:" + actual.toString) + assert(actual.contextType() === expected.contextType, + "Invalid contextType of a query context Actual:" + actual.toString) + if (actual.contextType() == QueryContextType.SQL) { + assert(actual.objectType() === expected.objectType, + "Invalid objectType of a query context Actual:" + actual.toString) + assert(actual.objectName() === expected.objectName, + "Invalid objectName of a query context. Actual:" + actual.toString) + assert(actual.startIndex() === expected.startIndex, + "Invalid startIndex of a query context. Actual:" + actual.toString) + assert(actual.stopIndex() === expected.stopIndex, + "Invalid stopIndex of a query context. Actual:" + actual.toString) + assert(actual.fragment() === expected.fragment, + "Invalid fragment of a query context. Actual:" + actual.toString) + } else if (actual.contextType() == QueryContextType.DataFrame) { + assert(actual.fragment() === expected.fragment, + "Invalid code fragment of a query context. Actual:" + actual.toString) + assert(actual.callSite().matches(expected.callSitePattern), + "Invalid callSite of a query context. Actual:" + actual.toString) + } } } @@ -389,21 +398,21 @@ abstract class SparkFunSuite errorClass: String, sqlState: String, parameters: Map[String, String], - context: QueryContext): Unit = + context: ExpectedContext): Unit = checkError(exception, errorClass, Some(sqlState), parameters, false, Array(context)) protected def checkError( exception: SparkThrowable, errorClass: String, parameters: Map[String, String], - context: QueryContext): Unit = + context: ExpectedContext): Unit = checkError(exception, errorClass, None, parameters, false, Array(context)) protected def checkError( exception: SparkThrowable, errorClass: String, sqlState: String, - context: QueryContext): Unit = + context: ExpectedContext): Unit = checkError(exception, errorClass, None, Map.empty, false, Array(context)) protected def checkError( @@ -411,7 +420,7 @@ abstract class SparkFunSuite errorClass: String, sqlState: Option[String], parameters: Map[String, String], - context: QueryContext): Unit = + context: ExpectedContext): Unit = checkError(exception, errorClass, sqlState, parameters, false, Array(context)) @@ -426,7 +435,7 @@ abstract class SparkFunSuite errorClass: String, sqlState: Option[String], parameters: Map[String, String], - context: QueryContext): Unit = + context: ExpectedContext): Unit = checkError(exception, errorClass, sqlState, parameters, matchPVals = true, Array(context)) @@ -453,16 +462,33 @@ abstract class SparkFunSuite parameters = Map("relationName" -> tableName)) case class ExpectedContext( + contextType: QueryContextType, objectType: String, objectName: String, startIndex: Int, stopIndex: Int, - fragment: String) extends QueryContext + fragment: String, + callSitePattern: String + ) object ExpectedContext { def apply(fragment: String, start: Int, stop: Int): ExpectedContext = { ExpectedContext("", "", start, stop, fragment) } + + def apply( + objectType: String, + objectName: String, + startIndex: Int, + stopIndex: Int, + fragment: String): ExpectedContext = { + new ExpectedContext(QueryContextType.SQL, objectType, objectName, startIndex, stopIndex, + fragment, "") + } + + def apply(fragment: String, callSitePattern: String): ExpectedContext = { + new ExpectedContext(QueryContextType.DataFrame, "", "", -1, -1, fragment, callSitePattern) + } } class LogAppender(msg: String = "", maxEvents: Int = 1000) diff --git a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala index 9f32d81f1ae3..0206205c353a 100644 --- a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala @@ -503,11 +503,14 @@ class SparkThrowableSuite extends SparkFunSuite { test("Get message in the specified format") { import ErrorMessageFormat._ class TestQueryContext extends QueryContext { + override val contextType = QueryContextType.SQL override val objectName = "v1" override val objectType = "VIEW" override val startIndex = 2 override val stopIndex = -1 override val fragment = "1 / 0" + override def callSite: String = throw new UnsupportedOperationException + override val summary = "" } val e = new SparkArithmeticException( errorClass = "DIVIDE_BY_ZERO", @@ -577,6 +580,54 @@ class SparkThrowableSuite extends SparkFunSuite { | "message" : "Test message" | } |}""".stripMargin) + + class TestQueryContext2 extends QueryContext { + override val contextType = QueryContextType.DataFrame + override def objectName: String = throw new UnsupportedOperationException + override def objectType: String = throw new UnsupportedOperationException + override def startIndex: Int = throw new UnsupportedOperationException + override def stopIndex: Int = throw new UnsupportedOperationException + override val fragment: String = "div" + override val callSite: String = "SimpleApp$.main(SimpleApp.scala:9)" + override val summary = "" + } + val e4 = new SparkArithmeticException( + errorClass = "DIVIDE_BY_ZERO", + messageParameters = Map("config" -> "CONFIG"), + context = Array(new TestQueryContext2), + summary = "Query summary") + + assert(SparkThrowableHelper.getMessage(e4, PRETTY) === + "[DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 " + + "and return NULL instead. If necessary set CONFIG to \"false\" to bypass this error." + + " SQLSTATE: 22012\nQuery summary") + // scalastyle:off line.size.limit + assert(SparkThrowableHelper.getMessage(e4, MINIMAL) === + """{ + | "errorClass" : "DIVIDE_BY_ZERO", + | "sqlState" : "22012", + | "messageParameters" : { + | "config" : "CONFIG" + | }, + | "queryContext" : [ { + | "fragment" : "div", + | "callSite" : "SimpleApp$.main(SimpleApp.scala:9)" + | } ] + |}""".stripMargin) + assert(SparkThrowableHelper.getMessage(e4, STANDARD) === + """{ + | "errorClass" : "DIVIDE_BY_ZERO", + | "messageTemplate" : "Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set to \"false\" to bypass this error.", + | "sqlState" : "22012", + | "messageParameters" : { + | "config" : "CONFIG" + | }, + | "queryContext" : [ { + | "fragment" : "div", + | "callSite" : "SimpleApp$.main(SimpleApp.scala:9)" + | } ] + |}""".stripMargin) + // scalastyle:on line.size.limit } test("overwrite error classes") { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 10864390e3fc..c0275e162722 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -45,7 +45,14 @@ object MimaExcludes { // [SPARK-45427][CORE] Add RPC SSL settings to SSLOptions and SparkTransportConf ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.SparkTransportConf.fromSparkConf"), // [SPARK-45136][CONNECT] Enhance ClosureCleaner with Ammonite support - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.MethodIdentifier$") + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.MethodIdentifier$"), + // [SPARK-45022][SQL] Provide context for dataset API errors + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.QueryContext.contextType"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.QueryContext.code"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.QueryContext.callSite"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.QueryContext.summary"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.types.Decimal.fromStringANSI$default$3"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.types.Decimal.fromStringANSI") ) // Default exclude rules diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py index 05040d813501..0ea02525f78f 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.py +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -37,7 +37,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xf5\x12\n\x12\x41nalyzePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12\x42\n\x06schema\x18\x04 \x01(\x0b\x32(.spark.connect.AnalyzePlanRequest.SchemaH\x00R\x06schema\x12\x45\n\x07\x65xplain\x18\x05 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.ExplainH\x00R\x07\x65xplain\x12O\n\x0btree_string\x18\x06 \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.TreeStringH\x00R\ntreeString\x12\x46\n\x08is_local\x18\x07 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.IsLocalH\x00R\x07isLocal\x12R\n\x0cis_streaming\x18\x08 \x01(\x0b\x32-.spark.connect.AnalyzePlanRequest.IsStreamingH\x00R\x0bisStreaming\x12O\n\x0binput_files\x18\t \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.InputFilesH\x00R\ninputFiles\x12U\n\rspark_version\x18\n \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SparkVersionH\x00R\x0csparkVersion\x12I\n\tddl_parse\x18\x0b \x01(\x0b\x32*.spark.connect.AnalyzePlanRequest.DDLParseH\x00R\x08\x64\x64lParse\x12X\n\x0esame_semantics\x18\x0c \x01(\x0b\x32/.spark.connect.AnalyzePlanRequest.SameSemanticsH\x00R\rsameSemantics\x12U\n\rsemantic_hash\x18\r \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SemanticHashH\x00R\x0csemanticHash\x12\x45\n\x07persist\x18\x0e \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.PersistH\x00R\x07persist\x12K\n\tunpersist\x18\x0f \x01(\x0b\x32+.spark.connect.AnalyzePlanRequest.UnpersistH\x00R\tunpersist\x12_\n\x11get_storage_level\x18\x10 \x01(\x0b\x32\x31.spark.connect.AnalyzePlanRequest.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x31\n\x06Schema\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\xbb\x02\n\x07\x45xplain\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12X\n\x0c\x65xplain_mode\x18\x02 \x01(\x0e\x32\x35.spark.connect.AnalyzePlanRequest.Explain.ExplainModeR\x0b\x65xplainMode"\xac\x01\n\x0b\x45xplainMode\x12\x1c\n\x18\x45XPLAIN_MODE_UNSPECIFIED\x10\x00\x12\x17\n\x13\x45XPLAIN_MODE_SIMPLE\x10\x01\x12\x19\n\x15\x45XPLAIN_MODE_EXTENDED\x10\x02\x12\x18\n\x14\x45XPLAIN_MODE_CODEGEN\x10\x03\x12\x15\n\x11\x45XPLAIN_MODE_COST\x10\x04\x12\x1a\n\x16\x45XPLAIN_MODE_FORMATTED\x10\x05\x1aZ\n\nTreeString\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12\x19\n\x05level\x18\x02 \x01(\x05H\x00R\x05level\x88\x01\x01\x42\x08\n\x06_level\x1a\x32\n\x07IsLocal\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x36\n\x0bIsStreaming\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x35\n\nInputFiles\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x0e\n\x0cSparkVersion\x1a)\n\x08\x44\x44LParse\x12\x1d\n\nddl_string\x18\x01 \x01(\tR\tddlString\x1ay\n\rSameSemantics\x12\x34\n\x0btarget_plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\ntargetPlan\x12\x32\n\nother_plan\x18\x02 \x01(\x0b\x32\x13.spark.connect.PlanR\totherPlan\x1a\x37\n\x0cSemanticHash\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x97\x01\n\x07Persist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x45\n\rstorage_level\x18\x02 \x01(\x0b\x32\x1b.spark.connect.StorageLevelH\x00R\x0cstorageLevel\x88\x01\x01\x42\x10\n\x0e_storage_level\x1an\n\tUnpersist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x1f\n\x08\x62locking\x18\x02 \x01(\x08H\x00R\x08\x62locking\x88\x01\x01\x42\x0b\n\t_blocking\x1a\x46\n\x0fGetStorageLevel\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relationB\t\n\x07\x61nalyzeB\x0e\n\x0c_client_type"\x99\r\n\x13\x41nalyzePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x43\n\x06schema\x18\x02 \x01(\x0b\x32).spark.connect.AnalyzePlanResponse.SchemaH\x00R\x06schema\x12\x46\n\x07\x65xplain\x18\x03 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.ExplainH\x00R\x07\x65xplain\x12P\n\x0btree_string\x18\x04 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.TreeStringH\x00R\ntreeString\x12G\n\x08is_local\x18\x05 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.IsLocalH\x00R\x07isLocal\x12S\n\x0cis_streaming\x18\x06 \x01(\x0b\x32..spark.connect.AnalyzePlanResponse.IsStreamingH\x00R\x0bisStreaming\x12P\n\x0binput_files\x18\x07 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.InputFilesH\x00R\ninputFiles\x12V\n\rspark_version\x18\x08 \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SparkVersionH\x00R\x0csparkVersion\x12J\n\tddl_parse\x18\t \x01(\x0b\x32+.spark.connect.AnalyzePlanResponse.DDLParseH\x00R\x08\x64\x64lParse\x12Y\n\x0esame_semantics\x18\n \x01(\x0b\x32\x30.spark.connect.AnalyzePlanResponse.SameSemanticsH\x00R\rsameSemantics\x12V\n\rsemantic_hash\x18\x0b \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SemanticHashH\x00R\x0csemanticHash\x12\x46\n\x07persist\x18\x0c \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.PersistH\x00R\x07persist\x12L\n\tunpersist\x18\r \x01(\x0b\x32,.spark.connect.AnalyzePlanResponse.UnpersistH\x00R\tunpersist\x12`\n\x11get_storage_level\x18\x0e \x01(\x0b\x32\x32.spark.connect.AnalyzePlanResponse.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x39\n\x06Schema\x12/\n\x06schema\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1a\x30\n\x07\x45xplain\x12%\n\x0e\x65xplain_string\x18\x01 \x01(\tR\rexplainString\x1a-\n\nTreeString\x12\x1f\n\x0btree_string\x18\x01 \x01(\tR\ntreeString\x1a$\n\x07IsLocal\x12\x19\n\x08is_local\x18\x01 \x01(\x08R\x07isLocal\x1a\x30\n\x0bIsStreaming\x12!\n\x0cis_streaming\x18\x01 \x01(\x08R\x0bisStreaming\x1a"\n\nInputFiles\x12\x14\n\x05\x66iles\x18\x01 \x03(\tR\x05\x66iles\x1a(\n\x0cSparkVersion\x12\x18\n\x07version\x18\x01 \x01(\tR\x07version\x1a;\n\x08\x44\x44LParse\x12/\n\x06parsed\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06parsed\x1a\'\n\rSameSemantics\x12\x16\n\x06result\x18\x01 \x01(\x08R\x06result\x1a&\n\x0cSemanticHash\x12\x16\n\x06result\x18\x01 \x01(\x05R\x06result\x1a\t\n\x07Persist\x1a\x0b\n\tUnpersist\x1aS\n\x0fGetStorageLevel\x12@\n\rstorage_level\x18\x01 \x01(\x0b\x32\x1b.spark.connect.StorageLevelR\x0cstorageLevelB\x08\n\x06result"\xa0\x04\n\x12\x45xecutePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12&\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId\x88\x01\x01\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x12X\n\x0frequest_options\x18\x05 \x03(\x0b\x32/.spark.connect.ExecutePlanRequest.RequestOptionR\x0erequestOptions\x12\x12\n\x04tags\x18\x07 \x03(\tR\x04tags\x1a\xa5\x01\n\rRequestOption\x12K\n\x10reattach_options\x18\x01 \x01(\x0b\x32\x1e.spark.connect.ReattachOptionsH\x00R\x0freattachOptions\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x10\n\x0erequest_optionB\x0f\n\r_operation_idB\x0e\n\x0c_client_type"\xe6\x0f\n\x13\x45xecutePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12!\n\x0coperation_id\x18\x0c \x01(\tR\x0boperationId\x12\x1f\n\x0bresponse_id\x18\r \x01(\tR\nresponseId\x12P\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32-.spark.connect.ExecutePlanResponse.ArrowBatchH\x00R\narrowBatch\x12\x63\n\x12sql_command_result\x18\x05 \x01(\x0b\x32\x33.spark.connect.ExecutePlanResponse.SqlCommandResultH\x00R\x10sqlCommandResult\x12~\n#write_stream_operation_start_result\x18\x08 \x01(\x0b\x32..spark.connect.WriteStreamOperationStartResultH\x00R\x1fwriteStreamOperationStartResult\x12q\n\x1estreaming_query_command_result\x18\t \x01(\x0b\x32*.spark.connect.StreamingQueryCommandResultH\x00R\x1bstreamingQueryCommandResult\x12k\n\x1cget_resources_command_result\x18\n \x01(\x0b\x32(.spark.connect.GetResourcesCommandResultH\x00R\x19getResourcesCommandResult\x12\x87\x01\n&streaming_query_manager_command_result\x18\x0b \x01(\x0b\x32\x31.spark.connect.StreamingQueryManagerCommandResultH\x00R"streamingQueryManagerCommandResult\x12\\\n\x0fresult_complete\x18\x0e \x01(\x0b\x32\x31.spark.connect.ExecutePlanResponse.ResultCompleteH\x00R\x0eresultComplete\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x44\n\x07metrics\x18\x04 \x01(\x0b\x32*.spark.connect.ExecutePlanResponse.MetricsR\x07metrics\x12]\n\x10observed_metrics\x18\x06 \x03(\x0b\x32\x32.spark.connect.ExecutePlanResponse.ObservedMetricsR\x0fobservedMetrics\x12/\n\x06schema\x18\x07 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1aG\n\x10SqlCommandResult\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x1av\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x12&\n\x0cstart_offset\x18\x03 \x01(\x03H\x00R\x0bstartOffset\x88\x01\x01\x42\x0f\n\r_start_offset\x1a\x85\x04\n\x07Metrics\x12Q\n\x07metrics\x18\x01 \x03(\x0b\x32\x37.spark.connect.ExecutePlanResponse.Metrics.MetricObjectR\x07metrics\x1a\xcc\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12z\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32M.spark.connect.ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1a{\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ExecutePlanResponse.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricType\x1at\n\x0fObservedMetrics\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x12\x12\n\x04keys\x18\x03 \x03(\tR\x04keys\x1a\x10\n\x0eResultCompleteB\x0f\n\rresponse_type"A\n\x08KeyValue\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x19\n\x05value\x18\x02 \x01(\tH\x00R\x05value\x88\x01\x01\x42\x08\n\x06_value"\x84\x08\n\rConfigRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x44\n\toperation\x18\x03 \x01(\x0b\x32&.spark.connect.ConfigRequest.OperationR\toperation\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x1a\xf2\x03\n\tOperation\x12\x34\n\x03set\x18\x01 \x01(\x0b\x32 .spark.connect.ConfigRequest.SetH\x00R\x03set\x12\x34\n\x03get\x18\x02 \x01(\x0b\x32 .spark.connect.ConfigRequest.GetH\x00R\x03get\x12W\n\x10get_with_default\x18\x03 \x01(\x0b\x32+.spark.connect.ConfigRequest.GetWithDefaultH\x00R\x0egetWithDefault\x12G\n\nget_option\x18\x04 \x01(\x0b\x32&.spark.connect.ConfigRequest.GetOptionH\x00R\tgetOption\x12>\n\x07get_all\x18\x05 \x01(\x0b\x32#.spark.connect.ConfigRequest.GetAllH\x00R\x06getAll\x12:\n\x05unset\x18\x06 \x01(\x0b\x32".spark.connect.ConfigRequest.UnsetH\x00R\x05unset\x12P\n\ris_modifiable\x18\x07 \x01(\x0b\x32).spark.connect.ConfigRequest.IsModifiableH\x00R\x0cisModifiableB\t\n\x07op_type\x1a\x34\n\x03Set\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x19\n\x03Get\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a?\n\x0eGetWithDefault\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x1f\n\tGetOption\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a\x30\n\x06GetAll\x12\x1b\n\x06prefix\x18\x01 \x01(\tH\x00R\x06prefix\x88\x01\x01\x42\t\n\x07_prefix\x1a\x1b\n\x05Unset\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a"\n\x0cIsModifiable\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keysB\x0e\n\x0c_client_type"z\n\x0e\x43onfigResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12-\n\x05pairs\x18\x02 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x12\x1a\n\x08warnings\x18\x03 \x03(\tR\x08warnings"\xe7\x06\n\x13\x41\x64\x64\x41rtifactsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x06 \x01(\tH\x01R\nclientType\x88\x01\x01\x12@\n\x05\x62\x61tch\x18\x03 \x01(\x0b\x32(.spark.connect.AddArtifactsRequest.BatchH\x00R\x05\x62\x61tch\x12Z\n\x0b\x62\x65gin_chunk\x18\x04 \x01(\x0b\x32\x37.spark.connect.AddArtifactsRequest.BeginChunkedArtifactH\x00R\nbeginChunk\x12H\n\x05\x63hunk\x18\x05 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkH\x00R\x05\x63hunk\x1a\x35\n\rArtifactChunk\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta\x12\x10\n\x03\x63rc\x18\x02 \x01(\x03R\x03\x63rc\x1ao\n\x13SingleChunkArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x44\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x04\x64\x61ta\x1a]\n\x05\x42\x61tch\x12T\n\tartifacts\x18\x01 \x03(\x0b\x32\x36.spark.connect.AddArtifactsRequest.SingleChunkArtifactR\tartifacts\x1a\xc1\x01\n\x14\x42\x65ginChunkedArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1f\n\x0btotal_bytes\x18\x02 \x01(\x03R\ntotalBytes\x12\x1d\n\nnum_chunks\x18\x03 \x01(\x03R\tnumChunks\x12U\n\rinitial_chunk\x18\x04 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x0cinitialChunkB\t\n\x07payloadB\x0e\n\x0c_client_type"\xbc\x01\n\x14\x41\x64\x64\x41rtifactsResponse\x12Q\n\tartifacts\x18\x01 \x03(\x0b\x32\x33.spark.connect.AddArtifactsResponse.ArtifactSummaryR\tartifacts\x1aQ\n\x0f\x41rtifactSummary\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12*\n\x11is_crc_successful\x18\x02 \x01(\x08R\x0fisCrcSuccessful"\xc3\x01\n\x17\x41rtifactStatusesRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x00R\nclientType\x88\x01\x01\x12\x14\n\x05names\x18\x04 \x03(\tR\x05namesB\x0e\n\x0c_client_type"\x8c\x02\n\x18\x41rtifactStatusesResponse\x12Q\n\x08statuses\x18\x01 \x03(\x0b\x32\x35.spark.connect.ArtifactStatusesResponse.StatusesEntryR\x08statuses\x1a(\n\x0e\x41rtifactStatus\x12\x16\n\x06\x65xists\x18\x01 \x01(\x08R\x06\x65xists\x1as\n\rStatusesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ArtifactStatusesResponse.ArtifactStatusR\x05value:\x02\x38\x01"\xd8\x03\n\x10InterruptRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12T\n\x0einterrupt_type\x18\x04 \x01(\x0e\x32-.spark.connect.InterruptRequest.InterruptTypeR\rinterruptType\x12%\n\roperation_tag\x18\x05 \x01(\tH\x00R\x0coperationTag\x12#\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId"\x80\x01\n\rInterruptType\x12\x1e\n\x1aINTERRUPT_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12INTERRUPT_TYPE_ALL\x10\x01\x12\x16\n\x12INTERRUPT_TYPE_TAG\x10\x02\x12\x1f\n\x1bINTERRUPT_TYPE_OPERATION_ID\x10\x03\x42\x0b\n\tinterruptB\x0e\n\x0c_client_type"[\n\x11InterruptResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\'\n\x0finterrupted_ids\x18\x02 \x03(\tR\x0einterruptedIds"5\n\x0fReattachOptions\x12"\n\x0creattachable\x18\x01 \x01(\x08R\x0creattachable"\x93\x02\n\x16ReattachExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x12-\n\x10last_response_id\x18\x05 \x01(\tH\x01R\x0elastResponseId\x88\x01\x01\x42\x0e\n\x0c_client_typeB\x13\n\x11_last_response_id"\xc6\x03\n\x15ReleaseExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x12R\n\x0brelease_all\x18\x05 \x01(\x0b\x32/.spark.connect.ReleaseExecuteRequest.ReleaseAllH\x00R\nreleaseAll\x12X\n\rrelease_until\x18\x06 \x01(\x0b\x32\x31.spark.connect.ReleaseExecuteRequest.ReleaseUntilH\x00R\x0creleaseUntil\x1a\x0c\n\nReleaseAll\x1a/\n\x0cReleaseUntil\x12\x1f\n\x0bresponse_id\x18\x01 \x01(\tR\nresponseIdB\t\n\x07releaseB\x0e\n\x0c_client_type"p\n\x16ReleaseExecuteResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12&\n\x0coperation_id\x18\x02 \x01(\tH\x00R\x0boperationId\x88\x01\x01\x42\x0f\n\r_operation_id"\xc9\x01\n\x18\x46\x65tchErrorDetailsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x19\n\x08\x65rror_id\x18\x03 \x01(\tR\x07\x65rrorId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"\xfb\t\n\x19\x46\x65tchErrorDetailsResponse\x12)\n\x0eroot_error_idx\x18\x01 \x01(\x05H\x00R\x0crootErrorIdx\x88\x01\x01\x12\x46\n\x06\x65rrors\x18\x02 \x03(\x0b\x32..spark.connect.FetchErrorDetailsResponse.ErrorR\x06\x65rrors\x1a\xae\x01\n\x11StackTraceElement\x12\'\n\x0f\x64\x65\x63laring_class\x18\x01 \x01(\tR\x0e\x64\x65\x63laringClass\x12\x1f\n\x0bmethod_name\x18\x02 \x01(\tR\nmethodName\x12 \n\tfile_name\x18\x03 \x01(\tH\x00R\x08\x66ileName\x88\x01\x01\x12\x1f\n\x0bline_number\x18\x04 \x01(\x05R\nlineNumberB\x0c\n\n_file_name\x1a\xac\x01\n\x0cQueryContext\x12\x1f\n\x0bobject_type\x18\x01 \x01(\tR\nobjectType\x12\x1f\n\x0bobject_name\x18\x02 \x01(\tR\nobjectName\x12\x1f\n\x0bstart_index\x18\x03 \x01(\x05R\nstartIndex\x12\x1d\n\nstop_index\x18\x04 \x01(\x05R\tstopIndex\x12\x1a\n\x08\x66ragment\x18\x05 \x01(\tR\x08\x66ragment\x1a\x99\x03\n\x0eSparkThrowable\x12$\n\x0b\x65rror_class\x18\x01 \x01(\tH\x00R\nerrorClass\x88\x01\x01\x12}\n\x12message_parameters\x18\x02 \x03(\x0b\x32N.spark.connect.FetchErrorDetailsResponse.SparkThrowable.MessageParametersEntryR\x11messageParameters\x12\\\n\x0equery_contexts\x18\x03 \x03(\x0b\x32\x35.spark.connect.FetchErrorDetailsResponse.QueryContextR\rqueryContexts\x12 \n\tsql_state\x18\x04 \x01(\tH\x01R\x08sqlState\x88\x01\x01\x1a\x44\n\x16MessageParametersEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x0e\n\x0c_error_classB\x0c\n\n_sql_state\x1a\xdb\x02\n\x05\x45rror\x12\x30\n\x14\x65rror_type_hierarchy\x18\x01 \x03(\tR\x12\x65rrorTypeHierarchy\x12\x18\n\x07message\x18\x02 \x01(\tR\x07message\x12[\n\x0bstack_trace\x18\x03 \x03(\x0b\x32:.spark.connect.FetchErrorDetailsResponse.StackTraceElementR\nstackTrace\x12 \n\tcause_idx\x18\x04 \x01(\x05H\x00R\x08\x63\x61useIdx\x88\x01\x01\x12\x65\n\x0fspark_throwable\x18\x05 \x01(\x0b\x32\x37.spark.connect.FetchErrorDetailsResponse.SparkThrowableH\x01R\x0esparkThrowable\x88\x01\x01\x42\x0c\n\n_cause_idxB\x12\n\x10_spark_throwableB\x11\n\x0f_root_error_idx2\xd1\x06\n\x13SparkConnectService\x12X\n\x0b\x45xecutePlan\x12!.spark.connect.ExecutePlanRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12V\n\x0b\x41nalyzePlan\x12!.spark.connect.AnalyzePlanRequest\x1a".spark.connect.AnalyzePlanResponse"\x00\x12G\n\x06\x43onfig\x12\x1c.spark.connect.ConfigRequest\x1a\x1d.spark.connect.ConfigResponse"\x00\x12[\n\x0c\x41\x64\x64\x41rtifacts\x12".spark.connect.AddArtifactsRequest\x1a#.spark.connect.AddArtifactsResponse"\x00(\x01\x12\x63\n\x0e\x41rtifactStatus\x12&.spark.connect.ArtifactStatusesRequest\x1a\'.spark.connect.ArtifactStatusesResponse"\x00\x12P\n\tInterrupt\x12\x1f.spark.connect.InterruptRequest\x1a .spark.connect.InterruptResponse"\x00\x12`\n\x0fReattachExecute\x12%.spark.connect.ReattachExecuteRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12_\n\x0eReleaseExecute\x12$.spark.connect.ReleaseExecuteRequest\x1a%.spark.connect.ReleaseExecuteResponse"\x00\x12h\n\x11\x46\x65tchErrorDetails\x12\'.spark.connect.FetchErrorDetailsRequest\x1a(.spark.connect.FetchErrorDetailsResponse"\x00\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xf5\x12\n\x12\x41nalyzePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12\x42\n\x06schema\x18\x04 \x01(\x0b\x32(.spark.connect.AnalyzePlanRequest.SchemaH\x00R\x06schema\x12\x45\n\x07\x65xplain\x18\x05 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.ExplainH\x00R\x07\x65xplain\x12O\n\x0btree_string\x18\x06 \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.TreeStringH\x00R\ntreeString\x12\x46\n\x08is_local\x18\x07 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.IsLocalH\x00R\x07isLocal\x12R\n\x0cis_streaming\x18\x08 \x01(\x0b\x32-.spark.connect.AnalyzePlanRequest.IsStreamingH\x00R\x0bisStreaming\x12O\n\x0binput_files\x18\t \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.InputFilesH\x00R\ninputFiles\x12U\n\rspark_version\x18\n \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SparkVersionH\x00R\x0csparkVersion\x12I\n\tddl_parse\x18\x0b \x01(\x0b\x32*.spark.connect.AnalyzePlanRequest.DDLParseH\x00R\x08\x64\x64lParse\x12X\n\x0esame_semantics\x18\x0c \x01(\x0b\x32/.spark.connect.AnalyzePlanRequest.SameSemanticsH\x00R\rsameSemantics\x12U\n\rsemantic_hash\x18\r \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SemanticHashH\x00R\x0csemanticHash\x12\x45\n\x07persist\x18\x0e \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.PersistH\x00R\x07persist\x12K\n\tunpersist\x18\x0f \x01(\x0b\x32+.spark.connect.AnalyzePlanRequest.UnpersistH\x00R\tunpersist\x12_\n\x11get_storage_level\x18\x10 \x01(\x0b\x32\x31.spark.connect.AnalyzePlanRequest.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x31\n\x06Schema\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\xbb\x02\n\x07\x45xplain\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12X\n\x0c\x65xplain_mode\x18\x02 \x01(\x0e\x32\x35.spark.connect.AnalyzePlanRequest.Explain.ExplainModeR\x0b\x65xplainMode"\xac\x01\n\x0b\x45xplainMode\x12\x1c\n\x18\x45XPLAIN_MODE_UNSPECIFIED\x10\x00\x12\x17\n\x13\x45XPLAIN_MODE_SIMPLE\x10\x01\x12\x19\n\x15\x45XPLAIN_MODE_EXTENDED\x10\x02\x12\x18\n\x14\x45XPLAIN_MODE_CODEGEN\x10\x03\x12\x15\n\x11\x45XPLAIN_MODE_COST\x10\x04\x12\x1a\n\x16\x45XPLAIN_MODE_FORMATTED\x10\x05\x1aZ\n\nTreeString\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12\x19\n\x05level\x18\x02 \x01(\x05H\x00R\x05level\x88\x01\x01\x42\x08\n\x06_level\x1a\x32\n\x07IsLocal\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x36\n\x0bIsStreaming\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x35\n\nInputFiles\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x0e\n\x0cSparkVersion\x1a)\n\x08\x44\x44LParse\x12\x1d\n\nddl_string\x18\x01 \x01(\tR\tddlString\x1ay\n\rSameSemantics\x12\x34\n\x0btarget_plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\ntargetPlan\x12\x32\n\nother_plan\x18\x02 \x01(\x0b\x32\x13.spark.connect.PlanR\totherPlan\x1a\x37\n\x0cSemanticHash\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x97\x01\n\x07Persist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x45\n\rstorage_level\x18\x02 \x01(\x0b\x32\x1b.spark.connect.StorageLevelH\x00R\x0cstorageLevel\x88\x01\x01\x42\x10\n\x0e_storage_level\x1an\n\tUnpersist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x1f\n\x08\x62locking\x18\x02 \x01(\x08H\x00R\x08\x62locking\x88\x01\x01\x42\x0b\n\t_blocking\x1a\x46\n\x0fGetStorageLevel\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relationB\t\n\x07\x61nalyzeB\x0e\n\x0c_client_type"\x99\r\n\x13\x41nalyzePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x43\n\x06schema\x18\x02 \x01(\x0b\x32).spark.connect.AnalyzePlanResponse.SchemaH\x00R\x06schema\x12\x46\n\x07\x65xplain\x18\x03 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.ExplainH\x00R\x07\x65xplain\x12P\n\x0btree_string\x18\x04 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.TreeStringH\x00R\ntreeString\x12G\n\x08is_local\x18\x05 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.IsLocalH\x00R\x07isLocal\x12S\n\x0cis_streaming\x18\x06 \x01(\x0b\x32..spark.connect.AnalyzePlanResponse.IsStreamingH\x00R\x0bisStreaming\x12P\n\x0binput_files\x18\x07 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.InputFilesH\x00R\ninputFiles\x12V\n\rspark_version\x18\x08 \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SparkVersionH\x00R\x0csparkVersion\x12J\n\tddl_parse\x18\t \x01(\x0b\x32+.spark.connect.AnalyzePlanResponse.DDLParseH\x00R\x08\x64\x64lParse\x12Y\n\x0esame_semantics\x18\n \x01(\x0b\x32\x30.spark.connect.AnalyzePlanResponse.SameSemanticsH\x00R\rsameSemantics\x12V\n\rsemantic_hash\x18\x0b \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SemanticHashH\x00R\x0csemanticHash\x12\x46\n\x07persist\x18\x0c \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.PersistH\x00R\x07persist\x12L\n\tunpersist\x18\r \x01(\x0b\x32,.spark.connect.AnalyzePlanResponse.UnpersistH\x00R\tunpersist\x12`\n\x11get_storage_level\x18\x0e \x01(\x0b\x32\x32.spark.connect.AnalyzePlanResponse.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x39\n\x06Schema\x12/\n\x06schema\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1a\x30\n\x07\x45xplain\x12%\n\x0e\x65xplain_string\x18\x01 \x01(\tR\rexplainString\x1a-\n\nTreeString\x12\x1f\n\x0btree_string\x18\x01 \x01(\tR\ntreeString\x1a$\n\x07IsLocal\x12\x19\n\x08is_local\x18\x01 \x01(\x08R\x07isLocal\x1a\x30\n\x0bIsStreaming\x12!\n\x0cis_streaming\x18\x01 \x01(\x08R\x0bisStreaming\x1a"\n\nInputFiles\x12\x14\n\x05\x66iles\x18\x01 \x03(\tR\x05\x66iles\x1a(\n\x0cSparkVersion\x12\x18\n\x07version\x18\x01 \x01(\tR\x07version\x1a;\n\x08\x44\x44LParse\x12/\n\x06parsed\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06parsed\x1a\'\n\rSameSemantics\x12\x16\n\x06result\x18\x01 \x01(\x08R\x06result\x1a&\n\x0cSemanticHash\x12\x16\n\x06result\x18\x01 \x01(\x05R\x06result\x1a\t\n\x07Persist\x1a\x0b\n\tUnpersist\x1aS\n\x0fGetStorageLevel\x12@\n\rstorage_level\x18\x01 \x01(\x0b\x32\x1b.spark.connect.StorageLevelR\x0cstorageLevelB\x08\n\x06result"\xa0\x04\n\x12\x45xecutePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12&\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId\x88\x01\x01\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x12X\n\x0frequest_options\x18\x05 \x03(\x0b\x32/.spark.connect.ExecutePlanRequest.RequestOptionR\x0erequestOptions\x12\x12\n\x04tags\x18\x07 \x03(\tR\x04tags\x1a\xa5\x01\n\rRequestOption\x12K\n\x10reattach_options\x18\x01 \x01(\x0b\x32\x1e.spark.connect.ReattachOptionsH\x00R\x0freattachOptions\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x10\n\x0erequest_optionB\x0f\n\r_operation_idB\x0e\n\x0c_client_type"\xe6\x0f\n\x13\x45xecutePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12!\n\x0coperation_id\x18\x0c \x01(\tR\x0boperationId\x12\x1f\n\x0bresponse_id\x18\r \x01(\tR\nresponseId\x12P\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32-.spark.connect.ExecutePlanResponse.ArrowBatchH\x00R\narrowBatch\x12\x63\n\x12sql_command_result\x18\x05 \x01(\x0b\x32\x33.spark.connect.ExecutePlanResponse.SqlCommandResultH\x00R\x10sqlCommandResult\x12~\n#write_stream_operation_start_result\x18\x08 \x01(\x0b\x32..spark.connect.WriteStreamOperationStartResultH\x00R\x1fwriteStreamOperationStartResult\x12q\n\x1estreaming_query_command_result\x18\t \x01(\x0b\x32*.spark.connect.StreamingQueryCommandResultH\x00R\x1bstreamingQueryCommandResult\x12k\n\x1cget_resources_command_result\x18\n \x01(\x0b\x32(.spark.connect.GetResourcesCommandResultH\x00R\x19getResourcesCommandResult\x12\x87\x01\n&streaming_query_manager_command_result\x18\x0b \x01(\x0b\x32\x31.spark.connect.StreamingQueryManagerCommandResultH\x00R"streamingQueryManagerCommandResult\x12\\\n\x0fresult_complete\x18\x0e \x01(\x0b\x32\x31.spark.connect.ExecutePlanResponse.ResultCompleteH\x00R\x0eresultComplete\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x44\n\x07metrics\x18\x04 \x01(\x0b\x32*.spark.connect.ExecutePlanResponse.MetricsR\x07metrics\x12]\n\x10observed_metrics\x18\x06 \x03(\x0b\x32\x32.spark.connect.ExecutePlanResponse.ObservedMetricsR\x0fobservedMetrics\x12/\n\x06schema\x18\x07 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1aG\n\x10SqlCommandResult\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x1av\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x12&\n\x0cstart_offset\x18\x03 \x01(\x03H\x00R\x0bstartOffset\x88\x01\x01\x42\x0f\n\r_start_offset\x1a\x85\x04\n\x07Metrics\x12Q\n\x07metrics\x18\x01 \x03(\x0b\x32\x37.spark.connect.ExecutePlanResponse.Metrics.MetricObjectR\x07metrics\x1a\xcc\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12z\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32M.spark.connect.ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1a{\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ExecutePlanResponse.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricType\x1at\n\x0fObservedMetrics\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x12\x12\n\x04keys\x18\x03 \x03(\tR\x04keys\x1a\x10\n\x0eResultCompleteB\x0f\n\rresponse_type"A\n\x08KeyValue\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x19\n\x05value\x18\x02 \x01(\tH\x00R\x05value\x88\x01\x01\x42\x08\n\x06_value"\x84\x08\n\rConfigRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x44\n\toperation\x18\x03 \x01(\x0b\x32&.spark.connect.ConfigRequest.OperationR\toperation\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x1a\xf2\x03\n\tOperation\x12\x34\n\x03set\x18\x01 \x01(\x0b\x32 .spark.connect.ConfigRequest.SetH\x00R\x03set\x12\x34\n\x03get\x18\x02 \x01(\x0b\x32 .spark.connect.ConfigRequest.GetH\x00R\x03get\x12W\n\x10get_with_default\x18\x03 \x01(\x0b\x32+.spark.connect.ConfigRequest.GetWithDefaultH\x00R\x0egetWithDefault\x12G\n\nget_option\x18\x04 \x01(\x0b\x32&.spark.connect.ConfigRequest.GetOptionH\x00R\tgetOption\x12>\n\x07get_all\x18\x05 \x01(\x0b\x32#.spark.connect.ConfigRequest.GetAllH\x00R\x06getAll\x12:\n\x05unset\x18\x06 \x01(\x0b\x32".spark.connect.ConfigRequest.UnsetH\x00R\x05unset\x12P\n\ris_modifiable\x18\x07 \x01(\x0b\x32).spark.connect.ConfigRequest.IsModifiableH\x00R\x0cisModifiableB\t\n\x07op_type\x1a\x34\n\x03Set\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x19\n\x03Get\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a?\n\x0eGetWithDefault\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x1f\n\tGetOption\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a\x30\n\x06GetAll\x12\x1b\n\x06prefix\x18\x01 \x01(\tH\x00R\x06prefix\x88\x01\x01\x42\t\n\x07_prefix\x1a\x1b\n\x05Unset\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a"\n\x0cIsModifiable\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keysB\x0e\n\x0c_client_type"z\n\x0e\x43onfigResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12-\n\x05pairs\x18\x02 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x12\x1a\n\x08warnings\x18\x03 \x03(\tR\x08warnings"\xe7\x06\n\x13\x41\x64\x64\x41rtifactsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x06 \x01(\tH\x01R\nclientType\x88\x01\x01\x12@\n\x05\x62\x61tch\x18\x03 \x01(\x0b\x32(.spark.connect.AddArtifactsRequest.BatchH\x00R\x05\x62\x61tch\x12Z\n\x0b\x62\x65gin_chunk\x18\x04 \x01(\x0b\x32\x37.spark.connect.AddArtifactsRequest.BeginChunkedArtifactH\x00R\nbeginChunk\x12H\n\x05\x63hunk\x18\x05 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkH\x00R\x05\x63hunk\x1a\x35\n\rArtifactChunk\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta\x12\x10\n\x03\x63rc\x18\x02 \x01(\x03R\x03\x63rc\x1ao\n\x13SingleChunkArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x44\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x04\x64\x61ta\x1a]\n\x05\x42\x61tch\x12T\n\tartifacts\x18\x01 \x03(\x0b\x32\x36.spark.connect.AddArtifactsRequest.SingleChunkArtifactR\tartifacts\x1a\xc1\x01\n\x14\x42\x65ginChunkedArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1f\n\x0btotal_bytes\x18\x02 \x01(\x03R\ntotalBytes\x12\x1d\n\nnum_chunks\x18\x03 \x01(\x03R\tnumChunks\x12U\n\rinitial_chunk\x18\x04 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x0cinitialChunkB\t\n\x07payloadB\x0e\n\x0c_client_type"\xbc\x01\n\x14\x41\x64\x64\x41rtifactsResponse\x12Q\n\tartifacts\x18\x01 \x03(\x0b\x32\x33.spark.connect.AddArtifactsResponse.ArtifactSummaryR\tartifacts\x1aQ\n\x0f\x41rtifactSummary\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12*\n\x11is_crc_successful\x18\x02 \x01(\x08R\x0fisCrcSuccessful"\xc3\x01\n\x17\x41rtifactStatusesRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x00R\nclientType\x88\x01\x01\x12\x14\n\x05names\x18\x04 \x03(\tR\x05namesB\x0e\n\x0c_client_type"\x8c\x02\n\x18\x41rtifactStatusesResponse\x12Q\n\x08statuses\x18\x01 \x03(\x0b\x32\x35.spark.connect.ArtifactStatusesResponse.StatusesEntryR\x08statuses\x1a(\n\x0e\x41rtifactStatus\x12\x16\n\x06\x65xists\x18\x01 \x01(\x08R\x06\x65xists\x1as\n\rStatusesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ArtifactStatusesResponse.ArtifactStatusR\x05value:\x02\x38\x01"\xd8\x03\n\x10InterruptRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12T\n\x0einterrupt_type\x18\x04 \x01(\x0e\x32-.spark.connect.InterruptRequest.InterruptTypeR\rinterruptType\x12%\n\roperation_tag\x18\x05 \x01(\tH\x00R\x0coperationTag\x12#\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId"\x80\x01\n\rInterruptType\x12\x1e\n\x1aINTERRUPT_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12INTERRUPT_TYPE_ALL\x10\x01\x12\x16\n\x12INTERRUPT_TYPE_TAG\x10\x02\x12\x1f\n\x1bINTERRUPT_TYPE_OPERATION_ID\x10\x03\x42\x0b\n\tinterruptB\x0e\n\x0c_client_type"[\n\x11InterruptResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\'\n\x0finterrupted_ids\x18\x02 \x03(\tR\x0einterruptedIds"5\n\x0fReattachOptions\x12"\n\x0creattachable\x18\x01 \x01(\x08R\x0creattachable"\x93\x02\n\x16ReattachExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x12-\n\x10last_response_id\x18\x05 \x01(\tH\x01R\x0elastResponseId\x88\x01\x01\x42\x0e\n\x0c_client_typeB\x13\n\x11_last_response_id"\xc6\x03\n\x15ReleaseExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x12R\n\x0brelease_all\x18\x05 \x01(\x0b\x32/.spark.connect.ReleaseExecuteRequest.ReleaseAllH\x00R\nreleaseAll\x12X\n\rrelease_until\x18\x06 \x01(\x0b\x32\x31.spark.connect.ReleaseExecuteRequest.ReleaseUntilH\x00R\x0creleaseUntil\x1a\x0c\n\nReleaseAll\x1a/\n\x0cReleaseUntil\x12\x1f\n\x0bresponse_id\x18\x01 \x01(\tR\nresponseIdB\t\n\x07releaseB\x0e\n\x0c_client_type"p\n\x16ReleaseExecuteResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12&\n\x0coperation_id\x18\x02 \x01(\tH\x00R\x0boperationId\x88\x01\x01\x42\x0f\n\r_operation_id"\xc9\x01\n\x18\x46\x65tchErrorDetailsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x19\n\x08\x65rror_id\x18\x03 \x01(\tR\x07\x65rrorId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"\xbe\x0b\n\x19\x46\x65tchErrorDetailsResponse\x12)\n\x0eroot_error_idx\x18\x01 \x01(\x05H\x00R\x0crootErrorIdx\x88\x01\x01\x12\x46\n\x06\x65rrors\x18\x02 \x03(\x0b\x32..spark.connect.FetchErrorDetailsResponse.ErrorR\x06\x65rrors\x1a\xae\x01\n\x11StackTraceElement\x12\'\n\x0f\x64\x65\x63laring_class\x18\x01 \x01(\tR\x0e\x64\x65\x63laringClass\x12\x1f\n\x0bmethod_name\x18\x02 \x01(\tR\nmethodName\x12 \n\tfile_name\x18\x03 \x01(\tH\x00R\x08\x66ileName\x88\x01\x01\x12\x1f\n\x0bline_number\x18\x04 \x01(\x05R\nlineNumberB\x0c\n\n_file_name\x1a\xef\x02\n\x0cQueryContext\x12\x64\n\x0c\x63ontext_type\x18\n \x01(\x0e\x32\x41.spark.connect.FetchErrorDetailsResponse.QueryContext.ContextTypeR\x0b\x63ontextType\x12\x1f\n\x0bobject_type\x18\x01 \x01(\tR\nobjectType\x12\x1f\n\x0bobject_name\x18\x02 \x01(\tR\nobjectName\x12\x1f\n\x0bstart_index\x18\x03 \x01(\x05R\nstartIndex\x12\x1d\n\nstop_index\x18\x04 \x01(\x05R\tstopIndex\x12\x1a\n\x08\x66ragment\x18\x05 \x01(\tR\x08\x66ragment\x12\x1a\n\x08\x63\x61llSite\x18\x06 \x01(\tR\x08\x63\x61llSite\x12\x18\n\x07summary\x18\x07 \x01(\tR\x07summary"%\n\x0b\x43ontextType\x12\x07\n\x03SQL\x10\x00\x12\r\n\tDATAFRAME\x10\x01\x1a\x99\x03\n\x0eSparkThrowable\x12$\n\x0b\x65rror_class\x18\x01 \x01(\tH\x00R\nerrorClass\x88\x01\x01\x12}\n\x12message_parameters\x18\x02 \x03(\x0b\x32N.spark.connect.FetchErrorDetailsResponse.SparkThrowable.MessageParametersEntryR\x11messageParameters\x12\\\n\x0equery_contexts\x18\x03 \x03(\x0b\x32\x35.spark.connect.FetchErrorDetailsResponse.QueryContextR\rqueryContexts\x12 \n\tsql_state\x18\x04 \x01(\tH\x01R\x08sqlState\x88\x01\x01\x1a\x44\n\x16MessageParametersEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x0e\n\x0c_error_classB\x0c\n\n_sql_state\x1a\xdb\x02\n\x05\x45rror\x12\x30\n\x14\x65rror_type_hierarchy\x18\x01 \x03(\tR\x12\x65rrorTypeHierarchy\x12\x18\n\x07message\x18\x02 \x01(\tR\x07message\x12[\n\x0bstack_trace\x18\x03 \x03(\x0b\x32:.spark.connect.FetchErrorDetailsResponse.StackTraceElementR\nstackTrace\x12 \n\tcause_idx\x18\x04 \x01(\x05H\x00R\x08\x63\x61useIdx\x88\x01\x01\x12\x65\n\x0fspark_throwable\x18\x05 \x01(\x0b\x32\x37.spark.connect.FetchErrorDetailsResponse.SparkThrowableH\x01R\x0esparkThrowable\x88\x01\x01\x42\x0c\n\n_cause_idxB\x12\n\x10_spark_throwableB\x11\n\x0f_root_error_idx2\xd1\x06\n\x13SparkConnectService\x12X\n\x0b\x45xecutePlan\x12!.spark.connect.ExecutePlanRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12V\n\x0b\x41nalyzePlan\x12!.spark.connect.AnalyzePlanRequest\x1a".spark.connect.AnalyzePlanResponse"\x00\x12G\n\x06\x43onfig\x12\x1c.spark.connect.ConfigRequest\x1a\x1d.spark.connect.ConfigResponse"\x00\x12[\n\x0c\x41\x64\x64\x41rtifacts\x12".spark.connect.AddArtifactsRequest\x1a#.spark.connect.AddArtifactsResponse"\x00(\x01\x12\x63\n\x0e\x41rtifactStatus\x12&.spark.connect.ArtifactStatusesRequest\x1a\'.spark.connect.ArtifactStatusesResponse"\x00\x12P\n\tInterrupt\x12\x1f.spark.connect.InterruptRequest\x1a .spark.connect.InterruptResponse"\x00\x12`\n\x0fReattachExecute\x12%.spark.connect.ReattachExecuteRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12_\n\x0eReleaseExecute\x12$.spark.connect.ReleaseExecuteRequest\x1a%.spark.connect.ReleaseExecuteResponse"\x00\x12h\n\x11\x46\x65tchErrorDetails\x12\'.spark.connect.FetchErrorDetailsRequest\x1a(.spark.connect.FetchErrorDetailsResponse"\x00\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -202,17 +202,19 @@ _FETCHERRORDETAILSREQUEST._serialized_start = 11378 _FETCHERRORDETAILSREQUEST._serialized_end = 11579 _FETCHERRORDETAILSRESPONSE._serialized_start = 11582 - _FETCHERRORDETAILSRESPONSE._serialized_end = 12857 + _FETCHERRORDETAILSRESPONSE._serialized_end = 13052 _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 11727 _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 11901 _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 11904 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 12076 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 12079 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 12488 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start = 12390 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end = 12458 - _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 12491 - _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 12838 - _SPARKCONNECTSERVICE._serialized_start = 12860 - _SPARKCONNECTSERVICE._serialized_end = 13709 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 12271 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_start = 12234 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_end = 12271 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 12274 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 12683 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start = 12585 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end = 12653 + _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 12686 + _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 13033 + _SPARKCONNECTSERVICE._serialized_start = 13055 + _SPARKCONNECTSERVICE._serialized_end = 13904 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index 5d2ebeb57399..c29feb4164cf 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -2885,11 +2885,35 @@ class FetchErrorDetailsResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor + class _ContextType: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + + class _ContextTypeEnumTypeWrapper( + google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[ + FetchErrorDetailsResponse.QueryContext._ContextType.ValueType + ], + builtins.type, + ): # noqa: F821 + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + SQL: FetchErrorDetailsResponse.QueryContext._ContextType.ValueType # 0 + DATAFRAME: FetchErrorDetailsResponse.QueryContext._ContextType.ValueType # 1 + + class ContextType(_ContextType, metaclass=_ContextTypeEnumTypeWrapper): + """The type of this query context.""" + + SQL: FetchErrorDetailsResponse.QueryContext.ContextType.ValueType # 0 + DATAFRAME: FetchErrorDetailsResponse.QueryContext.ContextType.ValueType # 1 + + CONTEXT_TYPE_FIELD_NUMBER: builtins.int OBJECT_TYPE_FIELD_NUMBER: builtins.int OBJECT_NAME_FIELD_NUMBER: builtins.int START_INDEX_FIELD_NUMBER: builtins.int STOP_INDEX_FIELD_NUMBER: builtins.int FRAGMENT_FIELD_NUMBER: builtins.int + CALLSITE_FIELD_NUMBER: builtins.int + SUMMARY_FIELD_NUMBER: builtins.int + context_type: global___FetchErrorDetailsResponse.QueryContext.ContextType.ValueType object_type: builtins.str """The object type of the query which throws the exception. If the exception is directly from the main query, it should be an empty string. @@ -2906,18 +2930,29 @@ class FetchErrorDetailsResponse(google.protobuf.message.Message): """The stopping index in the query which throws the exception. The index starts from 0.""" fragment: builtins.str """The corresponding fragment of the query which throws the exception.""" + callSite: builtins.str + """The user code (call site of the API) that caused throwing the exception.""" + summary: builtins.str + """Summary of the exception cause.""" def __init__( self, *, + context_type: global___FetchErrorDetailsResponse.QueryContext.ContextType.ValueType = ..., object_type: builtins.str = ..., object_name: builtins.str = ..., start_index: builtins.int = ..., stop_index: builtins.int = ..., fragment: builtins.str = ..., + callSite: builtins.str = ..., + summary: builtins.str = ..., ) -> None: ... def ClearField( self, field_name: typing_extensions.Literal[ + "callSite", + b"callSite", + "context_type", + b"context_type", "fragment", b"fragment", "object_name", @@ -2928,6 +2963,8 @@ class FetchErrorDetailsResponse(google.protobuf.message.Message): b"start_index", "stop_index", b"stop_index", + "summary", + b"summary", ], ) -> None: ... diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala index 51d2b4beab22..22e6c67090b4 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala @@ -26,7 +26,7 @@ import org.antlr.v4.runtime.tree.TerminalNodeImpl import org.apache.spark.{QueryContext, SparkThrowable, SparkThrowableHelper} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, WithOrigin} +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, SQLQueryContext, WithOrigin} import org.apache.spark.sql.catalyst.util.SparkParserUtils import org.apache.spark.sql.errors.QueryParsingErrors import org.apache.spark.sql.internal.SqlApiConf @@ -229,7 +229,7 @@ class ParseException( val builder = new StringBuilder builder ++= "\n" ++= message start match { - case Origin(Some(l), Some(p), _, _, _, _, _) => + case Origin(Some(l), Some(p), _, _, _, _, _, _) => builder ++= s" (line $l, pos $p)\n" command.foreach { cmd => val (above, below) = cmd.split("\n").splitAt(l) @@ -262,8 +262,7 @@ class ParseException( object ParseException { def getQueryContext(): Array[QueryContext] = { - val context = CurrentOrigin.get.context - if (context.isValid) Array(context) else Array.empty + Some(CurrentOrigin.get.context).collect { case b: SQLQueryContext if b.isValid => b }.toArray } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/SQLQueryContext.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala similarity index 78% rename from sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/SQLQueryContext.scala rename to sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala index 5b29cb3dde74..b8288b24535e 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/SQLQueryContext.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.trees -import org.apache.spark.QueryContext +import org.apache.spark.{QueryContext, QueryContextType} /** The class represents error context of a SQL query. */ case class SQLQueryContext( @@ -28,6 +28,7 @@ case class SQLQueryContext( sqlText: Option[String], originObjectType: Option[String], originObjectName: Option[String]) extends QueryContext { + override val contextType = QueryContextType.SQL override val objectType = originObjectType.getOrElse("") override val objectName = originObjectName.getOrElse("") @@ -40,7 +41,7 @@ case class SQLQueryContext( * SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i * ^^^^^^^^^^^^^^^ */ - lazy val summary: String = { + override lazy val summary: String = { // If the query context is missing or incorrect, simply return an empty string. if (!isValid) { "" @@ -116,7 +117,7 @@ case class SQLQueryContext( } /** Gets the textual fragment of a SQL query. */ - override lazy val fragment: String = { + lazy val fragment: String = { if (!isValid) { "" } else { @@ -128,6 +129,45 @@ case class SQLQueryContext( sqlText.isDefined && originStartIndex.isDefined && originStopIndex.isDefined && originStartIndex.get >= 0 && originStopIndex.get < sqlText.get.length && originStartIndex.get <= originStopIndex.get + } + + override def callSite: String = throw new UnsupportedOperationException +} + +case class DataFrameQueryContext( + override val fragment: String, + override val callSite: String) extends QueryContext { + override val contextType = QueryContextType.DataFrame + + override def objectType: String = throw new UnsupportedOperationException + override def objectName: String = throw new UnsupportedOperationException + override def startIndex: Int = throw new UnsupportedOperationException + override def stopIndex: Int = throw new UnsupportedOperationException + + override lazy val summary: String = { + val builder = new StringBuilder + builder ++= "== DataFrame ==\n" + builder ++= "\"" + + builder ++= fragment + builder ++= "\"" + builder ++= " was called from " + builder ++= callSite + builder += '\n' + builder.result() + } +} + +object DataFrameQueryContext { + def apply(elements: Array[StackTraceElement]): DataFrameQueryContext = { + val methodName = elements(0).getMethodName + val code = if (methodName.length > 1 && methodName(0) == '$') { + methodName.substring(1) + } else { + methodName + } + val callSite = elements(1).toString + DataFrameQueryContext(code, callSite) } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala index ec3e627ac958..dd24dae16ba8 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala @@ -30,15 +30,21 @@ case class Origin( stopIndex: Option[Int] = None, sqlText: Option[String] = None, objectType: Option[String] = None, - objectName: Option[String] = None) { + objectName: Option[String] = None, + stackTrace: Option[Array[StackTraceElement]] = None) { - lazy val context: SQLQueryContext = SQLQueryContext( - line, startPosition, startIndex, stopIndex, sqlText, objectType, objectName) - - def getQueryContext: Array[QueryContext] = if (context.isValid) { - Array(context) + lazy val context: QueryContext = if (stackTrace.isDefined) { + DataFrameQueryContext(stackTrace.get) } else { - Array.empty + SQLQueryContext( + line, startPosition, startIndex, stopIndex, sqlText, objectType, objectName) + } + + def getQueryContext: Array[QueryContext] = { + Some(context).filter { + case s: SQLQueryContext => s.isValid + case _ => true + }.toArray } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala index 7c1b37e9e581..99caef978bb4 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.QueryContext import org.apache.spark.sql.errors.ExecutionErrors /** @@ -27,37 +27,37 @@ object MathUtils { def addExact(a: Int, b: Int): Int = withOverflow(Math.addExact(a, b)) - def addExact(a: Int, b: Int, context: SQLQueryContext): Int = { + def addExact(a: Int, b: Int, context: QueryContext): Int = { withOverflow(Math.addExact(a, b), hint = "try_add", context) } def addExact(a: Long, b: Long): Long = withOverflow(Math.addExact(a, b)) - def addExact(a: Long, b: Long, context: SQLQueryContext): Long = { + def addExact(a: Long, b: Long, context: QueryContext): Long = { withOverflow(Math.addExact(a, b), hint = "try_add", context) } def subtractExact(a: Int, b: Int): Int = withOverflow(Math.subtractExact(a, b)) - def subtractExact(a: Int, b: Int, context: SQLQueryContext): Int = { + def subtractExact(a: Int, b: Int, context: QueryContext): Int = { withOverflow(Math.subtractExact(a, b), hint = "try_subtract", context) } def subtractExact(a: Long, b: Long): Long = withOverflow(Math.subtractExact(a, b)) - def subtractExact(a: Long, b: Long, context: SQLQueryContext): Long = { + def subtractExact(a: Long, b: Long, context: QueryContext): Long = { withOverflow(Math.subtractExact(a, b), hint = "try_subtract", context) } def multiplyExact(a: Int, b: Int): Int = withOverflow(Math.multiplyExact(a, b)) - def multiplyExact(a: Int, b: Int, context: SQLQueryContext): Int = { + def multiplyExact(a: Int, b: Int, context: QueryContext): Int = { withOverflow(Math.multiplyExact(a, b), hint = "try_multiply", context) } def multiplyExact(a: Long, b: Long): Long = withOverflow(Math.multiplyExact(a, b)) - def multiplyExact(a: Long, b: Long, context: SQLQueryContext): Long = { + def multiplyExact(a: Long, b: Long, context: QueryContext): Long = { withOverflow(Math.multiplyExact(a, b), hint = "try_multiply", context) } @@ -78,7 +78,7 @@ object MathUtils { def withOverflow[A]( f: => A, hint: String = "", - context: SQLQueryContext = null): A = { + context: QueryContext = null): A = { try { f } catch { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala index 698e7b37a9ef..f8a9274a5646 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala @@ -25,7 +25,7 @@ import scala.util.control.NonFatal import sun.util.calendar.ZoneInfo -import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.RebaseDateTime.{rebaseGregorianToJulianDays, rebaseGregorianToJulianMicros, rebaseJulianToGregorianDays, rebaseJulianToGregorianMicros} import org.apache.spark.sql.errors.ExecutionErrors @@ -355,7 +355,7 @@ trait SparkDateTimeUtils { def stringToDateAnsi( s: UTF8String, - context: SQLQueryContext = null): Int = { + context: QueryContext = null): Int = { stringToDate(s).getOrElse { throw ExecutionErrors.invalidInputInCastToDatetimeError(s, DateType, context) } @@ -567,7 +567,7 @@ trait SparkDateTimeUtils { def stringToTimestampAnsi( s: UTF8String, timeZoneId: ZoneId, - context: SQLQueryContext = null): Long = { + context: QueryContext = null): Long = { stringToTimestamp(s, timeZoneId).getOrElse { throw ExecutionErrors.invalidInputInCastToDatetimeError(s, TimestampType, context) } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala index 5e52e283338d..b30f7b7a00e9 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala @@ -16,9 +16,9 @@ */ package org.apache.spark.sql.errors -import org.apache.spark.{SparkArithmeticException, SparkException, SparkIllegalArgumentException, SparkNumberFormatException, SparkRuntimeException, SparkUnsupportedOperationException} +import org.apache.spark.{QueryContext, SparkArithmeticException, SparkException, SparkIllegalArgumentException, SparkNumberFormatException, SparkRuntimeException, SparkUnsupportedOperationException} import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext} +import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.catalyst.util.QuotingUtils import org.apache.spark.sql.catalyst.util.QuotingUtils.toSQLSchema import org.apache.spark.sql.types.{DataType, Decimal, StringType} @@ -191,7 +191,7 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { value: Decimal, decimalPrecision: Int, decimalScale: Int, - context: SQLQueryContext = null): ArithmeticException = { + context: QueryContext = null): ArithmeticException = { numericValueOutOfRange(value, decimalPrecision, decimalScale, context) } @@ -199,7 +199,7 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { value: Decimal, decimalPrecision: Int, decimalScale: Int, - context: SQLQueryContext = null): ArithmeticException = { + context: QueryContext = null): ArithmeticException = { numericValueOutOfRange(value, decimalPrecision, decimalScale, context) } @@ -207,7 +207,7 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { value: Decimal, decimalPrecision: Int, decimalScale: Int, - context: SQLQueryContext): ArithmeticException = { + context: QueryContext): ArithmeticException = { new SparkArithmeticException( errorClass = "NUMERIC_VALUE_OUT_OF_RANGE", messageParameters = Map( @@ -222,7 +222,7 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { def invalidInputInCastToNumberError( to: DataType, s: UTF8String, - context: SQLQueryContext): SparkNumberFormatException = { + context: QueryContext): SparkNumberFormatException = { val convertedValueStr = "'" + s.toString.replace("\\", "\\\\").replace("'", "\\'") + "'" new SparkNumberFormatException( errorClass = "CAST_INVALID_INPUT", diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrorsBase.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrorsBase.scala index 911d900053cf..d1d9dd806b3b 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrorsBase.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrorsBase.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.errors import java.util.Locale import org.apache.spark.QueryContext -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.util.{AttributeNameParser, QuotingUtils} import org.apache.spark.sql.types.{AbstractDataType, DataType, TypeCollection} import org.apache.spark.unsafe.types.UTF8String @@ -89,11 +88,11 @@ private[sql] trait DataTypeErrorsBase { "\"" + elem + "\"" } - def getSummary(sqlContext: SQLQueryContext): String = { + def getSummary(sqlContext: QueryContext): String = { if (sqlContext == null) "" else sqlContext.summary } - def getQueryContext(sqlContext: SQLQueryContext): Array[QueryContext] = { - if (sqlContext == null) Array.empty else Array(sqlContext.asInstanceOf[QueryContext]) + def getQueryContext(context: QueryContext): Array[QueryContext] = { + if (context == null) Array.empty else Array(context) } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala index c8321e81027b..394e56062071 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala @@ -21,9 +21,8 @@ import java.time.temporal.ChronoField import org.apache.arrow.vector.types.pojo.ArrowType -import org.apache.spark.{SparkArithmeticException, SparkBuildInfo, SparkDateTimeException, SparkException, SparkRuntimeException, SparkUnsupportedOperationException, SparkUpgradeException} +import org.apache.spark.{QueryContext, SparkArithmeticException, SparkBuildInfo, SparkDateTimeException, SparkException, SparkRuntimeException, SparkUnsupportedOperationException, SparkUpgradeException} import org.apache.spark.sql.catalyst.WalkedTypePath -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types.{DataType, DoubleType, StringType, UserDefinedType} import org.apache.spark.unsafe.types.UTF8String @@ -83,14 +82,14 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { def invalidInputInCastToDatetimeError( value: UTF8String, to: DataType, - context: SQLQueryContext): SparkDateTimeException = { + context: QueryContext): SparkDateTimeException = { invalidInputInCastToDatetimeErrorInternal(toSQLValue(value), StringType, to, context) } def invalidInputInCastToDatetimeError( value: Double, to: DataType, - context: SQLQueryContext): SparkDateTimeException = { + context: QueryContext): SparkDateTimeException = { invalidInputInCastToDatetimeErrorInternal(toSQLValue(value), DoubleType, to, context) } @@ -98,7 +97,7 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { sqlValue: String, from: DataType, to: DataType, - context: SQLQueryContext): SparkDateTimeException = { + context: QueryContext): SparkDateTimeException = { new SparkDateTimeException( errorClass = "CAST_INVALID_INPUT", messageParameters = Map( @@ -113,7 +112,7 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { def arithmeticOverflowError( message: String, hint: String = "", - context: SQLQueryContext = null): ArithmeticException = { + context: QueryContext = null): ArithmeticException = { val alternative = if (hint.nonEmpty) { s" Use '$hint' to tolerate overflow and return NULL instead." } else "" diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala index afe73635a682..c1661038025c 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -21,8 +21,8 @@ import java.math.{BigDecimal => JavaBigDecimal, BigInteger, MathContext, Roundin import scala.util.Try +import org.apache.spark.QueryContext import org.apache.spark.annotation.Unstable -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.errors.DataTypeErrors import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.unsafe.types.UTF8String @@ -341,7 +341,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { scale: Int, roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP, nullOnOverflow: Boolean = true, - context: SQLQueryContext = null): Decimal = { + context: QueryContext = null): Decimal = { val copy = clone() if (copy.changePrecision(precision, scale, roundMode)) { copy @@ -617,7 +617,7 @@ object Decimal { def fromStringANSI( str: UTF8String, to: DecimalType = DecimalType.USER_DEFAULT, - context: SQLQueryContext = null): Decimal = { + context: QueryContext = null): Decimal = { try { val bigDecimal = stringToJavaBigDecimal(str) // We fast fail because constructing a very large JavaBigDecimal to Decimal is very slow. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 99117d81b34a..62295fe26053 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -21,13 +21,13 @@ import java.time.{ZoneId, ZoneOffset} import java.util.Locale import java.util.concurrent.TimeUnit._ -import org.apache.spark.SparkArithmeticException +import org.apache.spark.{QueryContext, SparkArithmeticException} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.{SQLQueryContext, TreeNodeTag} +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.types.{PhysicalFractionalType, PhysicalIntegralType, PhysicalNumericType} import org.apache.spark.sql.catalyst.util._ @@ -527,7 +527,7 @@ case class Cast( } } - override def initQueryContext(): Option[SQLQueryContext] = if (ansiEnabled) { + override def initQueryContext(): Option[QueryContext] = if (ansiEnabled) { Some(origin.context) } else { None @@ -945,7 +945,7 @@ case class Cast( private[this] def toPrecision( value: Decimal, decimalType: DecimalType, - context: SQLQueryContext): Decimal = + context: QueryContext): Decimal = value.toPrecision( decimalType.precision, decimalType.scale, Decimal.ROUND_HALF_UP, !ansiEnabled, context) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index bd7369e57b05..3870e6d39a34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -19,14 +19,14 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale -import org.apache.spark.SparkException +import org.apache.spark.{QueryContext, SparkException} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.{BinaryLike, CurrentOrigin, LeafLike, QuaternaryLike, SQLQueryContext, TernaryLike, TreeNode, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{BinaryLike, CurrentOrigin, LeafLike, QuaternaryLike, TernaryLike, TreeNode, UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE, TreePattern} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.truncatedString @@ -613,11 +613,11 @@ abstract class UnaryExpression extends Expression with UnaryLike[Expression] { * to executors. It will also be kept after rule transforms. */ trait SupportQueryContext extends Expression with Serializable { - protected var queryContext: Option[SQLQueryContext] = initQueryContext() + protected var queryContext: Option[QueryContext] = initQueryContext() - def initQueryContext(): Option[SQLQueryContext] + def initQueryContext(): Option[QueryContext] - def getContextOrNull(): SQLQueryContext = queryContext.orNull + def getContextOrNull(): QueryContext = queryContext.orNull def getContextOrNullCode(ctx: CodegenContext, withErrorContext: Boolean = true): String = { if (withErrorContext && queryContext.isDefined) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index fd6131f18560..fe30e2ea6f3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.trees.{SQLQueryContext, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{AVERAGE, TreePattern} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.errors.QueryCompilationErrors @@ -134,7 +135,7 @@ case class Average( override protected def withNewChildInternal(newChild: Expression): Average = copy(child = newChild) - override def initQueryContext(): Option[SQLQueryContext] = if (evalMode == EvalMode.ANSI) { + override def initQueryContext(): Option[QueryContext] = if (evalMode == EvalMode.ANSI) { Some(origin.context) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index e3881520e490..dfd41ad12a28 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{EvalMode, _} -import org.apache.spark.sql.catalyst.trees.{SQLQueryContext, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{SUM, TreePattern} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.errors.QueryCompilationErrors @@ -186,7 +187,7 @@ case class Sum( // The flag `evalMode` won't be shown in the `toString` or `toAggString` methods override def flatArguments: Iterator[Any] = Iterator(child) - override def initQueryContext(): Option[SQLQueryContext] = if (evalMode == EvalMode.ANSI) { + override def initQueryContext(): Option[QueryContext] = if (evalMode == EvalMode.ANSI) { Some(origin.context) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index a556ac9f1294..e3c5184c5acc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.catalyst.expressions import scala.math.{max, min} +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLId, toSQLType} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_ARITHMETIC, TreePattern, UNARY_POSITIVE} import org.apache.spark.sql.catalyst.types.{PhysicalDecimalType, PhysicalFractionalType, PhysicalIntegerType, PhysicalIntegralType, PhysicalLongType} import org.apache.spark.sql.catalyst.util.{IntervalMathUtils, IntervalUtils, MathUtils, TypeUtils} @@ -266,7 +266,7 @@ abstract class BinaryArithmetic extends BinaryOperator final override val nodePatterns: Seq[TreePattern] = Seq(BINARY_ARITHMETIC) - override def initQueryContext(): Option[SQLQueryContext] = { + override def initQueryContext(): Option[QueryContext] = { if (failOnError) { Some(origin.context) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e45f0d72a5c2..7e2cf08a069b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -22,13 +22,14 @@ import java.util.Comparator import scala.collection.mutable import scala.reflect.ClassTag +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedSeed} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.{BinaryLike, SQLQueryContext, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{ARRAYS_ZIP, CONCAT, TreePattern} import org.apache.spark.sql.catalyst.types.{DataTypeUtils, PhysicalDataType, PhysicalIntegralType} import org.apache.spark.sql.catalyst.util._ @@ -2526,7 +2527,7 @@ case class ElementAt( override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): ElementAt = copy(left = newLeft, right = newRight) - override def initQueryContext(): Option[SQLQueryContext] = { + override def initQueryContext(): Option[QueryContext] = { if (failOnError && left.resolved && left.dataType.isInstanceOf[ArrayType]) { Some(origin.context) } else { @@ -5046,7 +5047,7 @@ case class ArrayInsert( newSrcArrayExpr: Expression, newPosExpr: Expression, newItemExpr: Expression): ArrayInsert = copy(srcArrayExpr = newSrcArrayExpr, posExpr = newPosExpr, itemExpr = newItemExpr) - override def initQueryContext(): Option[SQLQueryContext] = Some(origin.context) + override def initQueryContext(): Option[QueryContext] = Some(origin.context) } @ExpressionDescription( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 3885a5b9f5b3..a801d0367080 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.trees.TreePattern.{EXTRACT_VALUE, TreePattern} import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -316,7 +316,7 @@ case class GetArrayItem( newLeft: Expression, newRight: Expression): GetArrayItem = copy(child = newLeft, ordinal = newRight) - override def initQueryContext(): Option[SQLQueryContext] = if (failOnError) { + override def initQueryContext(): Option[QueryContext] = if (failOnError) { Some(origin.context) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 378920856eb1..5f13d397d1bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.types.PhysicalDecimalType import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.errors.QueryExecutionErrors @@ -146,7 +146,7 @@ case class CheckOverflow( override protected def withNewChildInternal(newChild: Expression): CheckOverflow = copy(child = newChild) - override def initQueryContext(): Option[SQLQueryContext] = if (!nullOnOverflow) { + override def initQueryContext(): Option[QueryContext] = if (!nullOnOverflow) { Some(origin.context) } else { None @@ -158,7 +158,7 @@ case class CheckOverflowInSum( child: Expression, dataType: DecimalType, nullOnOverflow: Boolean, - context: SQLQueryContext) extends UnaryExpression with SupportQueryContext { + context: QueryContext) extends UnaryExpression with SupportQueryContext { override def nullable: Boolean = true @@ -210,7 +210,7 @@ case class CheckOverflowInSum( override protected def withNewChildInternal(newChild: Expression): CheckOverflowInSum = copy(child = newChild) - override def initQueryContext(): Option[SQLQueryContext] = Option(context) + override def initQueryContext(): Option[QueryContext] = Option(context) } /** @@ -256,12 +256,12 @@ case class DecimalDivideWithOverflowCheck( left: Expression, right: Expression, override val dataType: DecimalType, - context: SQLQueryContext, + context: QueryContext, nullOnOverflow: Boolean) extends BinaryExpression with ExpectsInputTypes with SupportQueryContext { override def nullable: Boolean = nullOnOverflow override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, DecimalType) - override def initQueryContext(): Option[SQLQueryContext] = Option(context) + override def initQueryContext(): Option[QueryContext] = Option(context) def decimalMethod: String = "$div" override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 3750a9271cff..aa1f6159def8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, Un import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.trees.{BinaryLike, QuaternaryLike, TernaryLike} +import org.apache.spark.sql.catalyst.trees.{BinaryLike, CurrentOrigin, QuaternaryLike, TernaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util._ @@ -200,9 +200,11 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes { */ final def bind( f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): HigherOrderFunction = { - val res = bindInternal(f) - res.copyTagsFrom(this) - res + CurrentOrigin.withOrigin(origin) { + val res = bindInternal(f) + res.copyTagsFrom(this) + res + } } protected def bindInternal( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 5378639e6838..13676733a9ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -22,8 +22,8 @@ import java.util.Locale import com.google.common.math.{DoubleMath, IntMath, LongMath} +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.util.DateTimeConstants.MONTHS_PER_YEAR import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.catalyst.util.IntervalUtils._ @@ -604,7 +604,7 @@ trait IntervalDivide { minValue: Any, num: Expression, numValue: Any, - context: SQLQueryContext): Unit = { + context: QueryContext): Unit = { if (value == minValue && num.dataType.isInstanceOf[IntegralType]) { if (numValue.asInstanceOf[Number].longValue() == -1) { throw QueryExecutionErrors.intervalArithmeticOverflowError( @@ -616,7 +616,7 @@ trait IntervalDivide { def divideByZeroCheck( dataType: DataType, num: Any, - context: SQLQueryContext): Unit = dataType match { + context: QueryContext): Unit = dataType match { case _: DecimalType => if (num.asInstanceOf[Decimal].isZero) { throw QueryExecutionErrors.intervalDividedByZeroError(context) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index c320d98d9fd1..0c09e9be12e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.catalyst.expressions import java.{lang => jl} import java.util.Locale +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.util.{MathUtils, NumberConverter, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf @@ -480,7 +480,7 @@ case class Conv( newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(numExpr = newFirst, fromBaseExpr = newSecond, toBaseExpr = newThird) - override def initQueryContext(): Option[SQLQueryContext] = if (ansiEnabled) { + override def initQueryContext(): Option[QueryContext] = if (ansiEnabled) { Some(origin.context) } else { None @@ -1523,7 +1523,7 @@ abstract class RoundBase(child: Expression, scale: Expression, private lazy val scaleV: Any = scale.eval(EmptyRow) protected lazy val _scale: Int = scaleV.asInstanceOf[Int] - override def initQueryContext(): Option[SQLQueryContext] = { + override def initQueryContext(): Option[QueryContext] = { if (ansiEnabled) { Some(origin.context) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 563223bb33ba..5a123a178e8f 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -23,6 +23,7 @@ import java.util.{HashMap, Locale, Map => JMap} import scala.collection.mutable.ArrayBuffer +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch @@ -30,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke -import org.apache.spark.sql.catalyst.trees.{BinaryLike, SQLQueryContext} +import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LOWER} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -411,7 +412,7 @@ case class Elt( override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Elt = copy(children = newChildren) - override def initQueryContext(): Option[SQLQueryContext] = if (failOnError) { + override def initQueryContext(): Option[QueryContext] = if (failOnError) { Some(origin.context) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 23bbc91c16d5..8fabb4487620 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -24,7 +24,7 @@ import java.util.concurrent.TimeUnit._ import scala.util.control.NonFatal -import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.{Decimal, DoubleExactNumeric, TimestampNTZType, TimestampType} @@ -70,7 +70,7 @@ object DateTimeUtils extends SparkDateTimeUtils { // the "GMT" string. For example, it returns 2000-01-01T00:00+01:00 for 2000-01-01T00:00GMT+01:00. def cleanLegacyTimestampStr(s: UTF8String): UTF8String = s.replace(gmtUtf8, UTF8String.EMPTY_UTF8) - def doubleToTimestampAnsi(d: Double, context: SQLQueryContext): Long = { + def doubleToTimestampAnsi(d: Double, context: QueryContext): Long = { if (d.isNaN || d.isInfinite) { throw QueryExecutionErrors.invalidInputInCastToDatetimeError(d, TimestampType, context) } else { @@ -91,7 +91,7 @@ object DateTimeUtils extends SparkDateTimeUtils { def stringToTimestampWithoutTimeZoneAnsi( s: UTF8String, - context: SQLQueryContext): Long = { + context: QueryContext): Long = { stringToTimestampWithoutTimeZone(s, true).getOrElse { throw QueryExecutionErrors.invalidInputInCastToDatetimeError(s, TimestampNTZType, context) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala index 59765cde1f92..2730ab8f4b89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.QueryContext import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.unsafe.types.UTF8String @@ -54,7 +54,7 @@ object NumberConverter { fromPos: Int, value: Array[Byte], ansiEnabled: Boolean, - context: SQLQueryContext): Long = { + context: QueryContext): Long = { var v: Long = 0L // bound will always be positive since radix >= 2 // Note that: -1 is equivalent to 11111111...1111 which is the largest unsigned long value @@ -134,7 +134,7 @@ object NumberConverter { fromBase: Int, toBase: Int, ansiEnabled: Boolean, - context: SQLQueryContext): UTF8String = { + context: QueryContext): UTF8String = { if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX || Math.abs(toBase) < Character.MIN_RADIX || Math.abs(toBase) > Character.MAX_RADIX) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala index f7800469c352..1c3a5075dab2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.QueryContext import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.{ByteType, DataType, IntegerType, LongType, ShortType} import org.apache.spark.unsafe.types.UTF8String @@ -27,21 +27,21 @@ import org.apache.spark.unsafe.types.UTF8String */ object UTF8StringUtils { - def toLongExact(s: UTF8String, context: SQLQueryContext): Long = + def toLongExact(s: UTF8String, context: QueryContext): Long = withException(s.toLongExact, context, LongType, s) - def toIntExact(s: UTF8String, context: SQLQueryContext): Int = + def toIntExact(s: UTF8String, context: QueryContext): Int = withException(s.toIntExact, context, IntegerType, s) - def toShortExact(s: UTF8String, context: SQLQueryContext): Short = + def toShortExact(s: UTF8String, context: QueryContext): Short = withException(s.toShortExact, context, ShortType, s) - def toByteExact(s: UTF8String, context: SQLQueryContext): Byte = + def toByteExact(s: UTF8String, context: QueryContext): Byte = withException(s.toByteExact, context, ByteType, s) private def withException[A]( f: => A, - context: SQLQueryContext, + context: QueryContext, to: DataType, s: UTF8String): A = { try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index afc244509c41..30dfe8eebe6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.ValueInterval -import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext, TreeNode} +import org.apache.spark.sql.catalyst.trees.{Origin, TreeNode} import org.apache.spark.sql.catalyst.util.{sideBySide, BadRecordException, DateTimeUtils, FailFastMode, MapData} import org.apache.spark.sql.connector.catalog.{CatalogNotFoundException, Table, TableProvider} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -104,7 +104,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE value: Decimal, decimalPrecision: Int, decimalScale: Int, - context: SQLQueryContext = null): ArithmeticException = { + context: QueryContext = null): ArithmeticException = { new SparkArithmeticException( errorClass = "NUMERIC_VALUE_OUT_OF_RANGE", messageParameters = Map( @@ -118,7 +118,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def invalidInputSyntaxForBooleanError( s: UTF8String, - context: SQLQueryContext): SparkRuntimeException = { + context: QueryContext): SparkRuntimeException = { new SparkRuntimeException( errorClass = "CAST_INVALID_INPUT", messageParameters = Map( @@ -133,7 +133,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def invalidInputInCastToNumberError( to: DataType, s: UTF8String, - context: SQLQueryContext): SparkNumberFormatException = { + context: QueryContext): SparkNumberFormatException = { new SparkNumberFormatException( errorClass = "CAST_INVALID_INPUT", messageParameters = Map( @@ -194,15 +194,15 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE cause = e) } - def divideByZeroError(context: SQLQueryContext): ArithmeticException = { + def divideByZeroError(context: QueryContext): ArithmeticException = { new SparkArithmeticException( errorClass = "DIVIDE_BY_ZERO", messageParameters = Map("config" -> toSQLConf(SQLConf.ANSI_ENABLED.key)), - context = getQueryContext(context), + context = Array(context), summary = getSummary(context)) } - def intervalDividedByZeroError(context: SQLQueryContext): ArithmeticException = { + def intervalDividedByZeroError(context: QueryContext): ArithmeticException = { new SparkArithmeticException( errorClass = "INTERVAL_DIVIDED_BY_ZERO", messageParameters = Map.empty, @@ -213,7 +213,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def invalidArrayIndexError( index: Int, numElements: Int, - context: SQLQueryContext): ArrayIndexOutOfBoundsException = { + context: QueryContext): ArrayIndexOutOfBoundsException = { new SparkArrayIndexOutOfBoundsException( errorClass = "INVALID_ARRAY_INDEX", messageParameters = Map( @@ -227,7 +227,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def invalidElementAtIndexError( index: Int, numElements: Int, - context: SQLQueryContext): ArrayIndexOutOfBoundsException = { + context: QueryContext): ArrayIndexOutOfBoundsException = { new SparkArrayIndexOutOfBoundsException( errorClass = "INVALID_ARRAY_INDEX_IN_ELEMENT_AT", messageParameters = Map( @@ -292,15 +292,15 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE ansiIllegalArgumentError(e.getMessage) } - def overflowInSumOfDecimalError(context: SQLQueryContext): ArithmeticException = { + def overflowInSumOfDecimalError(context: QueryContext): ArithmeticException = { arithmeticOverflowError("Overflow in sum of decimals", context = context) } - def overflowInIntegralDivideError(context: SQLQueryContext): ArithmeticException = { + def overflowInIntegralDivideError(context: QueryContext): ArithmeticException = { arithmeticOverflowError("Overflow in integral divide", "try_divide", context) } - def overflowInConvError(context: SQLQueryContext): ArithmeticException = { + def overflowInConvError(context: QueryContext): ArithmeticException = { arithmeticOverflowError("Overflow in function conv()", context = context) } @@ -625,7 +625,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def intervalArithmeticOverflowError( message: String, hint: String = "", - context: SQLQueryContext): ArithmeticException = { + context: QueryContext): ArithmeticException = { val alternative = if (hint.nonEmpty) { s" Use '$hint' to tolerate overflow and return NULL instead." } else "" @@ -1391,7 +1391,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE "functionName" -> toSQLId(prettyName))) } - def invalidIndexOfZeroError(context: SQLQueryContext): RuntimeException = { + def invalidIndexOfZeroError(context: QueryContext): RuntimeException = { new SparkRuntimeException( errorClass = "INVALID_INDEX_OF_ZERO", cause = null, @@ -2556,7 +2556,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE cause = null) } - def multipleRowScalarSubqueryError(context: SQLQueryContext): Throwable = { + def multipleRowScalarSubqueryError(context: QueryContext): Throwable = { new SparkException( errorClass = "SCALAR_SUBQUERY_TOO_MANY_ROWS", messageParameters = Map.empty, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 997308c6ef44..ba4e7b279f51 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -35,8 +35,6 @@ import org.apache.spark.sql.types.StructType trait AnalysisTest extends PlanTest { - import org.apache.spark.QueryContext - protected def extendedAnalysisRules: Seq[Rule[LogicalPlan]] = Nil protected def createTempView( @@ -177,7 +175,7 @@ trait AnalysisTest extends PlanTest { inputPlan: LogicalPlan, expectedErrorClass: String, expectedMessageParameters: Map[String, String], - queryContext: Array[QueryContext] = Array.empty, + queryContext: Array[ExpectedContext] = Array.empty, caseSensitive: Boolean = true): Unit = { withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { val analyzer = getAnalyzer diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala index d91a080d8fe8..3fd0c1ee5de4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale -import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cast, CreateNamedStruct, GetStructField, If, IsNull, LessThanOrEqual, Literal} @@ -159,7 +158,7 @@ abstract class V2ANSIWriteAnalysisSuiteBase extends V2WriteAnalysisSuiteBase { inputPlan: LogicalPlan, expectedErrorClass: String, expectedMessageParameters: Map[String, String], - queryContext: Array[QueryContext] = Array.empty, + queryContext: Array[ExpectedContext] = Array.empty, caseSensitive: Boolean = true): Unit = { withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> StoreAssignmentPolicy.ANSI.toString) { super.assertAnalysisErrorClass( @@ -196,7 +195,7 @@ abstract class V2StrictWriteAnalysisSuiteBase extends V2WriteAnalysisSuiteBase { inputPlan: LogicalPlan, expectedErrorClass: String, expectedMessageParameters: Map[String, String], - queryContext: Array[QueryContext] = Array.empty, + queryContext: Array[ExpectedContext] = Array.empty, caseSensitive: Boolean = true): Unit = { withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> StoreAssignmentPolicy.STRICT.toString) { super.assertAnalysisErrorClass( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 9bb35a8b0b3d..0ca55ef67fd3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -70,8 +70,10 @@ private[sql] object Column { name: String, isDistinct: Boolean, ignoreNulls: Boolean, - inputs: Column*): Column = Column { - UnresolvedFunction(Seq(name), inputs.map(_.expr), isDistinct, ignoreNulls = ignoreNulls) + inputs: Column*): Column = withOrigin { + Column { + UnresolvedFunction(Seq(name), inputs.map(_.expr), isDistinct, ignoreNulls = ignoreNulls) + } } } @@ -148,12 +150,14 @@ class TypedColumn[-T, U]( @Stable class Column(val expr: Expression) extends Logging { - def this(name: String) = this(name match { - case "*" => UnresolvedStar(None) - case _ if name.endsWith(".*") => - val parts = UnresolvedAttribute.parseAttributeName(name.substring(0, name.length - 2)) - UnresolvedStar(Some(parts)) - case _ => UnresolvedAttribute.quotedString(name) + def this(name: String) = this(withOrigin { + name match { + case "*" => UnresolvedStar(None) + case _ if name.endsWith(".*") => + val parts = UnresolvedAttribute.parseAttributeName(name.substring(0, name.length - 2)) + UnresolvedStar(Some(parts)) + case _ => UnresolvedAttribute.quotedString(name) + } }) private def fn(name: String): Column = { @@ -180,7 +184,9 @@ class Column(val expr: Expression) extends Logging { } /** Creates a column based on the given expression. */ - private def withExpr(newExpr: Expression): Column = new Column(newExpr) + private def withExpr(newExpr: => Expression): Column = withOrigin { + new Column(newExpr) + } /** * Returns the expression for this column either with an existing or auto assigned name. @@ -1370,7 +1376,9 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def over(window: expressions.WindowSpec): Column = window.withAggregate(this) + def over(window: expressions.WindowSpec): Column = withOrigin { + window.withAggregate(this) + } /** * Defines an empty analytic clause. In this case the analytic function is applied diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index f3690773f6dd..a8c4d4f8d2ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -72,7 +72,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { def approxQuantile( col: String, probabilities: Array[Double], - relativeError: Double): Array[Double] = { + relativeError: Double): Array[Double] = withOrigin { approxQuantile(Array(col), probabilities, relativeError).head } @@ -97,7 +97,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { def approxQuantile( cols: Array[String], probabilities: Array[Double], - relativeError: Double): Array[Array[Double]] = { + relativeError: Double): Array[Array[Double]] = withOrigin { StatFunctions.multipleApproxQuantiles( df.select(cols.map(col): _*), cols, @@ -132,7 +132,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 1.4.0 */ - def cov(col1: String, col2: String): Double = { + def cov(col1: String, col2: String): Double = withOrigin { StatFunctions.calculateCov(df, Seq(col1, col2)) } @@ -154,7 +154,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 1.4.0 */ - def corr(col1: String, col2: String, method: String): Double = { + def corr(col1: String, col2: String, method: String): Double = withOrigin { require(method == "pearson", "Currently only the calculation of the Pearson Correlation " + "coefficient is supported.") StatFunctions.pearsonCorrelation(df, Seq(col1, col2)) @@ -210,7 +210,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 1.4.0 */ - def crosstab(col1: String, col2: String): DataFrame = { + def crosstab(col1: String, col2: String): DataFrame = withOrigin { StatFunctions.crossTabulate(df, col1, col2) } @@ -257,7 +257,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 1.4.0 */ - def freqItems(cols: Array[String], support: Double): DataFrame = { + def freqItems(cols: Array[String], support: Double): DataFrame = withOrigin { FrequentItems.singlePassFreqItems(df, cols, support) } @@ -276,7 +276,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 1.4.0 */ - def freqItems(cols: Array[String]): DataFrame = { + def freqItems(cols: Array[String]): DataFrame = withOrigin { FrequentItems.singlePassFreqItems(df, cols, 0.01) } @@ -320,7 +320,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 1.4.0 */ - def freqItems(cols: Seq[String], support: Double): DataFrame = { + def freqItems(cols: Seq[String], support: Double): DataFrame = withOrigin { FrequentItems.singlePassFreqItems(df, cols, support) } @@ -339,7 +339,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 1.4.0 */ - def freqItems(cols: Seq[String]): DataFrame = { + def freqItems(cols: Seq[String]): DataFrame = withOrigin { FrequentItems.singlePassFreqItems(df, cols, 0.01) } @@ -415,7 +415,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 3.0.0 */ - def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DataFrame = { + def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DataFrame = withOrigin { require(fractions.values.forall(p => p >= 0.0 && p <= 1.0), s"Fractions must be in [0, 1], but got $fractions.") import org.apache.spark.sql.functions.{rand, udf} @@ -497,7 +497,11 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @return a `CountMinSketch` over column `colName` * @since 2.0.0 */ - def countMinSketch(col: Column, eps: Double, confidence: Double, seed: Int): CountMinSketch = { + def countMinSketch( + col: Column, + eps: Double, + confidence: Double, + seed: Int): CountMinSketch = withOrigin { val countMinSketchAgg = new CountMinSketchAgg( col.expr, Literal(eps, DoubleType), @@ -555,7 +559,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @param numBits expected number of bits of the filter. * @since 2.0.0 */ - def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = { + def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = withOrigin { val bloomFilterAgg = new BloomFilterAggregate( col.expr, Literal(expectedNumItems, LongType), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 4f07133bb761..ba5eb790cea9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -508,9 +508,11 @@ class Dataset[T] private[sql]( * @group basic * @since 3.4.0 */ - def to(schema: StructType): DataFrame = withPlan { - val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] - Project.matchSchema(logicalPlan, replaced, sparkSession.sessionState.conf) + def to(schema: StructType): DataFrame = withOrigin { + withPlan { + val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] + Project.matchSchema(logicalPlan, replaced, sparkSession.sessionState.conf) + } } /** @@ -770,12 +772,14 @@ class Dataset[T] private[sql]( */ // We only accept an existing column name, not a derived column here as a watermark that is // defined on a derived column cannot referenced elsewhere in the plan. - def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = withTypedPlan { - val parsedDelay = IntervalUtils.fromIntervalString(delayThreshold) - require(!IntervalUtils.isNegative(parsedDelay), - s"delay threshold ($delayThreshold) should not be negative.") - EliminateEventTimeWatermark( - EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan)) + def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = withOrigin { + withTypedPlan { + val parsedDelay = IntervalUtils.fromIntervalString(delayThreshold) + require(!IntervalUtils.isNegative(parsedDelay), + s"delay threshold ($delayThreshold) should not be negative.") + EliminateEventTimeWatermark( + EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan)) + } } /** @@ -947,8 +951,10 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.0.0 */ - def join(right: Dataset[_]): DataFrame = withPlan { - Join(logicalPlan, right.logicalPlan, joinType = Inner, None, JoinHint.NONE) + def join(right: Dataset[_]): DataFrame = withOrigin { + withPlan { + Join(logicalPlan, right.logicalPlan, joinType = Inner, None, JoinHint.NONE) + } } /** @@ -1081,22 +1087,23 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.0.0 */ - def join(right: Dataset[_], usingColumns: Seq[String], joinType: String): DataFrame = { - // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right - // by creating a new instance for one of the branch. - val joined = sparkSession.sessionState.executePlan( - Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None, JoinHint.NONE)) - .analyzed.asInstanceOf[Join] - - withPlan { - Join( - joined.left, - joined.right, - UsingJoin(JoinType(joinType), usingColumns.toIndexedSeq), - None, - JoinHint.NONE) + def join(right: Dataset[_], usingColumns: Seq[String], joinType: String): DataFrame = + withOrigin { + // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right + // by creating a new instance for one of the branch. + val joined = sparkSession.sessionState.executePlan( + Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None, JoinHint.NONE)) + .analyzed.asInstanceOf[Join] + + withPlan { + Join( + joined.left, + joined.right, + UsingJoin(JoinType(joinType), usingColumns.toIndexedSeq), + None, + JoinHint.NONE) + } } - } /** * Inner join with another `DataFrame`, using the given join expression. @@ -1177,7 +1184,7 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.0.0 */ - def join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame = { + def join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame = withOrigin { withPlan { resolveSelfJoinCondition(right, Some(joinExprs), joinType) } @@ -1193,8 +1200,10 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.1.0 */ - def crossJoin(right: Dataset[_]): DataFrame = withPlan { - Join(logicalPlan, right.logicalPlan, joinType = Cross, None, JoinHint.NONE) + def crossJoin(right: Dataset[_]): DataFrame = withOrigin { + withPlan { + Join(logicalPlan, right.logicalPlan, joinType = Cross, None, JoinHint.NONE) + } } /** @@ -1218,27 +1227,28 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { - // Creates a Join node and resolve it first, to get join condition resolved, self-join resolved, - // etc. - val joined = sparkSession.sessionState.executePlan( - Join( - this.logicalPlan, - other.logicalPlan, - JoinType(joinType), - Some(condition.expr), - JoinHint.NONE)).analyzed.asInstanceOf[Join] - - implicit val tuple2Encoder: Encoder[(T, U)] = - ExpressionEncoder.tuple(this.exprEnc, other.exprEnc) - - withTypedPlan(JoinWith.typedJoinWith( - joined, - sqlContext.conf.dataFrameSelfJoinAutoResolveAmbiguity, - sparkSession.sessionState.analyzer.resolver, - this.exprEnc.isSerializedAsStructForTopLevel, - other.exprEnc.isSerializedAsStructForTopLevel)) - } + def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = + withOrigin { + // Creates a Join node and resolve it first, to get join condition resolved, self-join + // resolved, etc. + val joined = sparkSession.sessionState.executePlan( + Join( + this.logicalPlan, + other.logicalPlan, + JoinType(joinType), + Some(condition.expr), + JoinHint.NONE)).analyzed.asInstanceOf[Join] + + implicit val tuple2Encoder: Encoder[(T, U)] = + ExpressionEncoder.tuple(this.exprEnc, other.exprEnc) + + withTypedPlan(JoinWith.typedJoinWith( + joined, + sqlContext.conf.dataFrameSelfJoinAutoResolveAmbiguity, + sparkSession.sessionState.analyzer.resolver, + this.exprEnc.isSerializedAsStructForTopLevel, + other.exprEnc.isSerializedAsStructForTopLevel)) + } /** * Using inner equi-join to join this Dataset returning a `Tuple2` for each pair @@ -1416,14 +1426,16 @@ class Dataset[T] private[sql]( * @since 2.2.0 */ @scala.annotation.varargs - def hint(name: String, parameters: Any*): Dataset[T] = withTypedPlan { - val exprs = parameters.map { - case c: Column => c.expr - case s: Symbol => Column(s.name).expr - case e: Expression => e - case literal => Literal(literal) - }.toSeq - UnresolvedHint(name, exprs, logicalPlan) + def hint(name: String, parameters: Any*): Dataset[T] = withOrigin { + withTypedPlan { + val exprs = parameters.map { + case c: Column => c.expr + case s: Symbol => Column(s.name).expr + case e: Expression => e + case literal => Literal(literal) + }.toSeq + UnresolvedHint(name, exprs, logicalPlan) + } } /** @@ -1499,8 +1511,10 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def as(alias: String): Dataset[T] = withTypedPlan { - SubqueryAlias(alias, logicalPlan) + def as(alias: String): Dataset[T] = withOrigin { + withTypedPlan { + SubqueryAlias(alias, logicalPlan) + } } /** @@ -1537,25 +1551,28 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def select(cols: Column*): DataFrame = withPlan { - val untypedCols = cols.map { - case typedCol: TypedColumn[_, _] => - // Checks if a `TypedColumn` has been inserted with - // specific input type and schema by `withInputType`. - val needInputType = typedCol.expr.exists { - case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty => true - case _ => false - } + def select(cols: Column*): DataFrame = withOrigin { + withPlan { + val untypedCols = cols.map { + case typedCol: TypedColumn[_, _] => + // Checks if a `TypedColumn` has been inserted with + // specific input type and schema by `withInputType`. + val needInputType = typedCol.expr.exists { + case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty => true + case _ => false + } - if (!needInputType) { - typedCol - } else { - throw QueryCompilationErrors.cannotPassTypedColumnInUntypedSelectError(typedCol.toString) - } + if (!needInputType) { + typedCol + } else { + throw + QueryCompilationErrors.cannotPassTypedColumnInUntypedSelectError(typedCol.toString) + } - case other => other + case other => other + } + Project(untypedCols.map(_.named), logicalPlan) } - Project(untypedCols.map(_.named), logicalPlan) } /** @@ -1572,7 +1589,9 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)) : _*) + def select(col: String, cols: String*): DataFrame = withOrigin { + select((col +: cols).map(Column(_)) : _*) + } /** * Selects a set of SQL expressions. This is a variant of `select` that accepts @@ -1588,10 +1607,12 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def selectExpr(exprs: String*): DataFrame = sparkSession.withActive { - select(exprs.map { expr => - Column(sparkSession.sessionState.sqlParser.parseExpression(expr)) - }: _*) + def selectExpr(exprs: String*): DataFrame = withOrigin { + sparkSession.withActive { + select(exprs.map { expr => + Column(sparkSession.sessionState.sqlParser.parseExpression(expr)) + }: _*) + } } /** @@ -1605,7 +1626,7 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { + def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = withOrigin { implicit val encoder = c1.encoder val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan) @@ -1689,8 +1710,10 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def filter(condition: Column): Dataset[T] = withTypedPlan { - Filter(condition.expr, logicalPlan) + def filter(condition: Column): Dataset[T] = withOrigin { + withTypedPlan { + Filter(condition.expr, logicalPlan) + } } /** @@ -2049,15 +2072,17 @@ class Dataset[T] private[sql]( ids: Array[Column], values: Array[Column], variableColumnName: String, - valueColumnName: String): DataFrame = withPlan { - Unpivot( - Some(ids.map(_.named)), - Some(values.map(v => Seq(v.named))), - None, - variableColumnName, - Seq(valueColumnName), - logicalPlan - ) + valueColumnName: String): DataFrame = withOrigin { + withPlan { + Unpivot( + Some(ids.map(_.named)), + Some(values.map(v => Seq(v.named))), + None, + variableColumnName, + Seq(valueColumnName), + logicalPlan + ) + } } /** @@ -2080,15 +2105,17 @@ class Dataset[T] private[sql]( def unpivot( ids: Array[Column], variableColumnName: String, - valueColumnName: String): DataFrame = withPlan { - Unpivot( - Some(ids.map(_.named)), - None, - None, - variableColumnName, - Seq(valueColumnName), - logicalPlan - ) + valueColumnName: String): DataFrame = withOrigin { + withPlan { + Unpivot( + Some(ids.map(_.named)), + None, + None, + variableColumnName, + Seq(valueColumnName), + logicalPlan + ) + } } /** @@ -2205,8 +2232,10 @@ class Dataset[T] private[sql]( * @since 3.0.0 */ @varargs - def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = withTypedPlan { - CollectMetrics(name, (expr +: exprs).map(_.named), logicalPlan, id) + def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = withOrigin { + withTypedPlan { + CollectMetrics(name, (expr +: exprs).map(_.named), logicalPlan, id) + } } /** @@ -2243,8 +2272,10 @@ class Dataset[T] private[sql]( * @group typedrel * @since 2.0.0 */ - def limit(n: Int): Dataset[T] = withTypedPlan { - Limit(Literal(n), logicalPlan) + def limit(n: Int): Dataset[T] = withOrigin { + withTypedPlan { + Limit(Literal(n), logicalPlan) + } } /** @@ -2253,8 +2284,10 @@ class Dataset[T] private[sql]( * @group typedrel * @since 3.4.0 */ - def offset(n: Int): Dataset[T] = withTypedPlan { - Offset(Literal(n), logicalPlan) + def offset(n: Int): Dataset[T] = withOrigin { + withTypedPlan { + Offset(Literal(n), logicalPlan) + } } // This breaks caching, but it's usually ok because it addresses a very specific use case: @@ -2664,20 +2697,20 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @deprecated("use flatMap() or select() with functions.explode() instead", "2.0.0") - def explode[A <: Product : TypeTag](input: Column*)(f: Row => IterableOnce[A]): DataFrame = { - val elementSchema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] - - val convert = CatalystTypeConverters.createToCatalystConverter(elementSchema) - - val rowFunction = - f.andThen(_.map(convert(_).asInstanceOf[InternalRow])) - val generator = UserDefinedGenerator(elementSchema, rowFunction, input.map(_.expr)) - - withPlan { - Generate(generator, unrequiredChildIndex = Nil, outer = false, - qualifier = None, generatorOutput = Nil, logicalPlan) + def explode[A <: Product : TypeTag](input: Column*)(f: Row => IterableOnce[A]): DataFrame = + withOrigin { + val elementSchema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] + val convert = CatalystTypeConverters.createToCatalystConverter(elementSchema) + + val rowFunction = + f.andThen(_.map(convert(_).asInstanceOf[InternalRow])) + val generator = UserDefinedGenerator(elementSchema, rowFunction, input.map(_.expr)) + + withPlan { + Generate(generator, unrequiredChildIndex = Nil, outer = false, + qualifier = None, generatorOutput = Nil, logicalPlan) + } } - } /** * (Scala-specific) Returns a new Dataset where a single column has been expanded to zero @@ -2702,7 +2735,7 @@ class Dataset[T] private[sql]( */ @deprecated("use flatMap() or select() with functions.explode() instead", "2.0.0") def explode[A, B : TypeTag](inputColumn: String, outputColumn: String)(f: A => IterableOnce[B]) - : DataFrame = { + : DataFrame = withOrigin { val dataType = ScalaReflection.schemaFor[B].dataType val attributes = AttributeReference(outputColumn, dataType)() :: Nil // TODO handle the metadata? @@ -2859,7 +2892,7 @@ class Dataset[T] private[sql]( * @since 3.4.0 */ @throws[AnalysisException] - def withColumnsRenamed(colsMap: Map[String, String]): DataFrame = { + def withColumnsRenamed(colsMap: Map[String, String]): DataFrame = withOrigin { val resolver = sparkSession.sessionState.analyzer.resolver val output: Seq[NamedExpression] = queryExecution.analyzed.output @@ -3073,9 +3106,11 @@ class Dataset[T] private[sql]( * @group typedrel * @since 2.0.0 */ - def dropDuplicates(colNames: Seq[String]): Dataset[T] = withTypedPlan { - val groupCols = groupColsFromDropDuplicates(colNames) - Deduplicate(groupCols, logicalPlan) + def dropDuplicates(colNames: Seq[String]): Dataset[T] = withOrigin { + withTypedPlan { + val groupCols = groupColsFromDropDuplicates(colNames) + Deduplicate(groupCols, logicalPlan) + } } /** @@ -3151,10 +3186,12 @@ class Dataset[T] private[sql]( * @group typedrel * @since 3.5.0 */ - def dropDuplicatesWithinWatermark(colNames: Seq[String]): Dataset[T] = withTypedPlan { - val groupCols = groupColsFromDropDuplicates(colNames) - // UnsupportedOperationChecker will fail the query if this is called with batch Dataset. - DeduplicateWithinWatermark(groupCols, logicalPlan) + def dropDuplicatesWithinWatermark(colNames: Seq[String]): Dataset[T] = withOrigin { + withTypedPlan { + val groupCols = groupColsFromDropDuplicates(colNames) + // UnsupportedOperationChecker will fail the query if this is called with batch Dataset. + DeduplicateWithinWatermark(groupCols, logicalPlan) + } } /** @@ -3378,7 +3415,7 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def filter(func: T => Boolean): Dataset[T] = { + def filter(func: T => Boolean): Dataset[T] = withOrigin { withTypedPlan(TypedFilter(func, logicalPlan)) } @@ -3389,7 +3426,7 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def filter(func: FilterFunction[T]): Dataset[T] = { + def filter(func: FilterFunction[T]): Dataset[T] = withOrigin { withTypedPlan(TypedFilter(func, logicalPlan)) } @@ -3400,8 +3437,10 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan { - MapElements[T, U](func, logicalPlan) + def map[U : Encoder](func: T => U): Dataset[U] = withOrigin { + withTypedPlan { + MapElements[T, U](func, logicalPlan) + } } /** @@ -3411,7 +3450,7 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { + def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = withOrigin { implicit val uEnc = encoder withTypedPlan(MapElements[T, U](func, logicalPlan)) } @@ -3574,8 +3613,9 @@ class Dataset[T] private[sql]( * @group action * @since 3.0.0 */ - def tail(n: Int): Array[T] = withAction( - "tail", withTypedPlan(Tail(Literal(n), logicalPlan)).queryExecution)(collectFromPlan) + def tail(n: Int): Array[T] = withOrigin { + withAction("tail", withTypedPlan(Tail(Literal(n), logicalPlan)).queryExecution)(collectFromPlan) + } /** * Returns the first `n` rows in the Dataset as a list. @@ -3639,8 +3679,10 @@ class Dataset[T] private[sql]( * @group action * @since 1.6.0 */ - def count(): Long = withAction("count", groupBy().count().queryExecution) { plan => - plan.executeCollect().head.getLong(0) + def count(): Long = withOrigin { + withAction("count", groupBy().count().queryExecution) { plan => + plan.executeCollect().head.getLong(0) + } } /** @@ -3649,13 +3691,15 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def repartition(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = true, logicalPlan) + def repartition(numPartitions: Int): Dataset[T] = withOrigin { + withTypedPlan { + Repartition(numPartitions, shuffle = true, logicalPlan) + } } private def repartitionByExpression( numPartitions: Option[Int], - partitionExprs: Seq[Column]): Dataset[T] = { + partitionExprs: Seq[Column]): Dataset[T] = withOrigin { // The underlying `LogicalPlan` operator special-cases all-`SortOrder` arguments. // However, we don't want to complicate the semantics of this API method. // Instead, let's give users a friendly error message, pointing them to the new method. @@ -3700,7 +3744,7 @@ class Dataset[T] private[sql]( private def repartitionByRange( numPartitions: Option[Int], - partitionExprs: Seq[Column]): Dataset[T] = { + partitionExprs: Seq[Column]): Dataset[T] = withOrigin { require(partitionExprs.nonEmpty, "At least one partition-by expression must be specified.") val sortOrder: Seq[SortOrder] = partitionExprs.map(_.expr match { case expr: SortOrder => expr @@ -3772,8 +3816,10 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def coalesce(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = false, logicalPlan) + def coalesce(numPartitions: Int): Dataset[T] = withOrigin { + withTypedPlan { + Repartition(numPartitions, shuffle = false, logicalPlan) + } } /** @@ -3917,8 +3963,10 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @throws[AnalysisException] - def createTempView(viewName: String): Unit = withPlan { - createTempViewCommand(viewName, replace = false, global = false) + def createTempView(viewName: String): Unit = withOrigin { + withPlan { + createTempViewCommand(viewName, replace = false, global = false) + } } @@ -3930,8 +3978,10 @@ class Dataset[T] private[sql]( * @group basic * @since 2.0.0 */ - def createOrReplaceTempView(viewName: String): Unit = withPlan { - createTempViewCommand(viewName, replace = true, global = false) + def createOrReplaceTempView(viewName: String): Unit = withOrigin { + withPlan { + createTempViewCommand(viewName, replace = true, global = false) + } } /** @@ -3949,8 +3999,10 @@ class Dataset[T] private[sql]( * @since 2.1.0 */ @throws[AnalysisException] - def createGlobalTempView(viewName: String): Unit = withPlan { - createTempViewCommand(viewName, replace = false, global = true) + def createGlobalTempView(viewName: String): Unit = withOrigin { + withPlan { + createTempViewCommand(viewName, replace = false, global = true) + } } /** @@ -4358,7 +4410,7 @@ class Dataset[T] private[sql]( plan.executeCollect().map(fromRow) } - private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = { + private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = withOrigin { val sortOrder: Seq[SortOrder] = sortExprs.map { col => col.expr match { case expr: SortOrder => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 58f720154df5..771c743f7062 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.execution +import org.apache.spark.QueryContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, ExprId, InSet, ListQuery, Literal, PlanExpression, Predicate, SupportQueryContext} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.{LeafLike, SQLQueryContext, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{LeafLike, UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf @@ -68,7 +69,7 @@ case class ScalarSubquery( override def nullable: Boolean = true override def toString: String = plan.simpleString(SQLConf.get.maxToStringFields) override def withNewPlan(query: BaseSubqueryExec): ScalarSubquery = copy(plan = query) - def initQueryContext(): Option[SQLQueryContext] = Some(origin.context) + def initQueryContext(): Option[QueryContext] = Some(origin.context) override lazy val canonicalized: Expression = { ScalarSubquery(plan.canonicalized.asInstanceOf[BaseSubqueryExec], ExprId(0)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b5e40fe35cfe..a42df5bbcc29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -93,11 +93,13 @@ import org.apache.spark.util.Utils object functions { // scalastyle:on - private def withExpr(expr: Expression): Column = Column(expr) + private def withExpr(expr: => Expression): Column = withOrigin { + Column(expr) + } private def withAggregateFunction( - func: AggregateFunction, - isDistinct: Boolean = false): Column = { + func: => AggregateFunction, + isDistinct: Boolean = false): Column = withOrigin { Column(func.toAggregateExpression(isDistinct)) } @@ -127,16 +129,18 @@ object functions { * @group normal_funcs * @since 1.3.0 */ - def lit(literal: Any): Column = literal match { - case c: Column => c - case s: Symbol => new ColumnName(s.name) - case _ => - // This is different from `typedlit`. `typedlit` calls `Literal.create` to use - // `ScalaReflection` to get the type of `literal`. However, since we use `Any` in this method, - // `typedLit[Any](literal)` will always fail and fallback to `Literal.apply`. Hence, we can - // just manually call `Literal.apply` to skip the expensive `ScalaReflection` code. This is - // significantly better when there are many threads calling `lit` concurrently. + def lit(literal: Any): Column = withOrigin { + literal match { + case c: Column => c + case s: Symbol => new ColumnName(s.name) + case _ => + // This is different from `typedlit`. `typedlit` calls `Literal.create` to use + // `ScalaReflection` to get the type of `literal`. However, since we use `Any` in this + // method, `typedLit[Any](literal)` will always fail and fallback to `Literal.apply`. Hence, + // we can just manually call `Literal.apply` to skip the expensive `ScalaReflection` code. + // This is significantly better when there are many threads calling `lit` concurrently. Column(Literal(literal)) + } } /** @@ -147,7 +151,9 @@ object functions { * @group normal_funcs * @since 2.2.0 */ - def typedLit[T : TypeTag](literal: T): Column = typedlit(literal) + def typedLit[T : TypeTag](literal: T): Column = withOrigin { + typedlit(literal) + } /** * Creates a [[Column]] of literal value. @@ -164,10 +170,12 @@ object functions { * @group normal_funcs * @since 3.2.0 */ - def typedlit[T : TypeTag](literal: T): Column = literal match { - case c: Column => c - case s: Symbol => new ColumnName(s.name) - case _ => Column(Literal.create(literal)) + def typedlit[T : TypeTag](literal: T): Column = withOrigin { + literal match { + case c: Column => c + case s: Symbol => new ColumnName(s.name) + case _ => Column(Literal.create(literal)) + } } ////////////////////////////////////////////////////////////////////////////////////////////// @@ -5965,25 +5973,31 @@ object functions { def array_except(col1: Column, col2: Column): Column = Column.fn("array_except", col1, col2) - private def createLambda(f: Column => Column) = Column { - val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) - val function = f(Column(x)).expr - LambdaFunction(function, Seq(x)) + private def createLambda(f: Column => Column) = withOrigin { + Column { + val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) + val function = f(Column(x)).expr + LambdaFunction(function, Seq(x)) + } } - private def createLambda(f: (Column, Column) => Column) = Column { - val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) - val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y"))) - val function = f(Column(x), Column(y)).expr - LambdaFunction(function, Seq(x, y)) + private def createLambda(f: (Column, Column) => Column) = withOrigin { + Column { + val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) + val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y"))) + val function = f(Column(x), Column(y)).expr + LambdaFunction(function, Seq(x, y)) + } } - private def createLambda(f: (Column, Column, Column) => Column) = Column { - val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) - val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y"))) - val z = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("z"))) - val function = f(Column(x), Column(y), Column(z)).expr - LambdaFunction(function, Seq(x, y, z)) + private def createLambda(f: (Column, Column, Column) => Column) = withOrigin { + Column { + val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) + val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y"))) + val z = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("z"))) + val function = f(Column(x), Column(y), Column(z)).expr + LambdaFunction(function, Seq(x, y, z)) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 1794ac513749..7f00f6d6317c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -17,7 +17,10 @@ package org.apache.spark +import java.util.regex.Pattern + import org.apache.spark.annotation.{DeveloperApi, Unstable} +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} import org.apache.spark.sql.execution.SparkStrategy /** @@ -73,4 +76,41 @@ package object sql { * with rebasing. */ private[sql] val SPARK_LEGACY_INT96_METADATA_KEY = "org.apache.spark.legacyINT96" + + /** + * This helper function captures the Spark API and its call site in the user code from the current + * stacktrace. + * + * As adding `withOrigin` explicitly to all Spark API definition would be a huge change, + * `withOrigin` is used only at certain places where all API implementation surely pass through + * and the current stacktrace is filtered to the point where first Spark API code is invoked from + * the user code. + * + * As there might be multiple nested `withOrigin` calls (e.g. any Spark API implementations can + * invoke other APIs) only the first `withOrigin` is captured because that is closer to the user + * code. + * + * @param f The function that can use the origin. + * @return The result of `f`. + */ + private[sql] def withOrigin[T](f: => T): T = { + if (CurrentOrigin.get.stackTrace.isDefined) { + f + } else { + val st = Thread.currentThread().getStackTrace + var i = 3 + while (i < st.length && sparkCode(st(i))) i += 1 + val origin = + Origin(stackTrace = Some(Thread.currentThread().getStackTrace.slice(i - 1, i + 1))) + CurrentOrigin.withOrigin(origin)(f) + } + } + + private val sparkCodePattern = Pattern.compile("org\\.apache\\.spark\\.sql\\." + + "(?:functions|Column|ColumnName|SQLImplicits|Dataset|DataFrameStatFunctions)" + + "(?:|\\..*|\\$.*)") + + private def sparkCode(ste: StackTraceElement): Boolean = { + sparkCodePattern.matcher(ste.getClassName).matches() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 8c9ad2180faa..140daced3223 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -458,7 +458,8 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { parameters = Map( "functionName" -> "`in`", "dataType" -> "[\"INT\", \"ARRAY\"]", - "sqlExpr" -> "\"(a IN (b))\"") + "sqlExpr" -> "\"(a IN (b))\""), + context = ExpectedContext(fragment = "isin", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -525,7 +526,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { parameters = Map( "functionName" -> "`in`", "dataType" -> "[\"INT\", \"ARRAY\"]", - "sqlExpr" -> "\"(a IN (b))\"") + "sqlExpr" -> "\"(a IN (b))\""), + context = ExpectedContext( + fragment = "isInCollection", + callSitePattern = getCurrentClassCallSitePattern) ) } } @@ -1056,7 +1060,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"key\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"STRUCT\"") + "requiredType" -> "\"STRUCT\""), + context = ExpectedContext( + fragment = "withField", + callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1101,7 +1108,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"a.b\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"STRUCT\"") + "requiredType" -> "\"STRUCT\""), + context = ExpectedContext( + fragment = "withField", + callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1849,7 +1859,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"key\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"STRUCT\"") + "requiredType" -> "\"STRUCT\""), + context = ExpectedContext( + fragment = "dropFields", + callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1886,7 +1899,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"a.b\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"STRUCT\"") + "requiredType" -> "\"STRUCT\""), + context = ExpectedContext( + fragment = "dropFields", + callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1952,7 +1968,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { structLevel1.withColumn("a", $"a".dropFields("a", "b", "c")) }, errorClass = "DATATYPE_MISMATCH.CANNOT_DROP_ALL_FIELDS", - parameters = Map("sqlExpr" -> "\"update_fields(a, dropfield(), dropfield(), dropfield())\"") + parameters = Map("sqlExpr" -> "\"update_fields(a, dropfield(), dropfield(), dropfield())\""), + context = ExpectedContext( + fragment = "dropFields", + callSitePattern = getCurrentClassCallSitePattern) ) } @@ -2224,7 +2243,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { .select($"struct_col".dropFields("a", "b")) }, errorClass = "DATATYPE_MISMATCH.CANNOT_DROP_ALL_FIELDS", - parameters = Map("sqlExpr" -> "\"update_fields(struct_col, dropfield(), dropfield())\"") + parameters = Map("sqlExpr" -> "\"update_fields(struct_col, dropfield(), dropfield())\""), + context = ExpectedContext( + fragment = "dropFields", + callSitePattern = getCurrentClassCallSitePattern) ) checkAnswer( @@ -2398,7 +2420,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { structLevel1.select($"a".withField("d", lit(4)).withField("e", $"a.d" + 1).as("a")) }, errorClass = "FIELD_NOT_FOUND", - parameters = Map("fieldName" -> "`d`", "fields" -> "`a`, `b`, `c`")) + parameters = Map("fieldName" -> "`d`", "fields" -> "`a`, `b`, `c`"), + context = ExpectedContext( + fragment = "$", + callSitePattern = getCurrentClassCallSitePattern)) checkAnswer( structLevel1 @@ -2451,7 +2476,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { .select($"a".withField("z", $"a.c")).as("a") }, errorClass = "FIELD_NOT_FOUND", - parameters = Map("fieldName" -> "`c`", "fields" -> "`a`, `b`")) + parameters = Map("fieldName" -> "`c`", "fields" -> "`a`, `b`"), + context = ExpectedContext( + fragment = "$", + callSitePattern = getCurrentClassCallSitePattern)) } test("nestedDf should generate nested DataFrames") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index c40ecb88257f..e7c1f0414b61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -52,7 +52,8 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "inputSchema" -> "\"ARRAY\"", "dataType" -> "\"ARRAY\"" - ) + ), + context = ExpectedContext(fragment = "from_csv", getCurrentClassCallSitePattern) ) checkError( @@ -395,7 +396,8 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { .select(from_csv($"csv", $"schema", options)).collect() }, errorClass = "INVALID_SCHEMA.NON_STRING_LITERAL", - parameters = Map("inputSchema" -> "\"schema\"") + parameters = Map("inputSchema" -> "\"schema\""), + context = ExpectedContext(fragment = "from_csv", getCurrentClassCallSitePattern) ) checkError( @@ -403,7 +405,8 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { Seq("1").toDF("csv").select(from_csv($"csv", lit(1), options)).collect() }, errorClass = "INVALID_SCHEMA.NON_STRING_LITERAL", - parameters = Map("inputSchema" -> "\"1\"") + parameters = Map("inputSchema" -> "\"1\""), + context = ExpectedContext(fragment = "from_csv", getCurrentClassCallSitePattern) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index f34d7cf36807..c8eea985c106 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -633,7 +633,10 @@ class DataFrameAggregateSuite extends QueryTest "functionName" -> "`collect_set`", "dataType" -> "\"MAP\"", "sqlExpr" -> "\"collect_set(b)\"" - ) + ), + context = ExpectedContext( + fragment = "collect_set", + callSitePattern = getCurrentClassCallSitePattern) ) } @@ -706,7 +709,8 @@ class DataFrameAggregateSuite extends QueryTest testData.groupBy(sum($"key")).count() }, errorClass = "GROUP_BY_AGGREGATE", - parameters = Map("sqlExpr" -> "sum(key)") + parameters = Map("sqlExpr" -> "sum(key)"), + context = ExpectedContext(fragment = "sum", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1302,7 +1306,8 @@ class DataFrameAggregateSuite extends QueryTest "paramIndex" -> "2", "inputSql" -> "\"a\"", "inputType" -> "\"STRING\"", - "requiredType" -> "\"INTEGRAL\"")) + "requiredType" -> "\"INTEGRAL\""), + context = ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern)) } test("SPARK-34716: Support ANSI SQL intervals by the aggregate function `sum`") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 789196583c60..135ce834bfe5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -171,7 +171,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "requiredType" -> "\"ARRAY\"", "inputSql" -> "\"k\"", "inputType" -> "\"INT\"" - ) + ), + queryContext = Array( + ExpectedContext( + fragment = "map_from_arrays", + callSitePattern = getCurrentClassCallSitePattern)) ) val df5 = Seq((Seq("a", null), Seq(1, 2))).toDF("k", "v") @@ -758,7 +762,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { test("The given function only supports array input") { val df = Seq(1, 2, 3).toDF("a") - checkErrorMatchPVals( + checkError( exception = intercept[AnalysisException] { df.select(array_sort(col("a"), (x, y) => x - y)) }, @@ -769,7 +773,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "requiredType" -> "\"ARRAY\"", "inputSql" -> "\"a\"", "inputType" -> "\"INT\"" - )) + ), + matchPVals = true, + queryContext = Array( + ExpectedContext( + fragment = "array_sort", + callSitePattern = getCurrentClassCallSitePattern)) + ) } test("sort_array/array_sort functions") { @@ -1305,7 +1315,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"map_concat(map1, map2)\"", "dataType" -> "(\"MAP, INT>\" or \"MAP\")", - "functionName" -> "`map_concat`") + "functionName" -> "`map_concat`"), + context = + ExpectedContext( + fragment = "map_concat", + callSitePattern = getCurrentClassCallSitePattern) ) checkError( @@ -1333,7 +1347,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"map_concat(map1, 12)\"", "dataType" -> "[\"MAP, INT>\", \"INT\"]", - "functionName" -> "`map_concat`") + "functionName" -> "`map_concat`"), + context = + ExpectedContext( + fragment = "map_concat", + callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1402,7 +1420,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"a\"", "inputType" -> "\"INT\"", "requiredType" -> "\"ARRAY\" of pair \"STRUCT\"" - ) + ), + context = + ExpectedContext( + fragment = "map_from_entries", + callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1439,7 +1461,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"array_contains(a, NULL)\"", "functionName" -> "`array_contains`" - ) + ), + context = + ExpectedContext( + fragment = "array_contains", + callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -2348,7 +2374,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "functionName" -> "`array_union`", "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", - "rightType" -> "\"ARRAY\"")) + "rightType" -> "\"ARRAY\""), + context = + ExpectedContext( + fragment = "array_union", + callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = intercept[AnalysisException] { @@ -2379,7 +2409,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "functionName" -> "`array_union`", "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", - "rightType" -> "\"VOID\"") + "rightType" -> "\"VOID\""), + context = ExpectedContext( + fragment = "array_union", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -2410,7 +2442,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "functionName" -> "`array_union`", "arrayType" -> "\"ARRAY\"", "leftType" -> "\"ARRAY>\"", - "rightType" -> "\"ARRAY\"") + "rightType" -> "\"ARRAY\""), + queryContext = Array(ExpectedContext( + fragment = "array_union", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -2647,7 +2681,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"arr\"", "inputType" -> "\"ARRAY\"", "requiredType" -> "\"ARRAY\" of \"ARRAY\"" - ) + ), + context = ExpectedContext( + fragment = "flatten", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -2660,7 +2696,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", "requiredType" -> "\"ARRAY\" of \"ARRAY\"" - ) + ), + queryContext = Array(ExpectedContext( + fragment = "flatten", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -2673,7 +2711,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"s\"", "inputType" -> "\"STRING\"", "requiredType" -> "\"ARRAY\" of \"ARRAY\"" - ) + ), + queryContext = Array(ExpectedContext( + fragment = "flatten", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -2782,7 +2822,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"b\"", "inputType" -> "\"STRING\"", "requiredType" -> "\"INT\"" - ) + ), + context = ExpectedContext( + fragment = "array_repeat", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -2795,7 +2837,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"1\"", "inputType" -> "\"STRING\"", "requiredType" -> "\"INT\"" - ) + ), + queryContext = Array(ExpectedContext( + fragment = "array_repeat", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3123,7 +3167,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", "rightType" -> "\"VOID\"" - ) + ), + context = ExpectedContext( + fragment = "array_except", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -3151,7 +3197,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"ARRAY\"", "rightType" -> "\"ARRAY\"" - ) + ), + queryContext = Array(ExpectedContext( + fragment = "array_except", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3179,7 +3227,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"ARRAY\"", "rightType" -> "\"VOID\"" - ) + ), + queryContext = Array(ExpectedContext( + fragment = "array_except", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3207,7 +3257,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", "rightType" -> "\"ARRAY\"" - ) + ), + queryContext = Array(ExpectedContext( + fragment = "array_except", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3276,7 +3328,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", "rightType" -> "\"VOID\"" - ) + ), + context = ExpectedContext( + fragment = "array_intersect", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -3305,7 +3359,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"ARRAY\"", "rightType" -> "\"ARRAY\"" - ) + ), + queryContext = Array(ExpectedContext( + fragment = "array_intersect", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3334,7 +3390,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", "rightType" -> "\"ARRAY\"" - ) + ), + queryContext = Array( + ExpectedContext( + fragment = "array_intersect", + callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3750,7 +3810,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"MAP\"")) + "requiredType" -> "\"MAP\""), + queryContext = Array( + ExpectedContext(fragment = "map_filter", callSitePattern = getCurrentClassCallSitePattern))) checkError( exception = @@ -3933,7 +3995,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"ARRAY\"")) + "requiredType" -> "\"ARRAY\""), + queryContext = Array( + ExpectedContext(fragment = "filter", callSitePattern = getCurrentClassCallSitePattern))) checkError( exception = intercept[AnalysisException] { @@ -3945,7 +4009,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "2", "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", - "requiredType" -> "\"BOOLEAN\"")) + "requiredType" -> "\"BOOLEAN\""), + ExpectedContext( + fragment = "filter(s, x -> x)", + start = 0, + stop = 16)) checkError( exception = intercept[AnalysisException] { @@ -3957,7 +4025,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "2", "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", - "requiredType" -> "\"BOOLEAN\"")) + "requiredType" -> "\"BOOLEAN\""), + context = ExpectedContext( + fragment = "filter", + callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = @@ -4112,7 +4183,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"ARRAY\"")) + "requiredType" -> "\"ARRAY\""), + queryContext = Array( + ExpectedContext(fragment = "exists", callSitePattern = getCurrentClassCallSitePattern))) checkError( exception = intercept[AnalysisException] { @@ -4124,7 +4197,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "2", "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", - "requiredType" -> "\"BOOLEAN\"")) + "requiredType" -> "\"BOOLEAN\""), + context = ExpectedContext( + fragment = "exists(s, x -> x)", + start = 0, + stop = 16) + ) checkError( exception = intercept[AnalysisException] { @@ -4136,7 +4214,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "2", "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", - "requiredType" -> "\"BOOLEAN\"")) + "requiredType" -> "\"BOOLEAN\""), + context = + ExpectedContext(fragment = "exists", callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = intercept[AnalysisException](df.selectExpr("exists(a, x -> x)")), @@ -4304,7 +4384,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"ARRAY\"")) + "requiredType" -> "\"ARRAY\""), + queryContext = Array( + ExpectedContext(fragment = "forall", callSitePattern = getCurrentClassCallSitePattern))) checkError( exception = intercept[AnalysisException] { @@ -4316,7 +4398,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "2", "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", - "requiredType" -> "\"BOOLEAN\"")) + "requiredType" -> "\"BOOLEAN\""), + context = ExpectedContext( + fragment = "forall(s, x -> x)", + start = 0, + stop = 16)) checkError( exception = intercept[AnalysisException] { @@ -4328,7 +4414,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "2", "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", - "requiredType" -> "\"BOOLEAN\"")) + "requiredType" -> "\"BOOLEAN\""), + context = + ExpectedContext(fragment = "forall", callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = intercept[AnalysisException](df.selectExpr("forall(a, x -> x)")), @@ -4343,7 +4431,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkError( exception = intercept[AnalysisException](df.select(forall(col("a"), x => x))), errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", - parameters = Map("objectName" -> "`a`", "proposal" -> "`i`, `s`")) + parameters = Map("objectName" -> "`a`", "proposal" -> "`i`, `s`"), + queryContext = Array( + ExpectedContext(fragment = "col", callSitePattern = getCurrentClassCallSitePattern))) } test("aggregate function - array for primitive type not containing null") { @@ -4581,7 +4671,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"ARRAY\"")) + "requiredType" -> "\"ARRAY\""), + queryContext = Array( + ExpectedContext(fragment = "aggregate", callSitePattern = getCurrentClassCallSitePattern))) // scalastyle:on line.size.limit // scalastyle:off line.size.limit @@ -4597,7 +4689,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", "requiredType" -> "\"INT\"" - )) + ), + context = ExpectedContext( + fragment = s"$agg(s, 0, (acc, x) -> x)", + start = 0, + stop = agg.length + 20)) } // scalastyle:on line.size.limit @@ -4613,7 +4709,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", "requiredType" -> "\"INT\"" - )) + ), + context = + ExpectedContext(fragment = "aggregate", callSitePattern = getCurrentClassCallSitePattern)) // scalastyle:on line.size.limit Seq("aggregate", "reduce").foreach { agg => @@ -4719,7 +4817,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "sqlExpr" -> """"map_zip_with\(mis, mmi, lambdafunction\(concat\(x_\d+, y_\d+, z_\d+\), x_\d+, y_\d+, z_\d+\)\)"""", "functionName" -> "`map_zip_with`", "leftType" -> "\"INT\"", - "rightType" -> "\"MAP\"")) + "rightType" -> "\"MAP\""), + queryContext = Array( + ExpectedContext(fragment = "map_zip_with", callSitePattern = getCurrentClassCallSitePattern))) // scalastyle:on line.size.limit checkError( @@ -4749,7 +4849,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "sqlExpr" -> """"map_zip_with\(i, mis, lambdafunction\(concat\(x_\d+, y_\d+, z_\d+\), x_\d+, y_\d+, z_\d+\)\)"""", "paramIndex" -> "1", "inputSql" -> "\"i\"", - "inputType" -> "\"INT\"", "requiredType" -> "\"MAP\"")) + "inputType" -> "\"INT\"", "requiredType" -> "\"MAP\""), + queryContext = Array( + ExpectedContext(fragment = "map_zip_with", callSitePattern = getCurrentClassCallSitePattern))) // scalastyle:on line.size.limit checkError( @@ -4779,7 +4881,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "sqlExpr" -> """"map_zip_with\(mis, i, lambdafunction\(concat\(x_\d+, y_\d+, z_\d+\), x_\d+, y_\d+, z_\d+\)\)"""", "paramIndex" -> "2", "inputSql" -> "\"i\"", - "inputType" -> "\"INT\"", "requiredType" -> "\"MAP\"")) + "inputType" -> "\"INT\"", "requiredType" -> "\"MAP\""), + queryContext = Array( + ExpectedContext(fragment = "map_zip_with", callSitePattern = getCurrentClassCallSitePattern))) // scalastyle:on line.size.limit checkError( @@ -5235,7 +5339,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"x\"", "inputType" -> "\"ARRAY\"", - "requiredType" -> "\"MAP\"")) + "requiredType" -> "\"MAP\""), + queryContext = Array( + ExpectedContext( + fragment = "transform_values", + callSitePattern = getCurrentClassCallSitePattern))) } testInvalidLambdaFunctions() @@ -5375,7 +5483,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"ARRAY\"")) + "requiredType" -> "\"ARRAY\""), + queryContext = Array( + ExpectedContext(fragment = "zip_with", callSitePattern = getCurrentClassCallSitePattern))) checkError( exception = @@ -5631,7 +5741,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"map(m, 1)\"", "keyType" -> "\"MAP\"" - ) + ), + context = + ExpectedContext(fragment = "map", callSitePattern = getCurrentClassCallSitePattern) ) checkAnswer( df.select(map(map_entries($"m"), lit(1))), @@ -5753,7 +5865,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "requiredType" -> "\"ARRAY\"", "inputSql" -> "\"a\"", "inputType" -> "\"INT\"" - )) + ), + context = ExpectedContext( + fragment = "array_compact", + callSitePattern = getCurrentClassCallSitePattern)) } test("array_append -> Unit Test cases for the function ") { @@ -5772,7 +5887,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "dataType" -> "\"ARRAY\"", "leftType" -> "\"ARRAY\"", "rightType" -> "\"INT\"", - "sqlExpr" -> "\"array_append(a, b)\"") + "sqlExpr" -> "\"array_append(a, b)\""), + context = + ExpectedContext(fragment = "array_append", callSitePattern = getCurrentClassCallSitePattern) ) checkAnswer(df1.selectExpr("array_append(a, 3)"), Seq(Row(Seq(3, 2, 5, 1, 2, 3)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 237915fb63fa..b3bf9405a99f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -310,7 +310,8 @@ class DataFramePivotSuite extends QueryTest with SharedSparkSession { .agg(sum($"sales.earnings")) }, errorClass = "GROUP_BY_AGGREGATE", - parameters = Map("sqlExpr" -> "min(training)") + parameters = Map("sqlExpr" -> "min(training)"), + context = ExpectedContext(fragment = "min", callSitePattern = getCurrentClassCallSitePattern) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala index 88ef5936264d..c777d2207584 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala @@ -484,7 +484,8 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { checkError(ex, errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map("objectName" -> "`df1`.`timeStr`", - "proposal" -> "`df3`.`timeStr`, `df1`.`tsStr`")) + "proposal" -> "`df3`.`timeStr`, `df1`.`tsStr`"), + context = ExpectedContext(fragment = "$", getCurrentClassCallSitePattern)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index ab8aab0713a4..e7c1d2c772c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -373,7 +373,10 @@ class DataFrameSetOperationsSuite extends QueryTest errorClass = "UNSUPPORTED_FEATURE.SET_OPERATION_ON_MAP_TYPE", parameters = Map( "colName" -> "`m`", - "dataType" -> "\"MAP\"") + "dataType" -> "\"MAP\""), + context = ExpectedContext( + fragment = "distinct", + callSitePattern = getCurrentClassCallSitePattern) ) withTempView("v") { df.createOrReplaceTempView("v") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 430e36221025..20ac2a9e9461 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -146,7 +146,9 @@ class DataFrameStatSuite extends QueryTest with SharedSparkSession { parameters = Map( "name" -> "`num`", "referenceNames" -> "[`table1`.`num`, `table2`.`num`]" - ) + ), + context = + ExpectedContext(fragment = "freqItems", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -156,7 +158,9 @@ class DataFrameStatSuite extends QueryTest with SharedSparkSession { parameters = Map( "name" -> "`num`", "referenceNames" -> "[`table1`.`num`, `table2`.`num`]" - ) + ), + context = ExpectedContext( + fragment = "approxQuantile", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b7450e564872..b0a0b189cb7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -354,7 +354,8 @@ class DataFrameSuite extends QueryTest "paramIndex" -> "1", "inputSql"-> "\"csv\"", "inputType" -> "\"STRING\"", - "requiredType" -> "(\"ARRAY\" or \"MAP\")") + "requiredType" -> "(\"ARRAY\" or \"MAP\")"), + context = ExpectedContext(fragment = "explode", getCurrentClassCallSitePattern) ) val df2 = Seq(Array("1", "2"), Array("4"), Array("7", "8", "9")).toDF("csv") @@ -2947,7 +2948,8 @@ class DataFrameSuite extends QueryTest df.groupBy($"d", $"b").as[GroupByKey, Row] }, errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", - parameters = Map("objectName" -> "`d`", "proposal" -> "`a`, `b`, `c`")) + parameters = Map("objectName" -> "`d`", "proposal" -> "`a`, `b`, `c`"), + context = ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern)) } test("SPARK-40601: flatMapCoGroupsInPandas should fail with different number of keys") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala index 2a81f7e7c2f3..bb744cfd8ab4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala @@ -184,7 +184,9 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { "sqlExpr" -> (""""\(ORDER BY key ASC NULLS FIRST, value ASC NULLS FIRST RANGE """ + """BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING\)"""") ), - matchPVals = true + matchPVals = true, + queryContext = + Array(ExpectedContext(fragment = "over", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( @@ -198,7 +200,9 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { "sqlExpr" -> (""""\(ORDER BY key ASC NULLS FIRST, value ASC NULLS FIRST RANGE """ + """BETWEEN -1 FOLLOWING AND UNBOUNDED FOLLOWING\)"""") ), - matchPVals = true + matchPVals = true, + queryContext = + Array(ExpectedContext(fragment = "over", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( @@ -212,7 +216,9 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { "sqlExpr" -> (""""\(ORDER BY key ASC NULLS FIRST, value ASC NULLS FIRST RANGE """ + """BETWEEN -1 FOLLOWING AND 1 FOLLOWING\)"""") ), - matchPVals = true + matchPVals = true, + queryContext = + Array(ExpectedContext(fragment = "over", callSitePattern = getCurrentClassCallSitePattern)) ) } @@ -240,7 +246,8 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { "expectedType" -> ("(\"NUMERIC\" or \"INTERVAL DAY TO SECOND\" or \"INTERVAL YEAR " + "TO MONTH\" or \"INTERVAL\")"), "sqlExpr" -> "\"RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING\"" - ) + ), + context = ExpectedContext(fragment = "over", callSitePattern = getCurrentClassCallSitePattern) ) checkError( @@ -255,7 +262,8 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { "expectedType" -> ("(\"NUMERIC\" or \"INTERVAL DAY TO SECOND\" or \"INTERVAL YEAR " + "TO MONTH\" or \"INTERVAL\")"), "sqlExpr" -> "\"RANGE BETWEEN -1 FOLLOWING AND UNBOUNDED FOLLOWING\"" - ) + ), + context = ExpectedContext(fragment = "over", callSitePattern = getCurrentClassCallSitePattern) ) checkError( @@ -270,7 +278,8 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { "expectedType" -> ("(\"NUMERIC\" or \"INTERVAL DAY TO SECOND\" or \"INTERVAL YEAR " + "TO MONTH\" or \"INTERVAL\")"), "sqlExpr" -> "\"RANGE BETWEEN -1 FOLLOWING AND 1 FOLLOWING\"" - ) + ), + context = ExpectedContext(fragment = "over", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -462,7 +471,8 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { "upper" -> "\"2\"", "lowerType" -> "\"INTERVAL\"", "upperType" -> "\"BIGINT\"" - ) + ), + context = ExpectedContext(fragment = "over", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -481,7 +491,8 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"RANGE BETWEEN nonfoldableliteral() FOLLOWING AND 2 FOLLOWING\"", "location" -> "lower", - "expression" -> "\"nonfoldableliteral()\"") + "expression" -> "\"nonfoldableliteral()\""), + context = ExpectedContext(fragment = "over", callSitePattern = getCurrentClassCallSitePattern) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 84133eb485f0..6969c4303e01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -412,7 +412,10 @@ class DataFrameWindowFunctionsSuite extends QueryTest errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`invalid`", - "proposal" -> "`value`, `key`")) + "proposal" -> "`value`, `key`"), + context = ExpectedContext( + fragment = "count", + callSitePattern = getCurrentClassCallSitePattern)) } test("numerical aggregate functions on string column") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index bf78e6e11fe9..66105d2ac429 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -606,7 +606,8 @@ class DatasetSuite extends QueryTest } }, errorClass = "INVALID_USAGE_OF_STAR_OR_REGEX", - parameters = Map("elem" -> "'*'", "prettyName" -> "MapGroups")) + parameters = Map("elem" -> "'*'", "prettyName" -> "MapGroups"), + context = ExpectedContext(fragment = "$", getCurrentClassCallSitePattern)) } test("groupBy function, flatMapSorted") { @@ -634,7 +635,8 @@ class DatasetSuite extends QueryTest } }, errorClass = "INVALID_USAGE_OF_STAR_OR_REGEX", - parameters = Map("elem" -> "'*'", "prettyName" -> "MapGroups")) + parameters = Map("elem" -> "'*'", "prettyName" -> "MapGroups"), + context = ExpectedContext(fragment = "$", getCurrentClassCallSitePattern)) } test("groupBy, flatMapSorted desc") { @@ -2290,7 +2292,8 @@ class DatasetSuite extends QueryTest sqlState = None, parameters = Map( "objectName" -> s"`${colName.replace(".", "`.`")}`", - "proposal" -> "`field.1`, `field 2`")) + "proposal" -> "`field.1`, `field 2`"), + context = ExpectedContext(fragment = "select", getCurrentClassCallSitePattern)) } } } @@ -2304,7 +2307,8 @@ class DatasetSuite extends QueryTest sqlState = None, parameters = Map( "objectName" -> "`the`.`id`", - "proposal" -> "`the.id`")) + "proposal" -> "`the.id`"), + context = ExpectedContext(fragment = "select", getCurrentClassCallSitePattern)) } test("SPARK-39783: backticks in error message for map candidate key with dots") { @@ -2318,7 +2322,8 @@ class DatasetSuite extends QueryTest sqlState = None, parameters = Map( "objectName" -> "`nonexisting`", - "proposal" -> "`map`, `other.column`")) + "proposal" -> "`map`, `other.column`"), + context = ExpectedContext(fragment = "$", getCurrentClassCallSitePattern)) } test("groupBy.as") { @@ -2659,6 +2664,22 @@ class DatasetSuite extends QueryTest assert(join.count() == 1000000) } + test("SPARK-45022: exact DatasetQueryContext call site") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + val df = Seq(1).toDS() + var callSitePattern: String = null + checkError( + exception = intercept[AnalysisException] { + callSitePattern = getNextLineCallSitePattern() + val c = col("a") + df.select(c) + }, + errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + sqlState = "42703", + parameters = Map("objectName" -> "`a`", "proposal" -> "`value`"), + context = ExpectedContext(fragment = "col", callSitePattern = callSitePattern)) + } + } } class DatasetLargeResultCollectingSuite extends QueryTest diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala index 4117ea63bdd8..49811d8ac61b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala @@ -373,7 +373,8 @@ class DatasetUnpivotSuite extends QueryTest errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`1`", - "proposal" -> "`id`, `int1`, `str1`, `long1`, `str2`")) + "proposal" -> "`id`, `int1`, `str1`, `long1`, `str2`"), + context = ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern)) // unpivoting where value column does not exist val e2 = intercept[AnalysisException] { @@ -389,7 +390,8 @@ class DatasetUnpivotSuite extends QueryTest errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`does`", - "proposal" -> "`id`, `int1`, `long1`, `str1`, `str2`")) + "proposal" -> "`id`, `int1`, `long1`, `str1`, `str2`"), + context = ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern)) // unpivoting without values where potential value columns are of incompatible types val e3 = intercept[AnalysisException] { @@ -506,7 +508,8 @@ class DatasetUnpivotSuite extends QueryTest errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`an`.`id`", - "proposal" -> "`an.id`, `int1`, `long1`, `str.one`, `str.two`")) + "proposal" -> "`an.id`, `int1`, `long1`, `str.one`, `str.two`"), + context = ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern)) } test("unpivot with struct fields") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index 9b4ad7688186..2ab651237206 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -293,7 +293,10 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"array()\"", "inputType" -> "\"ARRAY\"", - "requiredType" -> "\"ARRAY\"") + "requiredType" -> "\"ARRAY\""), + context = ExpectedContext( + fragment = "inline", + callSitePattern = getCurrentClassCallSitePattern) ) } @@ -331,7 +334,10 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"array(struct(a), struct(b))\"", "functionName" -> "`array`", - "dataType" -> "(\"STRUCT\" or \"STRUCT\")")) + "dataType" -> "(\"STRUCT\" or \"STRUCT\")"), + context = ExpectedContext( + fragment = "array", + callSitePattern = getCurrentClassCallSitePattern)) checkAnswer( df.select(inline(array(struct(Symbol("a")), struct(Symbol("b").alias("a"))))), @@ -346,7 +352,10 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"array(struct(a), struct(2))\"", "functionName" -> "`array`", - "dataType" -> "(\"STRUCT\" or \"STRUCT\")")) + "dataType" -> "(\"STRUCT\" or \"STRUCT\")"), + context = ExpectedContext( + fragment = "array", + callSitePattern = getCurrentClassCallSitePattern)) checkAnswer( df.select(inline(array(struct(Symbol("a")), struct(lit(2).alias("a"))))), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 5effa2edf585..933f362db663 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -195,7 +195,9 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"json_tuple(a, 1)\"", "funcName" -> "`json_tuple`" - ) + ), + context = + ExpectedContext(fragment = "json_tuple", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -648,7 +650,9 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { errorClass = "DATATYPE_MISMATCH.INVALID_JSON_MAP_KEY_TYPE", parameters = Map( "schema" -> "\"MAP, STRING>\"", - "sqlExpr" -> "\"entries\"")) + "sqlExpr" -> "\"entries\""), + context = + ExpectedContext(fragment = "from_json", callSitePattern = getCurrentClassCallSitePattern)) } test("SPARK-24709: infers schemas of json strings and pass them to from_json") { @@ -958,7 +962,8 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { .select(from_json($"json", $"schema", options)).collect() }, errorClass = "INVALID_SCHEMA.NON_STRING_LITERAL", - parameters = Map("inputSchema" -> "\"schema\"") + parameters = Map("inputSchema" -> "\"schema\""), + context = ExpectedContext(fragment = "from_json", getCurrentClassCallSitePattern) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala index 2a24f0cc3996..afbe9cdac636 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala @@ -600,6 +600,10 @@ class ParametersSuite extends QueryTest with SharedSparkSession { array(str_to_map(Column(Literal("a:1,b:2,c:3"))))))) }, errorClass = "INVALID_SQL_ARG", - parameters = Map("name" -> "m")) + parameters = Map("name" -> "m"), + context = ExpectedContext( + fragment = "map_from_arrays", + callSitePattern = getCurrentClassCallSitePattern) + ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index b5ae9c7f3520..8668d6131740 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.util.TimeZone +import java.util.regex.Pattern import scala.jdk.CollectionConverters._ @@ -229,6 +230,17 @@ abstract class QueryTest extends PlanTest { assert(query.queryExecution.executedPlan.missingInput.isEmpty, s"The physical plan has missing inputs:\n${query.queryExecution.executedPlan}") } + + protected def getCurrentClassCallSitePattern: String = { + val cs = Thread.currentThread().getStackTrace()(2) + s"${cs.getClassName}\\..*\\(${cs.getFileName}:\\d+\\)" + } + + protected def getNextLineCallSitePattern(lines: Int = 1): String = { + val cs = Thread.currentThread().getStackTrace()(2) + Pattern.quote( + s"${cs.getClassName}.${cs.getMethodName}(${cs.getFileName}:${cs.getLineNumber + lines})") + } } object QueryTest extends Assertions { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 3612f4a7eda8..b7201c2d96d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1920,7 +1920,10 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark dfNoCols.select($"b.*") }, errorClass = "CANNOT_RESOLVE_STAR_EXPAND", - parameters = Map("targetString" -> "`b`", "columns" -> "")) + parameters = Map("targetString" -> "`b`", "columns" -> ""), + context = ExpectedContext( + fragment = "$", + callSitePattern = getCurrentClassCallSitePattern)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 179f40742c28..38a6b9a50272 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -879,7 +879,10 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "funcName" -> s"`$funcName`", "paramName" -> "`format`", - "paramType" -> "\"STRING\"")) + "paramType" -> "\"STRING\""), + context = ExpectedContext( + fragment = funcName, + callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = intercept[AnalysisException] { df2.select(func(col("input"), lit("invalid_format"))).collect() @@ -888,7 +891,10 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "parameter" -> "`format`", "functionName" -> s"`$funcName`", - "invalidFormat" -> "'invalid_format'")) + "invalidFormat" -> "'invalid_format'"), + context = ExpectedContext( + fragment = funcName, + callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = intercept[AnalysisException] { sql(s"select $funcName('a', 'b', 'c')") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala index 5c12ba307806..30a5bf709066 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala @@ -696,7 +696,9 @@ class QueryCompilationErrorsSuite Seq("""{"a":1}""").toDF("a").select(from_json($"a", IntegerType)).collect() }, errorClass = "DATATYPE_MISMATCH.INVALID_JSON_SCHEMA", - parameters = Map("schema" -> "\"INT\"", "sqlExpr" -> "\"from_json(a)\"")) + parameters = Map("schema" -> "\"INT\"", "sqlExpr" -> "\"from_json(a)\""), + context = + ExpectedContext(fragment = "from_json", callSitePattern = getCurrentClassCallSitePattern)) } test("WRONG_NUM_ARGS.WITHOUT_SUGGESTION: wrong args of CAST(parameter types contains DataType)") { @@ -767,7 +769,8 @@ class QueryCompilationErrorsSuite }, errorClass = "AMBIGUOUS_REFERENCE_TO_FIELDS", sqlState = "42000", - parameters = Map("field" -> "`firstname`", "count" -> "2") + parameters = Map("field" -> "`firstname`", "count" -> "2"), + context = ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -780,7 +783,9 @@ class QueryCompilationErrorsSuite }, errorClass = "INVALID_EXTRACT_BASE_FIELD_TYPE", sqlState = "42000", - parameters = Map("base" -> "\"firstname\"", "other" -> "\"STRING\"")) + parameters = Map("base" -> "\"firstname\"", "other" -> "\"STRING\""), + context = ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern) + ) } test("INVALID_EXTRACT_FIELD_TYPE: extract not string literal field") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala index ee28a90aed9a..eafa89e8e007 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.errors import org.apache.spark._ import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, CheckOverflowInTableInsert, ExpressionProxy, Literal, SubExprEvaluationRuntime} +import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.ByteType @@ -53,6 +55,26 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest sqlState = "22012", parameters = Map("config" -> ansiConf), context = ExpectedContext(fragment = "6/0", start = 7, stop = 9)) + + checkError( + exception = intercept[SparkArithmeticException] { + OneRowRelation().select(lit(5) / lit(0)).collect() + }, + errorClass = "DIVIDE_BY_ZERO", + sqlState = "22012", + parameters = Map("config" -> ansiConf), + context = ExpectedContext(fragment = "div", callSitePattern = getCurrentClassCallSitePattern)) + + checkError( + exception = intercept[SparkArithmeticException] { + OneRowRelation().select(lit(5).divide(lit(0))).collect() + }, + errorClass = "DIVIDE_BY_ZERO", + sqlState = "22012", + parameters = Map("config" -> ansiConf), + context = ExpectedContext( + fragment = "divide", + callSitePattern = getCurrentClassCallSitePattern)) } test("INTERVAL_DIVIDED_BY_ZERO: interval divided by zero") { @@ -92,6 +114,21 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest fragment = "CAST('66666666666666.666' AS DECIMAL(8, 1))", start = 7, stop = 49)) + + checkError( + exception = intercept[SparkArithmeticException] { + OneRowRelation().select(lit("66666666666666.666").cast("DECIMAL(8, 1)")).collect() + }, + errorClass = "NUMERIC_VALUE_OUT_OF_RANGE", + sqlState = "22003", + parameters = Map( + "value" -> "66666666666666.666", + "precision" -> "8", + "scale" -> "1", + "config" -> ansiConf), + context = ExpectedContext( + fragment = "cast", + callSitePattern = getCurrentClassCallSitePattern)) } test("INVALID_ARRAY_INDEX: get element from array") { @@ -102,6 +139,16 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest errorClass = "INVALID_ARRAY_INDEX", parameters = Map("indexValue" -> "8", "arraySize" -> "5", "ansiConfig" -> ansiConf), context = ExpectedContext(fragment = "array(1, 2, 3, 4, 5)[8]", start = 7, stop = 29)) + + checkError( + exception = intercept[SparkArrayIndexOutOfBoundsException] { + OneRowRelation().select(lit(Array(1, 2, 3, 4, 5))(8)).collect() + }, + errorClass = "INVALID_ARRAY_INDEX", + parameters = Map("indexValue" -> "8", "arraySize" -> "5", "ansiConfig" -> ansiConf), + context = ExpectedContext( + fragment = "apply", + callSitePattern = getCurrentClassCallSitePattern)) } test("INVALID_ARRAY_INDEX_IN_ELEMENT_AT: element_at from array") { @@ -115,6 +162,15 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest fragment = "element_at(array(1, 2, 3, 4, 5), 8)", start = 7, stop = 41)) + + checkError( + exception = intercept[SparkArrayIndexOutOfBoundsException] { + OneRowRelation().select(element_at(lit(Array(1, 2, 3, 4, 5)), 8)).collect() + }, + errorClass = "INVALID_ARRAY_INDEX_IN_ELEMENT_AT", + parameters = Map("indexValue" -> "8", "arraySize" -> "5", "ansiConfig" -> ansiConf), + context = + ExpectedContext(fragment = "element_at", callSitePattern = getCurrentClassCallSitePattern)) } test("INVALID_INDEX_OF_ZERO: element_at from array by index zero") { @@ -129,6 +185,15 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest start = 7, stop = 41) ) + + checkError( + exception = intercept[SparkRuntimeException]( + OneRowRelation().select(element_at(lit(Array(1, 2, 3, 4, 5)), 0)).collect() + ), + errorClass = "INVALID_INDEX_OF_ZERO", + parameters = Map.empty, + context = + ExpectedContext(fragment = "element_at", callSitePattern = getCurrentClassCallSitePattern)) } test("CAST_INVALID_INPUT: cast string to double") { @@ -146,6 +211,20 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest fragment = "CAST('111111111111xe23' AS DOUBLE)", start = 7, stop = 40)) + + checkError( + exception = intercept[SparkNumberFormatException] { + OneRowRelation().select(lit("111111111111xe23").cast("DOUBLE")).collect() + }, + errorClass = "CAST_INVALID_INPUT", + parameters = Map( + "expression" -> "'111111111111xe23'", + "sourceType" -> "\"STRING\"", + "targetType" -> "\"DOUBLE\"", + "ansiConfig" -> ansiConf), + context = ExpectedContext( + fragment = "cast", + callSitePattern = getCurrentClassCallSitePattern)) } test("CANNOT_PARSE_TIMESTAMP: parse string to timestamp") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 75e5d4d452e1..a7cab381c7f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -1206,7 +1206,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { }, errorClass = "DIVIDE_BY_ZERO", parameters = Map("config" -> "\"spark.sql.ansi.enabled\""), - context = new ExpectedContext( + context = ExpectedContext( objectType = "VIEW", objectName = s"$SESSION_CATALOG_NAME.default.v5", fragment = "1/0", @@ -1225,7 +1225,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { }, errorClass = "DIVIDE_BY_ZERO", parameters = Map("config" -> "\"spark.sql.ansi.enabled\""), - context = new ExpectedContext( + context = ExpectedContext( objectType = "VIEW", objectName = s"$SESSION_CATALOG_NAME.default.v1", fragment = "1/0", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala index 9f2d20229955..0e4985bac994 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala @@ -244,7 +244,9 @@ class FileMetadataStructSuite extends QueryTest with SharedSparkSession { df.select("name", METADATA_FILE_NAME).collect() }, errorClass = "FIELD_NOT_FOUND", - parameters = Map("fieldName" -> "`file_name`", "fields" -> "`id`, `university`")) + parameters = Map("fieldName" -> "`file_name`", "fields" -> "`id`, `university`"), + context = + ExpectedContext(fragment = "select", callSitePattern = getCurrentClassCallSitePattern)) } metadataColumnsTest("SPARK-42683: df metadataColumn - schema conflict", @@ -522,14 +524,20 @@ class FileMetadataStructSuite extends QueryTest with SharedSparkSession { df.select("name", "_metadata.file_name").collect() }, errorClass = "FIELD_NOT_FOUND", - parameters = Map("fieldName" -> "`file_name`", "fields" -> "`id`, `university`")) + parameters = Map("fieldName" -> "`file_name`", "fields" -> "`id`, `university`"), + context = ExpectedContext( + fragment = "select", + callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = intercept[AnalysisException] { df.select("name", "_METADATA.file_NAME").collect() }, errorClass = "FIELD_NOT_FOUND", - parameters = Map("fieldName" -> "`file_NAME`", "fields" -> "`id`, `university`")) + parameters = Map("fieldName" -> "`file_NAME`", "fields" -> "`id`, `university`"), + context = ExpectedContext( + fragment = "select", + callSitePattern = getCurrentClassCallSitePattern)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 111e88d57c78..a84aea278682 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -2697,7 +2697,9 @@ abstract class CSVSuite readback.filter($"AAA" === 2 && $"bbb" === 3).collect() }, errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", - parameters = Map("objectName" -> "`AAA`", "proposal" -> "`BBB`, `aaa`")) + parameters = Map("objectName" -> "`AAA`", "proposal" -> "`BBB`, `aaa`"), + context = + ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index f0561a30727b..2f8b0a323dc8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -3105,7 +3105,9 @@ abstract class JsonSuite readback.filter($"AAA" === 0 && $"bbb" === 1).collect() }, errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", - parameters = Map("objectName" -> "`AAA`", "proposal" -> "`BBB`, `aaa`")) + parameters = Map("objectName" -> "`AAA`", "proposal" -> "`BBB`, `aaa`"), + context = + ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern)) // Schema inferring val readback2 = spark.read.json(path.getCanonicalPath) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileMetadataStructRowIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileMetadataStructRowIndexSuite.scala index 2465dee230de..d3e9819b9a05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileMetadataStructRowIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileMetadataStructRowIndexSuite.scala @@ -133,7 +133,8 @@ class ParquetFileMetadataStructRowIndexSuite extends QueryTest with SharedSparkS parameters = Map( "fieldName" -> "`row_index`", "fields" -> ("`file_path`, `file_name`, `file_size`, " + - "`file_block_start`, `file_block_length`, `file_modification_time`"))) + "`file_block_start`, `file_block_length`, `file_modification_time`")), + context = ExpectedContext(fragment = "select", getCurrentClassCallSitePattern)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index baffc5088eb2..94535bc84a4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -1918,7 +1918,9 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { Seq("xyz").toDF().select("value", "default").write.insertInto("t") }, errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", - parameters = Map("objectName" -> "`default`", "proposal" -> "`value`")) + parameters = Map("objectName" -> "`default`", "proposal" -> "`value`"), + context = + ExpectedContext(fragment = "select", callSitePattern = getCurrentClassCallSitePattern)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 2174e91cb443..66d37e996a6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -713,7 +713,9 @@ class StreamSuite extends StreamTest { "columnName" -> "`rn_col`", "windowSpec" -> ("(PARTITION BY COL1 ORDER BY COL2 ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING " + - "AND CURRENT ROW)"))) + "AND CURRENT ROW)")), + queryContext = Array( + ExpectedContext(fragment = "withColumn", callSitePattern = getCurrentClassCallSitePattern))) }