From 359c547d23d31f3633510e7854cc53bb6793fadd Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Sun, 7 Apr 2024 17:09:40 +0800 Subject: [PATCH] fix --- .../connect/planner/SparkConnectPlanner.scala | 23 +++------ python/pyspark/sql/tests/test_group.py | 5 ++ .../spark/sql/RelationalGroupedDataset.scala | 47 ++++++++++--------- 3 files changed, 36 insertions(+), 39 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 40dc7f88255e..0813b0a57671 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -2452,25 +2452,13 @@ class SparkConnectPlanner( } val pivotExpr = transformExpression(rel.getPivot.getCol) - - var valueExprs = rel.getPivot.getValuesList.asScala.toSeq.map(transformLiteral) - if (valueExprs.isEmpty) { - // This is to prevent unintended OOM errors when the number of distinct values is large - val maxValues = session.sessionState.conf.dataFramePivotMaxValues - // Get the distinct values of the column and sort them so its consistent - val pivotCol = Column(pivotExpr) - valueExprs = Dataset - .ofRows(session, input) - .select(pivotCol) - .distinct() - .limit(maxValues + 1) - .sort(pivotCol) // ensure that the output columns are in a consistent logical order - .collect() - .map(_.get(0)) - .toImmutableArraySeq + val valueExprs = if (rel.getPivot.getValuesCount > 0) { + rel.getPivot.getValuesList.asScala.toSeq.map(transformLiteral) + } else { + RelationalGroupedDataset + .collectPivotValues(Dataset.ofRows(session, input), Column(pivotExpr)) .map(expressions.Literal.apply) } - logical.Pivot( groupByExprsOpt = Some(groupingExprs.map(toNamedExpression)), pivotColumn = pivotExpr, @@ -2489,6 +2477,7 @@ class SparkConnectPlanner( userGivenGroupByExprs = groupingExprs)), aggregateExpressions = aliasedAgg, child = input) + case other => throw InvalidPlanInput(s"Unknown Group Type $other") } } diff --git a/python/pyspark/sql/tests/test_group.py b/python/pyspark/sql/tests/test_group.py index 1a9b7d9d836c..958fc4e65dac 100644 --- a/python/pyspark/sql/tests/test_group.py +++ b/python/pyspark/sql/tests/test_group.py @@ -18,6 +18,7 @@ from pyspark.sql import Row from pyspark.sql import functions as sf +from pyspark.errors import AnalysisException from pyspark.testing.sqlutils import ( ReusedSQLTestCase, have_pandas, @@ -185,6 +186,10 @@ def test_order_by_ordinal(self): with self.assertRaises(IndexError): df.orderBy(-3) + def test_pivot_exceed_max_values(self): + with self.assertRaises(AnalysisException): + spark.range(100001).groupBy(sf.lit(1)).pivot("id").count().show() + class GroupTests(GroupTestsMixin, ReusedSQLTestCase): pass diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index f697eea4827e..0d66632a1c3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -423,28 +423,7 @@ class RelationalGroupedDataset protected[sql]( * @since 2.4.0 */ def pivot(pivotColumn: Column): RelationalGroupedDataset = { - if (df.isStreaming) { - throw new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_3063", - messageParameters = Map.empty) - } - // This is to prevent unintended OOM errors when the number of distinct values is large - val maxValues = df.sparkSession.sessionState.conf.dataFramePivotMaxValues - // Get the distinct values of the column and sort them so its consistent - val values = df.select(pivotColumn) - .distinct() - .limit(maxValues + 1) - .sort(pivotColumn) // ensure that the output columns are in a consistent logical order - .collect() - .map(_.get(0)) - .toImmutableArraySeq - - if (values.length > maxValues) { - throw QueryCompilationErrors.aggregationFunctionAppliedOnNonNumericColumnError( - pivotColumn.toString, maxValues) - } - - pivot(pivotColumn, values) + pivot(pivotColumn, collectPivotValues(df, pivotColumn)) } /** @@ -798,6 +777,30 @@ private[sql] object RelationalGroupedDataset { case expr: Expression => Alias(expr, toPrettySQL(expr))() } + private[sql] def collectPivotValues(df: DataFrame, pivotColumn: Column): Seq[Any] = { + if (df.isStreaming) { + throw new AnalysisException( + errorClass = "_LEGACY_ERROR_TEMP_3063", + messageParameters = Map.empty) + } + // This is to prevent unintended OOM errors when the number of distinct values is large + val maxValues = df.sparkSession.sessionState.conf.dataFramePivotMaxValues + // Get the distinct values of the column and sort them so its consistent + val values = df.select(pivotColumn) + .distinct() + .limit(maxValues + 1) + .sort(pivotColumn) // ensure that the output columns are in a consistent logical order + .collect() + .map(_.get(0)) + .toImmutableArraySeq + + if (values.length > maxValues) { + throw QueryCompilationErrors.aggregationFunctionAppliedOnNonNumericColumnError( + pivotColumn.toString, maxValues) + } + values + } + /** * The Grouping Type */