Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fix the incorrect results for sum for decimal overflow, support repor…
…ting null for it when ansiEnabled is false and throw exception if ansiEnabled is true
  • Loading branch information
skambha committed Apr 27, 2020
commit edb32b91cfd1d8e249b43ed6dc06810de702b3df
Original file line number Diff line number Diff line change
Expand Up @@ -61,38 +61,104 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
private lazy val sumDataType = resultType

private lazy val sum = AttributeReference("sum", sumDataType)()
private lazy val overflow = AttributeReference("overflow", BooleanType)()

private lazy val zero = Literal.default(sumDataType)

override lazy val aggBufferAttributes = sum :: Nil
override lazy val aggBufferAttributes = sum :: overflow :: Nil

override lazy val initialValues: Seq[Expression] = Seq(
/* sum = */ Literal.create(null, sumDataType)
/* sum = */ Literal.create(null, sumDataType),
/* overflow = */ Literal.create(false, BooleanType)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We keep track of overflow using this aggBufferAttributes - overflow to know if any of the intermediate add operations in updateExpressions and/or mergeExpressions overflow'd. If the overflow is true and if spark.sql.ansi.enabled flag is false, then we return null for the sum operation in evaluateExpression.

)

override lazy val updateExpressions: Seq[Expression] = {
if (child.nullable) {
Seq(
/* sum = */
coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum)
)
if (!SQLConf.get.ansiEnabled) {
Seq(
/* sum = */
resultType match {
case d: DecimalType => coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum)
case _ => coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum)
},
/* overflow = */
resultType match {
case d: DecimalType =>
If (overflow, true, HasOverflow(coalesce(sum, zero) + child.cast(sumDataType), d))
case _ => If(overflow, true, false)
})
} else {
Seq(
/* sum = */
resultType match {
case d: DecimalType => coalesce(
CheckOverflow(
coalesce(sum, zero) + child.cast(sumDataType), d, !SQLConf.get.ansiEnabled), sum)
case _ => coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum)
},
/* overflow = */
false
)
}
} else {
Seq(
/* sum = */
coalesce(sum, zero) + child.cast(sumDataType)
)
if (!SQLConf.get.ansiEnabled) {
Seq(
/* sum = */
resultType match {
case d: DecimalType => coalesce(sum, zero) + child.cast(sumDataType)
case _ => coalesce(sum, zero) + child.cast(sumDataType)
},
/* overflow = */
resultType match {
case d: DecimalType =>
If(overflow, true, HasOverflow(coalesce(sum, zero) + child.cast(sumDataType), d))
case _ => If(overflow, true, false)
})
} else {
Seq(
/* sum = */
resultType match {
case d: DecimalType => coalesce(
CheckOverflow(
coalesce(sum, zero) + child.cast(sumDataType), d, !SQLConf.get.ansiEnabled), sum)
case _ => coalesce(sum, zero) + child.cast(sumDataType)
},
/* overflow = */
false
)
}
}
}

override lazy val mergeExpressions: Seq[Expression] = {
Seq(
/* sum = */
coalesce(coalesce(sum.left, zero) + sum.right, sum.left)
resultType match {
case d: DecimalType =>
if (!SQLConf.get.ansiEnabled) {
coalesce(coalesce(sum.left, zero) + sum.right, sum.left)
} else {
coalesce(CheckOverflow(
coalesce(sum.left, zero) + sum.right, d, !SQLConf.get.ansiEnabled), sum.left)
}
case _ => coalesce(coalesce(sum.left, zero) + sum.right, sum.left)
},
/* overflow = */
resultType match {
case d: DecimalType =>
if (!SQLConf.get.ansiEnabled) {
If(overflow.left || overflow.right,
true, HasOverflow(coalesce(sum.left, zero) + sum.right, d))
} else {
If(overflow.left || overflow.right, true, false)
}
}
)
}

override lazy val evaluateExpression: Expression = resultType match {
case d: DecimalType => CheckOverflow(sum, d, !SQLConf.get.ansiEnabled)
case d: DecimalType => If(overflow && !SQLConf.get.ansiEnabled,
Literal.create(null, sumDataType) , sum)
case _ => sum
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,31 @@ case class CheckOverflow(

override def sql: String = child.sql
}

case class HasOverflow(
child: Expression,
inputType: DecimalType) extends UnaryExpression {

override def dataType: DataType= BooleanType

override def nullable: Boolean = false

override def nullSafeEval(input: Any): Any =
input.asInstanceOf[Decimal].changePrecision(
inputType.precision,
inputType.scale,
Decimal.ROUND_HALF_UP)

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, eval => {
s"""
|${ev.value} = $eval.changePrecision(
| ${inputType.precision}, ${inputType.scale}, Decimal.ROUND_HALF_UP());
""".stripMargin
})
}

override def toString: String = s"HasOverflow($child, $inputType)"

override def sql: String = child.sql
}