diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 125e796a98c2..bfd05e833e79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.catalyst.expressions -import java.time.ZoneId +import java.time.{Duration, Period, ZoneId} import java.util.Comparator import scala.collection.mutable @@ -2484,8 +2484,8 @@ case class Flatten(child: Expression) extends UnaryExpression with NullIntoleran The start and stop expressions must resolve to the same type. If start and stop expressions resolve to the 'date' or 'timestamp' type - then the step expression must resolve to the 'interval' type, otherwise to the same type - as the start and stop expressions. + then the step expression must resolve to the 'interval' or 'year-month interval' or + 'day-time interval' type, otherwise to the same type as the start and stop expressions. """, arguments = """ Arguments: @@ -2504,6 +2504,8 @@ case class Flatten(child: Expression) extends UnaryExpression with NullIntoleran [5,4,3,2,1] > SELECT _FUNC_(to_date('2018-01-01'), to_date('2018-03-01'), interval 1 month); [2018-01-01,2018-02-01,2018-03-01] + > SELECT _FUNC_(to_date('2018-01-01'), to_date('2018-03-01'), interval '0-1' year to month); + [2018-01-01,2018-02-01,2018-03-01] """, group = "array_funcs", since = "2.4.0" @@ -2550,8 +2552,13 @@ case class Sequence( val typesCorrect = startType.sameType(stop.dataType) && (startType match { - case TimestampType | DateType => - stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepType) + case TimestampType => + stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepType) || + YearMonthIntervalType.acceptsType(stepType) || + DayTimeIntervalType.acceptsType(stepType) + case DateType => + stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepType) || + YearMonthIntervalType.acceptsType(stepType) case _: IntegralType => stepOpt.isEmpty || stepType.sameType(startType) case _ => false @@ -2561,29 +2568,51 @@ case class Sequence( TypeCheckResult.TypeCheckSuccess } else { TypeCheckResult.TypeCheckFailure( - s"$prettyName only supports integral, timestamp or date types") + s""" + |$prettyName uses the wrong parameter type. The parameter type must conform to: + |1. The start and stop expressions must resolve to the same type. + |2. If start and stop expressions resolve to the 'date' or 'timestamp' type + |then the step expression must resolve to the 'interval' or + |'${YearMonthIntervalType.typeName}' or '${DayTimeIntervalType.typeName}' type, + |otherwise to the same type as the start and stop expressions. + """.stripMargin) } } - def coercibleChildren: Seq[Expression] = children.filter(_.dataType != CalendarIntervalType) + private def isNotIntervalType(expr: Expression) = expr.dataType match { + case CalendarIntervalType | YearMonthIntervalType | DayTimeIntervalType => false + case _ => true + } + + def coercibleChildren: Seq[Expression] = children.filter(isNotIntervalType) def castChildrenTo(widerType: DataType): Expression = Sequence( Cast(start, widerType), Cast(stop, widerType), - stepOpt.map(step => if (step.dataType != CalendarIntervalType) Cast(step, widerType) else step), + stepOpt.map(step => if (isNotIntervalType(step)) Cast(step, widerType) else step), timeZoneId) - @transient private lazy val impl: SequenceImpl = dataType.elementType match { + @transient private lazy val impl: InternalSequence = dataType.elementType match { case iType: IntegralType => type T = iType.InternalType val ct = ClassTag[T](iType.tag.mirror.runtimeClass(iType.tag.tpe)) new IntegralSequenceImpl(iType)(ct, iType.integral) case TimestampType => - new TemporalSequenceImpl[Long](LongType, 1, identity, zoneId) + if (stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepOpt.get.dataType)) { + new TemporalSequenceImpl[Long](LongType, 1, identity, zoneId) + } else if (YearMonthIntervalType.acceptsType(stepOpt.get.dataType)) { + new PeriodSequenceImpl[Long](LongType, 1, identity, zoneId) + } else { + new DurationSequenceImpl[Long](LongType, 1, identity, zoneId) + } case DateType => - new TemporalSequenceImpl[Int](IntegerType, MICROS_PER_DAY, _.toInt, zoneId) + if (stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepOpt.get.dataType)) { + new TemporalSequenceImpl[Int](IntegerType, MICROS_PER_DAY, _.toInt, zoneId) + } else { + new PeriodSequenceImpl[Int](IntegerType, MICROS_PER_DAY, _.toInt, zoneId) + } } override def eval(input: InternalRow): Any = { @@ -2666,7 +2695,7 @@ object Sequence { } } - private trait SequenceImpl { + private trait InternalSequence { def eval(start: Any, stop: Any, step: Any): Any def genCode( @@ -2681,7 +2710,7 @@ object Sequence { } private class IntegralSequenceImpl[T: ClassTag] - (elemType: IntegralType)(implicit num: Integral[T]) extends SequenceImpl { + (elemType: IntegralType)(implicit num: Integral[T]) extends InternalSequence { override val defaultStep: DefaultStep = new DefaultStep( (elemType.ordering.lteq _).asInstanceOf[LessThanOrEqualFn], @@ -2695,7 +2724,7 @@ object Sequence { val stop = input2.asInstanceOf[T] val step = input3.asInstanceOf[T] - var i: Int = getSequenceLength(start, stop, step) + var i: Int = getSequenceLength(start, stop, step, step) val arr = new Array[T](i) while (i > 0) { i -= 1 @@ -2713,7 +2742,7 @@ object Sequence { elemType: String): String = { val i = ctx.freshName("i") s""" - |${genSequenceLengthCode(ctx, start, stop, step, i)} + |${genSequenceLengthCode(ctx, start, stop, step, step, i)} |$arr = new $elemType[$i]; |while ($i > 0) { | $i--; @@ -2723,32 +2752,105 @@ object Sequence { } } + private class PeriodSequenceImpl[T: ClassTag] + (dt: IntegralType, scale: Long, fromLong: Long => T, zoneId: ZoneId) + (implicit num: Integral[T]) extends InternalSequenceBase(dt, scale, fromLong, zoneId) { + + override val defaultStep: DefaultStep = new DefaultStep( + (dt.ordering.lteq _).asInstanceOf[LessThanOrEqualFn], + YearMonthIntervalType, + Period.of(0, 1, 0)) + + val intervalType: DataType = YearMonthIntervalType + + def splitStep(input: Any): (Int, Int, Long) = { + (input.asInstanceOf[Int], 0, 0) + } + + def stepSplitCode( + stepMonths: String, stepDays: String, stepMicros: String, step: String): String = { + s""" + |final int $stepMonths = $step; + |final int $stepDays = 0; + |final long $stepMicros = 0L; + """.stripMargin + } + } + + private class DurationSequenceImpl[T: ClassTag] + (dt: IntegralType, scale: Long, fromLong: Long => T, zoneId: ZoneId) + (implicit num: Integral[T]) extends InternalSequenceBase(dt, scale, fromLong, zoneId) { + + override val defaultStep: DefaultStep = new DefaultStep( + (dt.ordering.lteq _).asInstanceOf[LessThanOrEqualFn], + DayTimeIntervalType, + Duration.ofDays(1)) + + val intervalType: DataType = DayTimeIntervalType + + def splitStep(input: Any): (Int, Int, Long) = { + (0, 0, input.asInstanceOf[Long]) + } + + def stepSplitCode( + stepMonths: String, stepDays: String, stepMicros: String, step: String): String = { + s""" + |final int $stepMonths = 0; + |final int $stepDays = 0; + |final long $stepMicros = $step; + """.stripMargin + } + } + private class TemporalSequenceImpl[T: ClassTag] (dt: IntegralType, scale: Long, fromLong: Long => T, zoneId: ZoneId) - (implicit num: Integral[T]) extends SequenceImpl { + (implicit num: Integral[T]) extends InternalSequenceBase(dt, scale, fromLong, zoneId) { override val defaultStep: DefaultStep = new DefaultStep( (dt.ordering.lteq _).asInstanceOf[LessThanOrEqualFn], CalendarIntervalType, new CalendarInterval(0, 1, 0)) + val intervalType: DataType = CalendarIntervalType + + def splitStep(input: Any): (Int, Int, Long) = { + val step = input.asInstanceOf[CalendarInterval] + (step.months, step.days, step.microseconds) + } + + def stepSplitCode( + stepMonths: String, stepDays: String, stepMicros: String, step: String): String = { + s""" + |final int $stepMonths = $step.months; + |final int $stepDays = $step.days; + |final long $stepMicros = $step.microseconds; + """.stripMargin + } + } + + private abstract class InternalSequenceBase[T: ClassTag] + (dt: IntegralType, scale: Long, fromLong: Long => T, zoneId: ZoneId) + (implicit num: Integral[T]) extends InternalSequence { + + val defaultStep: DefaultStep + private val backedSequenceImpl = new IntegralSequenceImpl[T](dt) - private val microsPerDay = HOURS_PER_DAY * MICROS_PER_HOUR // We choose a minimum days(28) in one month to calculate the `intervalStepInMicros` // in order to make sure the estimated array length is long enough - private val microsPerMonth = 28 * microsPerDay + private val microsPerMonth = 28 * MICROS_PER_DAY + + protected val intervalType: DataType + + protected def splitStep(input: Any): (Int, Int, Long) override def eval(input1: Any, input2: Any, input3: Any): Array[T] = { val start = input1.asInstanceOf[T] val stop = input2.asInstanceOf[T] - val step = input3.asInstanceOf[CalendarInterval] - val stepMonths = step.months - val stepDays = step.days - val stepMicros = step.microseconds + val (stepMonths, stepDays, stepMicros) = splitStep(input3) if (scale == MICROS_PER_DAY && stepMonths == 0 && stepDays == 0) { throw new IllegalArgumentException( - "sequence step must be a day interval if start and end values are dates") + s"sequence step must be a day ${intervalType.typeName} if start and end values are dates") } if (stepMonths == 0 && stepMicros == 0 && scale == MICROS_PER_DAY) { @@ -2763,11 +2865,12 @@ object Sequence { // To estimate the resulted array length we need to make assumptions // about a month length in days and a day length in microseconds val intervalStepInMicros = - stepMicros + stepMonths * microsPerMonth + stepDays * microsPerDay + stepMicros + stepMonths * microsPerMonth + stepDays * MICROS_PER_DAY val startMicros: Long = num.toLong(start) * scale val stopMicros: Long = num.toLong(stop) * scale + val maxEstimatedArrayLength = - getSequenceLength(startMicros, stopMicros, intervalStepInMicros) + getSequenceLength(startMicros, stopMicros, input3, intervalStepInMicros) val stepSign = if (stopMicros >= startMicros) +1 else -1 val exclusiveItem = stopMicros + stepSign @@ -2787,6 +2890,9 @@ object Sequence { } } + protected def stepSplitCode( + stepMonths: String, stepDays: String, stepMicros: String, step: String): String + override def genCode( ctx: CodegenContext, start: String, @@ -2811,25 +2917,27 @@ object Sequence { val sequenceLengthCode = s""" |final long $intervalInMicros = - | $stepMicros + $stepMonths * ${microsPerMonth}L + $stepDays * ${microsPerDay}L; - |${genSequenceLengthCode(ctx, startMicros, stopMicros, intervalInMicros, arrLength)} - """.stripMargin + | $stepMicros + $stepMonths * ${microsPerMonth}L + $stepDays * ${MICROS_PER_DAY}L; + |${genSequenceLengthCode( + ctx, startMicros, stopMicros, step, intervalInMicros, arrLength)} + """.stripMargin val check = if (scale == MICROS_PER_DAY) { s""" |if ($stepMonths == 0 && $stepDays == 0) { | throw new IllegalArgumentException( - | "sequence step must be a day interval if start and end values are dates"); + | "sequence step must be a day ${intervalType.typeName} " + + | "if start and end values are dates"); |} - """.stripMargin + """.stripMargin } else { "" } + val stepSplits = stepSplitCode(stepMonths, stepDays, stepMicros, step) + s""" - |final int $stepMonths = $step.months; - |final int $stepDays = $step.days; - |final long $stepMicros = $step.microseconds; + |$stepSplits | |$check | @@ -2866,15 +2974,16 @@ object Sequence { } } - private def getSequenceLength[U](start: U, stop: U, step: U)(implicit num: Integral[U]): Int = { + private def getSequenceLength[U](start: U, stop: U, step: Any, estimatedStep: U) + (implicit num: Integral[U]): Int = { import num._ require( - (step > num.zero && start <= stop) - || (step < num.zero && start >= stop) - || (step == num.zero && start == stop), + (estimatedStep > num.zero && start <= stop) + || (estimatedStep < num.zero && start >= stop) + || (estimatedStep == num.zero && start == stop), s"Illegal sequence boundaries: $start to $stop by $step") - val len = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) / step.toLong + val len = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) / estimatedStep.toLong require( len <= MAX_ROUNDED_ARRAY_LENGTH, @@ -2888,16 +2997,17 @@ object Sequence { start: String, stop: String, step: String, + estimatedStep: String, len: String): String = { val longLen = ctx.freshName("longLen") s""" - |if (!(($step > 0 && $start <= $stop) || - | ($step < 0 && $start >= $stop) || - | ($step == 0 && $start == $stop))) { + |if (!(($estimatedStep > 0 && $start <= $stop) || + | ($estimatedStep < 0 && $start >= $stop) || + | ($estimatedStep == 0 && $start == $stop))) { | throw new IllegalArgumentException( | "Illegal sequence boundaries: " + $start + " to " + $stop + " by " + $step); |} - |long $longLen = $stop == $start ? 1L : 1L + ((long) $stop - $start) / $step; + |long $longLen = $stop == $start ? 1L : 1L + ((long) $stop - $start) / $estimatedStep; |if ($longLen > $MAX_ROUNDED_ARRAY_LENGTH) { | throw new IllegalArgumentException( | "Too long sequence: " + $longLen + ". Should be <= $MAX_ROUNDED_ARRAY_LENGTH"); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 095894b9fffa..aec8725d515d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} +import java.time.{Duration, Period} import java.util.TimeZone import scala.language.implicitConversions @@ -28,7 +29,6 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils} -import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_DAY import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.UTC import org.apache.spark.sql.catalyst.util.IntervalUtils._ import org.apache.spark.sql.internal.SQLConf @@ -932,7 +932,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper Literal(Date.valueOf("1970-02-01")), Literal(negateExact(stringToInterval("interval 1 month")))), EmptyRow, - s"sequence boundaries: 0 to 2678400000000 by -${28 * MICROS_PER_DAY}") + s"sequence boundaries: 0 to 2678400000000 by -1 months") // SPARK-32133: Sequence step must be a day interval if start and end values are dates checkExceptionInExpression[IllegalArgumentException](Sequence( @@ -943,6 +943,178 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper } } + test("SPARK-35088: Accept ANSI intervals by the Sequence expression") { + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-02 00:00:00")), + Literal(Duration.ofHours(12))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-02 00:00:01")), + Literal(Duration.ofHours(12))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-02 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-02 00:00:00")), + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Duration.ofHours(-12))), + Seq( + Timestamp.valueOf("2018-01-02 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-02 00:00:00")), + Literal(Timestamp.valueOf("2017-12-31 23:59:59")), + Literal(Duration.ofHours(-12))), + Seq( + Timestamp.valueOf("2018-01-02 00:00:00"), + Timestamp.valueOf("2018-01-01 12:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-03-01 00:00:00")), + Literal(Period.ofMonths(1))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00"), + Timestamp.valueOf("2018-02-01 00:00:00"), + Timestamp.valueOf("2018-03-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-03-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Period.ofMonths(-1))), + Seq( + Timestamp.valueOf("2018-03-01 00:00:00"), + Timestamp.valueOf("2018-02-01 00:00:00"), + Timestamp.valueOf("2018-01-01 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-31 00:00:00")), + Literal(Timestamp.valueOf("2018-04-30 00:00:00")), + Literal(Period.ofMonths(1))), + Seq( + Timestamp.valueOf("2018-01-31 00:00:00"), + Timestamp.valueOf("2018-02-28 00:00:00"), + Timestamp.valueOf("2018-03-31 00:00:00"), + Timestamp.valueOf("2018-04-30 00:00:00"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2023-01-01 00:00:00")), + Literal(Period.of(1, 5, 0))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00.000"), + Timestamp.valueOf("2019-06-01 00:00:00.000"), + Timestamp.valueOf("2020-11-01 00:00:00.000"), + Timestamp.valueOf("2022-04-01 00:00:00.000"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2022-04-01 00:00:00")), + Literal(Timestamp.valueOf("2017-01-01 00:00:00")), + Literal(Period.of(-1, -5, 0))), + Seq( + Timestamp.valueOf("2022-04-01 00:00:00.000"), + Timestamp.valueOf("2020-11-01 00:00:00.000"), + Timestamp.valueOf("2019-06-01 00:00:00.000"), + Timestamp.valueOf("2018-01-01 00:00:00.000"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-04 00:00:00")), + Literal(Duration.ofDays(1))), + Seq( + Timestamp.valueOf("2018-01-01 00:00:00.000"), + Timestamp.valueOf("2018-01-02 00:00:00.000"), + Timestamp.valueOf("2018-01-03 00:00:00.000"), + Timestamp.valueOf("2018-01-04 00:00:00.000"))) + + checkEvaluation(new Sequence( + Literal(Timestamp.valueOf("2018-01-04 00:00:00")), + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Duration.ofDays(-1))), + Seq( + Timestamp.valueOf("2018-01-04 00:00:00.000"), + Timestamp.valueOf("2018-01-03 00:00:00.000"), + Timestamp.valueOf("2018-01-02 00:00:00.000"), + Timestamp.valueOf("2018-01-01 00:00:00.000"))) + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence( + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Timestamp.valueOf("2018-01-04 00:00:00")), + Literal(Period.ofDays(1))), + EmptyRow, s"sequence boundaries: 1514793600000000 to 1515052800000000 by 0") + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence( + Literal(Timestamp.valueOf("2018-01-04 00:00:00")), + Literal(Timestamp.valueOf("2018-01-01 00:00:00")), + Literal(Period.ofDays(-1))), + EmptyRow, s"sequence boundaries: 1515052800000000 to 1514793600000000 by 0") + + DateTimeTestUtils.withDefaultTimeZone(UTC) { + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-01")), + Literal(Date.valueOf("2018-03-01")), + Literal(Period.ofMonths(1))), + Seq( + Date.valueOf("2018-01-01"), + Date.valueOf("2018-02-01"), + Date.valueOf("2018-03-01"))) + + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-31")), + Literal(Date.valueOf("2018-04-30")), + Literal(Period.ofMonths(1))), + Seq( + Date.valueOf("2018-01-31"), + Date.valueOf("2018-02-28"), + Date.valueOf("2018-03-31"), + Date.valueOf("2018-04-30"))) + + checkEvaluation(new Sequence( + Literal(Date.valueOf("2018-01-01")), + Literal(Date.valueOf("2023-01-01")), + Literal(Period.of(1, 5, 0))), + Seq( + Date.valueOf("2018-01-01"), + Date.valueOf("2019-06-01"), + Date.valueOf("2020-11-01"), + Date.valueOf("2022-04-01"))) + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence( + Literal(Date.valueOf("2018-01-01")), + Literal(Date.valueOf("2018-01-05")), + Literal(Period.ofDays(2))), + EmptyRow, + "sequence step must be a day year-month interval if start and end values are dates") + + checkExceptionInExpression[IllegalArgumentException]( + new Sequence( + Literal(Date.valueOf("1970-01-01")), + Literal(Date.valueOf("1970-02-01")), + Literal(Period.ofMonths(-1))), + EmptyRow, + s"sequence boundaries: 0 to 2678400000000 by -1") + + assert(Sequence( + Cast(Literal("2011-03-01"), DateType), + Cast(Literal("2011-04-01"), DateType), + Option(Literal(Duration.ofHours(1)))).checkInputDataTypes().isFailure) + } + } + test("Sequence with default step") { // +/- 1 for integral type checkEvaluation(new Sequence(Literal(1), Literal(3)), Seq(1, 2, 3))