Skip to content

Commit f6888f7

Browse files
committed
[SPARK-20636] Add the rule TransposeWindow to the optimization batch
## What changes were proposed in this pull request? This PR is a follow-up of the PR apache#17899. It is to add the rule TransposeWindow the optimizer batch. ## How was this patch tested? The existing tests. Closes apache#23222 from gatorsmile/followupSPARK-20636. Authored-by: gatorsmile <[email protected]> Signed-off-by: gatorsmile <[email protected]>
1 parent 5960a82 commit f6888f7

File tree

2 files changed

+27
-12
lines changed

2 files changed

+27
-12
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
7373
CombineLimits,
7474
CombineUnions,
7575
// Constant folding and strength reduction
76+
TransposeWindow,
7677
NullPropagation,
7778
ConstantPropagation,
7879
FoldablePropagation,

sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ package org.apache.spark.sql
2020
import org.scalatest.Matchers.the
2121

2222
import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled}
23+
import org.apache.spark.sql.catalyst.optimizer.TransposeWindow
24+
import org.apache.spark.sql.execution.exchange.Exchange
2325
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window}
2426
import org.apache.spark.sql.functions._
2527
import org.apache.spark.sql.internal.SQLConf
@@ -668,18 +670,30 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
668670
("S2", "P2", 300)
669671
).toDF("sno", "pno", "qty")
670672

671-
val w1 = Window.partitionBy("sno")
672-
val w2 = Window.partitionBy("sno", "pno")
673-
674-
checkAnswer(
675-
df.select($"sno", $"pno", $"qty", sum($"qty").over(w2).alias("sum_qty_2"))
676-
.select($"sno", $"pno", $"qty", col("sum_qty_2"), sum("qty").over(w1).alias("sum_qty_1")),
677-
Seq(
678-
Row("S1", "P1", 100, 800, 800),
679-
Row("S1", "P1", 700, 800, 800),
680-
Row("S2", "P1", 200, 200, 500),
681-
Row("S2", "P2", 300, 300, 500)))
682-
673+
Seq(true, false).foreach { transposeWindowEnabled =>
674+
val excludedRules = if (transposeWindowEnabled) "" else TransposeWindow.ruleName
675+
withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> excludedRules) {
676+
val w1 = Window.partitionBy("sno")
677+
val w2 = Window.partitionBy("sno", "pno")
678+
679+
val select = df.select($"sno", $"pno", $"qty", sum($"qty").over(w2).alias("sum_qty_2"))
680+
.select($"sno", $"pno", $"qty", col("sum_qty_2"), sum("qty").over(w1).alias("sum_qty_1"))
681+
682+
val expectedNumExchanges = if (transposeWindowEnabled) 1 else 2
683+
val actualNumExchanges = select.queryExecution.executedPlan.collect {
684+
case e: Exchange => e
685+
}.length
686+
assert(actualNumExchanges == expectedNumExchanges)
687+
688+
checkAnswer(
689+
select,
690+
Seq(
691+
Row("S1", "P1", 100, 800, 800),
692+
Row("S1", "P1", 700, 800, 800),
693+
Row("S2", "P1", 200, 200, 500),
694+
Row("S2", "P2", 300, 300, 500)))
695+
}
696+
}
683697
}
684698

685699
test("NaN and -0.0 in window partition keys") {

0 commit comments

Comments
 (0)