Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add new test and also make the agg buffer changes only for decimal type
  • Loading branch information
skambha committed May 15, 2020
commit fbd80a65a6c02d2513dc978ed02bd6da09609dd5
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,15 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast

private lazy val zero = Literal.default(sumDataType)

override lazy val aggBufferAttributes = sum :: isEmptyOrNulls :: Nil
override lazy val aggBufferAttributes = resultType match {
case _: DecimalType => sum :: isEmptyOrNulls :: Nil
case _ => sum :: Nil
}

override lazy val initialValues: Seq[Expression] = Seq(
/* sum = */ zero,
/* isEmptyOrNulls = */ Literal.create(true, BooleanType)
)
override lazy val initialValues: Seq[Expression] = resultType match {
case _: DecimalType => Seq(zero, Literal.create(true, BooleanType))
case other => Seq(Literal.create(null, other))
}

/**
* For decimal types and when child is nullable:
Expand Down Expand Up @@ -105,10 +108,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
If(isEmptyOrNulls, IsNull(child.cast(sumDataType)), isEmptyOrNulls)
)
case _ =>
Seq(
coalesce(sum + child.cast(sumDataType), sum),
If(isEmptyOrNulls, IsNull(child.cast(sumDataType)), isEmptyOrNulls)
)
Seq(coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum))
}
} else {
resultType match {
Expand All @@ -119,13 +119,15 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
/* isEmptyOrNulls */
false
)
case _ => Seq(sum + child.cast(sumDataType), false)
case _ => Seq(coalesce(sum, zero) + child.cast(sumDataType))
}
}
}

/**
* For decimal type:
* If isEmptyOrNulls is false and if sum is null, then it means we have an overflow.
*
* update of the sum is as follows:
* Check if either portion of the left.sum or right.sum has overflowed
* If it has, then the sum value will remain null.
Expand All @@ -148,10 +150,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
And(isEmptyOrNulls.left, isEmptyOrNulls.right)
)
case _ =>
Seq(
coalesce(sum.left + sum.right, sum.left),
And(isEmptyOrNulls.left, isEmptyOrNulls.right)
)
Seq(coalesce(coalesce(sum.left, zero) + sum.right, sum.left))
}
}

Expand All @@ -168,7 +167,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
Literal.create(null, sumDataType),
If(And(SQLConf.get.ansiEnabled, IsNull(sum)),
OverflowException(resultType, "Arithmetic Operation overflow"), sum))
case _ => If(EqualTo(isEmptyOrNulls, true), Literal.create(null, resultType), sum)
case _ => sum
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -709,15 +709,15 @@ ReadSchema: struct<key:int,val:int>
Input [2]: [key#x, val#x]
Keys: []
Functions [3]: [partial_count(val#x), partial_sum(cast(key#x as bigint)), partial_count(key#x) FILTER (WHERE (val#x > 1))]
Aggregate Attributes [4]: [count#xL, sum#xL, isEmptyOrNulls#x, count#xL]
Results [4]: [count#xL, sum#xL, isEmptyOrNulls#x, count#xL]
Aggregate Attributes [3]: [count#xL, sum#xL, count#xL]
Results [3]: [count#xL, sum#xL, count#xL]

(3) Exchange
Input [4]: [count#xL, sum#xL, isEmptyOrNulls#x, count#xL]
Input [3]: [count#xL, sum#xL, count#xL]
Arguments: SinglePartition, true, [id=#x]

(4) HashAggregate
Input [4]: [count#xL, sum#xL, isEmptyOrNulls#x, count#xL]
Input [3]: [count#xL, sum#xL, count#xL]
Keys: []
Functions [3]: [count(val#x), sum(cast(key#x as bigint)), count(key#x)]
Aggregate Attributes [3]: [count(val#x)#xL, sum(cast(key#x as bigint))#xL, count(key#x)#xL]
Expand Down
8 changes: 4 additions & 4 deletions sql/core/src/test/resources/sql-tests/results/explain.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -918,15 +918,15 @@ Input [2]: [key#x, val#x]
Input [2]: [key#x, val#x]
Keys: []
Functions [3]: [partial_count(val#x), partial_sum(cast(key#x as bigint)), partial_count(key#x) FILTER (WHERE (val#x > 1))]
Aggregate Attributes [4]: [count#xL, sum#xL, isEmptyOrNulls#x, count#xL]
Results [4]: [count#xL, sum#xL, isEmptyOrNulls#x, count#xL]
Aggregate Attributes [3]: [count#xL, sum#xL, count#xL]
Results [3]: [count#xL, sum#xL, count#xL]

(4) Exchange
Input [4]: [count#xL, sum#xL, isEmptyOrNulls#x, count#xL]
Input [3]: [count#xL, sum#xL, count#xL]
Arguments: SinglePartition, true, [id=#x]

(5) HashAggregate [codegen id : 2]
Input [4]: [count#xL, sum#xL, isEmptyOrNulls#x, count#xL]
Input [3]: [count#xL, sum#xL, count#xL]
Keys: []
Functions [3]: [count(val#x), sum(cast(key#x as bigint)), count(key#x)]
Aggregate Attributes [3]: [count(val#x)#xL, sum(cast(key#x as bigint))#xL, count(key#x)#xL]
Expand Down
38 changes: 30 additions & 8 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,14 @@ class DataFrameSuite extends QueryTest
Seq(true, false).foreach { ansiEnabled =>
withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) {
val structDf = largeDecimals.select("a").agg(sum("a"))
checkAnsi(structDf, ansiEnabled)
checkAnsi(structDf, ansiEnabled, Row(null))
}
}
}

private def checkAnsi(df: DataFrame, ansiEnabled: Boolean): Unit = {
private def checkAnsi(df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row ): Unit = {
if (!ansiEnabled) {
checkAnswer(df, Row(null))
checkAnswer(df, expectedAnswer)
} else {
val e = intercept[SparkException] {
df.collect()
Expand Down Expand Up @@ -252,26 +252,48 @@ class DataFrameSuite extends QueryTest
val df = df0.union(df1)
val df2 = df.withColumnRenamed("decNum", "decNum2").
join(df, "intNum").agg(sum("decNum"))
checkAnsi(df2, ansiEnabled)

val expectedAnswer = Row(null)
checkAnsi(df2, ansiEnabled, expectedAnswer)

val decStr = "1" + "0" * 19
val d1 = spark.range(0, 12, 1, 1)
val d2 = d1.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d"))
checkAnsi(d2, ansiEnabled)
checkAnsi(d2, ansiEnabled, expectedAnswer)

val d3 = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1))
val d4 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d"))
checkAnsi(d4, ansiEnabled)
checkAnsi(d4, ansiEnabled, expectedAnswer)

val d5 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d"),
lit(1).as("key")).groupBy("key").agg(sum($"d").alias("sumd")).select($"sumd")
checkAnsi(d5, ansiEnabled)
checkAnsi(d5, ansiEnabled, expectedAnswer)

val nullsDf = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d"))

val largeDecimals = Seq(BigDecimal("1"* 20 + ".123"), BigDecimal("9"* 20 + ".123")).
toDF("d")
checkAnsi(nullsDf.union(largeDecimals).agg(sum($"d")), ansiEnabled)
checkAnsi(nullsDf.union(largeDecimals).agg(sum($"d")), ansiEnabled, expectedAnswer)

val df3 = Seq(
(BigDecimal("10000000000000000000"), 1),
(BigDecimal("50000000000000000000"), 1),
(BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")

val df4 = Seq(
(BigDecimal("10000000000000000000"), 1),
(BigDecimal("10000000000000000000"), 1),
(BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")

val df5 = Seq(
(BigDecimal("10000000000000000000"), 1),
(BigDecimal("10000000000000000000"), 1),
(BigDecimal("20000000000000000000"), 2)).toDF("decNum", "intNum")

val df6 = df3.union(df4).union(df5)
val df7 = df6.groupBy("intNum").agg(sum("decNum"), countDistinct("decNum")).
filter("intNum == 1")
checkAnsi(df7, ansiEnabled, Row(1, null, 2))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,12 +379,12 @@ class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuit
|""".stripMargin,
s"""
|(11) ShuffleQueryStage
|Output [6]: [k#x, count#xL, sum#xL, isEmptyOrNulls#x, sum#x, count#xL]
|Output [5]: [k#x, count#xL, sum#xL, sum#x, count#xL]
|Arguments: 1
|""".stripMargin,
s"""
|(12) CustomShuffleReader
|Input [6]: [k#x, count#xL, sum#xL, isEmptyOrNulls#x, sum#x, count#xL]
|Input [5]: [k#x, count#xL, sum#xL, sum#x, count#xL]
|Arguments: coalesced
|""".stripMargin,
s"""
Expand Down