Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
case _ => DoubleType
}

private lazy val wrongResultDueToOverflow = false
private lazy val sumDataType = resultType

private lazy val sum = AttributeReference("sum", sumDataType)()
Expand All @@ -73,7 +74,13 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
if (child.nullable) {
Seq(
/* sum = */
coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum)
resultType match {
case d: DecimalType => coalesce(
CheckOverflow(
coalesce(sum, zero) + child.cast(sumDataType), d, wrongResultDueToOverflow),
sum)
case _ => coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum)
}
)
} else {
Seq(
Expand All @@ -86,7 +93,12 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
override lazy val mergeExpressions: Seq[Expression] = {
Seq(
/* sum = */
coalesce(coalesce(sum.left, zero) + sum.right, sum.left)
resultType match {
case d: DecimalType => coalesce(
CheckOverflow(coalesce(sum.left, zero) + sum.right, d, wrongResultDueToOverflow),
sum.left)
case _ => coalesce(coalesce(sum.left, zero) + sum.right, sum.left)
}
)
}

Expand Down
16 changes: 6 additions & 10 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -169,18 +169,14 @@ class DataFrameSuite extends QueryTest
DecimalData(BigDecimal("1"* 20 + ".123"), BigDecimal("1"* 20 + ".123")) ::
DecimalData(BigDecimal("9"* 20 + ".123"), BigDecimal("9"* 20 + ".123")) :: Nil).toDF()

Seq(true, false).foreach { ansiEnabled =>
withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) {
Seq("true", "false").foreach { codegenEnabled =>
withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, codegenEnabled)) {
val structDf = largeDecimals.select("a").agg(sum("a"))
if (!ansiEnabled) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SPARK-28224 took care of decimal overflow for sum only partially for 2 values. In this test case that was added as part of SPARK-28224, if you add another row into the dataset, you will get incorrect results and not return null on overflow.

In this PR we address decimal overflow in aggregate sum by throwing an exception. Hence this test has been modified.

checkAnswer(structDf, Row(null))
} else {
val e = intercept[SparkException] {
structDf.collect
}
assert(e.getCause.getClass.equals(classOf[ArithmeticException]))
assert(e.getCause.getMessage.contains("cannot be represented as Decimal"))
val e = intercept[SparkException] {
structDf.collect
}
assert(e.getCause.getClass.equals(classOf[ArithmeticException]))
assert(e.getCause.getMessage.contains("cannot be represented as Decimal"))
}
}
}
Expand Down
26 changes: 26 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3319,6 +3319,32 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}
}

test("SPARK-28067 - Aggregate sum should not return wrong results") {
Seq("true", "false").foreach { wholeStageEnabled =>
withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) {
val df = Seq(
(BigDecimal("10000000000000000000"), 1),
(BigDecimal("10000000000000000000"), 1),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2),
(BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
val df2 = df.withColumnRenamed("decNum", "decNum2").join(df, "intNum").agg(sum("decNum"))
val e = intercept[SparkException] {
df2.collect()
}
assert(e.getCause.getClass.equals(classOf[ArithmeticException]))
assert(e.getCause.getMessage.contains("cannot be represented as Decimal"))
}
}
}

test("SPARK-29213: FilterExec should not throw NPE") {
withTempView("t1", "t2", "t3") {
sql("SELECT ''").as[String].map(identity).toDF("x").createOrReplaceTempView("t1")
Expand Down