|
17 | 17 |
|
18 | 18 | package org.apache.spark.sql.catalyst.expressions.aggregate |
19 | 19 |
|
| 20 | +import org.apache.spark.sql.catalyst._ |
20 | 21 | import org.apache.spark.sql.catalyst.dsl.expressions._ |
21 | 22 | import org.apache.spark.sql.catalyst.expressions._ |
22 | 23 | import org.apache.spark.sql.types._ |
@@ -302,3 +303,102 @@ case class Sum(child: Expression) extends AlgebraicAggregate { |
302 | 303 |
|
303 | 304 | override val evaluateExpression = Cast(currentSum, resultType) |
304 | 305 | } |
| 306 | + |
| 307 | +case class Corr(left: Expression, right: Expression) extends AggregateFunction2 { |
| 308 | + |
| 309 | + def children: Seq[Expression] = Seq(left, right) |
| 310 | + |
| 311 | + def nullable: Boolean = false |
| 312 | + |
| 313 | + def dataType: DataType = DoubleType |
| 314 | + |
| 315 | + def inputTypes: Seq[AbstractDataType] = Seq(DoubleType) |
| 316 | + |
| 317 | + def bufferSchema: StructType = StructType.fromAttributes(bufferAttributes) |
| 318 | + |
| 319 | + def cloneBufferAttributes: Seq[Attribute] = bufferAttributes.map(_.newInstance()) |
| 320 | + |
| 321 | + val bufferAttributes: Seq[AttributeReference] = Seq( |
| 322 | + AttributeReference("xAvg", DoubleType)(), |
| 323 | + AttributeReference("yAvg", DoubleType)(), |
| 324 | + AttributeReference("Ck", DoubleType)(), |
| 325 | + AttributeReference("MkX", DoubleType)(), |
| 326 | + AttributeReference("MkY", DoubleType)(), |
| 327 | + AttributeReference("count", LongType)()) |
| 328 | + |
| 329 | + override def initialize(buffer: MutableRow): Unit = { |
| 330 | + (0 until 5).map(idx => buffer.setDouble(mutableBufferOffset + idx, 0.0)) |
| 331 | + buffer.setLong(mutableBufferOffset + 5, 0L) |
| 332 | + } |
| 333 | + |
| 334 | + override def update(buffer: MutableRow, input: InternalRow): Unit = { |
| 335 | + val x = left.eval(input).asInstanceOf[Double] |
| 336 | + val y = right.eval(input).asInstanceOf[Double] |
| 337 | + |
| 338 | + var xAvg = buffer.getDouble(mutableBufferOffset) |
| 339 | + var yAvg = buffer.getDouble(mutableBufferOffset + 1) |
| 340 | + var Ck = buffer.getDouble(mutableBufferOffset + 2) |
| 341 | + var MkX = buffer.getDouble(mutableBufferOffset + 3) |
| 342 | + var MkY = buffer.getDouble(mutableBufferOffset + 4) |
| 343 | + var count = buffer.getLong(mutableBufferOffset + 5) |
| 344 | + |
| 345 | + val deltaX = x - xAvg |
| 346 | + val deltaY = y - yAvg |
| 347 | + count += 1 |
| 348 | + xAvg += deltaX / count |
| 349 | + yAvg += deltaY / count |
| 350 | + Ck += deltaX * (y - yAvg) |
| 351 | + MkX += deltaX * (x - xAvg) |
| 352 | + MkY += deltaY * (y - yAvg) |
| 353 | + |
| 354 | + buffer.setDouble(mutableBufferOffset, xAvg) |
| 355 | + buffer.setDouble(mutableBufferOffset + 1, yAvg) |
| 356 | + buffer.setDouble(mutableBufferOffset + 2, Ck) |
| 357 | + buffer.setDouble(mutableBufferOffset + 3, MkX) |
| 358 | + buffer.setDouble(mutableBufferOffset + 4, MkY) |
| 359 | + buffer.setLong(mutableBufferOffset + 5, count) |
| 360 | + } |
| 361 | + |
| 362 | + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { |
| 363 | + val count2 = buffer2.getLong(inputBufferOffset + 5) |
| 364 | + |
| 365 | + if (count2 > 0) { |
| 366 | + var xAvg = buffer1.getDouble(mutableBufferOffset) |
| 367 | + var yAvg = buffer1.getDouble(mutableBufferOffset + 1) |
| 368 | + var Ck = buffer1.getDouble(mutableBufferOffset + 2) |
| 369 | + var MkX = buffer1.getDouble(mutableBufferOffset + 3) |
| 370 | + var MkY = buffer1.getDouble(mutableBufferOffset + 4) |
| 371 | + var count = buffer1.getLong(mutableBufferOffset + 5) |
| 372 | + |
| 373 | + val xAvg2 = buffer2.getDouble(inputBufferOffset) |
| 374 | + val yAvg2 = buffer2.getDouble(inputBufferOffset + 1) |
| 375 | + val Ck2 = buffer2.getDouble(inputBufferOffset + 2) |
| 376 | + val MkX2 = buffer2.getDouble(inputBufferOffset + 3) |
| 377 | + val MkY2 = buffer2.getDouble(inputBufferOffset + 4) |
| 378 | + |
| 379 | + val totalCount = count + count2 |
| 380 | + val deltaX = xAvg - xAvg2 |
| 381 | + val deltaY = yAvg - yAvg2 |
| 382 | + Ck += Ck2 + deltaX * deltaY * count / totalCount * count2 |
| 383 | + xAvg = (xAvg * count + xAvg2 * count2) / totalCount |
| 384 | + yAvg = (yAvg * count + yAvg2 * count2) / totalCount |
| 385 | + MkX += MkX2 + deltaX * deltaX * count / totalCount * count2 |
| 386 | + MkY += MkY2 + deltaY * deltaY * count / totalCount * count2 |
| 387 | + count = totalCount |
| 388 | + |
| 389 | + buffer1.setDouble(mutableBufferOffset, xAvg) |
| 390 | + buffer1.setDouble(mutableBufferOffset + 1, yAvg) |
| 391 | + buffer1.setDouble(mutableBufferOffset + 2, Ck) |
| 392 | + buffer1.setDouble(mutableBufferOffset + 3, MkX) |
| 393 | + buffer1.setDouble(mutableBufferOffset + 4, MkY) |
| 394 | + buffer1.setLong(mutableBufferOffset + 5, count) |
| 395 | + } |
| 396 | + } |
| 397 | + |
| 398 | + override def eval(buffer: InternalRow): Any = { |
| 399 | + val Ck = buffer.getDouble(mutableBufferOffset + 2) |
| 400 | + val MkX = buffer.getDouble(mutableBufferOffset + 3) |
| 401 | + val MkY = buffer.getDouble(mutableBufferOffset + 4) |
| 402 | + Ck / math.sqrt(MkX * MkY) |
| 403 | + } |
| 404 | +} |
0 commit comments