Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,18 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast

private lazy val sum = AttributeReference("sum", sumDataType)()

private lazy val isEmptyOrNulls = AttributeReference("isEmptyOrNulls", BooleanType, false)()
private lazy val isEmpty = AttributeReference("isEmpty", BooleanType, nullable = false)()

private lazy val zero = Literal.default(sumDataType)

override lazy val aggBufferAttributes = resultType match {
case _: DecimalType => sum :: isEmptyOrNulls :: Nil
case _: DecimalType => sum :: isEmpty :: Nil
case _ => sum :: Nil
}

override lazy val initialValues: Seq[Expression] = resultType match {
case _: DecimalType => Seq(zero, Literal.create(true, BooleanType))
case other => Seq(Literal.create(null, other))
case _: DecimalType => Seq(Literal(null, resultType), Literal(true, BooleanType))
case _ => Seq(Literal(null, resultType))
}

/**
Expand All @@ -97,29 +97,18 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
*/
override lazy val updateExpressions: Seq[Expression] = {
if (child.nullable) {
val updateSumExpr = coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum)
resultType match {
case d: DecimalType =>
Seq(
/* sum */
If(IsNull(sum), sum,
If(IsNotNull(child.cast(sumDataType)),
CheckOverflow(sum + child.cast(sumDataType), d, true), sum)),
/* isEmptyOrNulls */
If(isEmptyOrNulls, IsNull(child.cast(sumDataType)), isEmptyOrNulls)
)
case _ =>
Seq(coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum))
case _: DecimalType =>
Seq(updateSumExpr, isEmpty && child.isNull)
case _ => Seq(updateSumExpr)
}
} else {
val updateSumExpr = coalesce(sum, zero) + child.cast(sumDataType)
resultType match {
case d: DecimalType =>
Seq(
/* sum */
If(IsNull(sum), sum, CheckOverflow(sum + child.cast(sumDataType), d, true)),
/* isEmptyOrNulls */
false
)
case _ => Seq(coalesce(sum, zero) + child.cast(sumDataType))
case _: DecimalType =>
Seq(updateSumExpr, Literal(false, BooleanType))
case _ => Seq(updateSumExpr)
}
}
}
Expand All @@ -138,19 +127,15 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
* If the value from bufferLeft and bufferRight are both true, then this will be true.
*/
override lazy val mergeExpressions: Seq[Expression] = {
val mergeSumExpr = coalesce(coalesce(sum.left, zero) + sum.right, sum.left)
resultType match {
case d: DecimalType =>
case _: DecimalType =>
val inputOverflow = !isEmpty.right && sum.right.isNull
val bufferOverflow = !isEmpty.left && sum.left.isNull
Seq(
/* sum = */
If(And(IsNull(sum.left), EqualTo(isEmptyOrNulls.left, false)) ||
And(IsNull(sum.right), EqualTo(isEmptyOrNulls.right, false)),
Literal.create(null, resultType),
CheckOverflow(sum.left + sum.right, d, true)),
/* isEmptyOrNulls = */
And(isEmptyOrNulls.left, isEmptyOrNulls.right)
)
case _ =>
Seq(coalesce(coalesce(sum.left, zero) + sum.right, sum.left))
If(inputOverflow || bufferOverflow, Literal.create(null, sumDataType), mergeSumExpr),
isEmpty.left && isEmpty.right)
case _ => Seq(mergeSumExpr)
}
}

Expand All @@ -163,11 +148,8 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
*/
override lazy val evaluateExpression: Expression = resultType match {
case d: DecimalType =>
If(EqualTo(isEmptyOrNulls, true),
Literal.create(null, sumDataType),
If(And(SQLConf.get.ansiEnabled, IsNull(sum)),
OverflowException(resultType, "Arithmetic Operation overflow"), sum))
If(isEmpty, Literal.create(null, sumDataType),
CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled))
case _ => sum
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -146,22 +146,53 @@ case class CheckOverflow(
override def sql: String = child.sql
}

case class OverflowException(dtype: DataType, msg: String) extends LeafExpression {

override def dataType: DataType = dtype
// A variant `CheckOverflow`, which treats null as overflow. This is necessary in `Sum`.
case class CheckOverflowInSum(
child: Expression,
dataType: DecimalType,
nullOnOverflow: Boolean) extends UnaryExpression {

override def nullable: Boolean = false
override def nullable: Boolean = true

def eval(input: InternalRow): Any = {
Decimal.throwArithmeticException(msg)
override def eval(input: InternalRow): Any = {
val value = child.eval(input)
if (value == null) {
if (nullOnOverflow) null else throw new ArithmeticException("Overflow in sum of decimals.")
} else {
input.asInstanceOf[Decimal].toPrecision(
dataType.precision,
dataType.scale,
Decimal.ROUND_HALF_UP,
nullOnOverflow)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
ev.copy(code = code"""
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|${ev.value} = Decimal.throwArithmeticException("${msg}");
|""", isNull = FalseLiteral)
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val childGen = child.genCode(ctx)
val nullHandling = if (nullOnOverflow) {
""
} else {
s"""
|throw new ArithmeticException("Overflow in sum of decimals.");
|""".stripMargin
}
val code = code"""
|${childGen.code}
|boolean ${ev.isNull} = ${childGen.isNull};
|Decimal ${ev.value} = null;
|if (${childGen.isNull}) {
| $nullHandling
|} else {
| ${ev.value} = ${childGen.value}.toPrecision(
| ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow);
| ${ev.isNull} = ${ev.value} == null;
|}
|""".stripMargin

ev.copy(code = code)
}

override def toString: String = "OverflowException"
override def toString: String = s"CheckOverflowInSum($child, $dataType, $nullOnOverflow)"

override def sql: String = child.sql
}
Original file line number Diff line number Diff line change
Expand Up @@ -651,9 +651,4 @@ object Decimal {
override def quot(x: Decimal, y: Decimal): Decimal = x quot y
override def rem(x: Decimal, y: Decimal): Decimal = x % y
}


def throwArithmeticException(msg: String): Decimal = {
throw new ArithmeticException(msg)
}
}
51 changes: 31 additions & 20 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,28 @@ class DataFrameSuite extends QueryTest
structDf.select(xxhash64($"a", $"record.*")))
}

private def assertDecimalSumOverflow(
df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row): Unit = {
if (!ansiEnabled) {
try {
checkAnswer(df, expectedAnswer)
} catch {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • This has changed the tests to say it is ok to throw exception even when ansienabled is false. Our ansienabled flag then isn't doing what it says it is supposed to do, right. That is a bug.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bug, in unsafe row writer, and I think we should fix it. According to apache#27627 (comment) , this bug is already there for a long time.

This is a less critical bug as Spark fails instead of returning a wrong result.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify, the issue we have is the UnsafeRow taking a overflow decimal value but when fetching it throws exception.

case e: SparkException if e.getCause.isInstanceOf[ArithmeticException] =>
// This is an existing bug that we can write overflowed decimal to UnsafeRow but fail
// to read it.
assert(e.getCause.getMessage.contains("Decimal precision 39 exceeds max precision 38"))
}
} else {
val e = intercept[SparkException] {
df.collect
}
assert(e.getCause.isInstanceOf[ArithmeticException])
assert(e.getCause.getMessage.contains("cannot be represented as Decimal") ||
e.getCause.getMessage.contains("Overflow in sum of decimals") ||
e.getCause.getMessage.contains("Decimal precision 39 exceeds max precision 38"))
}
}

test("SPARK-28224: Aggregate sum big decimal overflow") {
val largeDecimals = spark.sparkContext.parallelize(
DecimalData(BigDecimal("1"* 20 + ".123"), BigDecimal("1"* 20 + ".123")) ::
Expand All @@ -200,24 +222,12 @@ class DataFrameSuite extends QueryTest
Seq(true, false).foreach { ansiEnabled =>
withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) {
val structDf = largeDecimals.select("a").agg(sum("a"))
checkAnsi(structDf, ansiEnabled, Row(null))
}
}
}

private def checkAnsi(df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row ): Unit = {
if (!ansiEnabled) {
checkAnswer(df, expectedAnswer)
} else {
val e = intercept[SparkException] {
df.collect()
assertDecimalSumOverflow(structDf, ansiEnabled, Row(null))
}
assert(e.getCause.getClass.equals(classOf[ArithmeticException]))
assert(e.getCause.getMessage.contains("Arithmetic Operation overflow"))
}
}

test("test sum on null decimal values") {
test("SPARK-28067: sum of null decimal values") {
Seq("true", "false").foreach { wholeStageEnabled =>
withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) {
Seq("true", "false").foreach { ansiEnabled =>
Expand Down Expand Up @@ -254,26 +264,27 @@ class DataFrameSuite extends QueryTest
join(df, "intNum").agg(sum("decNum"))

val expectedAnswer = Row(null)
checkAnsi(df2, ansiEnabled, expectedAnswer)
assertDecimalSumOverflow(df2, ansiEnabled, expectedAnswer)

val decStr = "1" + "0" * 19
val d1 = spark.range(0, 12, 1, 1)
val d2 = d1.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d"))
checkAnsi(d2, ansiEnabled, expectedAnswer)
assertDecimalSumOverflow(d2, ansiEnabled, expectedAnswer)

val d3 = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1))
val d4 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d"))
checkAnsi(d4, ansiEnabled, expectedAnswer)
assertDecimalSumOverflow(d4, ansiEnabled, expectedAnswer)

val d5 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d"),
lit(1).as("key")).groupBy("key").agg(sum($"d").alias("sumd")).select($"sumd")
checkAnsi(d5, ansiEnabled, expectedAnswer)
assertDecimalSumOverflow(d5, ansiEnabled, expectedAnswer)

val nullsDf = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d"))

val largeDecimals = Seq(BigDecimal("1"* 20 + ".123"), BigDecimal("9"* 20 + ".123")).
toDF("d")
checkAnsi(nullsDf.union(largeDecimals).agg(sum($"d")), ansiEnabled, expectedAnswer)
assertDecimalSumOverflow(
nullsDf.union(largeDecimals).agg(sum($"d")), ansiEnabled, expectedAnswer)

val df3 = Seq(
(BigDecimal("10000000000000000000"), 1),
Expand All @@ -293,7 +304,7 @@ class DataFrameSuite extends QueryTest
val df6 = df3.union(df4).union(df5)
val df7 = df6.groupBy("intNum").agg(sum("decNum"), countDistinct("decNum")).
filter("intNum == 1")
checkAnsi(df7, ansiEnabled, Row(1, null, 2))
assertDecimalSumOverflow(df7, ansiEnabled, Row(1, null, 2))
}
}
}
Expand Down