Skip to content

Commit 4f8c381

Browse files
committed
Fix null case.
1 parent 3b731e2 commit 4f8c381

File tree

2 files changed

+37
-25
lines changed

2 files changed

+37
-25
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -592,31 +592,36 @@ case class Corr(
592592
}
593593

594594
override def update(buffer: MutableRow, input: InternalRow): Unit = {
595-
val x = left.eval(input).asInstanceOf[Double]
596-
val y = right.eval(input).asInstanceOf[Double]
597-
598-
var xAvg = buffer.getDouble(mutableAggBufferOffset)
599-
var yAvg = buffer.getDouble(mutableAggBufferOffsetPlus1)
600-
var Ck = buffer.getDouble(mutableAggBufferOffsetPlus2)
601-
var MkX = buffer.getDouble(mutableAggBufferOffsetPlus3)
602-
var MkY = buffer.getDouble(mutableAggBufferOffsetPlus4)
603-
var count = buffer.getLong(mutableAggBufferOffsetPlus5)
604-
605-
val deltaX = x - xAvg
606-
val deltaY = y - yAvg
607-
count += 1
608-
xAvg += deltaX / count
609-
yAvg += deltaY / count
610-
Ck += deltaX * (y - yAvg)
611-
MkX += deltaX * (x - xAvg)
612-
MkY += deltaY * (y - yAvg)
613-
614-
buffer.setDouble(mutableAggBufferOffset, xAvg)
615-
buffer.setDouble(mutableAggBufferOffsetPlus1, yAvg)
616-
buffer.setDouble(mutableAggBufferOffsetPlus2, Ck)
617-
buffer.setDouble(mutableAggBufferOffsetPlus3, MkX)
618-
buffer.setDouble(mutableAggBufferOffsetPlus4, MkY)
619-
buffer.setLong(mutableAggBufferOffsetPlus5, count)
595+
val leftEval = left.eval(input)
596+
val rightEval = right.eval(input)
597+
598+
if (leftEval != null && rightEval != null) {
599+
val x = leftEval.asInstanceOf[Double]
600+
val y = rightEval.asInstanceOf[Double]
601+
602+
var xAvg = buffer.getDouble(mutableAggBufferOffset)
603+
var yAvg = buffer.getDouble(mutableAggBufferOffsetPlus1)
604+
var Ck = buffer.getDouble(mutableAggBufferOffsetPlus2)
605+
var MkX = buffer.getDouble(mutableAggBufferOffsetPlus3)
606+
var MkY = buffer.getDouble(mutableAggBufferOffsetPlus4)
607+
var count = buffer.getLong(mutableAggBufferOffsetPlus5)
608+
609+
val deltaX = x - xAvg
610+
val deltaY = y - yAvg
611+
count += 1
612+
xAvg += deltaX / count
613+
yAvg += deltaY / count
614+
Ck += deltaX * (y - yAvg)
615+
MkX += deltaX * (x - xAvg)
616+
MkY += deltaY * (y - yAvg)
617+
618+
buffer.setDouble(mutableAggBufferOffset, xAvg)
619+
buffer.setDouble(mutableAggBufferOffsetPlus1, yAvg)
620+
buffer.setDouble(mutableAggBufferOffsetPlus2, Ck)
621+
buffer.setDouble(mutableAggBufferOffsetPlus3, MkX)
622+
buffer.setDouble(mutableAggBufferOffsetPlus4, MkY)
623+
buffer.setLong(mutableAggBufferOffsetPlus5, count)
624+
}
620625
}
621626

622627
// Merge counters from other partitions. Formula can be found at:

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,13 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
589589
val corr6 = df4.groupBy().agg(corr("a", "c")).collect()(0).getDouble(0)
590590
assert(math.abs(corr6 + 1.0) < 1e-12)
591591

592+
val df5 = Seq[(Integer, Integer)](
593+
(1, null),
594+
(null, -60)).toDF("a", "b")
595+
596+
val corr7 = df5.groupBy().agg(corr("a", "b")).collect()(0)
597+
assert(corr7 == Row(null))
598+
592599
withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") {
593600
val errorMessage = intercept[SparkException] {
594601
val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c")

0 commit comments

Comments
 (0)