Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ public void setDecimal(int ordinal, Decimal value, int precision) {
Platform.putLong(baseObject, baseOffset + cursor, 0L);
Platform.putLong(baseObject, baseOffset + cursor + 8, 0L);

if (value == null) {
if (value == null || !value.changePrecision(precision, value.scale())) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks to @allisonwang-db for catching this bug!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks valid for branch-3.0 and branch-2.4. Do you think we need to backport to to branch-2.4?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we should

setNullAt(ordinal);
// keep the offset for future update
Platform.putLong(baseObject, getFieldOffset(ordinal), cursor << 32);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,39 +58,50 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
case _ => DoubleType
}

private lazy val sumDataType = resultType

private lazy val sum = AttributeReference("sum", sumDataType)()
private lazy val sum = AttributeReference("sum", resultType)()

private lazy val isEmpty = AttributeReference("isEmpty", BooleanType, nullable = false)()

private lazy val zero = Literal.default(sumDataType)
private lazy val zero = Literal.default(resultType)

override lazy val aggBufferAttributes = resultType match {
case _: DecimalType => sum :: isEmpty :: Nil
case _ => sum :: Nil
}

override lazy val initialValues: Seq[Expression] = resultType match {
case _: DecimalType => Seq(Literal(null, resultType), Literal(true, BooleanType))
case _: DecimalType => Seq(zero, Literal(true, BooleanType))
case _ => Seq(Literal(null, resultType))
}

override lazy val updateExpressions: Seq[Expression] = {
if (child.nullable) {
val updateSumExpr = coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum)
resultType match {
case _: DecimalType =>
Seq(updateSumExpr, isEmpty && child.isNull)
case _ => Seq(updateSumExpr)
}
} else {
val updateSumExpr = coalesce(sum, zero) + child.cast(sumDataType)
resultType match {
case _: DecimalType =>
Seq(updateSumExpr, Literal(false, BooleanType))
case _ => Seq(updateSumExpr)
}
resultType match {
case _: DecimalType =>
// For decimal type, the initial value of `sum` is 0. We need to keep `sum` unchanged if
// the input is null, as SUM function ignores null input. The `sum` can only be null if
// overflow happens under non-ansi mode.
val sumExpr = if (child.nullable) {
If(child.isNull, sum, sum + KnownNotNull(child).cast(resultType))
} else {
sum + child.cast(resultType)
}
// The buffer becomes non-empty after seeing the first not-null input.
val isEmptyExpr = if (child.nullable) {
isEmpty && child.isNull
} else {
Literal(false, BooleanType)
}
Seq(sumExpr, isEmptyExpr)
case _ =>
// For non-decimal type, the initial value of `sum` is null, which indicates no value.
// We need `coalesce(sum, zero)` to start summing values. And we need an outer `coalesce`
// in case the input is nullable. The `sum` can only be null if there is no value, as
// non-decimal type can produce overflowed value under non-ansi mode.
if (child.nullable) {
Seq(coalesce(coalesce(sum, zero) + child.cast(resultType), sum))
} else {
Seq(coalesce(sum, zero) + child.cast(resultType))
}
}
}

Expand All @@ -107,15 +118,20 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
* means we have seen atleast a value that was not null.
*/
override lazy val mergeExpressions: Seq[Expression] = {
val mergeSumExpr = coalesce(coalesce(sum.left, zero) + sum.right, sum.left)
resultType match {
case _: DecimalType =>
val inputOverflow = !isEmpty.right && sum.right.isNull
val bufferOverflow = !isEmpty.left && sum.left.isNull
val inputOverflow = !isEmpty.right && sum.right.isNull
Seq(
If(inputOverflow || bufferOverflow, Literal.create(null, sumDataType), mergeSumExpr),
If(
bufferOverflow || inputOverflow,
Literal.create(null, resultType),
// If both the buffer and the input do not overflow, just add them, as they can't be
// null. See the comments inside `updateExpressions`: `sum` can only be null if
// overflow happens.
KnownNotNull(sum.left) + KnownNotNull(sum.right)),
isEmpty.left && isEmpty.right)
case _ => Seq(mergeSumExpr)
case _ => Seq(coalesce(coalesce(sum.left, zero) + sum.right, sum.left))
}
}

Expand All @@ -128,7 +144,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
*/
override lazy val evaluateExpression: Expression = resultType match {
case d: DecimalType =>
If(isEmpty, Literal.create(null, sumDataType),
If(isEmpty, Literal.create(null, resultType),
CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled))
case _ => sum
}
Expand Down
14 changes: 3 additions & 11 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -195,22 +195,14 @@ class DataFrameSuite extends QueryTest
private def assertDecimalSumOverflow(
df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row): Unit = {
if (!ansiEnabled) {
try {
checkAnswer(df, expectedAnswer)
} catch {
case e: SparkException if e.getCause.isInstanceOf[ArithmeticException] =>
// This is an existing bug that we can write overflowed decimal to UnsafeRow but fail
// to read it.
assert(e.getCause.getMessage.contains("Decimal precision 39 exceeds max precision 38"))
}
checkAnswer(df, expectedAnswer)
} else {
val e = intercept[SparkException] {
df.collect
df.collect()
}
assert(e.getCause.isInstanceOf[ArithmeticException])
assert(e.getCause.getMessage.contains("cannot be represented as Decimal") ||
e.getCause.getMessage.contains("Overflow in sum of decimals") ||
e.getCause.getMessage.contains("Decimal precision 39 exceeds max precision 38"))
e.getCause.getMessage.contains("Overflow in sum of decimals"))
}
}

Expand Down
10 changes: 10 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -178,4 +178,14 @@ class UnsafeRowSuite extends SparkFunSuite {
// Makes sure hashCode on unsafe array won't crash
unsafeRow.getArray(0).hashCode()
}

test("SPARK-32018: setDecimal with overflowed value") {
val d1 = new Decimal().set(BigDecimal("10000000000000000000")).toPrecision(38, 18)
val row = InternalRow.apply(d1)
val unsafeRow = UnsafeProjection.create(Array[DataType](DecimalType(38, 18))).apply(row)
assert(unsafeRow.getDecimal(0, 38, 18) === d1)
val d2 = (d1 * Decimal(10)).toPrecision(39, 18)
unsafeRow.setDecimal(0, d2, 38)
assert(unsafeRow.getDecimal(0, 38, 18) === null)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens with ansi mode true?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UnsafeRow is a low-level entity and doesn't respect ansi flag.

}
}