Skip to content

Commit 8dc455b

Browse files
belieferMaxGekk
authored andcommitted
[SPARK-34837][SQL] Support ANSI SQL intervals by the aggregate function avg
### What changes were proposed in this pull request? Extend the `Average` expression to support `DayTimeIntervalType` and `YearMonthIntervalType` added by #31614. Note: the expressions can throw the overflow exception independently from the SQL config `spark.sql.ansi.enabled`. In this way, the modified expressions always behave in the ANSI mode for the intervals. ### Why are the changes needed? Extend `org.apache.spark.sql.catalyst.expressions.aggregate.Average` to support `DayTimeIntervalType` and `YearMonthIntervalType`. ### Does this PR introduce _any_ user-facing change? 'No'. Should not since new types have not been released yet. ### How was this patch tested? Jenkins test Closes #32229 from beliefer/SPARK-34837. Authored-by: gengjiaan <[email protected]> Signed-off-by: Max Gekk <[email protected]>
1 parent 70b606f commit 8dc455b

File tree

5 files changed

+60
-9
lines changed

5 files changed

+60
-9
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,11 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
4040

4141
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("avg")
4242

43-
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
43+
override def inputTypes: Seq[AbstractDataType] =
44+
Seq(TypeCollection(NumericType, YearMonthIntervalType, DayTimeIntervalType))
4445

4546
override def checkInputDataTypes(): TypeCheckResult =
46-
TypeUtils.checkForNumericExpr(child.dataType, "function average")
47+
TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, "average")
4748

4849
override def nullable: Boolean = true
4950

@@ -53,11 +54,15 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
5354
private lazy val resultType = child.dataType match {
5455
case DecimalType.Fixed(p, s) =>
5556
DecimalType.bounded(p + 4, s + 4)
57+
case _: YearMonthIntervalType => YearMonthIntervalType
58+
case _: DayTimeIntervalType => DayTimeIntervalType
5659
case _ => DoubleType
5760
}
5861

5962
private lazy val sumDataType = child.dataType match {
6063
case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
64+
case _: YearMonthIntervalType => YearMonthIntervalType
65+
case _: DayTimeIntervalType => DayTimeIntervalType
6166
case _ => DoubleType
6267
}
6368

@@ -82,6 +87,8 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
8287
case _: DecimalType =>
8388
DecimalPrecision.decimalAndDecimal(
8489
Divide(sum, count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType)
90+
case _: YearMonthIntervalType => DivideYMInterval(sum, count)
91+
case _: DayTimeIntervalType => DivideDTInterval(sum, count)
8592
case _ =>
8693
Divide(sum.cast(resultType), count.cast(resultType), failOnError = false)
8794
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2121
import org.apache.spark.sql.catalyst.dsl.expressions._
2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.trees.UnaryLike
24+
import org.apache.spark.sql.catalyst.util.TypeUtils
2425
import org.apache.spark.sql.internal.SQLConf
2526
import org.apache.spark.sql.types._
2627

@@ -48,12 +49,8 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
4849
override def inputTypes: Seq[AbstractDataType] =
4950
Seq(TypeCollection(NumericType, YearMonthIntervalType, DayTimeIntervalType))
5051

51-
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
52-
case YearMonthIntervalType | DayTimeIntervalType | NullType => TypeCheckResult.TypeCheckSuccess
53-
case dt if dt.isInstanceOf[NumericType] => TypeCheckResult.TypeCheckSuccess
54-
case other => TypeCheckResult.TypeCheckFailure(
55-
s"function sum requires numeric or interval types, not ${other.catalogString}")
56-
}
52+
override def checkInputDataTypes(): TypeCheckResult =
53+
TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, "sum")
5754

5855
private lazy val resultType = child.dataType match {
5956
case DecimalType.Fixed(precision, scale) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ object TypeUtils {
6161
}
6262
}
6363

64+
def checkForAnsiIntervalOrNumericType(
65+
dt: DataType, funcName: String): TypeCheckResult = dt match {
66+
case YearMonthIntervalType | DayTimeIntervalType | NullType => TypeCheckResult.TypeCheckSuccess
67+
case dt if dt.isInstanceOf[NumericType] => TypeCheckResult.TypeCheckSuccess
68+
case other => TypeCheckResult.TypeCheckFailure(
69+
s"function $funcName requires numeric or interval types, not ${other.catalogString}")
70+
}
71+
6472
def getNumeric(t: DataType, exactNumericRequired: Boolean = false): Numeric[Any] = {
6573
if (exactNumericRequired) {
6674
t.asInstanceOf[NumericType].exactNumeric.asInstanceOf[Numeric[Any]]

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
159159
assertError(Min(Symbol("mapField")), "min does not support ordering on type")
160160
assertError(Max(Symbol("mapField")), "max does not support ordering on type")
161161
assertError(Sum(Symbol("booleanField")), "function sum requires numeric or interval types")
162-
assertError(Average(Symbol("booleanField")), "function average requires numeric type")
162+
assertError(Average(Symbol("booleanField")),
163+
"function average requires numeric or interval types")
163164
}
164165

165166
test("check types for others") {

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,6 +1151,44 @@ class DataFrameAggregateSuite extends QueryTest
11511151
}
11521152
assert(error2.toString contains "java.lang.ArithmeticException: long overflow")
11531153
}
1154+
1155+
test("SPARK-34837: Support ANSI SQL intervals by the aggregate function `avg`") {
1156+
val df = Seq((1, Period.ofMonths(10), Duration.ofDays(10)),
1157+
(2, Period.ofMonths(1), Duration.ofDays(1)),
1158+
(2, null, null),
1159+
(3, Period.ofMonths(-3), Duration.ofDays(-6)),
1160+
(3, Period.ofMonths(21), Duration.ofDays(-5)))
1161+
.toDF("class", "year-month", "day-time")
1162+
1163+
val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)),
1164+
(Period.ofMonths(10), Duration.ofDays(10)))
1165+
.toDF("year-month", "day-time")
1166+
1167+
val avgDF = df.select(avg($"year-month"), avg($"day-time"))
1168+
checkAnswer(avgDF, Row(Period.ofMonths(7), Duration.ofDays(0)))
1169+
assert(find(avgDF.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
1170+
assert(avgDF.schema == StructType(Seq(StructField("avg(year-month)", YearMonthIntervalType),
1171+
StructField("avg(day-time)", DayTimeIntervalType))))
1172+
1173+
val avgDF2 = df.groupBy($"class").agg(avg($"year-month"), avg($"day-time"))
1174+
checkAnswer(avgDF2, Row(1, Period.ofMonths(10), Duration.ofDays(10)) ::
1175+
Row(2, Period.ofMonths(1), Duration.ofDays(1)) ::
1176+
Row(3, Period.ofMonths(9), Duration.ofDays(-5).plusHours(-12)) ::Nil)
1177+
assert(find(avgDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
1178+
assert(avgDF2.schema == StructType(Seq(StructField("class", IntegerType, false),
1179+
StructField("avg(year-month)", YearMonthIntervalType),
1180+
StructField("avg(day-time)", DayTimeIntervalType))))
1181+
1182+
val error = intercept[SparkException] {
1183+
checkAnswer(df2.select(avg($"year-month")), Nil)
1184+
}
1185+
assert(error.toString contains "java.lang.ArithmeticException: integer overflow")
1186+
1187+
val error2 = intercept[SparkException] {
1188+
checkAnswer(df2.select(avg($"day-time")), Nil)
1189+
}
1190+
assert(error2.toString contains "java.lang.ArithmeticException: long overflow")
1191+
}
11541192
}
11551193

11561194
case class B(c: Option[Double])

0 commit comments

Comments
 (0)