diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index d2daaac72fc8..d442549f20e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -71,23 +71,36 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast ) override lazy val updateExpressions: Seq[Expression] = { + val sumWithChild = resultType match { + case d: DecimalType => + CheckOverflow(coalesce(sum, zero) + child.cast(sumDataType), d, nullOnOverflow = false) + case _ => + coalesce(sum, zero) + child.cast(sumDataType) + } + if (child.nullable) { Seq( /* sum = */ - coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) + coalesce(sumWithChild, sum) ) } else { Seq( /* sum = */ - coalesce(sum, zero) + child.cast(sumDataType) + sumWithChild ) } } override lazy val mergeExpressions: Seq[Expression] = { + val sumWithRight = resultType match { + case d: DecimalType => + CheckOverflow(coalesce(sum.left, zero) + sum.right, d, nullOnOverflow = false) + + case _ => coalesce(sum.left, zero) + sum.right + } Seq( /* sum = */ - coalesce(coalesce(sum.left, zero) + sum.right, sum.left) + coalesce(sumWithRight, sum.left) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 54327b38c100..8c0358e205b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -21,6 +21,7 @@ import scala.util.Random import org.scalatest.Matchers.the +import org.apache.spark.SparkException import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} @@ -1044,6 +1045,42 @@ class DataFrameAggregateSuite extends QueryTest checkAnswer(sql(queryTemplate("FIRST")), Row(1)) checkAnswer(sql(queryTemplate("LAST")), Row(3)) } + + private def exceptionOnDecimalOverflow(df: DataFrame): Unit = { + val msg = intercept[SparkException] { + df.collect() + }.getCause.getMessage + assert(msg.contains("cannot be represented as Decimal(38, 18)")) + } + + test("SPARK-32018: Throw exception on decimal overflow at partial aggregate phase") { + val decimalString = "1" + "0" * 19 + val union = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1)) + val hashAgg = union + .select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"), lit("1").as("key")) + .groupBy("key") + .agg(sum($"d").alias("sumD")) + .select($"sumD") + exceptionOnDecimalOverflow(hashAgg) + + val sortAgg = union + .select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"), lit("a").as("str"), + lit("1").as("key")).groupBy("key") + .agg(sum($"d").alias("sumD"), min($"str").alias("minStr")).select($"sumD", $"minStr") + exceptionOnDecimalOverflow(sortAgg) + } + + test("SPARK-32018: Throw exception on decimal overflow at merge aggregation phase") { + val decimalString = "5" + "0" * 19 + val union = spark.range(0, 1, 1, 1).union(spark.range(0, 1, 1, 1)) + .union(spark.range(0, 1, 1, 1)) + val agg = union + .select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"), lit("1").as("key")) + .groupBy("key") + .agg(sum($"d").alias("sumD")) + .select($"sumD") + exceptionOnDecimalOverflow(agg) + } } case class B(c: Option[Double])