From 7aae81c73873b37494e6f6d393717aad068a4a9f Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Tue, 11 Feb 2020 13:29:29 -0800 Subject: [PATCH] Fix incorrect results during aggregate sum for decimal when there is overflow, throw exception and make it consistent to when wholestage codegen is disabled. Also fix the affected test from spark-28224 --- .../catalyst/expressions/aggregate/Sum.scala | 16 ++++++++++-- .../org/apache/spark/sql/DataFrameSuite.scala | 16 +++++------- .../org/apache/spark/sql/SQLQuerySuite.scala | 26 +++++++++++++++++++ 3 files changed, 46 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 87f1a4f02e4f..82b184dadebd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -57,6 +57,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast case _ => DoubleType } + private lazy val wrongResultDueToOverflow = false private lazy val sumDataType = resultType private lazy val sum = AttributeReference("sum", sumDataType)() @@ -73,7 +74,13 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast if (child.nullable) { Seq( /* sum = */ - coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) + resultType match { + case d: DecimalType => coalesce( + CheckOverflow( + coalesce(sum, zero) + child.cast(sumDataType), d, wrongResultDueToOverflow), + sum) + case _ => coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) + } ) } else { Seq( @@ -86,7 +93,12 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast override lazy val mergeExpressions: Seq[Expression] = { Seq( /* sum = */ - coalesce(coalesce(sum.left, zero) + sum.right, sum.left) + resultType match { + case d: DecimalType => coalesce( + CheckOverflow(coalesce(sum.left, zero) + sum.right, d, wrongResultDueToOverflow), + sum.left) + case _ => coalesce(coalesce(sum.left, zero) + sum.right, sum.left) + } ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index d2d58a83ded5..c81fef0a691a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -169,18 +169,14 @@ class DataFrameSuite extends QueryTest DecimalData(BigDecimal("1"* 20 + ".123"), BigDecimal("1"* 20 + ".123")) :: DecimalData(BigDecimal("9"* 20 + ".123"), BigDecimal("9"* 20 + ".123")) :: Nil).toDF() - Seq(true, false).foreach { ansiEnabled => - withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) { + Seq("true", "false").foreach { codegenEnabled => + withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, codegenEnabled)) { val structDf = largeDecimals.select("a").agg(sum("a")) - if (!ansiEnabled) { - checkAnswer(structDf, Row(null)) - } else { - val e = intercept[SparkException] { - structDf.collect - } - assert(e.getCause.getClass.equals(classOf[ArithmeticException])) - assert(e.getCause.getMessage.contains("cannot be represented as Decimal")) + val e = intercept[SparkException] { + structDf.collect } + assert(e.getCause.getClass.equals(classOf[ArithmeticException])) + assert(e.getCause.getMessage.contains("cannot be represented as Decimal")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 11f9724e587f..981f41194090 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3319,6 +3319,32 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } + test("SPARK-28067 - Aggregate sum should not return wrong results") { + Seq("true", "false").foreach { wholeStageEnabled => + withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) { + val df = Seq( + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") + val df2 = df.withColumnRenamed("decNum", "decNum2").join(df, "intNum").agg(sum("decNum")) + val e = intercept[SparkException] { + df2.collect() + } + assert(e.getCause.getClass.equals(classOf[ArithmeticException])) + assert(e.getCause.getMessage.contains("cannot be represented as Decimal")) + } + } + } + test("SPARK-29213: FilterExec should not throw NPE") { withTempView("t1", "t2", "t3") { sql("SELECT ''").as[String].map(identity).toDF("x").createOrReplaceTempView("t1")