diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index d2daaac72fc8..6e850267100f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -62,38 +62,74 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast private lazy val sum = AttributeReference("sum", sumDataType)() + private lazy val isEmpty = AttributeReference("isEmpty", BooleanType, nullable = false)() + private lazy val zero = Literal.default(sumDataType) - override lazy val aggBufferAttributes = sum :: Nil + override lazy val aggBufferAttributes = resultType match { + case _: DecimalType => sum :: isEmpty :: Nil + case _ => sum :: Nil + } - override lazy val initialValues: Seq[Expression] = Seq( - /* sum = */ Literal.create(null, sumDataType) - ) + override lazy val initialValues: Seq[Expression] = resultType match { + case _: DecimalType => Seq(Literal(null, resultType), Literal(true, BooleanType)) + case _ => Seq(Literal(null, resultType)) + } override lazy val updateExpressions: Seq[Expression] = { if (child.nullable) { - Seq( - /* sum = */ - coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) - ) + val updateSumExpr = coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) + resultType match { + case _: DecimalType => + Seq(updateSumExpr, isEmpty && child.isNull) + case _ => Seq(updateSumExpr) + } } else { - Seq( - /* sum = */ - coalesce(sum, zero) + child.cast(sumDataType) - ) + val updateSumExpr = coalesce(sum, zero) + child.cast(sumDataType) + resultType match { + case _: DecimalType => + Seq(updateSumExpr, Literal(false, BooleanType)) + case _ => Seq(updateSumExpr) + } } } + /** + * For decimal type: + * If isEmpty is false and if sum is null, then it means we have had an overflow. + * + * update of the sum is as follows: + * Check if either portion of the left.sum or right.sum has overflowed + * If it has, then the sum value will remain null. + * If it did not have overflow, then add the sum.left and sum.right + * + * isEmpty: Set to false if either one of the left or right is set to false. This + * means we have seen atleast a value that was not null. + */ override lazy val mergeExpressions: Seq[Expression] = { - Seq( - /* sum = */ - coalesce(coalesce(sum.left, zero) + sum.right, sum.left) - ) + val mergeSumExpr = coalesce(coalesce(sum.left, zero) + sum.right, sum.left) + resultType match { + case _: DecimalType => + val inputOverflow = !isEmpty.right && sum.right.isNull + val bufferOverflow = !isEmpty.left && sum.left.isNull + Seq( + If(inputOverflow || bufferOverflow, Literal.create(null, sumDataType), mergeSumExpr), + isEmpty.left && isEmpty.right) + case _ => Seq(mergeSumExpr) + } } + /** + * If the isEmpty is true, then it means there were no values to begin with or all the values + * were null, so the result will be null. + * If the isEmpty is false, then if sum is null that means an overflow has happened. + * So now, if ansi is enabled, then throw exception, if not then return null. + * If sum is not null, then return the sum. + */ override lazy val evaluateExpression: Expression = resultType match { - case d: DecimalType => CheckOverflow(sum, d, !SQLConf.get.ansiEnabled) + case d: DecimalType => + If(isEmpty, Literal.create(null, sumDataType), + CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled)) case _ => sum } - } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 9014ebfe2f96..9f0408a380f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -144,3 +145,54 @@ case class CheckOverflow( override def sql: String = child.sql } + +// A variant `CheckOverflow`, which treats null as overflow. This is necessary in `Sum`. +case class CheckOverflowInSum( + child: Expression, + dataType: DecimalType, + nullOnOverflow: Boolean) extends UnaryExpression { + + override def nullable: Boolean = true + + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + if (nullOnOverflow) null else throw new ArithmeticException("Overflow in sum of decimals.") + } else { + value.asInstanceOf[Decimal].toPrecision( + dataType.precision, + dataType.scale, + Decimal.ROUND_HALF_UP, + nullOnOverflow) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx) + val nullHandling = if (nullOnOverflow) { + "" + } else { + s""" + |throw new ArithmeticException("Overflow in sum of decimals."); + |""".stripMargin + } + val code = code""" + |${childGen.code} + |boolean ${ev.isNull} = ${childGen.isNull}; + |Decimal ${ev.value} = null; + |if (${childGen.isNull}) { + | $nullHandling + |} else { + | ${ev.value} = ${childGen.value}.toPrecision( + | ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow); + | ${ev.isNull} = ${ev.value} == null; + |} + |""".stripMargin + + ev.copy(code = code) + } + + override def toString: String = s"CheckOverflowInSum($child, $dataType, $nullOnOverflow)" + + override def sql: String = child.sql +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f20e684bf765..bbcb9df45550 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -192,6 +192,28 @@ class DataFrameSuite extends QueryTest structDf.select(xxhash64($"a", $"record.*"))) } + private def assertDecimalSumOverflow( + df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row): Unit = { + if (!ansiEnabled) { + try { + checkAnswer(df, expectedAnswer) + } catch { + case e: SparkException if e.getCause.isInstanceOf[ArithmeticException] => + // This is an existing bug that we can write overflowed decimal to UnsafeRow but fail + // to read it. + assert(e.getCause.getMessage.contains("Decimal precision 39 exceeds max precision 38")) + } + } else { + val e = intercept[SparkException] { + df.collect + } + assert(e.getCause.isInstanceOf[ArithmeticException]) + assert(e.getCause.getMessage.contains("cannot be represented as Decimal") || + e.getCause.getMessage.contains("Overflow in sum of decimals") || + e.getCause.getMessage.contains("Decimal precision 39 exceeds max precision 38")) + } + } + test("SPARK-28224: Aggregate sum big decimal overflow") { val largeDecimals = spark.sparkContext.parallelize( DecimalData(BigDecimal("1"* 20 + ".123"), BigDecimal("1"* 20 + ".123")) :: @@ -200,14 +222,90 @@ class DataFrameSuite extends QueryTest Seq(true, false).foreach { ansiEnabled => withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) { val structDf = largeDecimals.select("a").agg(sum("a")) - if (!ansiEnabled) { - checkAnswer(structDf, Row(null)) - } else { - val e = intercept[SparkException] { - structDf.collect + assertDecimalSumOverflow(structDf, ansiEnabled, Row(null)) + } + } + } + + test("SPARK-28067: sum of null decimal values") { + Seq("true", "false").foreach { wholeStageEnabled => + withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) { + Seq("true", "false").foreach { ansiEnabled => + withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled)) { + val df = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d")) + checkAnswer(df.agg(sum($"d")), Row(null)) + } + } + } + } + } + + test("SPARK-28067: Aggregate sum should not return wrong results for decimal overflow") { + Seq("true", "false").foreach { wholeStageEnabled => + withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) { + Seq(true, false).foreach { ansiEnabled => + withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) { + val df0 = Seq( + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") + val df1 = Seq( + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") + val df = df0.union(df1) + val df2 = df.withColumnRenamed("decNum", "decNum2"). + join(df, "intNum").agg(sum("decNum")) + + val expectedAnswer = Row(null) + assertDecimalSumOverflow(df2, ansiEnabled, expectedAnswer) + + val decStr = "1" + "0" * 19 + val d1 = spark.range(0, 12, 1, 1) + val d2 = d1.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d")) + assertDecimalSumOverflow(d2, ansiEnabled, expectedAnswer) + + val d3 = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1)) + val d4 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d")) + assertDecimalSumOverflow(d4, ansiEnabled, expectedAnswer) + + val d5 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d"), + lit(1).as("key")).groupBy("key").agg(sum($"d").alias("sumd")).select($"sumd") + assertDecimalSumOverflow(d5, ansiEnabled, expectedAnswer) + + val nullsDf = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d")) + + val largeDecimals = Seq(BigDecimal("1"* 20 + ".123"), BigDecimal("9"* 20 + ".123")). + toDF("d") + assertDecimalSumOverflow( + nullsDf.union(largeDecimals).agg(sum($"d")), ansiEnabled, expectedAnswer) + + val df3 = Seq( + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("50000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") + + val df4 = Seq( + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") + + val df5 = Seq( + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("20000000000000000000"), 2)).toDF("decNum", "intNum") + + val df6 = df3.union(df4).union(df5) + val df7 = df6.groupBy("intNum").agg(sum("decNum"), countDistinct("decNum")). + filter("intNum == 1") + assertDecimalSumOverflow(df7, ansiEnabled, Row(1, null, 2)) } - assert(e.getCause.getClass.equals(classOf[ArithmeticException])) - assert(e.getCause.getMessage.contains("cannot be represented as Decimal")) } } }