Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.apache.spark.sql.catalyst.expressions

import java.time.ZoneId
import java.time.{Duration, Period, ZoneId}
import java.util.Comparator

import scala.collection.mutable
Expand Down Expand Up @@ -2484,8 +2484,8 @@ case class Flatten(child: Expression) extends UnaryExpression with NullIntoleran

The start and stop expressions must resolve to the same type.
If start and stop expressions resolve to the 'date' or 'timestamp' type
then the step expression must resolve to the 'interval' type, otherwise to the same type
as the start and stop expressions.
then the step expression must resolve to the 'interval' or 'year-month interval' or
'day-time interval' type, otherwise to the same type as the start and stop expressions.
""",
arguments = """
Arguments:
Expand All @@ -2504,6 +2504,8 @@ case class Flatten(child: Expression) extends UnaryExpression with NullIntoleran
[5,4,3,2,1]
> SELECT _FUNC_(to_date('2018-01-01'), to_date('2018-03-01'), interval 1 month);
[2018-01-01,2018-02-01,2018-03-01]
> SELECT _FUNC_(to_date('2018-01-01'), to_date('2018-03-01'), interval '0-1' year to month);
[2018-01-01,2018-02-01,2018-03-01]
""",
group = "array_funcs",
since = "2.4.0"
Expand Down Expand Up @@ -2550,8 +2552,13 @@ case class Sequence(
val typesCorrect =
startType.sameType(stop.dataType) &&
(startType match {
case TimestampType | DateType =>
stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepType)
case TimestampType =>
stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepType) ||
YearMonthIntervalType.acceptsType(stepType) ||
DayTimeIntervalType.acceptsType(stepType)
case DateType =>
stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepType) ||
YearMonthIntervalType.acceptsType(stepType)
case _: IntegralType =>
stepOpt.isEmpty || stepType.sameType(startType)
case _ => false
Expand All @@ -2561,29 +2568,51 @@ case class Sequence(
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(
s"$prettyName only supports integral, timestamp or date types")
s"""
|$prettyName uses the wrong parameter type. The parameter type must conform to:
|1. The start and stop expressions must resolve to the same type.
|2. If start and stop expressions resolve to the 'date' or 'timestamp' type
|then the step expression must resolve to the 'interval' or
|'${YearMonthIntervalType.typeName}' or '${DayTimeIntervalType.typeName}' type,
|otherwise to the same type as the start and stop expressions.
""".stripMargin)
}
}

def coercibleChildren: Seq[Expression] = children.filter(_.dataType != CalendarIntervalType)
private def isNotIntervalType(expr: Expression) = expr.dataType match {
case CalendarIntervalType | YearMonthIntervalType | DayTimeIntervalType => false
case _ => true
}

def coercibleChildren: Seq[Expression] = children.filter(isNotIntervalType)

def castChildrenTo(widerType: DataType): Expression = Sequence(
Cast(start, widerType),
Cast(stop, widerType),
stepOpt.map(step => if (step.dataType != CalendarIntervalType) Cast(step, widerType) else step),
stepOpt.map(step => if (isNotIntervalType(step)) Cast(step, widerType) else step),
timeZoneId)

@transient private lazy val impl: SequenceImpl = dataType.elementType match {
@transient private lazy val impl: InternalSequence = dataType.elementType match {
case iType: IntegralType =>
type T = iType.InternalType
val ct = ClassTag[T](iType.tag.mirror.runtimeClass(iType.tag.tpe))
new IntegralSequenceImpl(iType)(ct, iType.integral)

case TimestampType =>
new TemporalSequenceImpl[Long](LongType, 1, identity, zoneId)
if (stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepOpt.get.dataType)) {
new TemporalSequenceImpl[Long](LongType, 1, identity, zoneId)
} else if (YearMonthIntervalType.acceptsType(stepOpt.get.dataType)) {
new PeriodSequenceImpl[Long](LongType, 1, identity, zoneId)
} else {
new DurationSequenceImpl[Long](LongType, 1, identity, zoneId)
}

case DateType =>
new TemporalSequenceImpl[Int](IntegerType, MICROS_PER_DAY, _.toInt, zoneId)
if (stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepOpt.get.dataType)) {
new TemporalSequenceImpl[Int](IntegerType, MICROS_PER_DAY, _.toInt, zoneId)
} else {
new PeriodSequenceImpl[Int](IntegerType, MICROS_PER_DAY, _.toInt, zoneId)
}
}

override def eval(input: InternalRow): Any = {
Expand Down Expand Up @@ -2666,7 +2695,7 @@ object Sequence {
}
}

private trait SequenceImpl {
private trait InternalSequence {
def eval(start: Any, stop: Any, step: Any): Any

def genCode(
Expand All @@ -2681,7 +2710,7 @@ object Sequence {
}

private class IntegralSequenceImpl[T: ClassTag]
(elemType: IntegralType)(implicit num: Integral[T]) extends SequenceImpl {
(elemType: IntegralType)(implicit num: Integral[T]) extends InternalSequence {

override val defaultStep: DefaultStep = new DefaultStep(
(elemType.ordering.lteq _).asInstanceOf[LessThanOrEqualFn],
Expand All @@ -2695,7 +2724,7 @@ object Sequence {
val stop = input2.asInstanceOf[T]
val step = input3.asInstanceOf[T]

var i: Int = getSequenceLength(start, stop, step)
var i: Int = getSequenceLength(start, stop, step, step)
val arr = new Array[T](i)
while (i > 0) {
i -= 1
Expand All @@ -2713,7 +2742,7 @@ object Sequence {
elemType: String): String = {
val i = ctx.freshName("i")
s"""
|${genSequenceLengthCode(ctx, start, stop, step, i)}
|${genSequenceLengthCode(ctx, start, stop, step, step, i)}
|$arr = new $elemType[$i];
|while ($i > 0) {
| $i--;
Expand All @@ -2723,32 +2752,105 @@ object Sequence {
}
}

private class PeriodSequenceImpl[T: ClassTag]
(dt: IntegralType, scale: Long, fromLong: Long => T, zoneId: ZoneId)
(implicit num: Integral[T]) extends InternalSequenceBase(dt, scale, fromLong, zoneId) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure it is good idea to use timestampAddInterval() in InternalSequenceBase.eval for adding months to dates. I guess DateTimeUtils.dateAddMonths() and DateTimeUtils.timestampAddInterval can return different result, especially taking into account that dateAddMonths() does not depend on the current time zone.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, the current implement uses DateTimeUtils.timestampAddInterval and it's behavior seems good.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. Let's use timestampAddInterval since we don't have an example that could demonstrate any issues caused by timestampAddInterval().


override val defaultStep: DefaultStep = new DefaultStep(
(dt.ordering.lteq _).asInstanceOf[LessThanOrEqualFn],
YearMonthIntervalType,
Period.of(0, 1, 0))

val intervalType: DataType = YearMonthIntervalType

def splitStep(input: Any): (Int, Int, Long) = {
(input.asInstanceOf[Int], 0, 0)
}

def stepSplitCode(
stepMonths: String, stepDays: String, stepMicros: String, step: String): String = {
s"""
|final int $stepMonths = $step;
|final int $stepDays = 0;
|final long $stepMicros = 0L;
""".stripMargin
}
}

private class DurationSequenceImpl[T: ClassTag]
(dt: IntegralType, scale: Long, fromLong: Long => T, zoneId: ZoneId)
(implicit num: Integral[T]) extends InternalSequenceBase(dt, scale, fromLong, zoneId) {

override val defaultStep: DefaultStep = new DefaultStep(
(dt.ordering.lteq _).asInstanceOf[LessThanOrEqualFn],
DayTimeIntervalType,
Duration.ofDays(1))

val intervalType: DataType = DayTimeIntervalType

def splitStep(input: Any): (Int, Int, Long) = {
(0, 0, input.asInstanceOf[Long])
}

def stepSplitCode(
stepMonths: String, stepDays: String, stepMicros: String, step: String): String = {
s"""
|final int $stepMonths = 0;
|final int $stepDays = 0;
|final long $stepMicros = $step;
""".stripMargin
}
}

private class TemporalSequenceImpl[T: ClassTag]
(dt: IntegralType, scale: Long, fromLong: Long => T, zoneId: ZoneId)
(implicit num: Integral[T]) extends SequenceImpl {
(implicit num: Integral[T]) extends InternalSequenceBase(dt, scale, fromLong, zoneId) {

override val defaultStep: DefaultStep = new DefaultStep(
(dt.ordering.lteq _).asInstanceOf[LessThanOrEqualFn],
CalendarIntervalType,
new CalendarInterval(0, 1, 0))

val intervalType: DataType = CalendarIntervalType

def splitStep(input: Any): (Int, Int, Long) = {
val step = input.asInstanceOf[CalendarInterval]
(step.months, step.days, step.microseconds)
}

def stepSplitCode(
stepMonths: String, stepDays: String, stepMicros: String, step: String): String = {
s"""
|final int $stepMonths = $step.months;
|final int $stepDays = $step.days;
|final long $stepMicros = $step.microseconds;
""".stripMargin
}
}

private abstract class InternalSequenceBase[T: ClassTag]
(dt: IntegralType, scale: Long, fromLong: Long => T, zoneId: ZoneId)
(implicit num: Integral[T]) extends InternalSequence {

val defaultStep: DefaultStep

private val backedSequenceImpl = new IntegralSequenceImpl[T](dt)
private val microsPerDay = HOURS_PER_DAY * MICROS_PER_HOUR
// We choose a minimum days(28) in one month to calculate the `intervalStepInMicros`
// in order to make sure the estimated array length is long enough
private val microsPerMonth = 28 * microsPerDay
private val microsPerMonth = 28 * MICROS_PER_DAY

protected val intervalType: DataType

protected def splitStep(input: Any): (Int, Int, Long)

override def eval(input1: Any, input2: Any, input3: Any): Array[T] = {
val start = input1.asInstanceOf[T]
val stop = input2.asInstanceOf[T]
val step = input3.asInstanceOf[CalendarInterval]
val stepMonths = step.months
val stepDays = step.days
val stepMicros = step.microseconds
val (stepMonths, stepDays, stepMicros) = splitStep(input3)

if (scale == MICROS_PER_DAY && stepMonths == 0 && stepDays == 0) {
throw new IllegalArgumentException(
"sequence step must be a day interval if start and end values are dates")
s"sequence step must be a day ${intervalType.typeName} if start and end values are dates")
}

if (stepMonths == 0 && stepMicros == 0 && scale == MICROS_PER_DAY) {
Expand All @@ -2763,11 +2865,12 @@ object Sequence {
// To estimate the resulted array length we need to make assumptions
// about a month length in days and a day length in microseconds
val intervalStepInMicros =
stepMicros + stepMonths * microsPerMonth + stepDays * microsPerDay
stepMicros + stepMonths * microsPerMonth + stepDays * MICROS_PER_DAY
val startMicros: Long = num.toLong(start) * scale
val stopMicros: Long = num.toLong(stop) * scale

val maxEstimatedArrayLength =
getSequenceLength(startMicros, stopMicros, intervalStepInMicros)
getSequenceLength(startMicros, stopMicros, input3, intervalStepInMicros)

val stepSign = if (stopMicros >= startMicros) +1 else -1
val exclusiveItem = stopMicros + stepSign
Expand All @@ -2787,6 +2890,9 @@ object Sequence {
}
}

protected def stepSplitCode(
stepMonths: String, stepDays: String, stepMicros: String, step: String): String

override def genCode(
ctx: CodegenContext,
start: String,
Expand All @@ -2811,25 +2917,27 @@ object Sequence {
val sequenceLengthCode =
s"""
|final long $intervalInMicros =
| $stepMicros + $stepMonths * ${microsPerMonth}L + $stepDays * ${microsPerDay}L;
|${genSequenceLengthCode(ctx, startMicros, stopMicros, intervalInMicros, arrLength)}
""".stripMargin
| $stepMicros + $stepMonths * ${microsPerMonth}L + $stepDays * ${MICROS_PER_DAY}L;
|${genSequenceLengthCode(
ctx, startMicros, stopMicros, step, intervalInMicros, arrLength)}
""".stripMargin

val check = if (scale == MICROS_PER_DAY) {
s"""
|if ($stepMonths == 0 && $stepDays == 0) {
| throw new IllegalArgumentException(
| "sequence step must be a day interval if start and end values are dates");
| "sequence step must be a day ${intervalType.typeName} " +
| "if start and end values are dates");
|}
""".stripMargin
""".stripMargin
} else {
""
}

val stepSplits = stepSplitCode(stepMonths, stepDays, stepMicros, step)

s"""
|final int $stepMonths = $step.months;
|final int $stepDays = $step.days;
|final long $stepMicros = $step.microseconds;
|$stepSplits
|
|$check
|
Expand Down Expand Up @@ -2866,15 +2974,16 @@ object Sequence {
}
}

private def getSequenceLength[U](start: U, stop: U, step: U)(implicit num: Integral[U]): Int = {
private def getSequenceLength[U](start: U, stop: U, step: Any, estimatedStep: U)
(implicit num: Integral[U]): Int = {
import num._
require(
(step > num.zero && start <= stop)
|| (step < num.zero && start >= stop)
|| (step == num.zero && start == stop),
(estimatedStep > num.zero && start <= stop)
|| (estimatedStep < num.zero && start >= stop)
|| (estimatedStep == num.zero && start == stop),
s"Illegal sequence boundaries: $start to $stop by $step")

val len = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) / step.toLong
val len = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) / estimatedStep.toLong

require(
len <= MAX_ROUNDED_ARRAY_LENGTH,
Expand All @@ -2888,16 +2997,17 @@ object Sequence {
start: String,
stop: String,
step: String,
estimatedStep: String,
len: String): String = {
val longLen = ctx.freshName("longLen")
s"""
|if (!(($step > 0 && $start <= $stop) ||
| ($step < 0 && $start >= $stop) ||
| ($step == 0 && $start == $stop))) {
|if (!(($estimatedStep > 0 && $start <= $stop) ||
| ($estimatedStep < 0 && $start >= $stop) ||
| ($estimatedStep == 0 && $start == $stop))) {
| throw new IllegalArgumentException(
| "Illegal sequence boundaries: " + $start + " to " + $stop + " by " + $step);
|}
|long $longLen = $stop == $start ? 1L : 1L + ((long) $stop - $start) / $step;
|long $longLen = $stop == $start ? 1L : 1L + ((long) $stop - $start) / $estimatedStep;
|if ($longLen > $MAX_ROUNDED_ARRAY_LENGTH) {
| throw new IllegalArgumentException(
| "Too long sequence: " + $longLen + ". Should be <= $MAX_ROUNDED_ARRAY_LENGTH");
Expand Down
Loading