From 0a232794422aa69af4fccd8c024ee065022604ad Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 11 Aug 2020 13:55:39 +0800 Subject: [PATCH 1/4] throw exception on decimal overflow --- .../sql/catalyst/expressions/UnsafeRow.java | 2 +- .../catalyst/expressions/aggregate/Sum.scala | 37 ++++++++++++++----- .../spark/sql/DataFrameAggregateSuite.scala | 24 ++++++++++++ 3 files changed, 52 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 4dc5ce1de047..034894bd8608 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -288,7 +288,7 @@ public void setDecimal(int ordinal, Decimal value, int precision) { Platform.putLong(baseObject, baseOffset + cursor, 0L); Platform.putLong(baseObject, baseOffset + cursor + 8, 0L); - if (value == null || !value.changePrecision(precision, value.scale())) { + if (value == null) { setNullAt(ordinal); // keep the offset for future update Platform.putLong(baseObject, getFieldOffset(ordinal), cursor << 32); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index d2daaac72fc8..901ea8f7126a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -71,16 +71,33 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast ) override lazy val updateExpressions: Seq[Expression] = { - if (child.nullable) { - Seq( - /* sum = */ - coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) - ) - } else { - Seq( - /* sum = */ - coalesce(sum, zero) + child.cast(sumDataType) - ) + resultType match { + case d: DecimalType => + val sumWithChild = + CheckOverflow(coalesce(sum, zero) + child.cast(sumDataType), d, nullOnOverflow = false) + if (child.nullable) { + Seq( + /* sum = */ + coalesce(sumWithChild, sum) + ) + } else { + Seq( + /* sum = */ + sumWithChild + ) + } + case _ => + if (child.nullable) { + Seq( + /* sum = */ + coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) + ) + } else { + Seq( + /* sum = */ + coalesce(sum, zero) + child.cast(sumDataType) + ) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 54327b38c100..7781d253468e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -21,6 +21,7 @@ import scala.util.Random import org.scalatest.Matchers.the +import org.apache.spark.SparkException import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} @@ -1044,6 +1045,29 @@ class DataFrameAggregateSuite extends QueryTest checkAnswer(sql(queryTemplate("FIRST")), Row(1)) checkAnswer(sql(queryTemplate("LAST")), Row(3)) } + + test("Throw exception on decimal overflow") { + val decimalString = "1" + "0" * 19 + val union = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1)) + val hashAgg = union + .select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"), lit("1").as("key")) + .groupBy("key") + .agg(sum($"d").alias("sumD")) + .select($"sumD") + var msg = intercept[SparkException] { + hashAgg.collect() + }.getCause.getMessage + assert(msg.contains("cannot be represented as Decimal(38, 18)")) + + val sortAgg = union + .select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"), lit("a").as("str"), + lit("1").as("key")).groupBy("key") + .agg(sum($"d").alias("sumD"), min($"str").alias("minStr")).select($"sumD", $"minStr") + msg = intercept[SparkException] { + sortAgg.collect() + }.getCause.getMessage + assert(msg.contains("cannot be represented as Decimal(38, 18)")) + } } case class B(c: Option[Double]) From f21f1a0af29e4ed360a7af0d3bd824490d646ba1 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 11 Aug 2020 14:55:45 +0800 Subject: [PATCH 2/4] update --- .../catalyst/expressions/aggregate/Sum.scala | 40 +++++++------------ .../spark/sql/DataFrameAggregateSuite.scala | 2 +- 2 files changed, 16 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 901ea8f7126a..2b3399a3d572 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -71,33 +71,23 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast ) override lazy val updateExpressions: Seq[Expression] = { - resultType match { + val sumWithChild = resultType match { case d: DecimalType => - val sumWithChild = - CheckOverflow(coalesce(sum, zero) + child.cast(sumDataType), d, nullOnOverflow = false) - if (child.nullable) { - Seq( - /* sum = */ - coalesce(sumWithChild, sum) - ) - } else { - Seq( - /* sum = */ - sumWithChild - ) - } + CheckOverflow(coalesce(sum, zero) + child.cast(sumDataType), d, nullOnOverflow = false) case _ => - if (child.nullable) { - Seq( - /* sum = */ - coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) - ) - } else { - Seq( - /* sum = */ - coalesce(sum, zero) + child.cast(sumDataType) - ) - } + coalesce(sum, zero) + child.cast(sumDataType) + } + + if (child.nullable) { + Seq( + /* sum = */ + coalesce(sumWithChild, sum) + ) + } else { + Seq( + /* sum = */ + sumWithChild + ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 7781d253468e..d9f7cdb316a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -1046,7 +1046,7 @@ class DataFrameAggregateSuite extends QueryTest checkAnswer(sql(queryTemplate("LAST")), Row(3)) } - test("Throw exception on decimal overflow") { + test("SPARK-32018: Throw exception on decimal overflow") { val decimalString = "1" + "0" * 19 val union = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1)) val hashAgg = union From b9af4f57a3c0a0fb35e8c13c9b61b042578b177e Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 11 Aug 2020 15:06:43 +0800 Subject: [PATCH 3/4] update --- .../org/apache/spark/sql/catalyst/expressions/UnsafeRow.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 034894bd8608..4dc5ce1de047 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -288,7 +288,7 @@ public void setDecimal(int ordinal, Decimal value, int precision) { Platform.putLong(baseObject, baseOffset + cursor, 0L); Platform.putLong(baseObject, baseOffset + cursor + 8, 0L); - if (value == null) { + if (value == null || !value.changePrecision(precision, value.scale())) { setNullAt(ordinal); // keep the offset for future update Platform.putLong(baseObject, getFieldOffset(ordinal), cursor << 32); From a8be9e1a093c1d221e07a6156a1a408d477621f2 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 11 Aug 2020 16:27:56 +0800 Subject: [PATCH 4/4] fix issue at merge aggregation phase --- .../catalyst/expressions/aggregate/Sum.scala | 8 ++++- .../spark/sql/DataFrameAggregateSuite.scala | 31 +++++++++++++------ 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 2b3399a3d572..d442549f20e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -92,9 +92,15 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast } override lazy val mergeExpressions: Seq[Expression] = { + val sumWithRight = resultType match { + case d: DecimalType => + CheckOverflow(coalesce(sum.left, zero) + sum.right, d, nullOnOverflow = false) + + case _ => coalesce(sum.left, zero) + sum.right + } Seq( /* sum = */ - coalesce(coalesce(sum.left, zero) + sum.right, sum.left) + coalesce(sumWithRight, sum.left) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d9f7cdb316a3..8c0358e205b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -1046,7 +1046,14 @@ class DataFrameAggregateSuite extends QueryTest checkAnswer(sql(queryTemplate("LAST")), Row(3)) } - test("SPARK-32018: Throw exception on decimal overflow") { + private def exceptionOnDecimalOverflow(df: DataFrame): Unit = { + val msg = intercept[SparkException] { + df.collect() + }.getCause.getMessage + assert(msg.contains("cannot be represented as Decimal(38, 18)")) + } + + test("SPARK-32018: Throw exception on decimal overflow at partial aggregate phase") { val decimalString = "1" + "0" * 19 val union = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1)) val hashAgg = union @@ -1054,19 +1061,25 @@ class DataFrameAggregateSuite extends QueryTest .groupBy("key") .agg(sum($"d").alias("sumD")) .select($"sumD") - var msg = intercept[SparkException] { - hashAgg.collect() - }.getCause.getMessage - assert(msg.contains("cannot be represented as Decimal(38, 18)")) + exceptionOnDecimalOverflow(hashAgg) val sortAgg = union .select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"), lit("a").as("str"), lit("1").as("key")).groupBy("key") .agg(sum($"d").alias("sumD"), min($"str").alias("minStr")).select($"sumD", $"minStr") - msg = intercept[SparkException] { - sortAgg.collect() - }.getCause.getMessage - assert(msg.contains("cannot be represented as Decimal(38, 18)")) + exceptionOnDecimalOverflow(sortAgg) + } + + test("SPARK-32018: Throw exception on decimal overflow at merge aggregation phase") { + val decimalString = "5" + "0" * 19 + val union = spark.range(0, 1, 1, 1).union(spark.range(0, 1, 1, 1)) + .union(spark.range(0, 1, 1, 1)) + val agg = union + .select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"), lit("1").as("key")) + .groupBy("key") + .agg(sum($"d").alias("sumD")) + .select($"sumD") + exceptionOnDecimalOverflow(agg) } }