Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,18 @@ trait CodegenSupport extends SparkPlan {
evaluateVars.toString()
}

/**
* Returns source code to evaluate the variables for non-deterministic expressions, and clear the
* code of evaluated variables, to prevent them to be evaluated twice.
*/
protected def evaluateNondeterministicVariables(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick on naming: "variables" are never non-deterministic, only expressions can have the property of being deterministic or not. Two options:

  • I'd prefer naming this utility function evaluateNondeterministicResults to emphasis this should (mostly) be used on the results of an output projection list.
  • But the existing utility function evaluateRequiredVariables uses the "variable" notion, so keeping consistency there is fine too.

I'm fine either way.

Also, historically Spark SQL's WSCG would use variable names like eval for the ExprCode type, e.g. evals: Seq[ExprCode]. Not sure why it started that way but you can see that naming pattern throughout the WSCG code base.
Again, your new utility function follows the same names used in evaluateRequiredVariables so that's fine. Local consistency is good enough.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To keep the consistent naming, +1 for evaluateNondeterministicVariables .

attributes: Seq[Attribute],
variables: Seq[ExprCode],
expressions: Seq[NamedExpression]): String = {
val nondeterministicAttrs = expressions.filterNot(_.deterministic).map(_.toAttribute)
evaluateRequiredVariables(attributes, variables, AttributeSet(nondeterministicAttrs))
}

/**
* The subset of inputSet those should be evaluated before this plan.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -466,10 +466,13 @@ case class HashAggregateExec(
val resultVars = bindReferences[Expression](
resultExpressions,
inputAttrs).map(_.genCode(ctx))
val evaluateNondeterministicResults =
evaluateNondeterministicVariables(output, resultVars, resultExpressions)
s"""
$evaluateKeyVars
$evaluateBufferVars
$evaluateAggResults
$evaluateNondeterministicResults
${consume(ctx, resultVars)}
"""
} else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
Expand Down Expand Up @@ -506,10 +509,15 @@ case class HashAggregateExec(
// generate result based on grouping key
ctx.INPUT_ROW = keyTerm
ctx.currentVars = null
val eval = bindReferences[Expression](
val resultVars = bindReferences[Expression](
resultExpressions,
groupingAttributes).map(_.genCode(ctx))
consume(ctx, eval)
val evaluateNondeterministicResults =
evaluateNondeterministicVariables(output, resultVars, resultExpressions)
s"""
$evaluateNondeterministicResults
${consume(ctx, resultVars)}
"""
}
ctx.addNewFunction(funcName,
s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.expressions.scalalang.typed
import org.apache.spark.sql.functions.{avg, broadcast, col, lit, max}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
Expand Down Expand Up @@ -339,4 +339,32 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {

checkAnswer(df, Seq(Row(1, 3), Row(2, 3)))
}

test("SPARK-26572: evaluate non-deterministic expressions for aggregate results") {
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Long.MaxValue.toString,
SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
val baseTable = Seq(1, 1).toDF("idx")

// BroadcastHashJoinExec with a HashAggregateExec child containing no aggregate expressions
val distinctWithId = baseTable.distinct().withColumn("id", monotonically_increasing_id())
.join(baseTable, "idx")
assert(distinctWithId.queryExecution.executedPlan.collectFirst {
case WholeStageCodegenExec(
ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _))) => true
}.isDefined)
checkAnswer(distinctWithId, Seq(Row(1, 0), Row(1, 0)))

// BroadcastHashJoinExec with a HashAggregateExec child containing a Final mode aggregate
// expression
val groupByWithId =
baseTable.groupBy("idx").sum().withColumn("id", monotonically_increasing_id())
.join(baseTable, "idx")
assert(groupByWithId.queryExecution.executedPlan.collectFirst {
case WholeStageCodegenExec(
ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _))) => true
}.isDefined)
checkAnswer(groupByWithId, Seq(Row(1, 2, 0), Row(1, 2, 0)))
}
}
}