Skip to content
Closed
10 changes: 9 additions & 1 deletion sql/core/src/test/resources/sql-tests/inputs/window.sql
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,15 @@ ntile(2) OVER w AS ntile,
row_number() OVER w AS row_number,
var_pop(val) OVER w AS var_pop,
var_samp(val) OVER w AS var_samp,
approx_count_distinct(val) OVER w AS approx_count_distinct
approx_count_distinct(val) OVER w AS approx_count_distinct,
covar_pop(val, val_long) OVER w AS covar_pop,
corr(val, val_long) OVER w AS corr,
stddev_samp(val) OVER w AS stddev_samp,
stddev_pop(val) OVER w AS stddev_pop,
collect_list(val) OVER w AS collect_list,
collect_set(val) OVER w AS collect_set,
skewness(val_double) OVER w AS skewness,
kurtosis(val_double) OVER w AS kurtosis
FROM testData
WINDOW w AS (PARTITION BY cate ORDER BY val)
ORDER BY cate, val;
Expand Down
30 changes: 19 additions & 11 deletions sql/core/src/test/resources/sql-tests/results/window.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -273,22 +273,30 @@ ntile(2) OVER w AS ntile,
row_number() OVER w AS row_number,
var_pop(val) OVER w AS var_pop,
var_samp(val) OVER w AS var_samp,
approx_count_distinct(val) OVER w AS approx_count_distinct
approx_count_distinct(val) OVER w AS approx_count_distinct,
covar_pop(val, val_long) OVER w AS covar_pop,
corr(val, val_long) OVER w AS corr,
stddev_samp(val) OVER w AS stddev_samp,
stddev_pop(val) OVER w AS stddev_pop,
collect_list(val) OVER w AS collect_list,
collect_set(val) OVER w AS collect_set,
skewness(val_double) OVER w AS skewness,
kurtosis(val_double) OVER w AS kurtosis
FROM testData
WINDOW w AS (PARTITION BY cate ORDER BY val)
ORDER BY cate, val
-- !query 17 schema
struct<val:int,cate:string,max:int,min:int,min:int,count:bigint,sum:bigint,avg:double,stddev:double,first_value:int,first_value_ignore_null:int,first_value_contain_null:int,last_value:int,last_value_ignore_null:int,last_value_contain_null:int,rank:int,dense_rank:int,cume_dist:double,percent_rank:double,ntile:int,row_number:int,var_pop:double,var_samp:double,approx_count_distinct:bigint>
struct<val:int,cate:string,max:int,min:int,min:int,count:bigint,sum:bigint,avg:double,stddev:double,first_value:int,first_value_ignore_null:int,first_value_contain_null:int,last_value:int,last_value_ignore_null:int,last_value_contain_null:int,rank:int,dense_rank:int,cume_dist:double,percent_rank:double,ntile:int,row_number:int,var_pop:double,var_samp:double,approx_count_distinct:bigint,covar_pop:double,corr:double,stddev_samp:double,stddev_pop:double,collect_list:array<int>,collect_set:array<int>,skewness:double,kurtosis:double>
-- !query 17 output
NULL NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.5 0.0 1 1 NULL NULL 0
3 NULL 3 3 3 1 3 3.0 NaN NULL 3 NULL 3 3 3 2 2 1.0 1.0 2 2 0.0 NaN 1
NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0.0 1 1 NULL NULL 0
1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 1 2 0.0 0.0 1
1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 2 3 0.0 0.0 1
2 a 2 1 1 3 4 1.3333333333333333 0.5773502691896258 NULL 1 NULL 2 2 2 4 3 1.0 1.0 2 4 0.22222222222222224 0.33333333333333337 2
1 b 1 1 1 1 1 1.0 NaN 1 1 1 1 1 1 1 1 0.3333333333333333 0.0 1 1 0.0 NaN 1
2 b 2 1 1 2 3 1.5 0.7071067811865476 1 1 1 2 2 2 2 2 0.6666666666666666 0.5 1 2 0.25 0.5 2
3 b 3 1 1 3 6 2.0 1.0 1 1 1 3 3 3 3 3 1.0 1.0 2 3 0.6666666666666666 1.0 3
NULL NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.5 0.0 1 1 NULL NULL 0 NULL NULL NULL NULL [] [] NULL NULL
3 NULL 3 3 3 1 3 3.0 NaN NULL 3 NULL 3 3 3 2 2 1.0 1.0 2 2 0.0 NaN 1 0.0 NaN NaN 0.0 [3] [3] NaN NaN
NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0.0 1 1 NULL NULL 0 NULL NULL NULL NULL [] [] NaN NaN
1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 1 2 0.0 0.0 1 0.0 NULL 0.0 0.0 [1,1] [1] 0.7071067811865476 -1.5
1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 2 3 0.0 0.0 1 0.0 NULL 0.0 0.0 [1,1] [1] 0.7071067811865476 -1.5
2 a 2 1 1 3 4 1.3333333333333333 0.5773502691896258 NULL 1 NULL 2 2 2 4 3 1.0 1.0 2 4 0.22222222222222224 0.33333333333333337 2 4.772185885555555E8 1.0 0.5773502691896258 0.4714045207910317 [1,1,2] [1,2] 1.1539890888012805 -0.6672217220327235
1 b 1 1 1 1 1 1.0 NaN 1 1 1 1 1 1 1 1 0.3333333333333333 0.0 1 1 0.0 NaN 1 NULL NULL NaN 0.0 [1] [1] NaN NaN
2 b 2 1 1 2 3 1.5 0.7071067811865476 1 1 1 2 2 2 2 2 0.6666666666666666 0.5 1 2 0.25 0.5 2 0.0 NaN 0.7071067811865476 0.5 [1,2] [1,2] 0.0 -2.0000000000000013
3 b 3 1 1 3 6 2.0 1.0 1 1 1 3 3 3 3 3 1.0 1.0 2 3 0.6666666666666666 1.0 3 5.3687091175E8 1.0 1.0 0.816496580927726 [1,2,3] [1,2,3] 0.7057890433107311 -1.4999999999999984


-- !query 18
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql

import java.sql.{Date, Timestamp}

import scala.collection.mutable

import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window}
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -86,6 +88,236 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
assert(e.message.contains("requires window to be ordered"))
}

test("corr, covar_pop, stddev_pop functions in specific window") {
val df = Seq(
("a", "p1", 10.0, 20.0),
("b", "p1", 20.0, 10.0),
("c", "p2", 20.0, 20.0),
("d", "p2", 20.0, 20.0),
("e", "p3", 0.0, 0.0),
("f", "p3", 6.0, 12.0),
("g", "p3", 6.0, 12.0),
("h", "p3", 8.0, 16.0),
("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2")
checkAnswer(
df.select(
$"key",
corr("value1", "value2").over(Window.partitionBy("partitionId")
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
covar_pop("value1", "value2")
.over(Window.partitionBy("partitionId")
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
var_pop("value1")
.over(Window.partitionBy("partitionId")
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
stddev_pop("value1")
.over(Window.partitionBy("partitionId")
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
var_pop("value2")
.over(Window.partitionBy("partitionId")
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
stddev_pop("value2")
.over(Window.partitionBy("partitionId")
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))),

// As stddev_pop(expr) = sqrt(var_pop(expr))
// the "stddev_pop" column can be calculated from the "var_pop" column.
//
// As corr(expr1, expr2) = covar_pop(expr1, expr2) / (stddev_pop(expr1) * stddev_pop(expr2))
// the "corr" column can be calculated from the "covar_pop" and the two "stddev_pop" columns.
Seq(
Row("a", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0),
Row("b", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0),
Row("c", null, 0.0, 0.0, 0.0, 0.0, 0.0),
Row("d", null, 0.0, 0.0, 0.0, 0.0, 0.0),
Row("e", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0),
Row("f", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0),
Row("g", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0),
Row("h", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0),
Row("i", Double.NaN, 0.0, 0.0, 0.0, 0.0, 0.0)))
}

test("covar_samp, var_samp (variance), stddev_samp (stddev) functions in specific window") {
val df = Seq(
("a", "p1", 10.0, 20.0),
("b", "p1", 20.0, 10.0),
("c", "p2", 20.0, 20.0),
("d", "p2", 20.0, 20.0),
("e", "p3", 0.0, 0.0),
("f", "p3", 6.0, 12.0),
("g", "p3", 6.0, 12.0),
("h", "p3", 8.0, 16.0),
("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2")
checkAnswer(
df.select(
$"key",
covar_samp("value1", "value2").over(Window.partitionBy("partitionId")
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
var_samp("value1").over(Window.partitionBy("partitionId")
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
variance("value1").over(Window.partitionBy("partitionId")
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
stddev_samp("value1").over(Window.partitionBy("partitionId")
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
stddev("value1").over(Window.partitionBy("partitionId")
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))
),
Seq(
Row("a", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755),
Row("b", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755),
Row("c", 0.0, 0.0, 0.0, 0.0, 0.0 ),
Row("d", 0.0, 0.0, 0.0, 0.0, 0.0 ),
Row("e", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
Row("f", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
Row("g", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
Row("h", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
Row("i", Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN)))
}

test("collect_list in ascending ordered window") {
val df = Seq(
("a", "p1", "1"),
("b", "p1", "2"),
("c", "p1", "2"),
("d", "p1", null),
("e", "p1", "3"),
("f", "p2", "10"),
("g", "p2", "11"),
("h", "p3", "20"),
("i", "p4", null)).toDF("key", "partition", "value")
checkAnswer(
df.select(
$"key",
sort_array(
collect_list("value").over(Window.partitionBy($"partition").orderBy($"value")
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)))),
Seq(
Row("a", Array("1", "2", "2", "3")),
Row("b", Array("1", "2", "2", "3")),
Row("c", Array("1", "2", "2", "3")),
Row("d", Array("1", "2", "2", "3")),
Row("e", Array("1", "2", "2", "3")),
Row("f", Array("10", "11")),
Row("g", Array("10", "11")),
Row("h", Array("20")),
Row("i", Array())))
}

test("collect_list in descending ordered window") {
val df = Seq(
("a", "p1", "1"),
("b", "p1", "2"),
("c", "p1", "2"),
("d", "p1", null),
("e", "p1", "3"),
("f", "p2", "10"),
("g", "p2", "11"),
("h", "p3", "20"),
("i", "p4", null)).toDF("key", "partition", "value")
checkAnswer(
df.select(
$"key",
sort_array(
collect_list("value").over(Window.partitionBy($"partition").orderBy($"value".desc)
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)))),
Seq(
Row("a", Array("1", "2", "2", "3")),
Row("b", Array("1", "2", "2", "3")),
Row("c", Array("1", "2", "2", "3")),
Row("d", Array("1", "2", "2", "3")),
Row("e", Array("1", "2", "2", "3")),
Row("f", Array("10", "11")),
Row("g", Array("10", "11")),
Row("h", Array("20")),
Row("i", Array())))
}

test("collect_set in window") {
val df = Seq(
("a", "p1", "1"),
("b", "p1", "2"),
("c", "p1", "2"),
("d", "p1", "3"),
("e", "p1", "3"),
("f", "p2", "10"),
("g", "p2", "11"),
("h", "p3", "20")).toDF("key", "partition", "value")
checkAnswer(
df.select(
$"key",
sort_array(
collect_set("value").over(Window.partitionBy($"partition").orderBy($"value")
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)))),
Seq(
Row("a", Array("1", "2", "3")),
Row("b", Array("1", "2", "3")),
Row("c", Array("1", "2", "3")),
Row("d", Array("1", "2", "3")),
Row("e", Array("1", "2", "3")),
Row("f", Array("10", "11")),
Row("g", Array("10", "11")),
Row("h", Array("20"))))
}

test("skewness and kurtosis functions in window") {
val df = Seq(
("a", "p1", 1.0),
("b", "p1", 1.0),
("c", "p1", 2.0),
("d", "p1", 2.0),
("e", "p1", 3.0),
("f", "p1", 3.0),
("g", "p1", 3.0),
("h", "p2", 1.0),
("i", "p2", 2.0),
("j", "p2", 5.0)).toDF("key", "partition", "value")
checkAnswer(
df.select(
$"key",
skewness("value").over(Window.partitionBy("partition").orderBy($"key")
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
kurtosis("value").over(Window.partitionBy("partition").orderBy($"key")
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))),
// results are checked by scipy.stats.skew() and scipy.stats.kurtosis()
Seq(
Row("a", -0.27238010581457267, -1.506920415224914),
Row("b", -0.27238010581457267, -1.506920415224914),
Row("c", -0.27238010581457267, -1.506920415224914),
Row("d", -0.27238010581457267, -1.506920415224914),
Row("e", -0.27238010581457267, -1.506920415224914),
Row("f", -0.27238010581457267, -1.506920415224914),
Row("g", -0.27238010581457267, -1.506920415224914),
Row("h", 0.5280049792181881, -1.5000000000000013),
Row("i", 0.5280049792181881, -1.5000000000000013),
Row("j", 0.5280049792181881, -1.5000000000000013)))
}

test("aggregation function on invalid column") {
val df = Seq((1, "1")).toDF("key", "value")
val e = intercept[AnalysisException](
df.select($"key", count("invalid").over()))
assert(e.message.contains("cannot resolve '`invalid`' given input columns: [key, value]"))
}

test("numerical aggregate functions on string column") {
val df = Seq((1, "a", "b")).toDF("key", "value1", "value2")
checkAnswer(
df.select($"key",
var_pop("value1").over(),
variance("value1").over(),
stddev_pop("value1").over(),
stddev("value1").over(),
sum("value1").over(),
mean("value1").over(),
avg("value1").over(),
corr("value1", "value2").over(),
covar_pop("value1", "value2").over(),
covar_samp("value1", "value2").over(),
skewness("value1").over(),
kurtosis("value1").over()),
Seq(Row(1, null, null, null, null, null, null, null, null, null, null, null, null)))
}

test("statistical functions") {
val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)).
toDF("key", "value")
Expand Down Expand Up @@ -232,6 +464,40 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
Row("b", 2, null, null, null, null, null, null)))
}

test("last/first on descending ordered window") {
val nullStr: String = null
val df = Seq(
("a", 0, nullStr),
("a", 1, "x"),
("a", 2, "y"),
("a", 3, "z"),
("a", 4, "v"),
("b", 1, "k"),
("b", 2, "l"),
("b", 3, nullStr)).
toDF("key", "order", "value")
val window = Window.partitionBy($"key").orderBy($"order".desc)
checkAnswer(
df.select(
$"key",
$"order",
first($"value").over(window),
first($"value", ignoreNulls = false).over(window),
first($"value", ignoreNulls = true).over(window),
last($"value").over(window),
last($"value", ignoreNulls = false).over(window),
last($"value", ignoreNulls = true).over(window)),
Seq(
Row("a", 0, "v", "v", "v", null, null, "x"),
Row("a", 1, "v", "v", "v", "x", "x", "x"),
Row("a", 2, "v", "v", "v", "y", "y", "y"),
Row("a", 3, "v", "v", "v", "z", "z", "z"),
Row("a", 4, "v", "v", "v", "v", "v", "v"),
Row("b", 1, null, null, "l", "k", "k", "k"),
Row("b", 2, null, null, "l", "l", "l", "l"),
Row("b", 3, null, null, null, null, null, null)))
}

test("SPARK-12989 ExtractWindowExpressions treats alias as regular attribute") {
val src = Seq((0, 3, 5)).toDF("a", "b", "c")
.withColumn("Data", struct("a", "b"))
Expand Down