Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
For comments.
  • Loading branch information
viirya committed Oct 29, 2015
commit 5fbcf9115e8e9677ea49e621804e18ae4a7a41df
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ object FunctionRegistry {

// aggregate functions
expression[Average]("avg"),
expression[Corr]("corr"),
expression[Count]("count"),
expression[First]("first"),
expression[First]("first_value"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ case class Corr(

def dataType: DataType = DoubleType

def inputTypes: Seq[AbstractDataType] = Seq(DoubleType)
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)

def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)

Expand All @@ -562,27 +562,45 @@ case class Corr(
AttributeReference("MkY", DoubleType)(),
AttributeReference("count", LongType)())

// Local cache of mutableAggBufferOffset(s) that will be used in update and merge
private[this] val mutableAggBufferOffsetPlus1 = mutableAggBufferOffset + 1
private[this] val mutableAggBufferOffsetPlus2 = mutableAggBufferOffset + 2
private[this] val mutableAggBufferOffsetPlus3 = mutableAggBufferOffset + 3
private[this] val mutableAggBufferOffsetPlus4 = mutableAggBufferOffset + 4
private[this] val mutableAggBufferOffsetPlus5 = mutableAggBufferOffset + 5

// Local cache of inputAggBufferOffset(s) that will be used in update and merge
private[this] val inputAggBufferOffsetPlus1 = inputAggBufferOffset + 1
private[this] val inputAggBufferOffsetPlus2 = inputAggBufferOffset + 2
private[this] val inputAggBufferOffsetPlus3 = inputAggBufferOffset + 3
private[this] val inputAggBufferOffsetPlus4 = inputAggBufferOffset + 4
private[this] val inputAggBufferOffsetPlus5 = inputAggBufferOffset + 5

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)

override def initialize(buffer: MutableRow): Unit = {
(0 until 5).map(idx => buffer.setDouble(mutableAggBufferOffset + idx, 0.0))
buffer.setLong(mutableAggBufferOffset + 5, 0L)
buffer.setDouble(mutableAggBufferOffset, 0.0)
buffer.setDouble(mutableAggBufferOffsetPlus1, 0.0)
buffer.setDouble(mutableAggBufferOffsetPlus2, 0.0)
buffer.setDouble(mutableAggBufferOffsetPlus3, 0.0)
buffer.setDouble(mutableAggBufferOffsetPlus4, 0.0)
buffer.setLong(mutableAggBufferOffsetPlus5, 0L)
}

override def update(buffer: MutableRow, input: InternalRow): Unit = {
val x = left.eval(input).asInstanceOf[Double]
val y = right.eval(input).asInstanceOf[Double]

var xAvg = buffer.getDouble(mutableAggBufferOffset)
var yAvg = buffer.getDouble(mutableAggBufferOffset + 1)
var Ck = buffer.getDouble(mutableAggBufferOffset + 2)
var MkX = buffer.getDouble(mutableAggBufferOffset + 3)
var MkY = buffer.getDouble(mutableAggBufferOffset + 4)
var count = buffer.getLong(mutableAggBufferOffset + 5)
var yAvg = buffer.getDouble(mutableAggBufferOffsetPlus1)
var Ck = buffer.getDouble(mutableAggBufferOffsetPlus2)
var MkX = buffer.getDouble(mutableAggBufferOffsetPlus3)
var MkY = buffer.getDouble(mutableAggBufferOffsetPlus4)
var count = buffer.getLong(mutableAggBufferOffsetPlus5)

val deltaX = x - xAvg
val deltaY = y - yAvg
Expand All @@ -594,31 +612,34 @@ case class Corr(
MkY += deltaY * (y - yAvg)

buffer.setDouble(mutableAggBufferOffset, xAvg)
buffer.setDouble(mutableAggBufferOffset + 1, yAvg)
buffer.setDouble(mutableAggBufferOffset + 2, Ck)
buffer.setDouble(mutableAggBufferOffset + 3, MkX)
buffer.setDouble(mutableAggBufferOffset + 4, MkY)
buffer.setLong(mutableAggBufferOffset + 5, count)
buffer.setDouble(mutableAggBufferOffsetPlus1, yAvg)
buffer.setDouble(mutableAggBufferOffsetPlus2, Ck)
buffer.setDouble(mutableAggBufferOffsetPlus3, MkX)
buffer.setDouble(mutableAggBufferOffsetPlus4, MkY)
buffer.setLong(mutableAggBufferOffsetPlus5, count)
}

// Merge counters from other partitions. Formula can be found at:
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
val count2 = buffer2.getLong(inputAggBufferOffset + 5)
val count2 = buffer2.getLong(inputAggBufferOffsetPlus5)

// We only go to merge two buffers if there is at least one record aggregated in buffer2.
// We don't need to check count in buffer1 because if count2 is more than zero, totalCount
// is more than zero too, then we won't get a divide by zero exception.
if (count2 > 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it safe to assume that the count2 in buffer1 is non zero? There is - currently - no documentation on this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only need to consider count in buffer2. I will add document for it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment. Now it is obvious, I wasn't thinking...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to add comment for it?

var xAvg = buffer1.getDouble(mutableAggBufferOffset)
var yAvg = buffer1.getDouble(mutableAggBufferOffset + 1)
var Ck = buffer1.getDouble(mutableAggBufferOffset + 2)
var MkX = buffer1.getDouble(mutableAggBufferOffset + 3)
var MkY = buffer1.getDouble(mutableAggBufferOffset + 4)
var count = buffer1.getLong(mutableAggBufferOffset + 5)
var yAvg = buffer1.getDouble(mutableAggBufferOffsetPlus1)
var Ck = buffer1.getDouble(mutableAggBufferOffsetPlus2)
var MkX = buffer1.getDouble(mutableAggBufferOffsetPlus3)
var MkY = buffer1.getDouble(mutableAggBufferOffsetPlus4)
var count = buffer1.getLong(mutableAggBufferOffsetPlus5)

val xAvg2 = buffer2.getDouble(inputAggBufferOffset)
val yAvg2 = buffer2.getDouble(inputAggBufferOffset + 1)
val Ck2 = buffer2.getDouble(inputAggBufferOffset + 2)
val MkX2 = buffer2.getDouble(inputAggBufferOffset + 3)
val MkY2 = buffer2.getDouble(inputAggBufferOffset + 4)
val yAvg2 = buffer2.getDouble(inputAggBufferOffsetPlus1)
val Ck2 = buffer2.getDouble(inputAggBufferOffsetPlus2)
val MkX2 = buffer2.getDouble(inputAggBufferOffsetPlus3)
val MkY2 = buffer2.getDouble(inputAggBufferOffsetPlus4)

val totalCount = count + count2
val deltaX = xAvg - xAvg2
Expand All @@ -631,20 +652,20 @@ case class Corr(
count = totalCount

buffer1.setDouble(mutableAggBufferOffset, xAvg)
buffer1.setDouble(mutableAggBufferOffset + 1, yAvg)
buffer1.setDouble(mutableAggBufferOffset + 2, Ck)
buffer1.setDouble(mutableAggBufferOffset + 3, MkX)
buffer1.setDouble(mutableAggBufferOffset + 4, MkY)
buffer1.setLong(mutableAggBufferOffset + 5, count)
buffer1.setDouble(mutableAggBufferOffsetPlus1, yAvg)
buffer1.setDouble(mutableAggBufferOffsetPlus2, Ck)
buffer1.setDouble(mutableAggBufferOffsetPlus3, MkX)
buffer1.setDouble(mutableAggBufferOffsetPlus4, MkY)
buffer1.setLong(mutableAggBufferOffsetPlus5, count)
}
}

override def eval(buffer: InternalRow): Any = {
val count = buffer.getLong(mutableAggBufferOffset + 5)
val count = buffer.getLong(mutableAggBufferOffsetPlus5)
if (count > 0) {
val Ck = buffer.getDouble(mutableAggBufferOffset + 2)
val MkX = buffer.getDouble(mutableAggBufferOffset + 3)
val MkY = buffer.getDouble(mutableAggBufferOffset + 4)
val Ck = buffer.getDouble(mutableAggBufferOffsetPlus2)
val MkX = buffer.getDouble(mutableAggBufferOffsetPlus3)
val MkY = buffer.getDouble(mutableAggBufferOffsetPlus4)
Ck / math.sqrt(MkX * MkY)
} else {
Double.NaN
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,4 +194,30 @@ object Utils {
}
case other => None
}

def mustNewAggregation(aggregate: Aggregate): Unit = {
val onlyForAggregateExpression2 = aggregate.aggregateExpressions.flatMap { expr =>
expr.collect {
// If an aggregate expression only extends AggregateExpression
// without AggregateExpression1, it indicates it only supports AggregateExpression2
case agg: expressions.AggregateExpression
if !agg.isInstanceOf[expressions.AggregateExpression1] =>
agg
}
}
if (onlyForAggregateExpression2.nonEmpty) {
val invalidFunctions = {
if (onlyForAggregateExpression2.length > 1) {
s"${onlyForAggregateExpression2.tail.map(_.nodeName).mkString(",")} " +
s"and ${onlyForAggregateExpression2.head.nodeName} are"
} else {
s"${onlyForAggregateExpression2.head.nodeName} is"
}
}
val errorMessage =
s"${invalidFunctions} only implemented based on the new Aggregate Function " +
s"interface and it cannot be used when spark.sql.useAggregate2 = false."
throw new AnalysisException(errorMessage)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -752,12 +752,12 @@ case class LastFunction(
* Only support AggregateExpression2.
*
*/
case class Corr(
left: Expression,
right: Expression) extends BinaryExpression with AggregateExpression {
case class Corr(left: Expression, right: Expression)
extends BinaryExpression with AggregateExpression with ImplicitCastInputTypes {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a place holder, right? Can we change it to with AggregateExpression1 then we throw an exception (UnsupportedOperatorException) in the newInstance method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But how do we check spark.sql.useAggregate2=false at this expression? Catalyst expressions seems being independent from SQLConf. In newInstance method, we can't refer a conf object.

Sorry. I think I know what you meant.

override def nullable: Boolean = false
override def dataType: DoubleType.type = DoubleType
override def toString: String = s"CORRELATION($left, $right)"
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will be the error message if we call this function when spark.sql.useAggregate2=false? It will be good to provide a meaning error message.


// Compute standard deviation based on online algorithm specified here:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
groupingExpressions,
partialComputation,
child) if !canBeConvertedToNewAggregation(plan) =>
Utils.mustNewAggregation(plan.asInstanceOf[logical.Aggregate])
execution.Aggregate(
partial = false,
namedGroupingAttributes,
Expand Down Expand Up @@ -294,25 +295,24 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}


object BroadcastNestedLoopJoin extends Strategy {
object BroadcastNestedLoop extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Join(left, right, joinType, condition) =>
val buildSide =
if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
joins.BuildRight
} else {
joins.BuildLeft
}
joins.BroadcastNestedLoopJoin(
planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
case logical.Join(
CanBroadcast(left), right, joinType, condition) if joinType != LeftSemi =>
execution.joins.BroadcastNestedLoopJoin(
planLater(left), planLater(right), joins.BuildLeft, joinType, condition) :: Nil
case logical.Join(
left, CanBroadcast(right), joinType, condition) if joinType != LeftSemi =>
execution.joins.BroadcastNestedLoopJoin(
planLater(left), planLater(right), joins.BuildRight, joinType, condition) :: Nil
case _ => Nil
}
}

object CartesianProduct extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Join(left, right, _, None) =>
// TODO CartesianProduct doesn't support the Left Semi Join
case logical.Join(left, right, joinType, None) if joinType != LeftSemi =>
execution.joins.CartesianProduct(planLater(left), planLater(right)) :: Nil
case logical.Join(left, right, Inner, Some(condition)) =>
execution.Filter(condition,
Expand All @@ -321,6 +321,21 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}

object DefaultJoin extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Join(left, right, joinType, condition) =>
val buildSide =
if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
joins.BuildRight
} else {
joins.BuildLeft
}
joins.BroadcastNestedLoopJoin(
planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
case _ => Nil
}
}

protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1)

object TakeOrderedAndProject extends Strategy {
Expand Down Expand Up @@ -379,6 +394,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.AppendColumns(f, tEnc, uEnc, newCol, planLater(child)) :: Nil
case logical.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, child) =>
execution.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, planLater(child)) :: Nil
case logical.CoGroup(f, kEnc, leftEnc, rightEnc, rEnc, output,
leftGroup, rightGroup, left, right) =>
execution.CoGroup(f, kEnc, leftEnc, rightEnc, rEnc, output, leftGroup, rightGroup,
planLater(left), planLater(right)) :: Nil

case logical.Repartition(numPartitions, shuffle, child) =>
if (shuffle) {
Expand Down Expand Up @@ -414,6 +433,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
Nil
} else {
Utils.checkInvalidAggregateFunction2(a)
Utils.mustNewAggregation(a)
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,20 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
val df3 = Seq.tabulate(0)(i => (1.0 * i, 2.0 * i)).toDF("a", "b")
val corr4 = df3.groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
assert(corr4.isNaN)

val df4 = Seq.tabulate(10)(i => (1 * i, 2 * i, i * -1)).toDF("a", "b", "c")
val corr5 = df4.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
assert(math.abs(corr5 - 1.0) < 1e-12)
val corr6 = df4.groupBy().agg(corr("a", "c")).collect()(0).getDouble(0)
assert(math.abs(corr6 + 1.0) < 1e-12)

withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") {
val errorMessage = intercept[AnalysisException] {
val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c")
val corr1 = df.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
}.getMessage
assert(errorMessage.contains("Corr is only implemented based on the new Aggregate Function"))
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will happen if the data type of input parameters are not double?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add ImplicitCastInputTypes to case class Corr. So the other NumericType can be automatically casting to double.


test("test Last implemented based on AggregateExpression1") {
Expand Down