Skip to content

Commit cb34a95

Browse files
committed
Add corr aggregate function.
1 parent 67580f1 commit cb34a95

File tree

5 files changed

+161
-0
lines changed

5 files changed

+161
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions.aggregate
1919

20+
import org.apache.spark.sql.catalyst._
2021
import org.apache.spark.sql.catalyst.dsl.expressions._
2122
import org.apache.spark.sql.catalyst.expressions._
2223
import org.apache.spark.sql.types._
@@ -302,3 +303,102 @@ case class Sum(child: Expression) extends AlgebraicAggregate {
302303

303304
override val evaluateExpression = Cast(currentSum, resultType)
304305
}
306+
307+
case class Corr(left: Expression, right: Expression) extends AggregateFunction2 {
308+
309+
def children: Seq[Expression] = Seq(left, right)
310+
311+
def nullable: Boolean = false
312+
313+
def dataType: DataType = DoubleType
314+
315+
def inputTypes: Seq[AbstractDataType] = Seq(DoubleType)
316+
317+
def bufferSchema: StructType = StructType.fromAttributes(bufferAttributes)
318+
319+
def cloneBufferAttributes: Seq[Attribute] = bufferAttributes.map(_.newInstance())
320+
321+
val bufferAttributes: Seq[AttributeReference] = Seq(
322+
AttributeReference("xAvg", DoubleType)(),
323+
AttributeReference("yAvg", DoubleType)(),
324+
AttributeReference("Ck", DoubleType)(),
325+
AttributeReference("MkX", DoubleType)(),
326+
AttributeReference("MkY", DoubleType)(),
327+
AttributeReference("count", LongType)())
328+
329+
override def initialize(buffer: MutableRow): Unit = {
330+
(0 until 5).map(idx => buffer.setDouble(mutableBufferOffset + idx, 0.0))
331+
buffer.setLong(mutableBufferOffset + 5, 0L)
332+
}
333+
334+
override def update(buffer: MutableRow, input: InternalRow): Unit = {
335+
val x = left.eval(input).asInstanceOf[Double]
336+
val y = right.eval(input).asInstanceOf[Double]
337+
338+
var xAvg = buffer.getDouble(mutableBufferOffset)
339+
var yAvg = buffer.getDouble(mutableBufferOffset + 1)
340+
var Ck = buffer.getDouble(mutableBufferOffset + 2)
341+
var MkX = buffer.getDouble(mutableBufferOffset + 3)
342+
var MkY = buffer.getDouble(mutableBufferOffset + 4)
343+
var count = buffer.getLong(mutableBufferOffset + 5)
344+
345+
val deltaX = x - xAvg
346+
val deltaY = y - yAvg
347+
count += 1
348+
xAvg += deltaX / count
349+
yAvg += deltaY / count
350+
Ck += deltaX * (y - yAvg)
351+
MkX += deltaX * (x - xAvg)
352+
MkY += deltaY * (y - yAvg)
353+
354+
buffer.setDouble(mutableBufferOffset, xAvg)
355+
buffer.setDouble(mutableBufferOffset + 1, yAvg)
356+
buffer.setDouble(mutableBufferOffset + 2, Ck)
357+
buffer.setDouble(mutableBufferOffset + 3, MkX)
358+
buffer.setDouble(mutableBufferOffset + 4, MkY)
359+
buffer.setLong(mutableBufferOffset + 5, count)
360+
}
361+
362+
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
363+
val count2 = buffer2.getLong(inputBufferOffset + 5)
364+
365+
if (count2 > 0) {
366+
var xAvg = buffer1.getDouble(mutableBufferOffset)
367+
var yAvg = buffer1.getDouble(mutableBufferOffset + 1)
368+
var Ck = buffer1.getDouble(mutableBufferOffset + 2)
369+
var MkX = buffer1.getDouble(mutableBufferOffset + 3)
370+
var MkY = buffer1.getDouble(mutableBufferOffset + 4)
371+
var count = buffer1.getLong(mutableBufferOffset + 5)
372+
373+
val xAvg2 = buffer2.getDouble(inputBufferOffset)
374+
val yAvg2 = buffer2.getDouble(inputBufferOffset + 1)
375+
val Ck2 = buffer2.getDouble(inputBufferOffset + 2)
376+
val MkX2 = buffer2.getDouble(inputBufferOffset + 3)
377+
val MkY2 = buffer2.getDouble(inputBufferOffset + 4)
378+
379+
val totalCount = count + count2
380+
val deltaX = xAvg - xAvg2
381+
val deltaY = yAvg - yAvg2
382+
Ck += Ck2 + deltaX * deltaY * count / totalCount * count2
383+
xAvg = (xAvg * count + xAvg2 * count2) / totalCount
384+
yAvg = (yAvg * count + yAvg2 * count2) / totalCount
385+
MkX += MkX2 + deltaX * deltaX * count / totalCount * count2
386+
MkY += MkY2 + deltaY * deltaY * count / totalCount * count2
387+
count = totalCount
388+
389+
buffer1.setDouble(mutableBufferOffset, xAvg)
390+
buffer1.setDouble(mutableBufferOffset + 1, yAvg)
391+
buffer1.setDouble(mutableBufferOffset + 2, Ck)
392+
buffer1.setDouble(mutableBufferOffset + 3, MkX)
393+
buffer1.setDouble(mutableBufferOffset + 4, MkY)
394+
buffer1.setLong(mutableBufferOffset + 5, count)
395+
}
396+
}
397+
398+
override def eval(buffer: InternalRow): Any = {
399+
val Ck = buffer.getDouble(mutableBufferOffset + 2)
400+
val MkX = buffer.getDouble(mutableBufferOffset + 3)
401+
val MkY = buffer.getDouble(mutableBufferOffset + 4)
402+
Ck / math.sqrt(MkX * MkY)
403+
}
404+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,12 @@ object Utils {
9696
aggregateFunction = aggregate.Sum(child),
9797
mode = aggregate.Complete,
9898
isDistinct = true)
99+
100+
case expressions.Corr(left, right) =>
101+
aggregate.AggregateExpression2(
102+
aggregateFunction = aggregate.Corr(left, right),
103+
mode = aggregate.Complete,
104+
isDistinct = false)
99105
}
100106
// Check if there is any expressions.AggregateExpression1 left.
101107
// If so, we cannot convert this plan.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,3 +691,16 @@ case class LastFunction(expr: Expression, base: AggregateExpression1) extends Ag
691691
result
692692
}
693693
}
694+
695+
/**
696+
* Calculate Pearson Correlation Coefficient for the given columns.
697+
* Only support AggregateExpression2.
698+
*
699+
*/
700+
case class Corr(
701+
left: Expression,
702+
right: Expression) extends BinaryExpression with AggregateExpression {
703+
override def nullable: Boolean = false
704+
override def dataType: DoubleType.type = DoubleType
705+
override def toString: String = s"CORRELATION($left, $right)"
706+
}

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,24 @@ object functions {
172172
*/
173173
def avg(columnName: String): Column = avg(Column(columnName))
174174

175+
/**
176+
* Aggregate function: returns the Pearson Correlation Coefficient for two columns.
177+
*
178+
* @group agg_funcs
179+
* @since 1.6.0
180+
*/
181+
def corr(column1: Column, column2: Column): Column =
182+
Corr(column1.expr, column2.expr)
183+
184+
/**
185+
* Aggregate function: returns the Pearson Correlation Coefficient for two columns.
186+
*
187+
* @group agg_funcs
188+
* @since 1.6.0
189+
*/
190+
def corr(columnName1: String, columnName2: String): Column =
191+
corr(Column(columnName1), Column(columnName2))
192+
175193
/**
176194
* Aggregate function: returns the number of items in a group.
177195
*

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import org.scalatest.BeforeAndAfterAll
2121

2222
import org.apache.spark.sql._
2323
import org.apache.spark.sql.execution.aggregate
24+
import org.apache.spark.sql.functions._
2425
import org.apache.spark.sql.hive.test.TestHive
2526
import org.apache.spark.sql.test.SQLTestUtils
2627
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
@@ -480,6 +481,29 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be
480481
Row(0, null, 1, 1, null, 0) :: Nil)
481482
}
482483

484+
test("pearson correlation") {
485+
val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c")
486+
val corr1 = df.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
487+
assert(math.abs(corr1 - 1.0) < 1e-12)
488+
val corr2 = df.groupBy().agg(corr("a", "c")).collect()(0).getDouble(0)
489+
assert(math.abs(corr2 + 1.0) < 1e-12)
490+
// non-trivial example. To reproduce in python, use:
491+
// >>> from scipy.stats import pearsonr
492+
// >>> import numpy as np
493+
// >>> a = np.array(range(20))
494+
// >>> b = np.array([x * x - 2 * x + 3.5 for x in range(20)])
495+
// >>> pearsonr(a, b)
496+
// (0.95723391394758572, 3.8902121417802199e-11)
497+
// In R, use:
498+
// > a <- 0:19
499+
// > b <- mapply(function(x) x * x - 2 * x + 3.5, a)
500+
// > cor(a, b)
501+
// [1] 0.957233913947585835
502+
val df2 = Seq.tabulate(20)(x => (1.0 * x, x * x - 2 * x + 3.5)).toDF("a", "b")
503+
val corr3 = df2.groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
504+
assert(math.abs(corr3 - 0.95723391394758572) < 1e-12)
505+
}
506+
483507
test("test Last implemented based on AggregateExpression1") {
484508
// TODO: Remove this test once we remove AggregateExpression1.
485509
import org.apache.spark.sql.functions._

0 commit comments

Comments
 (0)