Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add corr aggregate function.
  • Loading branch information
viirya committed Sep 3, 2015
commit cb34a95e3dea152250b6409827fc869bd7fae407
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -302,3 +303,102 @@ case class Sum(child: Expression) extends AlgebraicAggregate {

override val evaluateExpression = Cast(currentSum, resultType)
}

case class Corr(left: Expression, right: Expression) extends AggregateFunction2 {

def children: Seq[Expression] = Seq(left, right)

def nullable: Boolean = false

def dataType: DataType = DoubleType

def inputTypes: Seq[AbstractDataType] = Seq(DoubleType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For pearson correlation, I think the return data type should be fixed to DoubleType.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. The return type should always be double.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I was not questioning the return type. My question is the input type.


def bufferSchema: StructType = StructType.fromAttributes(bufferAttributes)

def cloneBufferAttributes: Seq[Attribute] = bufferAttributes.map(_.newInstance())

val bufferAttributes: Seq[AttributeReference] = Seq(
AttributeReference("xAvg", DoubleType)(),
AttributeReference("yAvg", DoubleType)(),
AttributeReference("Ck", DoubleType)(),
AttributeReference("MkX", DoubleType)(),
AttributeReference("MkY", DoubleType)(),
AttributeReference("count", LongType)())

override def initialize(buffer: MutableRow): Unit = {
(0 until 5).map(idx => buffer.setDouble(mutableBufferOffset + idx, 0.0))
buffer.setLong(mutableBufferOffset + 5, 0L)
}

override def update(buffer: MutableRow, input: InternalRow): Unit = {
val x = left.eval(input).asInstanceOf[Double]
val y = right.eval(input).asInstanceOf[Double]

var xAvg = buffer.getDouble(mutableBufferOffset)
var yAvg = buffer.getDouble(mutableBufferOffset + 1)
var Ck = buffer.getDouble(mutableBufferOffset + 2)
var MkX = buffer.getDouble(mutableBufferOffset + 3)
var MkY = buffer.getDouble(mutableBufferOffset + 4)
var count = buffer.getLong(mutableBufferOffset + 5)

val deltaX = x - xAvg
val deltaY = y - yAvg
count += 1
xAvg += deltaX / count
yAvg += deltaY / count
Ck += deltaX * (y - yAvg)
MkX += deltaX * (x - xAvg)
MkY += deltaY * (y - yAvg)

buffer.setDouble(mutableBufferOffset, xAvg)
buffer.setDouble(mutableBufferOffset + 1, yAvg)
buffer.setDouble(mutableBufferOffset + 2, Ck)
buffer.setDouble(mutableBufferOffset + 3, MkX)
buffer.setDouble(mutableBufferOffset + 4, MkY)
buffer.setLong(mutableBufferOffset + 5, count)
}

override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
val count2 = buffer2.getLong(inputBufferOffset + 5)

if (count2 > 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it safe to assume that the count2 in buffer1 is non zero? There is - currently - no documentation on this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only need to consider count in buffer2. I will add document for it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment. Now it is obvious, I wasn't thinking...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to add comment for it?

var xAvg = buffer1.getDouble(mutableBufferOffset)
var yAvg = buffer1.getDouble(mutableBufferOffset + 1)
var Ck = buffer1.getDouble(mutableBufferOffset + 2)
var MkX = buffer1.getDouble(mutableBufferOffset + 3)
var MkY = buffer1.getDouble(mutableBufferOffset + 4)
var count = buffer1.getLong(mutableBufferOffset + 5)

val xAvg2 = buffer2.getDouble(inputBufferOffset)
val yAvg2 = buffer2.getDouble(inputBufferOffset + 1)
val Ck2 = buffer2.getDouble(inputBufferOffset + 2)
val MkX2 = buffer2.getDouble(inputBufferOffset + 3)
val MkY2 = buffer2.getDouble(inputBufferOffset + 4)

val totalCount = count + count2
val deltaX = xAvg - xAvg2
val deltaY = yAvg - yAvg2
Ck += Ck2 + deltaX * deltaY * count / totalCount * count2
xAvg = (xAvg * count + xAvg2 * count2) / totalCount
yAvg = (yAvg * count + yAvg2 * count2) / totalCount
MkX += MkX2 + deltaX * deltaX * count / totalCount * count2
MkY += MkY2 + deltaY * deltaY * count / totalCount * count2
count = totalCount

buffer1.setDouble(mutableBufferOffset, xAvg)
buffer1.setDouble(mutableBufferOffset + 1, yAvg)
buffer1.setDouble(mutableBufferOffset + 2, Ck)
buffer1.setDouble(mutableBufferOffset + 3, MkX)
buffer1.setDouble(mutableBufferOffset + 4, MkY)
buffer1.setLong(mutableBufferOffset + 5, count)
}
}

override def eval(buffer: InternalRow): Any = {
val Ck = buffer.getDouble(mutableBufferOffset + 2)
val MkX = buffer.getDouble(mutableBufferOffset + 3)
val MkY = buffer.getDouble(mutableBufferOffset + 4)
Ck / math.sqrt(MkX * MkY)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if count is zero? Shall we return NaN?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reasonable. I will let it return NaN.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the failed test HiveCompatibilitySuite, looks like Hive will return NULL for this case. I think we should follow it.

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ object Utils {
aggregateFunction = aggregate.Sum(child),
mode = aggregate.Complete,
isDistinct = true)

case expressions.Corr(left, right) =>
aggregate.AggregateExpression2(
aggregateFunction = aggregate.Corr(left, right),
mode = aggregate.Complete,
isDistinct = false)
}
// Check if there is any expressions.AggregateExpression1 left.
// If so, we cannot convert this plan.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -691,3 +691,16 @@ case class LastFunction(expr: Expression, base: AggregateExpression1) extends Ag
result
}
}

/**
* Calculate Pearson Correlation Coefficient for the given columns.
* Only support AggregateExpression2.
*
*/
case class Corr(
left: Expression,
right: Expression) extends BinaryExpression with AggregateExpression {
override def nullable: Boolean = false
override def dataType: DoubleType.type = DoubleType
override def toString: String = s"CORRELATION($left, $right)"
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will be the error message if we call this function when spark.sql.useAggregate2=false? It will be good to provide a meaning error message.

18 changes: 18 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,24 @@ object functions {
*/
def avg(columnName: String): Column = avg(Column(columnName))

/**
* Aggregate function: returns the Pearson Correlation Coefficient for two columns.
*
* @group agg_funcs
* @since 1.6.0
*/
def corr(column1: Column, column2: Column): Column =
Corr(column1.expr, column2.expr)

/**
* Aggregate function: returns the Pearson Correlation Coefficient for two columns.
*
* @group agg_funcs
* @since 1.6.0
*/
def corr(columnName1: String, columnName2: String): Column =
corr(Column(columnName1), Column(columnName2))

/**
* Aggregate function: returns the number of items in a group.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.scalatest.BeforeAndAfterAll

import org.apache.spark.sql._
import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
Expand Down Expand Up @@ -480,6 +481,29 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be
Row(0, null, 1, 1, null, 0) :: Nil)
}

test("pearson correlation") {
val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c")
val corr1 = df.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
assert(math.abs(corr1 - 1.0) < 1e-12)
val corr2 = df.groupBy().agg(corr("a", "c")).collect()(0).getDouble(0)
assert(math.abs(corr2 + 1.0) < 1e-12)
// non-trivial example. To reproduce in python, use:
// >>> from scipy.stats import pearsonr
// >>> import numpy as np
// >>> a = np.array(range(20))
// >>> b = np.array([x * x - 2 * x + 3.5 for x in range(20)])
// >>> pearsonr(a, b)
// (0.95723391394758572, 3.8902121417802199e-11)
// In R, use:
// > a <- 0:19
// > b <- mapply(function(x) x * x - 2 * x + 3.5, a)
// > cor(a, b)
// [1] 0.957233913947585835
val df2 = Seq.tabulate(20)(x => (1.0 * x, x * x - 2 * x + 3.5)).toDF("a", "b")
val corr3 = df2.groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
assert(math.abs(corr3 - 0.95723391394758572) < 1e-12)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will happen if the data type of input parameters are not double?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add ImplicitCastInputTypes to case class Corr. So the other NumericType can be automatically casting to double.


test("test Last implemented based on AggregateExpression1") {
// TODO: Remove this test once we remove AggregateExpression1.
import org.apache.spark.sql.functions._
Expand Down