Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,38 +62,74 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast

private lazy val sum = AttributeReference("sum", sumDataType)()

private lazy val isEmpty = AttributeReference("isEmpty", BooleanType, nullable = false)()

private lazy val zero = Literal.default(sumDataType)

override lazy val aggBufferAttributes = sum :: Nil
override lazy val aggBufferAttributes = resultType match {
case _: DecimalType => sum :: isEmpty :: Nil
case _ => sum :: Nil
}

override lazy val initialValues: Seq[Expression] = Seq(
/* sum = */ Literal.create(null, sumDataType)
)
override lazy val initialValues: Seq[Expression] = resultType match {
case _: DecimalType => Seq(Literal(null, resultType), Literal(true, BooleanType))
case _ => Seq(Literal(null, resultType))
}

override lazy val updateExpressions: Seq[Expression] = {
if (child.nullable) {
Seq(
/* sum = */
coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum)
)
val updateSumExpr = coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum)
resultType match {
case _: DecimalType =>
Seq(updateSumExpr, isEmpty && child.isNull)
case _ => Seq(updateSumExpr)
}
} else {
Seq(
/* sum = */
coalesce(sum, zero) + child.cast(sumDataType)
)
val updateSumExpr = coalesce(sum, zero) + child.cast(sumDataType)
resultType match {
case _: DecimalType =>
Seq(updateSumExpr, Literal(false, BooleanType))
case _ => Seq(updateSumExpr)
}
}
}

/**
* For decimal type:
* If isEmpty is false and if sum is null, then it means we have had an overflow.
*
* update of the sum is as follows:
* Check if either portion of the left.sum or right.sum has overflowed
* If it has, then the sum value will remain null.
* If it did not have overflow, then add the sum.left and sum.right
*
* isEmpty: Set to false if either one of the left or right is set to false. This
* means we have seen atleast a value that was not null.
*/
override lazy val mergeExpressions: Seq[Expression] = {
Seq(
/* sum = */
coalesce(coalesce(sum.left, zero) + sum.right, sum.left)
)
val mergeSumExpr = coalesce(coalesce(sum.left, zero) + sum.right, sum.left)
resultType match {
case _: DecimalType =>
val inputOverflow = !isEmpty.right && sum.right.isNull
val bufferOverflow = !isEmpty.left && sum.left.isNull
Seq(
If(inputOverflow || bufferOverflow, Literal.create(null, sumDataType), mergeSumExpr),
isEmpty.left && isEmpty.right)
case _ => Seq(mergeSumExpr)
}
}

/**
* If the isEmpty is true, then it means there were no values to begin with or all the values
* were null, so the result will be null.
* If the isEmpty is false, then if sum is null that means an overflow has happened.
* So now, if ansi is enabled, then throw exception, if not then return null.
* If sum is not null, then return the sum.
*/
override lazy val evaluateExpression: Expression = resultType match {
case d: DecimalType => CheckOverflow(sum, d, !SQLConf.get.ansiEnabled)
case d: DecimalType =>
If(isEmpty, Literal.create(null, sumDataType),
CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled))
case _ => sum
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -144,3 +145,54 @@ case class CheckOverflow(

override def sql: String = child.sql
}

// A variant `CheckOverflow`, which treats null as overflow. This is necessary in `Sum`.
case class CheckOverflowInSum(
child: Expression,
dataType: DecimalType,
nullOnOverflow: Boolean) extends UnaryExpression {

override def nullable: Boolean = true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we override nullable with false as doGenCode() does?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the child can be nullable, the input value can be null. Making nullable to false in that case will not work, as it may result in npe. We can change the doGenCode() here to make the check for null for that, but since the nullSafeCodeGen in UnaryExpression already takes care of the if nullable checks, it seems there is no need to add if null checks here.


override def eval(input: InternalRow): Any = {
val value = child.eval(input)
if (value == null) {
if (nullOnOverflow) null else throw new ArithmeticException("Overflow in sum of decimals.")
} else {
value.asInstanceOf[Decimal].toPrecision(
dataType.precision,
dataType.scale,
Decimal.ROUND_HALF_UP,
nullOnOverflow)
}
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val childGen = child.genCode(ctx)
val nullHandling = if (nullOnOverflow) {
""
} else {
s"""
|throw new ArithmeticException("Overflow in sum of decimals.");
|""".stripMargin
}
val code = code"""
|${childGen.code}
|boolean ${ev.isNull} = ${childGen.isNull};
|Decimal ${ev.value} = null;
|if (${childGen.isNull}) {
| $nullHandling
|} else {
| ${ev.value} = ${childGen.value}.toPrecision(
| ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow);
| ${ev.isNull} = ${ev.value} == null;
|}
|""".stripMargin

ev.copy(code = code)
}

override def toString: String = s"CheckOverflowInSum($child, $dataType, $nullOnOverflow)"

override def sql: String = child.sql
}
112 changes: 105 additions & 7 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,28 @@ class DataFrameSuite extends QueryTest
structDf.select(xxhash64($"a", $"record.*")))
}

private def assertDecimalSumOverflow(
df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row): Unit = {
if (!ansiEnabled) {
try {
checkAnswer(df, expectedAnswer)
} catch {
case e: SparkException if e.getCause.isInstanceOf[ArithmeticException] =>
// This is an existing bug that we can write overflowed decimal to UnsafeRow but fail
// to read it.
assert(e.getCause.getMessage.contains("Decimal precision 39 exceeds max precision 38"))
}
} else {
val e = intercept[SparkException] {
df.collect
}
assert(e.getCause.isInstanceOf[ArithmeticException])
assert(e.getCause.getMessage.contains("cannot be represented as Decimal") ||
e.getCause.getMessage.contains("Overflow in sum of decimals") ||
e.getCause.getMessage.contains("Decimal precision 39 exceeds max precision 38"))
}
}

test("SPARK-28224: Aggregate sum big decimal overflow") {
val largeDecimals = spark.sparkContext.parallelize(
DecimalData(BigDecimal("1"* 20 + ".123"), BigDecimal("1"* 20 + ".123")) ::
Expand All @@ -200,14 +222,90 @@ class DataFrameSuite extends QueryTest
Seq(true, false).foreach { ansiEnabled =>
withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) {
val structDf = largeDecimals.select("a").agg(sum("a"))
if (!ansiEnabled) {
checkAnswer(structDf, Row(null))
} else {
val e = intercept[SparkException] {
structDf.collect
assertDecimalSumOverflow(structDf, ansiEnabled, Row(null))
}
}
}

test("SPARK-28067: sum of null decimal values") {
Seq("true", "false").foreach { wholeStageEnabled =>
withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) {
Seq("true", "false").foreach { ansiEnabled =>
withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled)) {
val df = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d"))
checkAnswer(df.agg(sum($"d")), Row(null))
}
}
}
}
}

test("SPARK-28067: Aggregate sum should not return wrong results for decimal overflow") {
Seq("true", "false").foreach { wholeStageEnabled =>
withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) {
Seq(true, false).foreach { ansiEnabled =>
withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) {
val df0 = Seq(
(BigDecimal("10000000000000000000"), 1),
(BigDecimal("10000000000000000000"), 1),
(BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
val df1 = Seq(
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
val df = df0.union(df1)
val df2 = df.withColumnRenamed("decNum", "decNum2").
join(df, "intNum").agg(sum("decNum"))

val expectedAnswer = Row(null)
assertDecimalSumOverflow(df2, ansiEnabled, expectedAnswer)

val decStr = "1" + "0" * 19
val d1 = spark.range(0, 12, 1, 1)
val d2 = d1.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d"))
assertDecimalSumOverflow(d2, ansiEnabled, expectedAnswer)

val d3 = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1))
val d4 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d"))
assertDecimalSumOverflow(d4, ansiEnabled, expectedAnswer)

val d5 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d"),
lit(1).as("key")).groupBy("key").agg(sum($"d").alias("sumd")).select($"sumd")
assertDecimalSumOverflow(d5, ansiEnabled, expectedAnswer)

val nullsDf = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d"))

val largeDecimals = Seq(BigDecimal("1"* 20 + ".123"), BigDecimal("9"* 20 + ".123")).
toDF("d")
assertDecimalSumOverflow(
nullsDf.union(largeDecimals).agg(sum($"d")), ansiEnabled, expectedAnswer)

val df3 = Seq(
(BigDecimal("10000000000000000000"), 1),
(BigDecimal("50000000000000000000"), 1),
(BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")

val df4 = Seq(
(BigDecimal("10000000000000000000"), 1),
(BigDecimal("10000000000000000000"), 1),
(BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")

val df5 = Seq(
(BigDecimal("10000000000000000000"), 1),
(BigDecimal("10000000000000000000"), 1),
(BigDecimal("20000000000000000000"), 2)).toDF("decNum", "intNum")

val df6 = df3.union(df4).union(df5)
val df7 = df6.groupBy("intNum").agg(sum("decNum"), countDistinct("decNum")).
filter("intNum == 1")
assertDecimalSumOverflow(df7, ansiEnabled, Row(1, null, 2))
}
assert(e.getCause.getClass.equals(classOf[ArithmeticException]))
assert(e.getCause.getMessage.contains("cannot be represented as Decimal"))
}
}
}
Expand Down