Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add a superclass for *AggregateExec
  • Loading branch information
maropu committed Aug 22, 2016
commit e37ef6afd47e9dd325a7f9e6d0826a3cb66c8e2e
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.aggregate

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.execution.aggregate.{Aggregate => AggregateExec}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec}

Expand All @@ -27,20 +28,11 @@ import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateSto
*/
object AggUtils {

private[execution] def isAggregate(operator: SparkPlan): Boolean = {
operator.isInstanceOf[HashAggregateExec] || operator.isInstanceOf[SortAggregateExec]
}

private[execution] def supportPartialAggregate(operator: SparkPlan): Boolean = {
assert(isAggregate(operator))
def supportPartial(exprs: Seq[AggregateExpression]) =
exprs.map(_.aggregateFunction).forall(_.supportsPartial)
operator match {
case agg @ HashAggregateExec(_, _, aggregateExpressions, _, _, _, _) =>
supportPartial(aggregateExpressions)
case agg @ SortAggregateExec(_, _, aggregateExpressions, _, _, _, _) =>
supportPartial(aggregateExpressions)
}
private[execution] def supportPartialAggregate(operator: SparkPlan): Boolean = operator match {
case agg: AggregateExec =>
agg.aggregateExpressions.map(_.aggregateFunction).forall(_.supportsPartial)
case _ =>
false
}

private def createPartialAggregateExec(
Expand Down Expand Up @@ -86,23 +78,18 @@ object AggUtils {

private[execution] def createPartialAggregate(operator: SparkPlan)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of duplication here. It would be nice if we have an parent for the *AggregateExec nodes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about this change?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much better

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could make this public instead of private[execution]? We just opened up a lot of similar APIs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small: The name of the function is also quite misleading. It returns a map side and merge aggregate pair, so createMapMergeAggregatePair? Please also add a little bit of documentation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

: (SparkPlan, SparkPlan) = operator match {
case agg @ HashAggregateExec(_, groupingExpressions, aggregateExpressions, _, _, _, child) =>
val mapSideAgg = createPartialAggregateExec(
groupingExpressions, aggregateExpressions, child)
val mergeAgg = agg.copy(
groupingExpressions = groupingExpressions.map(_.toAttribute),
aggregateExpressions = updateMergeAggregateMode(aggregateExpressions),
initialInputBufferOffset = groupingExpressions.length)

(mergeAgg, mapSideAgg)

case agg @ SortAggregateExec(_, groupingExpressions, aggregateExpressions, _, _, _, child) =>
case agg: Aggregate =>
val mapSideAgg = createPartialAggregateExec(
groupingExpressions, aggregateExpressions, child)
val mergeAgg = agg.copy(
groupingExpressions = groupingExpressions.map(_.toAttribute),
aggregateExpressions = updateMergeAggregateMode(aggregateExpressions),
initialInputBufferOffset = groupingExpressions.length)
agg.groupingExpressions, agg.aggregateExpressions, agg.child)
val mergeAgg = createAggregateExec(
requiredChildDistributionExpressions = agg.requiredChildDistributionExpressions,
groupingExpressions = agg.groupingExpressions.map(_.toAttribute),
aggregateExpressions = updateMergeAggregateMode(agg.aggregateExpressions),
aggregateAttributes = agg.aggregateAttributes,
initialInputBufferOffset = agg.groupingExpressions.length,
resultExpressions = agg.resultExpressions,
child = agg.child
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mapSideAgg?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, the final plan is [MergeAgg]<-[Shuffle]<-[MapSideAgg]. So, this function just returns the two aggregations separately, and the plan is built in EnsureRequirements. Is this a bad idea?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It violates the principle of least surprise. The mergeAgg is not usable without the mapSideAgg. This is fine for usage in EnsureRequirements because it gets straightened out anyway, but can be very surprising if someone uses it in a different way.

Copy link
Member Author

@maropu maropu Aug 25, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed. Is this fix okay?

)

(mergeAgg, mapSideAgg)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.aggregate

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.SparkPlan

/**
* A base class for aggregate implementation.
*/
trait Aggregate {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why a trait and not a superclass?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just used trait along with the HashJoin trait. A super class is better?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well I think a super class makes a bit more sense. A trait to me is a way to bolt on functionality. The Aggregate contains core functionality for both the Hash and Sort based version, and is the natural parent class of both.

I do have to admit that this is more a personal preference.

self: SparkPlan =>

val requiredChildDistributionExpressions: Option[Seq[Expression]]
val groupingExpressions: Seq[NamedExpression]
val aggregateExpressions: Seq[AggregateExpression]
val aggregateAttributes: Seq[Attribute]
val initialInputBufferOffset: Int
val resultExpressions: Seq[NamedExpression]
val child: SparkPlan

protected[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
}

override def producedAttributes: AttributeSet =
AttributeSet(aggregateAttributes) ++
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
AttributeSet(aggregateBufferAttributes)

override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)


override def requiredChildDistribution: List[Distribution] = {
requiredChildDistributionExpressions match {
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
case None => UnspecifiedDistribution :: Nil
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
Expand All @@ -42,11 +41,7 @@ case class HashAggregateExec(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryExecNode with CodegenSupport {

private[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
}
extends UnaryExecNode with Aggregate with CodegenSupport {

require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes))

Expand All @@ -60,21 +55,6 @@ case class HashAggregateExec(
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"),
"aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time"))

override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)

override def producedAttributes: AttributeSet =
AttributeSet(aggregateAttributes) ++
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
AttributeSet(aggregateBufferAttributes)

override def requiredChildDistribution: List[Distribution] = {
requiredChildDistributionExpressions match {
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
case None => UnspecifiedDistribution :: Nil
}
}

// This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash
// map and/or the sort-based aggregation once it has processed a given number of input rows.
private val testFallbackStartsAt: Option[(Int, Int)] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.util.Utils
Expand All @@ -38,30 +37,11 @@ case class SortAggregateExec(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryExecNode {

private[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
}

override def producedAttributes: AttributeSet =
AttributeSet(aggregateAttributes) ++
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
AttributeSet(aggregateBufferAttributes)
extends UnaryExecNode with Aggregate {

override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))

override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)

override def requiredChildDistribution: List[Distribution] = {
requiredChildDistributionExpressions match {
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
case None => UnspecifiedDistribution :: Nil
}
}

override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.AggUtils
import org.apache.spark.sql.execution.aggregate.{Aggregate, AggUtils}
import org.apache.spark.sql.internal.SQLConf

/**
Expand Down Expand Up @@ -155,36 +155,28 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
assert(requiredChildDistributions.length == operator.children.length)
assert(requiredChildOrderings.length == operator.children.length)

// Ensure that the operator's children satisfy their output distribution requirements:
val childrenWithDist = operator.children.zip(requiredChildDistributions)

def createShuffleExchange(dist: Distribution, child: SparkPlan) =
ShuffleExchange(createPartitioning(dist, defaultNumPreShufflePartitions), child)

var (parent, children) = if (!AggUtils.isAggregate(operator)) {
val newChildren = childrenWithDist.map {
case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
child
case (child, BroadcastDistribution(mode)) =>
BroadcastExchangeExec(mode, child)
case (child, distribution) =>
createShuffleExchange(distribution, child)
}
(operator, newChildren)
} else {
val (child, distribution) = childrenWithDist.head
if (!child.outputPartitioning.satisfies(distribution)) {
if (AggUtils.supportPartialAggregate(operator)) {
// If an aggregation needs a shuffle and support partial aggregations, a map-side partial
// aggregation and a shuffle are added as children.
val (mergeAgg, mapSideAgg) = AggUtils.createPartialAggregate(operator)
(mergeAgg, createShuffleExchange(distribution, mapSideAgg) :: Nil)
} else {
(operator, createShuffleExchange(distribution, child) :: Nil)
var (parent, children) = operator match {
case agg if AggUtils.supportPartialAggregate(agg) &&
!operator.outputPartitioning.satisfies(requiredChildDistributions.head) =>
// If an aggregation needs a shuffle and support partial aggregations, a map-side partial
// aggregation and a shuffle are added as children.
val (mergeAgg, mapSideAgg) = AggUtils.createPartialAggregate(operator)
(mergeAgg, createShuffleExchange(requiredChildDistributions.head, mapSideAgg) :: Nil)
case _ =>
// Ensure that the operator's children satisfy their output distribution requirements:
val childrenWithDist = operator.children.zip(requiredChildDistributions)
val newChildren = childrenWithDist.map {
case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
child
case (child, BroadcastDistribution(mode)) =>
BroadcastExchangeExec(mode, child)
case (child, distribution) =>
createShuffleExchange(distribution, child)
}
} else {
(operator, child :: Nil)
}
(operator, newChildren)
}

// If the operator has multiple children and specifies child output distributions (e.g. join),
Expand Down