Skip to content
2 changes: 1 addition & 1 deletion R/pkg/inst/tests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -1271,7 +1271,7 @@ test_that("describe() and summarize() on a DataFrame", {
stats <- describe(df, "age")
expect_equal(collect(stats)[1, "summary"], "count")
expect_equal(collect(stats)[2, "age"], "24.5")
expect_equal(collect(stats)[3, "age"], "7.7781745930520225")
expect_equal(collect(stats)[3, "age"], "5.5")
stats <- describe(df)
expect_equal(collect(stats)[4, "name"], "Andy")
expect_equal(collect(stats)[5, "age"], "30")
Expand Down
36 changes: 18 additions & 18 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,25 +661,25 @@ def describe(self, *cols):
guarantee about the backward compatibility of the schema of the resulting DataFrame.

>>> df.describe().show()
+-------+------------------+
|summary| age|
+-------+------------------+
| count| 2|
| mean| 3.5|
| stddev|2.1213203435596424|
| min| 2|
| max| 5|
+-------+------------------+
+-------+---+
|summary|age|
+-------+---+
| count| 2|
| mean|3.5|
| stddev|1.5|
| min| 2|
| max| 5|
+-------+---+
>>> df.describe(['age', 'name']).show()
+-------+------------------+-----+
|summary| age| name|
+-------+------------------+-----+
| count| 2| 2|
| mean| 3.5| null|
| stddev|2.1213203435596424| null|
| min| 2|Alice|
| max| 5| Bob|
+-------+------------------+-----+
+-------+---+-----+
|summary|age| name|
+-------+---+-----+
| count| 2| 2|
| mean|3.5| null|
| stddev|1.5| null|
| min| 2|Alice|
| max| 5| Bob|
+-------+---+-----+
"""
if len(cols) == 1 and isinstance(cols[0], list):
cols = cols[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,149 +327,6 @@ case class Min(child: Expression) extends DeclarativeAggregate {
override val evaluateExpression = min
}

// Compute the sample standard deviation of a column
case class Stddev(child: Expression) extends StddevAgg(child) {

override def isSample: Boolean = true
override def prettyName: String = "stddev"
}

// Compute the population standard deviation of a column
case class StddevPop(child: Expression) extends StddevAgg(child) {

override def isSample: Boolean = false
override def prettyName: String = "stddev_pop"
}

// Compute the sample standard deviation of a column
case class StddevSamp(child: Expression) extends StddevAgg(child) {

override def isSample: Boolean = true
override def prettyName: String = "stddev_samp"
}

// Compute standard deviation based on online algorithm specified here:
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
abstract class StddevAgg(child: Expression) extends DeclarativeAggregate {

override def children: Seq[Expression] = child :: Nil

override def nullable: Boolean = true

def isSample: Boolean

// Return data type.
override def dataType: DataType = resultType

// Expected input data type.
// TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the
// new version at planning time (after analysis phase). For now, NullType is added at here
// to make it resolved when we have cases like `select stddev(null)`.
// We can use our analyzer to cast NullType to the default data type of the NumericType once
// we remove the old aggregate functions. Then, we will not need NullType at here.
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))

private val resultType = DoubleType

private val preCount = AttributeReference("preCount", resultType)()
private val currentCount = AttributeReference("currentCount", resultType)()
private val preAvg = AttributeReference("preAvg", resultType)()
private val currentAvg = AttributeReference("currentAvg", resultType)()
private val currentMk = AttributeReference("currentMk", resultType)()

override val aggBufferAttributes = preCount :: currentCount :: preAvg ::
currentAvg :: currentMk :: Nil

override val initialValues = Seq(
/* preCount = */ Cast(Literal(0), resultType),
/* currentCount = */ Cast(Literal(0), resultType),
/* preAvg = */ Cast(Literal(0), resultType),
/* currentAvg = */ Cast(Literal(0), resultType),
/* currentMk = */ Cast(Literal(0), resultType)
)

override val updateExpressions = {

// update average
// avg = avg + (value - avg)/count
def avgAdd: Expression = {
currentAvg + ((Cast(child, resultType) - currentAvg) / currentCount)
}

// update sum of square of difference from mean
// Mk = Mk + (value - preAvg) * (value - updatedAvg)
def mkAdd: Expression = {
val delta1 = Cast(child, resultType) - preAvg
val delta2 = Cast(child, resultType) - currentAvg
currentMk + (delta1 * delta2)
}

Seq(
/* preCount = */ If(IsNull(child), preCount, currentCount),
/* currentCount = */ If(IsNull(child), currentCount,
Add(currentCount, Cast(Literal(1), resultType))),
/* preAvg = */ If(IsNull(child), preAvg, currentAvg),
/* currentAvg = */ If(IsNull(child), currentAvg, avgAdd),
/* currentMk = */ If(IsNull(child), currentMk, mkAdd)
)
}

override val mergeExpressions = {

// count merge
def countMerge: Expression = {
currentCount.left + currentCount.right
}

// average merge
def avgMerge: Expression = {
((currentAvg.left * preCount) + (currentAvg.right * currentCount.right)) /
(preCount + currentCount.right)
}

// update sum of square differences
def mkMerge: Expression = {
val avgDelta = currentAvg.right - preAvg
val mkDelta = (avgDelta * avgDelta) * (preCount * currentCount.right) /
(preCount + currentCount.right)

currentMk.left + currentMk.right + mkDelta
}

Seq(
/* preCount = */ If(IsNull(currentCount.left),
Cast(Literal(0), resultType), currentCount.left),
/* currentCount = */ If(IsNull(currentCount.left), currentCount.right,
If(IsNull(currentCount.right), currentCount.left, countMerge)),
/* preAvg = */ If(IsNull(currentAvg.left), Cast(Literal(0), resultType), currentAvg.left),
/* currentAvg = */ If(IsNull(currentAvg.left), currentAvg.right,
If(IsNull(currentAvg.right), currentAvg.left, avgMerge)),
/* currentMk = */ If(IsNull(currentMk.left), currentMk.right,
If(IsNull(currentMk.right), currentMk.left, mkMerge))
)
}

override val evaluateExpression = {
// when currentCount == 0, return null
// when currentCount == 1, return 0
// when currentCount >1
// stddev_samp = sqrt (currentMk/(currentCount -1))
// stddev_pop = sqrt (currentMk/currentCount)
val varCol = {
if (isSample) {
currentMk / Cast((currentCount - Cast(Literal(1), resultType)), resultType)
}
else {
currentMk / currentCount
}
}

If(EqualTo(currentCount, Cast(Literal(0), resultType)), Cast(Literal(null), resultType),
If(EqualTo(currentCount, Cast(Literal(1), resultType)), Cast(Literal(0), resultType),
Cast(Sqrt(varCol), resultType)))
}
}

case class Sum(child: Expression) extends DeclarativeAggregate {

override def children: Seq[Expression] = child :: Nil
Expand Down Expand Up @@ -1135,7 +992,76 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w
moments(4) = buffer.getDouble(fourthMomentOffset)
}

getStatistic(n, mean, moments)
if (n == 0.0) null
else if (n == 1.0) 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe we want this behavior, since these edge cases should be handled in the getStatistic implementation. If you see previous PR we established that Skewness and Kurtosis should yield Double.NaN when n == 1.0 but other functions like VariancePop should yield 0.0.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

else getStatistic(n, mean, moments)
}
}

case class Stddev(child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) {

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)

override def prettyName: String = "stddev"

override protected val momentOrder = 2

override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")

if (n == 0.0) Double.NaN else math.sqrt(moments(2) / n)
}
}


case class StddevPop(child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) {

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)

override def prettyName: String = "stddev_pop"

override protected val momentOrder = 2

override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")

if (n == 0.0) Double.NaN else math.sqrt(moments(2) / n)
}
}

case class StddevSamp(child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) {

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)

override def prettyName: String = "stddev_samp"

override protected val momentOrder = 2

override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
require(moments.length == momentOrder + 1,
s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")

if (n == 0.0 || n == 1.0) Double.NaN else math.sqrt(moments(2) / (n - 1.0))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ class GroupedData protected[sql](
}

/**
* Compute the sample standard deviation for each numeric columns for each group.
* Compute the population standard deviation for each numeric columns for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
* When specified columns are given, only compute the stddev for them.
*
Expand Down Expand Up @@ -364,7 +364,7 @@ class GroupedData protected[sql](
}

/**
* Compute the sample variance for each numeric columns for each group.
* Compute the population variance for each numeric columns for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
* When specified columns are given, only compute the variance for them.
*
Expand Down
4 changes: 2 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ object functions {
def skewness(columnName: String): Column = skewness(Column(columnName))

/**
* Aggregate function: returns the unbiased sample standard deviation of
* Aggregate function: returns the population standard deviation of
* the expression in a group.
*
* @group agg_funcs
Expand All @@ -336,7 +336,7 @@ object functions {
def stddev(e: Column): Column = Stddev(e.expr)

/**
* Aggregate function: returns the unbiased sample standard deviation of
* Aggregate function: returns the population standard deviation of
* the expression in a group.
*
* @group agg_funcs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,19 +176,19 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
}

test("stddev") {
val testData2ADev = math.sqrt(4/5.0)
val testData2ADev = math.sqrt(4 / 6.0)

checkAnswer(
testData2.agg(stddev('a)),
Row(testData2ADev))

checkAnswer(
testData2.agg(stddev_pop('a)),
Row(math.sqrt(4/6.0)))
Row(testData2ADev))

checkAnswer(
testData2.agg(stddev_samp('a)),
Row(testData2ADev))
Row(math.sqrt(4 / 5.0)))
}

test("zero stddev") {
Expand Down Expand Up @@ -255,19 +255,19 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {

checkAnswer(
emptyTableData.agg(var_samp('a)),
Row(Double.NaN))
Row(0.0))

checkAnswer(
emptyTableData.agg(var_pop('a)),
Row(0.0))

checkAnswer(
emptyTableData.agg(skewness('a)),
Row(Double.NaN))
Row(0.0))

checkAnswer(
emptyTableData.agg(kurtosis('a)),
Row(Double.NaN))
Row(0.0))
}

test("null moments") {
Expand All @@ -276,22 +276,22 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {

checkAnswer(
emptyTableData.agg(variance('a)),
Row(Double.NaN))
Row(null))

checkAnswer(
emptyTableData.agg(var_samp('a)),
Row(Double.NaN))
Row(null))

checkAnswer(
emptyTableData.agg(var_pop('a)),
Row(Double.NaN))
Row(null))

checkAnswer(
emptyTableData.agg(skewness('a)),
Row(Double.NaN))
Row(null))

checkAnswer(
emptyTableData.agg(kurtosis('a)),
Row(Double.NaN))
Row(null))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val describeResult = Seq(
Row("count", "4", "4"),
Row("mean", "33.0", "178.0"),
Row("stddev", "19.148542155126762", "11.547005383792516"),
Row("stddev", "16.583123951777", "10.0"),
Row("min", "16", "164"),
Row("max", "60", "192"))

Expand Down
Loading