Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[SPARK-29688][SQL] Support average for interval type values
  • Loading branch information
yaooqinn committed Nov 6, 2019
commit 8ae0ced5750a0ce4b2ff85cfd24188307b7bee1d
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,8 @@ object TypeCoercion {
SubtractTimestamps(l, Cast(r, TimestampType))
case Subtract(l @ DateType(), r @ TimestampType()) =>
SubtractTimestamps(Cast(l, TimestampType), r)

case Divide(l @ CalendarIntervalType(), r) => IntervalDivide(l, Cast(r, DecimalType(29, 9)))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ package object dsl {
def * (other: Expression): Expression = Multiply(expr, other)
def / (other: Expression): Expression = Divide(expr, other)
def div (other: Expression): Expression = IntegralDivide(expr, other)
def intervalDiv (other: Expression): Expression = IntervalDivide(expr, other)
def % (other: Expression): Expression = Remainder(expr, other)
def & (other: Expression): Expression = BitwiseAnd(expr, other)
def | (other: Expression): Expression = BitwiseOr(expr, other)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import org.apache.spark.sql.types._
2.0
> SELECT _FUNC_(col) FROM VALUES (1), (2), (NULL) AS tab(col);
1.5
> SELECT _FUNC_(cast(v as interval)) FROM VALUES ('-1 weeks'), ('2 seconds'), (null) t(v);
interval -3 days -11 hours -59 minutes -59 seconds
""",
since = "1.0.0")
case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes {
Expand All @@ -39,10 +41,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit

override def children: Seq[Expression] = child :: Nil

override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "function average")
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)

override def nullable: Boolean = true

Expand All @@ -52,11 +51,13 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
private lazy val resultType = child.dataType match {
case DecimalType.Fixed(p, s) =>
DecimalType.bounded(p + 4, s + 4)
case CalendarIntervalType => CalendarIntervalType
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: case interval: CalendarIntervalType => interval?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I will change this and the one in the sumDataType

case _ => DoubleType
}

private lazy val sumDataType = child.dataType match {
case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
case CalendarIntervalType => CalendarIntervalType
case _ => DoubleType
}

Expand All @@ -66,7 +67,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
override lazy val aggBufferAttributes = sum :: count :: Nil

override lazy val initialValues = Seq(
/* sum = */ Literal(0).cast(sumDataType),
/* sum = */ Literal.default(sumDataType),
/* count = */ Literal(0L)
)

Expand All @@ -79,6 +80,8 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
override lazy val evaluateExpression = child.dataType match {
case _: DecimalType =>
DecimalPrecision.decimalAndDecimal(sum / count.cast(DecimalType.LongDecimal)).cast(resultType)
case CalendarIntervalType =>
sum.cast(resultType).intervalDiv(count.cast(DecimalType.LongDecimal))
case _ =>
sum.cast(resultType) / count.cast(resultType)
}
Expand All @@ -87,7 +90,7 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
/* sum = */
Add(
sum,
coalesce(child.cast(sumDataType), Literal(0).cast(sumDataType))),
coalesce(child.cast(sumDataType), Literal.default(sumDataType))),
/* count = */ If(child.isNull, count, count + 1L)
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -2164,3 +2164,32 @@ case class SubtractDates(left: Expression, right: Expression)
}
}

// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "expr1 _FUNC_ expr2 - Divide interval value `expr1` by `expr2`. It returns NULL if `expr2` is 0 or NULL.",
examples = """
Examples:
> SELECT interval '1 year 2 month' / 3.0;
interval 4 months 2 weeks 6 days
""",
since = "3.0.0")
// scalastyle:on line.size.limit
case class IntervalDivide(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes {

override def dataType: DataType = CalendarIntervalType

override def inputTypes: Seq[AbstractDataType] = Seq(CalendarIntervalType, DecimalType)

override def nullSafeEval(interval: Any, divisor: Any): Any = {
IntervalUtils.divide(interval.asInstanceOf[CalendarInterval],
divisor.asInstanceOf[Decimal])
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (interval, divisor) => {
val iu = IntervalUtils.getClass.getName.stripSuffix("$")
s"$iu.divide($interval, $divisor)"
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ object IntervalUtils {
Decimal(result, 18, 6)
}

def divide(interval: CalendarInterval, divisor: Decimal): CalendarInterval = {
if (divisor == Decimal.ZERO || divisor == null) return null
val months = Decimal(interval.months) / divisor
val milliseconds = (Decimal(interval.microseconds) / divisor +
months.remainder(Decimal.ONE) * Decimal(MICROS_PER_MONTH)).toLong
new CalendarInterval(months.toInt, milliseconds.toLong)
}

/**
* Converts a string to [[CalendarInterval]] case-insensitively.
*
Expand Down
32 changes: 32 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/average.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
-- average with interval type

-- null
select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) where v is null;

-- empty set
select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) where 1=0;

--
select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v);
select avg(cast(v as interval)) from VALUES ('-1 seconds'), ('2 seconds'), (null) t(v);
select avg(cast(v as interval)) from VALUES ('-1 seconds'), ('-2 seconds'), (null) t(v);
select avg(cast(v as interval)) from VALUES ('-1 weeks'), ('2 seconds'), (null) t(v);

--group by
select
i,
avg(cast(v as interval))
from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v)
group by i;

--having
select
avg(cast(v as interval)) as sv
from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v)
having sv is not null;

-- window
SELECT
i,
avg(cast(v as interval)) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)
FROM VALUES(1,'1 seconds'),(1,'2 seconds'),(2,NULL),(2,NULL) t(i,v);
89 changes: 89 additions & 0 deletions sql/core/src/test/resources/sql-tests/results/average.sql.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 9


-- !query 0
select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) where v is null
-- !query 0 schema
struct<avg(CAST(v AS INTERVAL)):interval>
-- !query 0 output
NULL


-- !query 1
select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v) where 1=0
-- !query 1 schema
struct<avg(CAST(v AS INTERVAL)):interval>
-- !query 1 output
NULL


-- !query 2
select avg(cast(v as interval)) from VALUES ('1 seconds'), ('2 seconds'), (null) t(v)
-- !query 2 schema
struct<avg(CAST(v AS INTERVAL)):interval>
-- !query 2 output
interval 1 seconds 500 milliseconds


-- !query 3
select avg(cast(v as interval)) from VALUES ('-1 seconds'), ('2 seconds'), (null) t(v)
-- !query 3 schema
struct<avg(CAST(v AS INTERVAL)):interval>
-- !query 3 output
interval 500 milliseconds


-- !query 4
select avg(cast(v as interval)) from VALUES ('-1 seconds'), ('-2 seconds'), (null) t(v)
-- !query 4 schema
struct<avg(CAST(v AS INTERVAL)):interval>
-- !query 4 output
interval -1 seconds -500 milliseconds


-- !query 5
select avg(cast(v as interval)) from VALUES ('-1 weeks'), ('2 seconds'), (null) t(v)
-- !query 5 schema
struct<avg(CAST(v AS INTERVAL)):interval>
-- !query 5 output
interval -3 days -11 hours -59 minutes -59 seconds


-- !query 6
select
i,
avg(cast(v as interval))
from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v)
group by i
-- !query 6 schema
struct<i:int,avg(CAST(v AS INTERVAL)):interval>
-- !query 6 output
1 interval -1 days
2 interval 2 seconds
3 NULL


-- !query 7
select
avg(cast(v as interval)) as sv
from VALUES (1, '-1 weeks'), (2, '2 seconds'), (3, null), (1, '5 days') t(i, v)
having sv is not null
-- !query 7 schema
struct<sv:interval>
-- !query 7 output
interval -15 hours -59 minutes -59 seconds -333 milliseconds -333 microseconds


-- !query 8
SELECT
i,
avg(cast(v as interval)) OVER (ORDER BY i ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING)
FROM VALUES(1,'1 seconds'),(1,'2 seconds'),(2,NULL),(2,NULL) t(i,v)
-- !query 8 schema
struct<i:int,avg(CAST(v AS INTERVAL)) OVER (ORDER BY i ASC NULLS FIRST ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING):interval>
-- !query 8 output
1 interval 1 seconds 500 milliseconds
1 interval 2 seconds
2 NULL
2 NULL