Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix review findings
  • Loading branch information
peter-toth committed Feb 13, 2019
commit 5ae9add508a9341c1ca781ffd54598d560d16ed1
26 changes: 0 additions & 26 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,11 @@ import org.apache.spark.SparkException
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.Uuid
import org.apache.spark.sql.catalyst.expressions.aggregate.Final
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union}
import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext}
Expand Down Expand Up @@ -2112,28 +2110,4 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
checkAnswer(res, Row("1-1", 6, 6))
}
}

test("SPARK-26572: fix aggregate codegen result evaluation") {
val baseTable = Seq((1), (1)).toDF("idx")

// BroadcastHashJoinExec with a HashAggregateExec child containing no aggregate expressions
val distinctWithId = baseTable.distinct().withColumn("id", monotonically_increasing_id())
.join(baseTable, "idx")
assert(distinctWithId.queryExecution.executedPlan.collectFirst {
case BroadcastHashJoinExec(_, _, _, _, _, HashAggregateExec(_, _, Seq(), _, _, _, _), _) =>
true
}.isDefined)
checkAnswer(distinctWithId, Seq(Row(1, 25769803776L), Row(1, 25769803776L)))

// BroadcastHashJoinExec with a HashAggregateExec child containing a Final mode aggregate
// expression
val groupByWithId =
baseTable.groupBy("idx").sum().withColumn("id", monotonically_increasing_id())
.join(baseTable, "idx")
assert(groupByWithId.queryExecution.executedPlan.collectFirst {
case BroadcastHashJoinExec(_, _, _, _, _, HashAggregateExec(_, _, ae, _, _, _, _), _)
if ae.exists(_.mode == Final) => true
}.isDefined)
checkAnswer(groupByWithId, Seq(Row(1, 2, 25769803776L), Row(1, 2, 25769803776L)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.expressions.scalalang.typed
import org.apache.spark.sql.functions.{avg, broadcast, col, lit, max}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
Expand Down Expand Up @@ -339,4 +339,32 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {

checkAnswer(df, Seq(Row(1, 3), Row(2, 3)))
}

test("SPARK-26572: evaluate non-deterministic expressions for aggregate results") {
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Long.MaxValue.toString,
SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
val baseTable = Seq(1, 1).toDF("idx")

// BroadcastHashJoinExec with a HashAggregateExec child containing no aggregate expressions
val distinctWithId = baseTable.distinct().withColumn("id", monotonically_increasing_id())
.join(baseTable, "idx")
assert(distinctWithId.queryExecution.executedPlan.collectFirst {
case WholeStageCodegenExec(
ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _))) => true
}.isDefined)
checkAnswer(distinctWithId, Seq(Row(1, 0), Row(1, 0)))

// BroadcastHashJoinExec with a HashAggregateExec child containing a Final mode aggregate
// expression
val groupByWithId =
baseTable.groupBy("idx").sum().withColumn("id", monotonically_increasing_id())
.join(baseTable, "idx")
assert(groupByWithId.queryExecution.executedPlan.collectFirst {
case WholeStageCodegenExec(
ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _))) => true
}.isDefined)
checkAnswer(groupByWithId, Seq(Row(1, 2, 0), Row(1, 2, 0)))
}
}
}