diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 00bde9f8c1f7..b4168820ffe5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -24,13 +24,13 @@ import scala.language.implicitConversions import org.apache.spark.api.java.function.FilterFunction import org.apache.spark.sql.Encoder -import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.types._ @@ -442,14 +442,35 @@ package object dsl extends SQLConfHelper { otherPlan) } - def orderBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, true, logicalPlan) + def orderBy(sortExprs: SortOrder*): LogicalPlan = { + val sortExpressionsWithOrdinals = sortExprs.map(replaceOrdinalsInSortOrder) + Sort(sortExpressionsWithOrdinals, true, logicalPlan) + } + + def sortBy(sortExprs: SortOrder*): LogicalPlan = { + val sortExpressionsWithOrdinals = sortExprs.map(replaceOrdinalsInSortOrder) + Sort(sortExpressionsWithOrdinals, false, logicalPlan) + } - def sortBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, false, logicalPlan) + /** + * Replaces top-level integer literals from [[SortOrder]] with [[UnresolvedOrdinal]], if + * `orderByOrdinal` is enabled. + */ + private def replaceOrdinalsInSortOrder(sortOrder: SortOrder): SortOrder = sortOrder match { + case sortOrderByOrdinal @ SortOrder(literal @ Literal(value: Int, IntegerType), _, _, _) + if conf.orderByOrdinal => + val ordinal = CurrentOrigin.withOrigin(literal.origin) { UnresolvedOrdinal(value) } + sortOrderByOrdinal + .withNewChildren(newChildren = Seq(ordinal)) + .asInstanceOf[SortOrder] + case other => other + } def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): LogicalPlan = { // Replace top-level integer literals with ordinals, if `groupByOrdinal` is enabled. val groupingExpressionsWithOrdinals = groupingExprs.map { - case Literal(value: Int, IntegerType) if conf.groupByOrdinal => UnresolvedOrdinal(value) + case literal @ Literal(value: Int, IntegerType) if conf.groupByOrdinal => + CurrentOrigin.withOrigin(literal.origin) { UnresolvedOrdinal(value) } case other => other } val aliasedExprs = aggregateExprs.map { diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 911d79ecdb12..e9057e389367 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -4108,7 +4108,8 @@ class SparkConnectPlanner( */ private def replaceIntegerLiteralWithOrdinal(groupingExpression: Expression) = groupingExpression match { - case Literal(value: Int, IntegerType) => UnresolvedOrdinal(value) + case literal @ Literal(value: Int, IntegerType) => + CurrentOrigin.withOrigin(literal.origin) { UnresolvedOrdinal(value) } case other => other } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala index 5c3ebb32b36a..8327d8181619 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala @@ -49,7 +49,7 @@ import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions} import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.{TreeNodeTag, TreePattern} +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNodeTag, TreePattern} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, IntervalUtils} import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId @@ -932,8 +932,9 @@ class Dataset[T] private[sql]( // Replace top-level integer literals in grouping expressions with ordinals, if // `groupByOrdinal` is enabled. val groupingExpressionsWithOrdinals = cols.map { col => col.expr match { - case Literal(value: Int, IntegerType) if sparkSession.sessionState.conf.groupByOrdinal => - UnresolvedOrdinal(value) + case literal @ Literal(value: Int, IntegerType) + if sparkSession.sessionState.conf.groupByOrdinal => + CurrentOrigin.withOrigin(literal.origin) { UnresolvedOrdinal(value) } case other => other }} RelationalGroupedDataset( @@ -2246,8 +2247,20 @@ class Dataset[T] private[sql]( protected def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = { val sortOrder: Seq[SortOrder] = sortExprs.map { col => col.expr match { + case sortOrderWithOrdinal @ SortOrder(literal @ Literal(value: Int, IntegerType), _, _, _) + if sparkSession.sessionState.conf.orderByOrdinal => + // Replace top-level integer literals with UnresolvedOrdinal, if `orderByOrdinal` is + // enabled. + val ordinal = CurrentOrigin.withOrigin(literal.origin) { UnresolvedOrdinal(value) } + sortOrderWithOrdinal.withNewChildren(newChildren = Seq(ordinal)).asInstanceOf[SortOrder] case expr: SortOrder => expr + case literal @ Literal(value: Int, IntegerType) + if sparkSession.sessionState.conf.orderByOrdinal => + // Replace top-level integer literals with UnresolvedOrdinal, if `orderByOrdinal` is + // enabled. + val ordinal = CurrentOrigin.withOrigin(literal.origin) { UnresolvedOrdinal(value) } + SortOrder(ordinal, Ascending) case expr: Expression => SortOrder(expr, Ascending) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ReplaceIntegerLiteralsWithOrdinalsDataframeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceIntegerLiteralsWithOrdinalsDataframeSuite.scala new file mode 100644 index 000000000000..1cfb53052656 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceIntegerLiteralsWithOrdinalsDataframeSuite.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.analysis.UnresolvedOrdinal +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +class ReplaceIntegerLiteralsWithOrdinalsDataframeSuite extends QueryTest with SharedSparkSession { + + test("Group by ordinal - Dataframe") { + val query = "SELECT * FROM VALUES(1,2),(1,3),(2,4)" + + withSQLConf(SQLConf.GROUP_BY_ORDINAL.key -> "true") { + val groupedDataset = sql(query).groupBy(lit(1)) + + assert(groupedDataset.groupingExprs.collect { + case ordinal @ UnresolvedOrdinal(1) => ordinal + }.nonEmpty) + + checkError( + exception = intercept[AnalysisException](sql(query).groupBy(lit(-1)).count()), + condition = "GROUP_BY_POS_OUT_OF_RANGE", + parameters = Map("index" -> "-1", "size" -> "2"), + context = ExpectedContext(fragment = "lit", getCurrentClassCallSitePattern) + ) + } + + withSQLConf(SQLConf.GROUP_BY_ORDINAL.key -> "false") { + val groupedDataset = sql(query).groupBy(lit(1)) + + assert(groupedDataset.groupingExprs.collect { + case ordinal @ UnresolvedOrdinal(1) => ordinal + }.isEmpty) + } + } + + test("Order/Sort by ordinal - Dataframe") { + val sqlText = "SELECT * FROM VALUES(2,1),(1,2)" + + withSQLConf(SQLConf.ORDER_BY_ORDINAL.key -> "true") { + val queries = Seq( + sql(sqlText).orderBy(lit(1)), + sql(sqlText).sort(lit(1)) + ) + + for (query <- queries) { + val unresolvedPlan = query.queryExecution.logical + val resolvedPlan = query.queryExecution.analyzed + + assert(unresolvedPlan.expressions.collect { + case ordinal @ SortOrder(UnresolvedOrdinal(1), _, _, _) => ordinal + }.nonEmpty) + + assert(resolvedPlan.expressions.collect { + case ordinal @ SortOrder(UnresolvedOrdinal(1), _, _, _) => ordinal + }.isEmpty) + + checkAnswer(query, Row(1, 2) :: Row(2, 1) :: Nil) + } + + checkError( + exception = intercept[AnalysisException](sql(sqlText).orderBy(lit(-1))), + condition = "ORDER_BY_POS_OUT_OF_RANGE", + parameters = Map("index" -> "-1", "size" -> "2") + ) + + checkError( + exception = intercept[AnalysisException](sql(sqlText).sort(lit(-1))), + condition = "ORDER_BY_POS_OUT_OF_RANGE", + parameters = Map("index" -> "-1", "size" -> "2") + ) + } + + withSQLConf(SQLConf.ORDER_BY_ORDINAL.key -> "false") { + val queries = Seq( + sql(sqlText).orderBy(lit(1)), + sql(sqlText).sort(lit(1)) + ) + + for (query <- queries) { + val unresolvedPlan = query.queryExecution.logical + val resolvedPlan = query.queryExecution.analyzed + + assert(unresolvedPlan.expressions.collect { + case ordinal @ SortOrder(UnresolvedOrdinal(1), _, _, _) => ordinal + }.isEmpty) + + assert(resolvedPlan.expressions.collect { + case ordinal @ SortOrder(UnresolvedOrdinal(1), _, _, _) => ordinal + }.isEmpty) + + checkAnswer(query, Row(2, 1) :: Row(1, 2) :: Nil) + } + + checkAnswer(sql(sqlText).orderBy(lit(-1)), Row(2, 1) :: Row(1, 2) :: Nil) + + checkAnswer(sql(sqlText).sort(lit(-1)), Row(2, 1) :: Row(1, 2) :: Nil) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ReplaceIntegerLiteralsWithOrdinalsSqlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceIntegerLiteralsWithOrdinalsSqlSuite.scala new file mode 100644 index 000000000000..ee8a07bd3bcc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceIntegerLiteralsWithOrdinalsSqlSuite.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.analysis.UnresolvedOrdinal +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +class ReplaceIntegerLiteralsWithOrdinalsSqlSuite extends QueryTest with SharedSparkSession { + + test("Group by ordinal - SQL") { + val correctSqlText = "SELECT col1, max(col2) FROM VALUES(1,2),(1,3),(2,4) GROUP BY 1" + val groupByPosOutOfRangeSqlText = + "SELECT col1, max(col2) FROM VALUES(1,2),(1,3),(2,4) GROUP BY -1" + + withSQLConf(SQLConf.GROUP_BY_ORDINAL.key -> "true") { + val query = sql(correctSqlText) + val parsedPlan = query.queryExecution.logical + val analyzedPlan = query.queryExecution.analyzed + + assert(parsedPlan.expressions.collect { + case ordinal @ UnresolvedOrdinal(1) => ordinal + }.nonEmpty) + + assert(analyzedPlan.expressions.collect { + case ordinal @ UnresolvedOrdinal(1) => ordinal + }.isEmpty) + + checkAnswer(query, Row(1, 3) :: Row(2, 4) :: Nil) + + checkError( + exception = intercept[AnalysisException](sql(groupByPosOutOfRangeSqlText)), + condition = "GROUP_BY_POS_OUT_OF_RANGE", + parameters = Map("index" -> "-1", "size" -> "2"), + queryContext = Array(ExpectedContext(fragment = "-1", start = 61, stop = 62)) + ) + } + + withSQLConf(SQLConf.GROUP_BY_ORDINAL.key -> "false") { + val parsedPlan = spark.sessionState.sqlParser.parsePlan(correctSqlText) + + assert(parsedPlan.expressions.collect { + case ordinal @ UnresolvedOrdinal(1) => ordinal + }.isEmpty) + + checkError( + exception = intercept[AnalysisException](sql(groupByPosOutOfRangeSqlText)), + condition = "MISSING_AGGREGATION", + parameters = Map("expression" -> "\"col1\"", "expressionAnyValue" -> "\"any_value(col1)\"") + ) + } + } + + test("Order by ordinal - SQL") { + val correctSqlText = "SELECT col1 FROM VALUES(2,1),(1,2) ORDER BY 1" + val orderByPosOutOfRangeSqlText = "SELECT col1 FROM VALUES(2,1),(1,2) ORDER BY -1" + + withSQLConf(SQLConf.ORDER_BY_ORDINAL.key -> "true") { + val query = sql(correctSqlText) + val parsedPlan = query.queryExecution.logical + val analyzedPlan = query.queryExecution.analyzed + + assert(parsedPlan.expressions.collect { + case ordinal @ SortOrder(UnresolvedOrdinal(1), _, _, _) => ordinal + }.nonEmpty) + + assert(analyzedPlan.expressions.collect { + case ordinal @ SortOrder(UnresolvedOrdinal(1), _, _, _) => ordinal + }.isEmpty) + + checkAnswer(query, Row(1) :: Row(2) :: Nil) + + checkError( + exception = intercept[AnalysisException](sql(orderByPosOutOfRangeSqlText)), + condition = "ORDER_BY_POS_OUT_OF_RANGE", + parameters = Map("index" -> "-1", "size" -> "1"), + queryContext = Array(ExpectedContext(fragment = "-1", start = 44, stop = 45)) + ) + } + + withSQLConf(SQLConf.ORDER_BY_ORDINAL.key -> "false") { + val query = sql(correctSqlText) + val parsedPlan = query.queryExecution.logical + val analyzedPlan = query.queryExecution.analyzed + + assert(parsedPlan.expressions.collect { + case ordinal @ SortOrder(UnresolvedOrdinal(1), _, _, _) => ordinal + }.isEmpty) + + assert(analyzedPlan.expressions.collect { + case ordinal @ SortOrder(UnresolvedOrdinal(1), _, _, _) => ordinal + }.isEmpty) + + checkAnswer(query, Row(2) :: Row(1) :: Nil) + + checkAnswer(sql(orderByPosOutOfRangeSqlText), Row(2) :: Row(1) :: Nil) + } + } +}