Skip to content
Prev Previous commit
Next Next commit
Resolve conflicts
  • Loading branch information
ptkool committed Sep 6, 2018
commit 94e8115d64db4942012869df129d23eda3faa06b
Original file line number Diff line number Diff line change
Expand Up @@ -726,8 +726,11 @@ object CollapseRepartition extends Rule[LogicalPlan] {
object CollapseWindow extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case w1 @ Window(we1, ps1, os1, w2 @ Window(we2, ps2, os2, grandChild))
if ps1 == ps2 && os1 == os2 && w1.references.intersect(w2.windowOutputSet).isEmpty =>
w1.copy(windowExpressions = we2 ++ we1, child = grandChild)
if ps1 == ps2 && os1 == os2 && w1.references.intersect(w2.windowOutputSet).isEmpty &&
// This assumes Window contains the same type of window expressions. This is ensured
// by ExtractWindowFunctions.
WindowFunctionType.functionType(we1.head) == WindowFunctionType.functionType(we2.head) =>
w1.copy(windowExpressions = we2 ++ we1, child = grandChild)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.sql.types._
* Window function testing for DataFrame API.
*/
class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {

import testImplicits._

test("reuse window partitionBy") {
Expand Down Expand Up @@ -72,9 +73,9 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
cume_dist().over(Window.partitionBy("value").orderBy("key")),
percent_rank().over(Window.partitionBy("value").orderBy("key"))),
Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d, 0.0d) ::
Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d / 3.0d, 0.0d) ::
Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 1, 2, 2, 2, 1.0d, 0.5d) ::
Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 2, 3, 2, 2, 1.0d, 0.5d) :: Nil)
Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d / 3.0d, 0.0d) ::
Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 1, 2, 2, 2, 1.0d, 0.5d) ::
Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 2, 3, 2, 2, 1.0d, 0.5d) :: Nil)
}

test("window function should fail if order by clause is not specified") {
Expand Down Expand Up @@ -162,12 +163,12 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
Seq(
Row("a", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755),
Row("b", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755),
Row("c", 0.0, 0.0, 0.0, 0.0, 0.0 ),
Row("d", 0.0, 0.0, 0.0, 0.0, 0.0 ),
Row("e", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
Row("f", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
Row("g", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
Row("h", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
Row("c", 0.0, 0.0, 0.0, 0.0, 0.0),
Row("d", 0.0, 0.0, 0.0, 0.0, 0.0),
Row("e", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544),
Row("f", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544),
Row("g", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544),
Row("h", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544),
Row("i", Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN)))
}

Expand Down Expand Up @@ -326,7 +327,7 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
var_samp($"value").over(window),
approx_count_distinct($"value").over(window)),
Seq.fill(4)(Row("a", 1.0d / 4.0d, 1.0d / 3.0d, 2))
++ Seq.fill(3)(Row("b", 2.0d / 3.0d, 1.0d, 3)))
++ Seq.fill(3)(Row("b", 2.0d / 3.0d, 1.0d, 3)))
}

test("window function with aggregates") {
Expand Down Expand Up @@ -622,6 +623,43 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
}
}

test("SPARK-24575: Window functions inside WHERE and HAVING clauses") {
def checkAnalysisError(df: => DataFrame): Unit = {
val thrownException = the[AnalysisException] thrownBy {
df.queryExecution.analyzed
}
assert(thrownException.message.contains("window functions inside WHERE and HAVING clauses"))
}

checkAnalysisError(testData2.select('a).where(rank().over(Window.orderBy('b)) === 1))
checkAnalysisError(testData2.where('b === 2 && rank().over(Window.orderBy('b)) === 1))
checkAnalysisError(
testData2.groupBy('a)
.agg(avg('b).as("avgb"))
.where('a > 'avgb && rank().over(Window.orderBy('a)) === 1))
checkAnalysisError(
testData2.groupBy('a)
.agg(max('b).as("maxb"), sum('b).as("sumb"))
.where(rank().over(Window.orderBy('a)) === 1))
checkAnalysisError(
testData2.groupBy('a)
.agg(max('b).as("maxb"), sum('b).as("sumb"))
.where('sumb === 5 && rank().over(Window.orderBy('a)) === 1))

checkAnalysisError(sql("SELECT a FROM testData2 WHERE RANK() OVER(ORDER BY b) = 1"))
checkAnalysisError(sql("SELECT * FROM testData2 WHERE b = 2 AND RANK() OVER(ORDER BY b) = 1"))
checkAnalysisError(
sql("SELECT * FROM testData2 GROUP BY a HAVING a > AVG(b) AND RANK() OVER(ORDER BY a) = 1"))
checkAnalysisError(
sql("SELECT a, MAX(b), SUM(b) FROM testData2 GROUP BY a HAVING RANK() OVER(ORDER BY a) = 1"))
checkAnalysisError(
sql(
s"""SELECT a, MAX(b)
|FROM testData2
|GROUP BY a
|HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin))
}

test("window functions in multiple selects") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this test? It does not really add anything new.

Copy link
Contributor Author

@ptkool ptkool May 18, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I'll remove it.

val df = Seq(
("S1", "P1", 100),
Expand All @@ -641,5 +679,6 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
Row("S1", "P1", 700, 800, 800),
Row("S2", "P1", 200, 200, 500),
Row("S2", "P2", 300, 300, 500)))

}
}
You are viewing a condensed version of this merge commit. You can view the full changes here.