Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
fix issue at merge aggregation phase
  • Loading branch information
gengliangwang committed Aug 11, 2020
commit a8be9e1a093c1d221e07a6156a1a408d477621f2
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,15 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
}

override lazy val mergeExpressions: Seq[Expression] = {
val sumWithRight = resultType match {
case d: DecimalType =>
CheckOverflow(coalesce(sum.left, zero) + sum.right, d, nullOnOverflow = false)

case _ => coalesce(sum.left, zero) + sum.right
}
Seq(
/* sum = */
coalesce(coalesce(sum.left, zero) + sum.right, sum.left)
coalesce(sumWithRight, sum.left)
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1046,27 +1046,40 @@ class DataFrameAggregateSuite extends QueryTest
checkAnswer(sql(queryTemplate("LAST")), Row(3))
}

test("SPARK-32018: Throw exception on decimal overflow") {
private def exceptionOnDecimalOverflow(df: DataFrame): Unit = {
val msg = intercept[SparkException] {
df.collect()
}.getCause.getMessage
assert(msg.contains("cannot be represented as Decimal(38, 18)"))
}

test("SPARK-32018: Throw exception on decimal overflow at partial aggregate phase") {
val decimalString = "1" + "0" * 19
val union = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1))
val hashAgg = union
.select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"), lit("1").as("key"))
.groupBy("key")
.agg(sum($"d").alias("sumD"))
.select($"sumD")
var msg = intercept[SparkException] {
hashAgg.collect()
}.getCause.getMessage
assert(msg.contains("cannot be represented as Decimal(38, 18)"))
exceptionOnDecimalOverflow(hashAgg)

val sortAgg = union
.select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"), lit("a").as("str"),
lit("1").as("key")).groupBy("key")
.agg(sum($"d").alias("sumD"), min($"str").alias("minStr")).select($"sumD", $"minStr")
msg = intercept[SparkException] {
sortAgg.collect()
}.getCause.getMessage
assert(msg.contains("cannot be represented as Decimal(38, 18)"))
exceptionOnDecimalOverflow(sortAgg)
}

test("SPARK-32018: Throw exception on decimal overflow at merge aggregation phase") {
val decimalString = "5" + "0" * 19
val union = spark.range(0, 1, 1, 1).union(spark.range(0, 1, 1, 1))
.union(spark.range(0, 1, 1, 1))
val agg = union
.select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"), lit("1").as("key"))
.groupBy("key")
.agg(sum($"d").alias("sumD"))
.select($"sumD")
exceptionOnDecimalOverflow(agg)
}
}

Expand Down