Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Make Corr extends AggregateExpression1.
  • Loading branch information
viirya committed Oct 29, 2015
commit 3b731e2c9b08dbade38da73bcff94cf1b2cd7636
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate {

/**
* Compute Pearson correlation between two expressions.
* When applied on empty data (i.e., count is zero), it returns NaN.
* When applied on empty data (i.e., count is zero), it returns NULL.
*
* Definition of Pearson correlation can be found at
* http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient
Expand Down Expand Up @@ -668,7 +668,7 @@ case class Corr(
val MkY = buffer.getDouble(mutableAggBufferOffsetPlus4)
Ck / math.sqrt(MkX * MkY)
} else {
Double.NaN
null
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,30 +194,4 @@ object Utils {
}
case other => None
}

def mustNewAggregation(aggregate: Aggregate): Unit = {
val onlyForAggregateExpression2 = aggregate.aggregateExpressions.flatMap { expr =>
expr.collect {
// If an aggregate expression only extends AggregateExpression
// without AggregateExpression1, it indicates it only supports AggregateExpression2
case agg: expressions.AggregateExpression
if !agg.isInstanceOf[expressions.AggregateExpression1] =>
agg
}
}
if (onlyForAggregateExpression2.nonEmpty) {
val invalidFunctions = {
if (onlyForAggregateExpression2.length > 1) {
s"${onlyForAggregateExpression2.tail.map(_.nodeName).mkString(",")} " +
s"and ${onlyForAggregateExpression2.head.nodeName} are"
} else {
s"${onlyForAggregateExpression2.head.nodeName} is"
}
}
val errorMessage =
s"${invalidFunctions} only implemented based on the new Aggregate Function " +
s"interface and it cannot be used when spark.sql.useAggregate2 = false."
throw new AnalysisException(errorMessage)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -753,11 +753,16 @@ case class LastFunction(
*
*/
case class Corr(left: Expression, right: Expression)
extends BinaryExpression with AggregateExpression with ImplicitCastInputTypes {
extends BinaryExpression with AggregateExpression1 with ImplicitCastInputTypes {
override def nullable: Boolean = false
override def dataType: DoubleType.type = DoubleType
override def toString: String = s"CORRELATION($left, $right)"
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
override def newInstance(): AggregateFunction1 = {
throw new UnsupportedOperationException(
"Corr only supports the new AggregateExpression2 and can only be used " +
"when spark.sql.useAggregate2 = true")
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will be the error message if we call this function when spark.sql.useAggregate2=false? It will be good to provide a meaning error message.


// Compute standard deviation based on online algorithm specified here:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
groupingExpressions,
partialComputation,
child) if !canBeConvertedToNewAggregation(plan) =>
Utils.mustNewAggregation(plan.asInstanceOf[logical.Aggregate])
execution.Aggregate(
partial = false,
namedGroupingAttributes,
Expand Down Expand Up @@ -433,7 +432,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
Nil
} else {
Utils.checkInvalidAggregateFunction2(a)
Utils.mustNewAggregation(a)
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.execution

import scala.collection.JavaConverters._

import org.apache.spark.SparkException
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.aggregate
Expand Down Expand Up @@ -579,8 +580,8 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
assert(math.abs(corr3 - 0.95723391394758572) < 1e-12)

val df3 = Seq.tabulate(0)(i => (1.0 * i, 2.0 * i)).toDF("a", "b")
val corr4 = df3.groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
assert(corr4.isNaN)
val corr4 = df3.groupBy().agg(corr("a", "b")).collect()(0)
assert(corr4 == Row(null))

val df4 = Seq.tabulate(10)(i => (1 * i, 2 * i, i * -1)).toDF("a", "b", "c")
val corr5 = df4.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
Expand All @@ -589,11 +590,12 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
assert(math.abs(corr6 + 1.0) < 1e-12)

withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") {
val errorMessage = intercept[AnalysisException] {
val errorMessage = intercept[SparkException] {
val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c")
val corr1 = df.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
}.getMessage
assert(errorMessage.contains("Corr is only implemented based on the new Aggregate Function"))
assert(errorMessage.contains("java.lang.UnsupportedOperationException: " +
"Corr only supports the new AggregateExpression2"))
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will happen if the data type of input parameters are not double?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add ImplicitCastInputTypes to case class Corr. So the other NumericType can be automatically casting to double.


Expand Down