diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 86c46e072c88..9e1bc9786a57 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -811,9 +811,12 @@ object CollapseRepartition extends Rule[LogicalPlan] { */ object OptimizeWindowFunctions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { - case we @ WindowExpression(AggregateExpression(first: First, _, _, _, _), spec) - if spec.orderSpec.nonEmpty && - spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame].frameType == RowFrame => + case we @ WindowExpression(AggregateExpression(first: First, _, _, _, _), + WindowSpecDefinition(_, orderSpec, frameSpecification: SpecifiedWindowFrame)) + if orderSpec.nonEmpty && frameSpecification.frameType == RowFrame && + frameSpecification.lower == UnboundedPreceding && + (frameSpecification.upper == UnboundedFollowing || + frameSpecification.upper == CurrentRow) => we.copy(windowFunction = NthValue(first.child, Literal(1), first.ignoreNulls)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWindowFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWindowFunctionsSuite.scala index 389aaeafe655..cf850bbe21ce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWindowFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWindowFunctionsSuite.scala @@ -36,7 +36,7 @@ class OptimizeWindowFunctionsSuite extends PlanTest { val b = testRelation.output(1) val c = testRelation.output(2) - test("replace first(col) by nth_value(col, 1)") { + test("replace first by nth_value if frame is UNBOUNDED PRECEDING AND CURRENT ROW") { val inputPlan = testRelation.select( WindowExpression( First(a, false).toAggregateExpression(), @@ -52,7 +52,34 @@ class OptimizeWindowFunctionsSuite extends PlanTest { assert(optimized == correctAnswer) } - test("can't replace first(col) by nth_value(col, 1) if the window frame type is range") { + test("replace first by nth_value if frame is UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING") { + val inputPlan = testRelation.select( + WindowExpression( + First(a, false).toAggregateExpression(), + WindowSpecDefinition(b :: Nil, c.asc :: Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing)))) + val correctAnswer = testRelation.select( + WindowExpression( + NthValue(a, Literal(1), false), + WindowSpecDefinition(b :: Nil, c.asc :: Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing)))) + + val optimized = Optimize.execute(inputPlan) + assert(optimized == correctAnswer) + } + + test("can't replace first by nth_value if frame is not suitable") { + val inputPlan = testRelation.select( + WindowExpression( + First(a, false).toAggregateExpression(), + WindowSpecDefinition(b :: Nil, c.asc :: Nil, + SpecifiedWindowFrame(RowFrame, Literal(1), CurrentRow)))) + + val optimized = Optimize.execute(inputPlan) + assert(optimized == inputPlan) + } + + test("can't replace first by nth_value if the window frame type is range") { val inputPlan = testRelation.select( WindowExpression( First(a, false).toAggregateExpression(), @@ -63,7 +90,7 @@ class OptimizeWindowFunctionsSuite extends PlanTest { assert(optimized == inputPlan) } - test("can't replace first(col) by nth_value(col, 1) if the window frame isn't ordered") { + test("can't replace first by nth_value if the window frame isn't ordered") { val inputPlan = testRelation.select( WindowExpression( First(a, false).toAggregateExpression(),