Skip to content

Commit 12abfe7

Browse files
belieferHyukjinKwon
authored andcommitted
[SPARK-34716][SQL] Support ANSI SQL intervals by the aggregate function sum
### What changes were proposed in this pull request? Extend the `Sum` expression to to support `DayTimeIntervalType` and `YearMonthIntervalType` added by #31614. Note: the expressions can throw the overflow exception independently from the SQL config `spark.sql.ansi.enabled`. In this way, the modified expressions always behave in the ANSI mode for the intervals. ### Why are the changes needed? Extend `org.apache.spark.sql.catalyst.expressions.aggregate.Sum` to support `DayTimeIntervalType` and `YearMonthIntervalType`. ### Does this PR introduce _any_ user-facing change? 'No'. Should not since new types have not been released yet. ### How was this patch tested? Jenkins test Closes #32107 from beliefer/SPARK-34716. Lead-authored-by: gengjiaan <[email protected]> Co-authored-by: beliefer <[email protected]> Co-authored-by: Hyukjin Kwon <[email protected]> Signed-off-by: Max Gekk <[email protected]>
1 parent d04b467 commit 12abfe7

File tree

6 files changed

+81
-8
lines changed

6 files changed

+81
-8
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ public static int calculateBitSetWidthInBytes(int numFields) {
9090
FloatType,
9191
DoubleType,
9292
DateType,
93-
TimestampType
93+
TimestampType,
94+
YearMonthIntervalType,
95+
DayTimeIntervalType
9496
})));
9597
}
9698

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2121
import org.apache.spark.sql.catalyst.dsl.expressions._
2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.trees.UnaryLike
24-
import org.apache.spark.sql.catalyst.util.TypeUtils
2524
import org.apache.spark.sql.internal.SQLConf
2625
import org.apache.spark.sql.types._
2726

@@ -46,15 +45,22 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
4645
// Return data type.
4746
override def dataType: DataType = resultType
4847

49-
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
48+
override def inputTypes: Seq[AbstractDataType] =
49+
Seq(TypeCollection(NumericType, YearMonthIntervalType, DayTimeIntervalType))
5050

51-
override def checkInputDataTypes(): TypeCheckResult =
52-
TypeUtils.checkForNumericExpr(child.dataType, "function sum")
51+
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
52+
case YearMonthIntervalType | DayTimeIntervalType | NullType => TypeCheckResult.TypeCheckSuccess
53+
case dt if dt.isInstanceOf[NumericType] => TypeCheckResult.TypeCheckSuccess
54+
case other => TypeCheckResult.TypeCheckFailure(
55+
s"function sum requires numeric or interval types, not ${other.catalogString}")
56+
}
5357

5458
private lazy val resultType = child.dataType match {
5559
case DecimalType.Fixed(precision, scale) =>
5660
DecimalType.bounded(precision + 10, scale)
5761
case _: IntegralType => LongType
62+
case _: YearMonthIntervalType => YearMonthIntervalType
63+
case _: DayTimeIntervalType => DayTimeIntervalType
5864
case _ => DoubleType
5965
}
6066

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
158158

159159
assertError(Min(Symbol("mapField")), "min does not support ordering on type")
160160
assertError(Max(Symbol("mapField")), "max does not support ordering on type")
161-
assertError(Sum(Symbol("booleanField")), "function sum requires numeric type")
161+
assertError(Sum(Symbol("booleanField")), "function sum requires numeric or interval types")
162162
assertError(Average(Symbol("booleanField")), "function average requires numeric type")
163163
}
164164

sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -541,14 +541,14 @@ protected void reserveInternal(int newCapacity) {
541541
shortData = newData;
542542
}
543543
} else if (type instanceof IntegerType || type instanceof DateType ||
544-
DecimalType.is32BitDecimalType(type)) {
544+
DecimalType.is32BitDecimalType(type) || type instanceof YearMonthIntervalType) {
545545
if (intData == null || intData.length < newCapacity) {
546546
int[] newData = new int[newCapacity];
547547
if (intData != null) System.arraycopy(intData, 0, newData, 0, capacity);
548548
intData = newData;
549549
}
550550
} else if (type instanceof LongType || type instanceof TimestampType ||
551-
DecimalType.is64BitDecimalType(type)) {
551+
DecimalType.is64BitDecimalType(type) || type instanceof DayTimeIntervalType) {
552552
if (longData == null || longData.length < newCapacity) {
553553
long[] newData = new long[newCapacity];
554554
if (longData != null) System.arraycopy(longData, 0, newData, 0, capacity);

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,14 @@ sealed trait BufferSetterGetterUtils {
8787
(row: InternalRow, ordinal: Int) =>
8888
if (row.isNullAt(ordinal)) null else row.getLong(ordinal)
8989

90+
case YearMonthIntervalType =>
91+
(row: InternalRow, ordinal: Int) =>
92+
if (row.isNullAt(ordinal)) null else row.getInt(ordinal)
93+
94+
case DayTimeIntervalType =>
95+
(row: InternalRow, ordinal: Int) =>
96+
if (row.isNullAt(ordinal)) null else row.getLong(ordinal)
97+
9098
case other =>
9199
(row: InternalRow, ordinal: Int) =>
92100
if (row.isNullAt(ordinal)) null else row.get(ordinal, other)
@@ -187,6 +195,22 @@ sealed trait BufferSetterGetterUtils {
187195
row.setNullAt(ordinal)
188196
}
189197

198+
case YearMonthIntervalType =>
199+
(row: InternalRow, ordinal: Int, value: Any) =>
200+
if (value != null) {
201+
row.setInt(ordinal, value.asInstanceOf[Int])
202+
} else {
203+
row.setNullAt(ordinal)
204+
}
205+
206+
case DayTimeIntervalType =>
207+
(row: InternalRow, ordinal: Int, value: Any) =>
208+
if (value != null) {
209+
row.setLong(ordinal, value.asInstanceOf[Long])
210+
} else {
211+
row.setNullAt(ordinal)
212+
}
213+
190214
case other =>
191215
(row: InternalRow, ordinal: Int, value: Any) =>
192216
if (value != null) {

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@
1717

1818
package org.apache.spark.sql
1919

20+
import java.time.{Duration, Period}
21+
2022
import scala.util.Random
2123

2224
import org.scalatest.matchers.must.Matchers.the
2325

26+
import org.apache.spark.SparkException
2427
import org.apache.spark.sql.execution.WholeStageCodegenExec
2528
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
2629
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
@@ -1110,6 +1113,44 @@ class DataFrameAggregateSuite extends QueryTest
11101113
val e = intercept[AnalysisException](arrayDF.groupBy(struct($"col.a")).count())
11111114
assert(e.message.contains("requires integral type"))
11121115
}
1116+
1117+
test("SPARK-34716: Support ANSI SQL intervals by the aggregate function `sum`") {
1118+
val df = Seq((1, Period.ofMonths(10), Duration.ofDays(10)),
1119+
(2, Period.ofMonths(1), Duration.ofDays(1)),
1120+
(2, null, null),
1121+
(3, Period.ofMonths(-3), Duration.ofDays(-6)),
1122+
(3, Period.ofMonths(21), Duration.ofDays(-5)))
1123+
.toDF("class", "year-month", "day-time")
1124+
1125+
val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)),
1126+
(Period.ofMonths(10), Duration.ofDays(10)))
1127+
.toDF("year-month", "day-time")
1128+
1129+
val sumDF = df.select(sum($"year-month"), sum($"day-time"))
1130+
checkAnswer(sumDF, Row(Period.of(2, 5, 0), Duration.ofDays(0)))
1131+
assert(find(sumDF.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
1132+
assert(sumDF.schema == StructType(Seq(StructField("sum(year-month)", YearMonthIntervalType),
1133+
StructField("sum(day-time)", DayTimeIntervalType))))
1134+
1135+
val sumDF2 = df.groupBy($"class").agg(sum($"year-month"), sum($"day-time"))
1136+
checkAnswer(sumDF2, Row(1, Period.ofMonths(10), Duration.ofDays(10)) ::
1137+
Row(2, Period.ofMonths(1), Duration.ofDays(1)) ::
1138+
Row(3, Period.of(1, 6, 0), Duration.ofDays(-11)) ::Nil)
1139+
assert(find(sumDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
1140+
assert(sumDF2.schema == StructType(Seq(StructField("class", IntegerType, false),
1141+
StructField("sum(year-month)", YearMonthIntervalType),
1142+
StructField("sum(day-time)", DayTimeIntervalType))))
1143+
1144+
val error = intercept[SparkException] {
1145+
checkAnswer(df2.select(sum($"year-month")), Nil)
1146+
}
1147+
assert(error.toString contains "java.lang.ArithmeticException: integer overflow")
1148+
1149+
val error2 = intercept[SparkException] {
1150+
checkAnswer(df2.select(sum($"day-time")), Nil)
1151+
}
1152+
assert(error2.toString contains "java.lang.ArithmeticException: long overflow")
1153+
}
11131154
}
11141155

11151156
case class B(c: Option[Double])

0 commit comments

Comments
 (0)