diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 87f1a4f02e4f..82b184dadebd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -57,6 +57,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast case _ => DoubleType } + private lazy val wrongResultDueToOverflow = false private lazy val sumDataType = resultType private lazy val sum = AttributeReference("sum", sumDataType)() @@ -73,7 +74,13 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast if (child.nullable) { Seq( /* sum = */ - coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) + resultType match { + case d: DecimalType => coalesce( + CheckOverflow( + coalesce(sum, zero) + child.cast(sumDataType), d, wrongResultDueToOverflow), + sum) + case _ => coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) + } ) } else { Seq( @@ -86,7 +93,12 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast override lazy val mergeExpressions: Seq[Expression] = { Seq( /* sum = */ - coalesce(coalesce(sum.left, zero) + sum.right, sum.left) + resultType match { + case d: DecimalType => coalesce( + CheckOverflow(coalesce(sum.left, zero) + sum.right, d, wrongResultDueToOverflow), + sum.left) + case _ => coalesce(coalesce(sum.left, zero) + sum.right, sum.left) + } ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index d2d58a83ded5..c81fef0a691a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -169,18 +169,14 @@ class DataFrameSuite extends QueryTest DecimalData(BigDecimal("1"* 20 + ".123"), BigDecimal("1"* 20 + ".123")) :: DecimalData(BigDecimal("9"* 20 + ".123"), BigDecimal("9"* 20 + ".123")) :: Nil).toDF() - Seq(true, false).foreach { ansiEnabled => - withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) { + Seq("true", "false").foreach { codegenEnabled => + withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, codegenEnabled)) { val structDf = largeDecimals.select("a").agg(sum("a")) - if (!ansiEnabled) { - checkAnswer(structDf, Row(null)) - } else { - val e = intercept[SparkException] { - structDf.collect - } - assert(e.getCause.getClass.equals(classOf[ArithmeticException])) - assert(e.getCause.getMessage.contains("cannot be represented as Decimal")) + val e = intercept[SparkException] { + structDf.collect } + assert(e.getCause.getClass.equals(classOf[ArithmeticException])) + assert(e.getCause.getMessage.contains("cannot be represented as Decimal")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 11f9724e587f..981f41194090 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3319,6 +3319,32 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } + test("SPARK-28067 - Aggregate sum should not return wrong results") { + Seq("true", "false").foreach { wholeStageEnabled => + withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) { + val df = Seq( + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") + val df2 = df.withColumnRenamed("decNum", "decNum2").join(df, "intNum").agg(sum("decNum")) + val e = intercept[SparkException] { + df2.collect() + } + assert(e.getCause.getClass.equals(classOf[ArithmeticException])) + assert(e.getCause.getMessage.contains("cannot be represented as Decimal")) + } + } + } + test("SPARK-29213: FilterExec should not throw NPE") { withTempView("t1", "t2", "t3") { sql("SELECT ''").as[String].map(identity).toDF("x").createOrReplaceTempView("t1")