Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,18 @@ trait CodegenSupport extends SparkPlan {
evaluateVars.toString()
}

/**
* Returns source code to evaluate the variables for non-deterministic expressions, and clear the
* code of evaluated variables, to prevent them to be evaluated twice.
*/
protected def evaluateNondeterministicVariables(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick on naming: "variables" are never non-deterministic, only expressions can have the property of being deterministic or not. Two options:

  • I'd prefer naming this utility function evaluateNondeterministicResults to emphasis this should (mostly) be used on the results of an output projection list.
  • But the existing utility function evaluateRequiredVariables uses the "variable" notion, so keeping consistency there is fine too.

I'm fine either way.

Also, historically Spark SQL's WSCG would use variable names like eval for the ExprCode type, e.g. evals: Seq[ExprCode]. Not sure why it started that way but you can see that naming pattern throughout the WSCG code base.
Again, your new utility function follows the same names used in evaluateRequiredVariables so that's fine. Local consistency is good enough.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To keep the consistent naming, +1 for evaluateNondeterministicVariables .

attributes: Seq[Attribute],
variables: Seq[ExprCode],
expressions: Seq[NamedExpression]): String = {
val nondeterministicAttrs = expressions.filterNot(_.deterministic).map(_.toAttribute)
evaluateRequiredVariables(attributes, variables, AttributeSet(nondeterministicAttrs))
}

/**
* The subset of inputSet those should be evaluated before this plan.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -466,10 +466,13 @@ case class HashAggregateExec(
val resultVars = bindReferences[Expression](
resultExpressions,
inputAttrs).map(_.genCode(ctx))
val evaluateNondeterministicAggResults =
evaluateNondeterministicVariables(output, resultVars, resultExpressions)
s"""
$evaluateKeyVars
$evaluateBufferVars
$evaluateAggResults
$evaluateNondeterministicAggResults
${consume(ctx, resultVars)}
"""
} else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
Expand Down Expand Up @@ -506,10 +509,15 @@ case class HashAggregateExec(
// generate result based on grouping key
ctx.INPUT_ROW = keyTerm
ctx.currentVars = null
val eval = bindReferences[Expression](
val resultVars = bindReferences[Expression](
resultExpressions,
groupingAttributes).map(_.genCode(ctx))
consume(ctx, eval)
val evaluateNondeterministicAggResults =
evaluateNondeterministicVariables(output, resultVars, resultExpressions)
s"""
$evaluateNondeterministicAggResults
${consume(ctx, resultVars)}
"""
}
ctx.addNewFunction(funcName,
s"""
Expand Down
26 changes: 26 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@ import org.apache.spark.SparkException
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.Uuid
import org.apache.spark.sql.catalyst.expressions.aggregate.Final
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union}
import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext}
Expand Down Expand Up @@ -2110,4 +2112,28 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
checkAnswer(res, Row("1-1", 6, 6))
}
}

test("SPARK-26572: fix aggregate codegen result evaluation") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a problem with whole stage codegen, waht about moving this test to WholeStageCodegenSuite? And adding an assert that whole stage codegen is actually used, ie. the HashAggregate is a child of WholeStageCodegenExec?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with moving it to WholeStageCodegenSuite but the plan looks like:

*(3) Project [idx#4, id#6L]
+- *(3) BroadcastHashJoin [idx#4], [idx#9], Inner, BuildRight
   :- *(3) HashAggregate(keys=[idx#4], functions=[], output=[idx#4, id#6L])
   :  +- Exchange hashpartitioning(idx#4, 1)
   :     +- *(1) HashAggregate(keys=[idx#4], functions=[], output=[idx#4])
   :        +- *(1) Project [value#1 AS idx#4]
   :           +- LocalTableScan [value#1]
   +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)))
      +- *(2) Project [value#1 AS idx#9]
         +- LocalTableScan [value#1]

so I guess you mean checking WholeStageCodegenExec has a ProjectExec child that has a BroadcastHashJoinExec child?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved and added WholeStageCodegenExec check.

val baseTable = Seq((1), (1)).toDF("idx")

// BroadcastHashJoinExec with a HashAggregateExec child containing no aggregate expressions
val distinctWithId = baseTable.distinct().withColumn("id", monotonically_increasing_id())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how stable the results are going to be if you use monotonically_increasing_id here with an unspecified number of shuffle partitions. Since you're checking the exact value of the resulting id, if the number of shuffle partitions changes (let's say if someone decides to change the default shuffle partitions setting in all tests), this test can become fragile and fail unnecessarily.

It might be worth setting the shuffle partition to 1 explicitly inside this test case. Or go back to grouping by id instead of checking the exact value of id, or just assert the ids are equal.

Copy link
Member

@maropu maropu Feb 13, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, how about wrapping with withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Long.MaxValue.toString) for safeguard.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Fixed both.

.join(baseTable, "idx")
assert(distinctWithId.queryExecution.executedPlan.collectFirst {
case BroadcastHashJoinExec(_, _, _, _, _, HashAggregateExec(_, _, Seq(), _, _, _, _), _) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about this?

    assert(distinctWithId.queryExecution.executedPlan.collectFirst {
      case j: BroadcastHashJoinExec if j.left.asInstanceOf[HashAggregateExec] => true
    }.isDefined)

We need to strictly check agregate exprs? It seems baseTable.distinct() obviously has no aggregate expr?

Copy link
Contributor Author

@peter-toth peter-toth Feb 13, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer avoiding isInstanceOf if possible, but changed it a bit.

true
}.isDefined)
checkAnswer(distinctWithId, Seq(Row(1, 25769803776L), Row(1, 25769803776L)))

// BroadcastHashJoinExec with a HashAggregateExec child containing a Final mode aggregate
// expression
val groupByWithId =
baseTable.groupBy("idx").sum().withColumn("id", monotonically_increasing_id())
.join(baseTable, "idx")
assert(groupByWithId.queryExecution.executedPlan.collectFirst {
case BroadcastHashJoinExec(_, _, _, _, _, HashAggregateExec(_, _, ae, _, _, _, _), _)
if ae.exists(_.mode == Final) => true
}.isDefined)
checkAnswer(groupByWithId, Seq(Row(1, 2, 25769803776L), Row(1, 2, 25769803776L)))
}
}