Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add tests
  • Loading branch information
maropu committed Aug 22, 2016
commit 74a14d7cb1db8fc0a5dcd18453f7151ba8edb22d
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,3 @@ case class SortAggregateExec(
}
}
}

object SortAggregateExec

Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
if (!child.outputPartitioning.satisfies(distribution)) {
if (AggUtils.supportPartialAggregate(operator)) {
// If an aggregation needs a shuffle and support partial aggregations, a map-side partial
// an aggregation and a shuffle are added as children.
// aggregation and a shuffle are added as children.
val (mergeAgg, mapSideAgg) = AggUtils.createPartialAggregate(operator)
(mergeAgg, createShuffleExchange(distribution, mapSideAgg) :: Nil)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,24 +70,31 @@ class PlannerSuite extends SharedSQLContext {
s"The plan of query $query does not have partial aggregations.")
}

test("non-partial aggregation for distinct aggregates") {
test("non-partial aggregation for aggregates") {
withTempView("testNonPartialAggregation") {
val schema = StructType(StructField(s"value", IntegerType, true) :: Nil)
val row = Row.fromSeq(Seq.fill(1)(null))
val rowRDD = sparkContext.parallelize(row :: Nil)
spark.createDataFrame(rowRDD, schema).createOrReplaceTempView("testNonPartialAggregation")
spark.createDataFrame(rowRDD, schema).repartition($"value")
.createOrReplaceTempView("testNonPartialAggregation")

val planned = sql(
val planned1 = sql("SELECT SUM(value) FROM testNonPartialAggregation GROUP BY value")
.queryExecution.executedPlan

// If input data are already partitioned and the same columns are used in grouping keys and
// aggregation values, no partial aggregation exist in query plans.
val aggOps1 = planned1.collect { case n if n.nodeName contains "Aggregate" => n }
assert(aggOps1.size == 1, s"The plan $planned1 has partial aggregations.")

val planned2 = sql(
"""
|SELECT t.value, SUM(DISTINCT t.value)
|FROM (SELECT * FROM testNonPartialAggregation ORDER BY value) t
|GROUP BY t.value
""".stripMargin).queryExecution.executedPlan

// If input data are already partitioned and the same columns are used in grouping keys and
// aggregation values, no partial aggregation exist in query plans.
val aggOps = planned.collect { case n if n.nodeName contains "Aggregate" => n }
assert(aggOps.size == 2, s"The plan $planned has partial aggregations.")
val aggOps2 = planned1.collect { case n if n.nodeName contains "Aggregate" => n }
assert(aggOps2.size == 1, s"The plan $planned2 has partial aggregations.")
}
}

Expand Down