diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f615757a837a..3eb6bca6ec97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -73,6 +73,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) CombineLimits, CombineUnions, // Constant folding and strength reduction + TransposeWindow, NullPropagation, ConstantPropagation, FoldablePropagation, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 9a5d5a9966ab..9277dc685924 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql import org.scalatest.Matchers.the import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} +import org.apache.spark.sql.catalyst.optimizer.TransposeWindow +import org.apache.spark.sql.execution.exchange.Exchange import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -668,18 +670,30 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { ("S2", "P2", 300) ).toDF("sno", "pno", "qty") - val w1 = Window.partitionBy("sno") - val w2 = Window.partitionBy("sno", "pno") - - checkAnswer( - df.select($"sno", $"pno", $"qty", sum($"qty").over(w2).alias("sum_qty_2")) - .select($"sno", $"pno", $"qty", col("sum_qty_2"), sum("qty").over(w1).alias("sum_qty_1")), - Seq( - Row("S1", "P1", 100, 800, 800), - Row("S1", "P1", 700, 800, 800), - Row("S2", "P1", 200, 200, 500), - Row("S2", "P2", 300, 300, 500))) - + Seq(true, false).foreach { transposeWindowEnabled => + val excludedRules = if (transposeWindowEnabled) "" else TransposeWindow.ruleName + withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> excludedRules) { + val w1 = Window.partitionBy("sno") + val w2 = Window.partitionBy("sno", "pno") + + val select = df.select($"sno", $"pno", $"qty", sum($"qty").over(w2).alias("sum_qty_2")) + .select($"sno", $"pno", $"qty", col("sum_qty_2"), sum("qty").over(w1).alias("sum_qty_1")) + + val expectedNumExchanges = if (transposeWindowEnabled) 1 else 2 + val actualNumExchanges = select.queryExecution.executedPlan.collect { + case e: Exchange => e + }.length + assert(actualNumExchanges == expectedNumExchanges) + + checkAnswer( + select, + Seq( + Row("S1", "P1", 100, 800, 800), + Row("S1", "P1", 700, 800, 800), + Row("S2", "P1", 200, 200, 500), + Row("S2", "P2", 300, 300, 500))) + } + } } test("NaN and -0.0 in window partition keys") {