diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 034894bd8608..4dc5ce1de047 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -288,7 +288,7 @@ public void setDecimal(int ordinal, Decimal value, int precision) { Platform.putLong(baseObject, baseOffset + cursor, 0L); Platform.putLong(baseObject, baseOffset + cursor + 8, 0L); - if (value == null) { + if (value == null || !value.changePrecision(precision, value.scale())) { setNullAt(ordinal); // keep the offset for future update Platform.putLong(baseObject, getFieldOffset(ordinal), cursor << 32); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 6e850267100f..a29ae2c8b65a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -58,13 +58,11 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast case _ => DoubleType } - private lazy val sumDataType = resultType - - private lazy val sum = AttributeReference("sum", sumDataType)() + private lazy val sum = AttributeReference("sum", resultType)() private lazy val isEmpty = AttributeReference("isEmpty", BooleanType, nullable = false)() - private lazy val zero = Literal.default(sumDataType) + private lazy val zero = Literal.default(resultType) override lazy val aggBufferAttributes = resultType match { case _: DecimalType => sum :: isEmpty :: Nil @@ -72,25 +70,38 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast } override lazy val initialValues: Seq[Expression] = resultType match { - case _: DecimalType => Seq(Literal(null, resultType), Literal(true, BooleanType)) + case _: DecimalType => Seq(zero, Literal(true, BooleanType)) case _ => Seq(Literal(null, resultType)) } override lazy val updateExpressions: Seq[Expression] = { - if (child.nullable) { - val updateSumExpr = coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) - resultType match { - case _: DecimalType => - Seq(updateSumExpr, isEmpty && child.isNull) - case _ => Seq(updateSumExpr) - } - } else { - val updateSumExpr = coalesce(sum, zero) + child.cast(sumDataType) - resultType match { - case _: DecimalType => - Seq(updateSumExpr, Literal(false, BooleanType)) - case _ => Seq(updateSumExpr) - } + resultType match { + case _: DecimalType => + // For decimal type, the initial value of `sum` is 0. We need to keep `sum` unchanged if + // the input is null, as SUM function ignores null input. The `sum` can only be null if + // overflow happens under non-ansi mode. + val sumExpr = if (child.nullable) { + If(child.isNull, sum, sum + KnownNotNull(child).cast(resultType)) + } else { + sum + child.cast(resultType) + } + // The buffer becomes non-empty after seeing the first not-null input. + val isEmptyExpr = if (child.nullable) { + isEmpty && child.isNull + } else { + Literal(false, BooleanType) + } + Seq(sumExpr, isEmptyExpr) + case _ => + // For non-decimal type, the initial value of `sum` is null, which indicates no value. + // We need `coalesce(sum, zero)` to start summing values. And we need an outer `coalesce` + // in case the input is nullable. The `sum` can only be null if there is no value, as + // non-decimal type can produce overflowed value under non-ansi mode. + if (child.nullable) { + Seq(coalesce(coalesce(sum, zero) + child.cast(resultType), sum)) + } else { + Seq(coalesce(sum, zero) + child.cast(resultType)) + } } } @@ -107,15 +118,20 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast * means we have seen atleast a value that was not null. */ override lazy val mergeExpressions: Seq[Expression] = { - val mergeSumExpr = coalesce(coalesce(sum.left, zero) + sum.right, sum.left) resultType match { case _: DecimalType => - val inputOverflow = !isEmpty.right && sum.right.isNull val bufferOverflow = !isEmpty.left && sum.left.isNull + val inputOverflow = !isEmpty.right && sum.right.isNull Seq( - If(inputOverflow || bufferOverflow, Literal.create(null, sumDataType), mergeSumExpr), + If( + bufferOverflow || inputOverflow, + Literal.create(null, resultType), + // If both the buffer and the input do not overflow, just add them, as they can't be + // null. See the comments inside `updateExpressions`: `sum` can only be null if + // overflow happens. + KnownNotNull(sum.left) + KnownNotNull(sum.right)), isEmpty.left && isEmpty.right) - case _ => Seq(mergeSumExpr) + case _ => Seq(coalesce(coalesce(sum.left, zero) + sum.right, sum.left)) } } @@ -128,7 +144,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast */ override lazy val evaluateExpression: Expression = resultType match { case d: DecimalType => - If(isEmpty, Literal.create(null, sumDataType), + If(isEmpty, Literal.create(null, resultType), CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled)) case _ => sum } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 8359dff674a8..52ef5895ed9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -195,22 +195,14 @@ class DataFrameSuite extends QueryTest private def assertDecimalSumOverflow( df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row): Unit = { if (!ansiEnabled) { - try { - checkAnswer(df, expectedAnswer) - } catch { - case e: SparkException if e.getCause.isInstanceOf[ArithmeticException] => - // This is an existing bug that we can write overflowed decimal to UnsafeRow but fail - // to read it. - assert(e.getCause.getMessage.contains("Decimal precision 39 exceeds max precision 38")) - } + checkAnswer(df, expectedAnswer) } else { val e = intercept[SparkException] { - df.collect + df.collect() } assert(e.getCause.isInstanceOf[ArithmeticException]) assert(e.getCause.getMessage.contains("cannot be represented as Decimal") || - e.getCause.getMessage.contains("Overflow in sum of decimals") || - e.getCause.getMessage.contains("Decimal precision 39 exceeds max precision 38")) + e.getCause.getMessage.contains("Overflow in sum of decimals")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index a5f904c621e6..9daa69ce9f15 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -178,4 +178,14 @@ class UnsafeRowSuite extends SparkFunSuite { // Makes sure hashCode on unsafe array won't crash unsafeRow.getArray(0).hashCode() } + + test("SPARK-32018: setDecimal with overflowed value") { + val d1 = new Decimal().set(BigDecimal("10000000000000000000")).toPrecision(38, 18) + val row = InternalRow.apply(d1) + val unsafeRow = UnsafeProjection.create(Array[DataType](DecimalType(38, 18))).apply(row) + assert(unsafeRow.getDecimal(0, 38, 18) === d1) + val d2 = (d1 * Decimal(10)).toPrecision(39, 18) + unsafeRow.setDecimal(0, d2, 38) + assert(unsafeRow.getDecimal(0, 38, 18) === null) + } }