Skip to content

Commit e7f7990

Browse files
yaooqinncloud-fan
authored andcommitted
[SPARK-29688][SQL] Support average for interval type values
### What changes were proposed in this pull request? avg aggregate support interval type values ### Why are the changes needed? Part of SPARK-27764 Feature Parity between PostgreSQL and Spark ### Does this PR introduce any user-facing change? yes, we can do avg on intervals ### How was this patch tested? add ut Closes #26347 from yaooqinn/SPARK-29688. Authored-by: Kent Yao <yaooqinn@hotmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent afc943f commit e7f7990

File tree

4 files changed

+131
-9
lines changed

4 files changed

+131
-9
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ import org.apache.spark.sql.types._
3131
2.0
3232
> SELECT _FUNC_(col) FROM VALUES (1), (2), (NULL) AS tab(col);
3333
1.5
34+
> SELECT _FUNC_(cast(v as interval)) FROM VALUES ('-1 weeks'), ('2 seconds'), (null) t(v);
35+
-3 days -11 hours -59 minutes -59 seconds
3436
""",
3537
since = "1.0.0")
3638
case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes {
@@ -39,10 +41,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
3941

4042
override def children: Seq[Expression] = child :: Nil
4143

42-
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
43-
44-
override def checkInputDataTypes(): TypeCheckResult =
45-
TypeUtils.checkForNumericExpr(child.dataType, "function average")
44+
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
4645

4746
override def nullable: Boolean = true
4847

@@ -52,11 +51,13 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
5251
private lazy val resultType = child.dataType match {
5352
case DecimalType.Fixed(p, s) =>
5453
DecimalType.bounded(p + 4, s + 4)
54+
case interval: CalendarIntervalType => interval
5555
case _ => DoubleType
5656
}
5757

5858
private lazy val sumDataType = child.dataType match {
5959
case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
60+
case interval: CalendarIntervalType => interval
6061
case _ => DoubleType
6162
}
6263

@@ -66,7 +67,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
6667
override lazy val aggBufferAttributes = sum :: count :: Nil
6768

6869
override lazy val initialValues = Seq(
69-
/* sum = */ Literal(0).cast(sumDataType),
70+
/* sum = */ Literal.default(sumDataType),
7071
/* count = */ Literal(0L)
7172
)
7273

@@ -79,6 +80,8 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
7980
override lazy val evaluateExpression = child.dataType match {
8081
case _: DecimalType =>
8182
DecimalPrecision.decimalAndDecimal(sum / count.cast(DecimalType.LongDecimal)).cast(resultType)
83+
case CalendarIntervalType =>
84+
DivideInterval(sum.cast(resultType), count.cast(DoubleType))
8285
case _ =>
8386
sum.cast(resultType) / count.cast(resultType)
8487
}
@@ -87,7 +90,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
8790
/* sum = */
8891
Add(
8992
sum,
90-
coalesce(child.cast(sumDataType), Literal(0).cast(sumDataType))),
93+
coalesce(child.cast(sumDataType), Literal.default(sumDataType))),
9194
/* count = */ If(child.isNull, count, count + 1L)
9295
)
9396
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
150150
assertError(Min('mapField), "min does not support ordering on type")
151151
assertError(Max('mapField), "max does not support ordering on type")
152152
assertError(Sum('booleanField), "requires (numeric or interval) type")
153-
assertError(Average('booleanField), "function average requires numeric type")
153+
assertError(Average('booleanField), "requires (numeric or interval) type")
154154
}
155155

156156
test("check types for others") {

sql/core/src/test/resources/sql-tests/inputs/group-by.sql

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,4 +192,36 @@ having sv is not null;
192192
SELECT
193193
i,
194194
Sum(cast(v as interval)) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)
195-
FROM VALUES(1,'1 seconds'),(1,'2 seconds'),(2,NULL),(2,NULL) t(i,v);
195+
FROM VALUES(1,'1 seconds'),(1,'2 seconds'),(2,NULL),(2,NULL) t(i,v);
196+
197+
-- average with interval type
198+
-- null
199+
select avg(cast(v as interval)) from VALUES (null) t(v);
200+
201+
-- empty set
202+
select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) where 1=0;
203+
204+
-- basic interval avg
205+
select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v);
206+
select avg(cast(v as interval)) from VALUES ('-1 seconds'), ('2 seconds'), (null) t(v);
207+
select avg(cast(v as interval)) from VALUES ('-1 seconds'), ('-2 seconds'), (null) t(v);
208+
select avg(cast(v as interval)) from VALUES ('-1 weeks'), ('2 seconds'), (null) t(v);
209+
210+
-- group by
211+
select
212+
i,
213+
avg(cast(v as interval))
214+
from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v)
215+
group by i;
216+
217+
-- having
218+
select
219+
avg(cast(v as interval)) as sv
220+
from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v)
221+
having sv is not null;
222+
223+
-- window
224+
SELECT
225+
i,
226+
avg(cast(v as interval)) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)
227+
FROM VALUES (1,'1 seconds'), (1,'2 seconds'), (2,NULL), (2,NULL) t(i,v);

sql/core/src/test/resources/sql-tests/results/group-by.sql.out

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 65
2+
-- Number of queries: 74
33

44

55
-- !query 0
@@ -660,3 +660,90 @@ struct<i:int,sum(CAST(v AS INTERVAL)) OVER (ORDER BY i ASC NULLS FIRST ROWS BETW
660660
1 3 seconds
661661
2 NULL
662662
2 NULL
663+
664+
665+
-- !query 65
666+
select avg(cast(v as interval)) from VALUES (null) t(v)
667+
-- !query 65 schema
668+
struct<avg(CAST(v AS INTERVAL)):interval>
669+
-- !query 65 output
670+
NULL
671+
672+
673+
-- !query 66
674+
select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) where 1=0
675+
-- !query 66 schema
676+
struct<avg(CAST(v AS INTERVAL)):interval>
677+
-- !query 66 output
678+
NULL
679+
680+
681+
-- !query 67
682+
select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v)
683+
-- !query 67 schema
684+
struct<avg(CAST(v AS INTERVAL)):interval>
685+
-- !query 67 output
686+
1.5 seconds
687+
688+
689+
-- !query 68
690+
select avg(cast(v as interval)) from VALUES ('-1 seconds'), ('2 seconds'), (null) t(v)
691+
-- !query 68 schema
692+
struct<avg(CAST(v AS INTERVAL)):interval>
693+
-- !query 68 output
694+
0.5 seconds
695+
696+
697+
-- !query 69
698+
select avg(cast(v as interval)) from VALUES ('-1 seconds'), ('-2 seconds'), (null) t(v)
699+
-- !query 69 schema
700+
struct<avg(CAST(v AS INTERVAL)):interval>
701+
-- !query 69 output
702+
-1.5 seconds
703+
704+
705+
-- !query 70
706+
select avg(cast(v as interval)) from VALUES ('-1 weeks'), ('2 seconds'), (null) t(v)
707+
-- !query 70 schema
708+
struct<avg(CAST(v AS INTERVAL)):interval>
709+
-- !query 70 output
710+
-3 days -11 hours -59 minutes -59 seconds
711+
712+
713+
-- !query 71
714+
select
715+
i,
716+
avg(cast(v as interval))
717+
from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v)
718+
group by i
719+
-- !query 71 schema
720+
struct<i:int,avg(CAST(v AS INTERVAL)):interval>
721+
-- !query 71 output
722+
1 -1 days
723+
2 2 seconds
724+
3 NULL
725+
726+
727+
-- !query 72
728+
select
729+
avg(cast(v as interval)) as sv
730+
from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v)
731+
having sv is not null
732+
-- !query 72 schema
733+
struct<sv:interval>
734+
-- !query 72 output
735+
-15 hours -59 minutes -59.333333 seconds
736+
737+
738+
-- !query 73
739+
SELECT
740+
i,
741+
avg(cast(v as interval)) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)
742+
FROM VALUES (1,'1 seconds'), (1,'2 seconds'), (2,NULL), (2,NULL) t(i,v)
743+
-- !query 73 schema
744+
struct<i:int,avg(CAST(v AS INTERVAL)) OVER (ORDER BY i ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING):interval>
745+
-- !query 73 output
746+
1 1.5 seconds
747+
1 2 seconds
748+
2 NULL
749+
2 NULL

0 commit comments

Comments
 (0)