-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-28067][SQL] Fix incorrect results for decimal aggregate sum by returning null on decimal overflow #27627
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
edb32b9
8fe9aa4
1adc512
d90c790
136e6dc
6979e8d
4119e02
cc2fec0
23739c9
fa45378
fbd80a6
de2d68f
8339e28
59a00c4
7795888
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
…s well and also to use it for identifying overflow scenarios
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,91 +62,113 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast | |
|
|
||
| private lazy val sum = AttributeReference("sum", sumDataType)() | ||
|
|
||
| private lazy val overflow = AttributeReference("overflow", BooleanType, false)() | ||
| private lazy val isEmptyOrNulls = AttributeReference("isEmptyOrNulls", BooleanType, false)() | ||
|
|
||
| private lazy val zero = Literal.default(sumDataType) | ||
|
|
||
| override lazy val aggBufferAttributes = sum :: overflow :: Nil | ||
| override lazy val aggBufferAttributes = sum :: isEmptyOrNulls :: Nil | ||
|
|
||
| override lazy val initialValues: Seq[Expression] = Seq( | ||
| /* sum = */ Literal.create(null, sumDataType), | ||
| /* overflow = */ Literal.create(false, BooleanType) | ||
| /* sum = */ zero, | ||
| /* isEmptyOrNulls = */ Literal.create(true, BooleanType) | ||
| ) | ||
|
|
||
| /** | ||
| * For decimal types, update will do the following: | ||
| * We have a overflow flag and when it is true, it indicates overflow has happened | ||
| * 1. Start initial state with overflow = false, sum = null | ||
| * 2. Set sum to null if the value overflows else sum contains the intermediate sum | ||
| * 3. If overflow flag is true, keep sum as null | ||
| * 4. If overflow happened, then set overflow flag to true | ||
| * For decimal types and when child is nullable: | ||
| * isEmptyOrNulls flag is a boolean to represent if there are no rows or if all rows that | ||
| * have been seen are null. This will be used to identify if the end result of sum in | ||
| * evaluateExpression should be null or not. | ||
| * | ||
| * Update of the isEmptyOrNulls flag: | ||
| * If this flag is false, then keep it as is. | ||
| * If this flag is true, then check if the incoming value is null and if it is null, keep it | ||
| * as true else update it to false. | ||
| * Once this flag is switched to false, it will remain false. | ||
| * | ||
| * The update of the sum is as follows: | ||
| * If sum is null, then we have a case of overflow, so keep sum as is. | ||
| * If sum is not null, and the incoming value is not null, then perform the addition along | ||
| * with the overflow checking. Note, that if overflow occurs, then sum will be null here. | ||
|
||
| * If the new incoming value is null, we will keep the sum in buffer as is and skip this | ||
| * incoming null | ||
| */ | ||
| override lazy val updateExpressions: Seq[Expression] = { | ||
| if (child.nullable) { | ||
| resultType match { | ||
| case d: DecimalType => | ||
| Seq( | ||
| If(overflow, sum, coalesce( | ||
| CheckOverflow(coalesce(sum, zero) + child.cast(sumDataType), d, true), sum)), | ||
| overflow || | ||
| coalesce(HasOverflow(coalesce(sum, zero) + child.cast(sumDataType), d), false) | ||
| /* sum */ | ||
| If(IsNull(sum), sum, | ||
| If(IsNotNull(child.cast(sumDataType)), | ||
| CheckOverflow(sum + child.cast(sumDataType), d, true), sum)), | ||
| /* isEmptyOrNulls */ | ||
| If(isEmptyOrNulls, IsNull(child.cast(sumDataType)), isEmptyOrNulls) | ||
| ) | ||
| case _ => | ||
| Seq( | ||
| coalesce(sum + child.cast(sumDataType), sum), | ||
| If(isEmptyOrNulls, IsNull(child.cast(sumDataType)), isEmptyOrNulls) | ||
| ) | ||
| case _ => Seq(coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum), false) | ||
| } | ||
| } else { | ||
| resultType match { | ||
| case d: DecimalType => | ||
| Seq( | ||
| If(overflow, sum, | ||
| CheckOverflow(coalesce(sum, zero) + child.cast(sumDataType), d, true)), | ||
| overflow || | ||
| coalesce(HasOverflow(coalesce(sum, zero) + child.cast(sumDataType), d), false) | ||
| /* sum */ | ||
| If(IsNull(sum), sum, CheckOverflow(sum + child.cast(sumDataType), d, true)), | ||
| /* isEmptyOrNulls */ | ||
| false | ||
| ) | ||
| case _ => Seq(coalesce(sum, zero) + child.cast(sumDataType), false) | ||
| case _ => Seq(sum + child.cast(sumDataType), false) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * For decimal type: | ||
| * update of the sum is as follows: | ||
| * Check if either portion of the left.sum or right.sum has overflowed | ||
skambha marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| * If it has, then the sum value will remain null. | ||
| * If it did not have overflow, then add the sum.left and sum.right and check for overflow. | ||
cloud-fan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
skambha marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| * | ||
| * Decimal handling: | ||
| * If any of the left or right portion of the agg buffers has the overflow flag to true, | ||
| * then sum is set to null else sum is added for both sum.left and sum.right | ||
| * and if the value overflows it is set to null. | ||
| * If we have already seen overflow , then set overflow to true, else check if the addition | ||
| * overflowed and update the overflow buffer. | ||
| * isEmptyOrNulls: Set to false if either one of the left or right is set to false. This | ||
| * means we have seen atleast a row that was not null. | ||
skambha marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| * If the value from bufferLeft and bufferRight are both true, then this will be true. | ||
| */ | ||
| override lazy val mergeExpressions: Seq[Expression] = { | ||
| resultType match { | ||
| case d: DecimalType => | ||
| Seq( | ||
| If(coalesce(overflow.left, false) || coalesce(overflow.right, false), | ||
| Literal.create(null, d), | ||
| coalesce(CheckOverflow(coalesce(sum.left, zero) + sum.right, d, true), sum.left)), | ||
| If(coalesce(overflow.left, false) || coalesce(overflow.right, false), | ||
| true, HasOverflow(coalesce(sum.left, zero) + sum.right, d)) | ||
| ) | ||
| /* sum = */ | ||
| If(And(IsNull(sum.left), EqualTo(isEmptyOrNulls.left, false)) || | ||
| And(IsNull(sum.right), EqualTo(isEmptyOrNulls.right, false)), | ||
| Literal.create(null, resultType), | ||
| CheckOverflow(sum.left + sum.right, d, true)), | ||
| /* isEmptyOrNulls = */ | ||
| And(isEmptyOrNulls.left, isEmptyOrNulls.right) | ||
| ) | ||
| case _ => | ||
| Seq( | ||
| coalesce(coalesce(sum.left, zero) + sum.right, sum.left), | ||
| false | ||
| coalesce(sum.left + sum.right, sum.left), | ||
| And(isEmptyOrNulls.left, isEmptyOrNulls.right) | ||
| ) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Decimal handling: | ||
| * If overflow buffer is true, and if ansiEnabled is true then throw exception, else return null | ||
| * If overflow did not happen, then return the sum value | ||
| * If the isEmptyOrNulls is true, then it means either there are no rows, or all the rows were | ||
| * null, so the result will be null. | ||
| * If the isEmptyOrNulls is false, then if sum is null that means an overflow has happened. | ||
| * So now, if ansi is enabled, then throw exception, if not then return null. | ||
| * If sum is not null, then return the sum. | ||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| */ | ||
| override lazy val evaluateExpression: Expression = resultType match { | ||
| case d: DecimalType => | ||
| If(EqualTo(overflow, true), | ||
| If(!SQLConf.get.ansiEnabled, | ||
| Literal.create(null, sumDataType), | ||
| OverflowException(resultType, "Arithmetic Operation overflow")), | ||
| sum) | ||
| case _ => sum | ||
| If(EqualTo(isEmptyOrNulls, true), | ||
| Literal.create(null, sumDataType), | ||
| If(And(SQLConf.get.ansiEnabled, IsNull(sum)), | ||
| OverflowException(resultType, "Arithmetic Operation overflow"), sum)) | ||
| case _ => If(EqualTo(isEmptyOrNulls, true), Literal.create(null, resultType), sum) | ||
| } | ||
|
|
||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.