Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1744,11 +1744,14 @@ class Analyzer(
* it into the plan tree.
*/
object ExtractWindowExpressions extends Rule[LogicalPlan] {
private def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean =
projectList.exists(hasWindowFunction)
private def hasWindowFunction(exprs: Seq[Expression]): Boolean =
exprs.exists(hasWindowFunction)

private def hasWindowFunction(expr: NamedExpression): Boolean = {
private def hasWindowFunction(expr: Expression): Boolean = {
expr.find {
case AggregateExpression(aggFunc, _, _, _) if hasWindowFunction(aggFunc.children) =>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some doubts that this is the best place for this check. StackOverflow happens in extract. We can also define a separate method and call it inside extract. However, that method will share the same structure as hasWindowFunction.

Copy link
Contributor

@cloud-fan cloud-fan Jun 1, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's weird to throw exception inside a boolean function hasXXX.

Can we do this check in https://github.com/apache/spark/pull/21473/files#diff-57b3d87be744b7d79a9beacf8e5e5eb2R1810 ? i.e. adding a new case

case agg: AggregateExpression if hasWindowFunction(agg) => fail

failAnalysis("It is not allowed to use a window function inside an aggregate function. " +
"Please use the inner window function in a sub-query.")
case window: WindowExpression => true
case _ => false
}.isDefined
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ package org.apache.spark.sql

import scala.util.Random

import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.scalatest.Matchers.the

import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
Expand Down Expand Up @@ -687,4 +687,29 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
}
}
}

test("SPARK-21896: Window functions inside aggregate functions") {
def checkWindowError(df: => DataFrame): Unit = {
val thrownException = the [AnalysisException] thrownBy {
df.queryExecution.analyzed
}
assert(thrownException.message.contains("not allowed to use a window function"))
}

checkWindowError(testData2.select(min(avg('b).over(Window.partitionBy('a)))))
checkWindowError(testData2.agg(sum('b), max(rank().over(Window.orderBy('a)))))
checkWindowError(testData2.groupBy('a).agg(sum('b), max(rank().over(Window.orderBy('b)))))
checkWindowError(testData2.groupBy('a).agg(max(sum(sum('b)).over(Window.orderBy('b)))))

checkWindowError(
sql("SELECT MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a HAVING SUM(b) = 3"))
checkWindowError(
sql("SELECT MAX(RANK() OVER(ORDER BY a)) FROM testData2"))
checkWindowError(
sql("SELECT MAX(RANK() OVER(ORDER BY b)) FROM testData2 GROUP BY a"))
checkAnswer(
sql("SELECT a, MAX(b), RANK() OVER(ORDER BY a) FROM testData2 GROUP BY a HAVING SUM(b) = 3"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the dataset version should be

df.groupBy('a).agg(max('b), sum('b).as("sumb"), rank().over(window)).where('sumb === 5)

Row(1, 2, 1) :: Row(2, 2, 2) :: Row(3, 2, 3) :: Nil)
}

}