diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index d3a93f5eb395..443ce8cc46a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -290,6 +290,18 @@ trait CodegenSupport extends SparkPlan { evaluateVars.toString() } + /** + * Returns source code to evaluate the variables for non-deterministic expressions, and clear the + * code of evaluated variables, to prevent them to be evaluated twice. + */ + protected def evaluateNondeterministicVariables( + attributes: Seq[Attribute], + variables: Seq[ExprCode], + expressions: Seq[NamedExpression]): String = { + val nondeterministicAttrs = expressions.filterNot(_.deterministic).map(_.toAttribute) + evaluateRequiredVariables(attributes, variables, AttributeSet(nondeterministicAttrs)) + } + /** * The subset of inputSet those should be evaluated before this plan. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 19a47ffc6dd0..17cc7fde42bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -466,10 +466,13 @@ case class HashAggregateExec( val resultVars = bindReferences[Expression]( resultExpressions, inputAttrs).map(_.genCode(ctx)) + val evaluateNondeterministicResults = + evaluateNondeterministicVariables(output, resultVars, resultExpressions) s""" $evaluateKeyVars $evaluateBufferVars $evaluateAggResults + $evaluateNondeterministicResults ${consume(ctx, resultVars)} """ } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { @@ -506,10 +509,15 @@ case class HashAggregateExec( // generate result based on grouping key ctx.INPUT_ROW = keyTerm ctx.currentVars = null - val eval = bindReferences[Expression]( + val resultVars = bindReferences[Expression]( resultExpressions, groupingAttributes).map(_.genCode(ctx)) - consume(ctx, eval) + val evaluateNondeterministicResults = + evaluateNondeterministicVariables(output, resultVars, resultExpressions) + s""" + $evaluateNondeterministicResults + ${consume(ctx, resultVars)} + """ } ctx.addNewFunction(funcName, s""" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index e03f08417162..3c9a0908147a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.expressions.scalalang.typed -import org.apache.spark.sql.functions.{avg, broadcast, col, lit, max} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} @@ -339,4 +339,32 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Seq(Row(1, 3), Row(2, 3))) } + + test("SPARK-26572: evaluate non-deterministic expressions for aggregate results") { + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Long.MaxValue.toString, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val baseTable = Seq(1, 1).toDF("idx") + + // BroadcastHashJoinExec with a HashAggregateExec child containing no aggregate expressions + val distinctWithId = baseTable.distinct().withColumn("id", monotonically_increasing_id()) + .join(baseTable, "idx") + assert(distinctWithId.queryExecution.executedPlan.collectFirst { + case WholeStageCodegenExec( + ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _))) => true + }.isDefined) + checkAnswer(distinctWithId, Seq(Row(1, 0), Row(1, 0))) + + // BroadcastHashJoinExec with a HashAggregateExec child containing a Final mode aggregate + // expression + val groupByWithId = + baseTable.groupBy("idx").sum().withColumn("id", monotonically_increasing_id()) + .join(baseTable, "idx") + assert(groupByWithId.queryExecution.executedPlan.collectFirst { + case WholeStageCodegenExec( + ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _))) => true + }.isDefined) + checkAnswer(groupByWithId, Seq(Row(1, 2, 0), Row(1, 2, 0))) + } + } }