Skip to content
Prev Previous commit
Next Next commit
rebase with upstream to revert stddev as alias of stddev_samp
  • Loading branch information
JihongMA committed Nov 4, 2015
commit 57eeeed67b5e5aa86023f043dbc1b0615b16f4d7
2 changes: 1 addition & 1 deletion R/pkg/inst/tests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -1276,7 +1276,7 @@ test_that("describe() and summarize() on a DataFrame", {
stats <- describe(df, "age")
expect_equal(collect(stats)[1, "summary"], "count")
expect_equal(collect(stats)[2, "age"], "24.5")
expect_equal(collect(stats)[3, "age"], "5.5")
expect_equal(collect(stats)[3, "age"], "7.7781745930520225")
stats <- describe(df)
expect_equal(collect(stats)[4, "name"], "Andy")
expect_equal(collect(stats)[5, "age"], "30")
Expand Down
36 changes: 18 additions & 18 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,25 +661,25 @@ def describe(self, *cols):
guarantee about the backward compatibility of the schema of the resulting DataFrame.

>>> df.describe().show()
+-------+---+
|summary|age|
+-------+---+
| count| 2|
| mean|3.5|
| stddev|1.5|
| min| 2|
| max| 5|
+-------+---+
+-------+------------------+
|summary| age|
+-------+------------------+
| count| 2|
| mean| 3.5|
| stddev|2.1213203435596424|
| min| 2|
| max| 5|
+-------+------------------+
>>> df.describe(['age', 'name']).show()
+-------+---+-----+
|summary|age| name|
+-------+---+-----+
| count| 2| 2|
| mean|3.5| null|
| stddev|1.5| null|
| min| 2|Alice|
| max| 5| Bob|
+-------+---+-----+
+-------+------------------+-----+
|summary| age| name|
+-------+------------------+-----+
| count| 2| 2|
| mean| 3.5| null|
| stddev|2.1213203435596424| null|
| min| 2|Alice|
| max| 5| Bob|
+-------+------------------+-----+
"""
if len(cols) == 1 and isinstance(cols[0], list):
cols = cols[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1155,29 +1155,6 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w
}
}

case class Stddev(child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) {

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)

override def prettyName: String = "stddev"

override protected val momentOrder = 2

override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")

if (n == 0.0) null else math.sqrt(moments(2) / n)
}
}


case class StddevPop(child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) {
Expand Down Expand Up @@ -1222,28 +1199,6 @@ case class StddevSamp(child: Expression,
}
}

case class Variance(child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) {

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)

override def prettyName: String = "variance"

override protected val momentOrder = 2

override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Any = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")

if (n == 0.0) null else moments(2) / n
}
}

case class VarianceSamp(child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) {
Expand Down
3 changes: 1 addition & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,7 @@ object functions {
def skewness(e: Column): Column = Skewness(e.expr)

/**
* Aggregate function: returns the population standard deviation of
* the expression in a group.
* Aggregate function: alias for [[stddev_samp]].
*
* @group agg_funcs
* @since 1.6.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,19 +176,19 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
}

test("stddev") {
val testData2ADev = math.sqrt(4 / 6.0)
val testData2ADev = math.sqrt(4.0 / 5.0)

checkAnswer(
testData2.agg(stddev('a)),
Row(testData2ADev))

checkAnswer(
testData2.agg(stddev_pop('a)),
Row(testData2ADev))
Row(math.sqrt(4.0 / 6.0)))

checkAnswer(
testData2.agg(stddev_samp('a)),
Row(math.sqrt(4 / 5.0)))
Row(testData2ADev))
}

test("zero stddev") {
Expand Down Expand Up @@ -246,7 +246,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {

checkAnswer(
emptyTableData.agg(variance('a)),
Row(Double.NaN))
Row(null))

checkAnswer(
emptyTableData.agg(var_samp('a)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val describeResult = Seq(
Row("count", "4", "4"),
Row("mean", "33.0", "178.0"),
Row("stddev", "16.583123951777", "10.0"),
Row("stddev", "19.148542155126762", "11.547005383792516"),
Row("min", "16", "164"),
Row("max", "60", "192"))

Expand Down
10 changes: 5 additions & 5 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
checkAnswer(
sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," +
"AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"),
Row(0, -1.5, 1, 3, 2, 2.0 / 3.0, math.sqrt(2.0 / 3.0), 6, 3)
Row(0, -1.5, 1, 3, 2, 1.0, 1, 6, 3)
)
}

Expand Down Expand Up @@ -722,7 +722,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("stddev") {
checkAnswer(
sql("SELECT STDDEV(a) FROM testData2"),
Row(math.sqrt(4.0 / 6.0))
Row(math.sqrt(4.0 / 5.0))
)
}

Expand All @@ -736,7 +736,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("stddev_samp") {
checkAnswer(
sql("SELECT STDDEV_SAMP(a) FROM testData2"),
Row(math.sqrt(4 / 5.0))
Row(math.sqrt(4.0 / 5.0))
)
}

Expand Down Expand Up @@ -777,8 +777,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {

test("stddev agg") {
checkAnswer(
sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"),
(1 to 3).map(i => Row(i, math.sqrt(1.0 / 4.0), math.sqrt(1.0 / 4.0), math.sqrt(1.0 / 2.0))))
sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"),
(1 to 3).map(i => Row(i, math.sqrt(1.0 / 2.0), math.sqrt(1.0 / 4.0), math.sqrt(1.0 / 2.0))))
}

test("variance agg") {
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.