Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor logics in EnsureRequirements
  • Loading branch information
maropu committed Aug 19, 2016
commit 5de4871d5670b156c1cfab54779257624eb4243c
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ object AggUtils {
}

private[execution] def addMapSideAggregate(operator: SparkPlan)
: (SparkPlan, Seq[SparkPlan]) = operator match {
: (SparkPlan, SparkPlan) = operator match {
case agg @ HashAggregateExec(
requiredChildDistributionExpressions,
groupingExpressions,
Expand All @@ -69,14 +69,14 @@ object AggUtils {
initialInputBufferOffset,
resultExpressions,
child) =>
val newChild = createPartialAggregate(groupingExpressions, aggregateExpressions, child)
val parent = agg.copy(
val mapSideAgg = createPartialAggregate(groupingExpressions, aggregateExpressions, child)
val mergeAgg = agg.copy(
groupingExpressions = groupingExpressions.map(_.toAttribute),
aggregateExpressions =
aggregateExpressions.map(e => e.copy(mode = updateAggregateMode(e.mode))),
initialInputBufferOffset = groupingExpressions.length
)
(parent, newChild :: Nil)
initialInputBufferOffset = groupingExpressions.length)

(mergeAgg, mapSideAgg)

case agg @ SortAggregateExec(
requiredChildDistributionExpressions,
Expand All @@ -86,14 +86,14 @@ object AggUtils {
initialInputBufferOffset,
resultExpressions,
child) =>
val newChild = createPartialAggregate(groupingExpressions, aggregateExpressions, child)
val parent = agg.copy(
val mapSideAgg = createPartialAggregate(groupingExpressions, aggregateExpressions, child)
val mergeAgg = agg.copy(
groupingExpressions = groupingExpressions.map(_.toAttribute),
aggregateExpressions =
aggregateExpressions.map(e => e.copy(mode = updateAggregateMode(e.mode))),
initialInputBufferOffset = groupingExpressions.length
)
(parent, newChild :: Nil)
initialInputBufferOffset = groupingExpressions.length)

(mergeAgg, mapSideAgg)
}

private def createAggregate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,25 +159,28 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
// Ensure that the operator's children satisfy their output distribution requirements:
val childrenWithDist = operator.children.zip(requiredChildDistributions)

// If necessary, add map-side aggregates
var (parent, children) = if (AggUtils.isAggregateExec(operator)) {
// If an aggregation need a shuffle to satisfy its distribution, a map-side partial an
// aggregation and a shuffle are added as children.
val (child, distribution) = childrenWithDist.head
if (!child.outputPartitioning.satisfies(distribution)) {
AggUtils.addMapSideAggregate(operator)
val (mergeAgg, mapSideAgg) = AggUtils.addMapSideAggregate(operator)
val newChild = ShuffleExchange(
createPartitioning(distribution, defaultNumPreShufflePartitions), mapSideAgg)
(mergeAgg, newChild :: Nil)
} else {
(operator, child :: Nil)
}
} else {
(operator, operator.children)
}

children = children.zip(requiredChildDistributions).map {
case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
child
case (child, BroadcastDistribution(mode)) =>
BroadcastExchangeExec(mode, child)
case (child, distribution) =>
ShuffleExchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child)
val newChildren = childrenWithDist.map {
case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
child
case (child, BroadcastDistribution(mode)) =>
BroadcastExchangeExec(mode, child)
case (child, distribution) =>
ShuffleExchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child)
}
(operator, newChildren)
}

// If the operator has multiple children and specifies child output distributions (e.g. join),
Expand Down