Skip to content

Commit 2228ee5

Browse files
peter-tothcloud-fan
authored andcommitted
[SPARK-26572][SQL] fix aggregate codegen result evaluation
## What changes were proposed in this pull request? This PR is a correctness fix in `HashAggregateExec` code generation. It forces evaluation of result expressions before calling `consume()` to avoid multiple executions. This PR fixes a use case where an aggregate is nested into a broadcast join and appears on the "stream" side. The issue is that Broadcast join generates it's own loop. And without forcing evaluation of `resultExpressions` of `HashAggregateExec` before the join's loop these expressions can be executed multiple times giving incorrect results. ## How was this patch tested? New UT was added. Closes apache#23731 from peter-toth/SPARK-26572. Authored-by: Peter Toth <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent ac9c053 commit 2228ee5

File tree

3 files changed

+51
-3
lines changed

3 files changed

+51
-3
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,18 @@ trait CodegenSupport extends SparkPlan {
290290
evaluateVars.toString()
291291
}
292292

293+
/**
294+
* Returns source code to evaluate the variables for non-deterministic expressions, and clear the
295+
* code of evaluated variables, to prevent them to be evaluated twice.
296+
*/
297+
protected def evaluateNondeterministicVariables(
298+
attributes: Seq[Attribute],
299+
variables: Seq[ExprCode],
300+
expressions: Seq[NamedExpression]): String = {
301+
val nondeterministicAttrs = expressions.filterNot(_.deterministic).map(_.toAttribute)
302+
evaluateRequiredVariables(attributes, variables, AttributeSet(nondeterministicAttrs))
303+
}
304+
293305
/**
294306
* The subset of inputSet those should be evaluated before this plan.
295307
*

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -466,10 +466,13 @@ case class HashAggregateExec(
466466
val resultVars = bindReferences[Expression](
467467
resultExpressions,
468468
inputAttrs).map(_.genCode(ctx))
469+
val evaluateNondeterministicResults =
470+
evaluateNondeterministicVariables(output, resultVars, resultExpressions)
469471
s"""
470472
$evaluateKeyVars
471473
$evaluateBufferVars
472474
$evaluateAggResults
475+
$evaluateNondeterministicResults
473476
${consume(ctx, resultVars)}
474477
"""
475478
} else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
@@ -506,10 +509,15 @@ case class HashAggregateExec(
506509
// generate result based on grouping key
507510
ctx.INPUT_ROW = keyTerm
508511
ctx.currentVars = null
509-
val eval = bindReferences[Expression](
512+
val resultVars = bindReferences[Expression](
510513
resultExpressions,
511514
groupingAttributes).map(_.genCode(ctx))
512-
consume(ctx, eval)
515+
val evaluateNondeterministicResults =
516+
evaluateNondeterministicVariables(output, resultVars, resultExpressions)
517+
s"""
518+
$evaluateNondeterministicResults
519+
${consume(ctx, resultVars)}
520+
"""
513521
}
514522
ctx.addNewFunction(funcName,
515523
s"""

sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
2525
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
2626
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
2727
import org.apache.spark.sql.expressions.scalalang.typed
28-
import org.apache.spark.sql.functions.{avg, broadcast, col, lit, max}
28+
import org.apache.spark.sql.functions._
2929
import org.apache.spark.sql.internal.SQLConf
3030
import org.apache.spark.sql.test.SharedSQLContext
3131
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
@@ -339,4 +339,32 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
339339

340340
checkAnswer(df, Seq(Row(1, 3), Row(2, 3)))
341341
}
342+
343+
test("SPARK-26572: evaluate non-deterministic expressions for aggregate results") {
344+
withSQLConf(
345+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Long.MaxValue.toString,
346+
SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
347+
val baseTable = Seq(1, 1).toDF("idx")
348+
349+
// BroadcastHashJoinExec with a HashAggregateExec child containing no aggregate expressions
350+
val distinctWithId = baseTable.distinct().withColumn("id", monotonically_increasing_id())
351+
.join(baseTable, "idx")
352+
assert(distinctWithId.queryExecution.executedPlan.collectFirst {
353+
case WholeStageCodegenExec(
354+
ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _))) => true
355+
}.isDefined)
356+
checkAnswer(distinctWithId, Seq(Row(1, 0), Row(1, 0)))
357+
358+
// BroadcastHashJoinExec with a HashAggregateExec child containing a Final mode aggregate
359+
// expression
360+
val groupByWithId =
361+
baseTable.groupBy("idx").sum().withColumn("id", monotonically_increasing_id())
362+
.join(baseTable, "idx")
363+
assert(groupByWithId.queryExecution.executedPlan.collectFirst {
364+
case WholeStageCodegenExec(
365+
ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _))) => true
366+
}.isDefined)
367+
checkAnswer(groupByWithId, Seq(Row(1, 2, 0), Row(1, 2, 0)))
368+
}
369+
}
342370
}

0 commit comments

Comments
 (0)