Skip to content
Prev Previous commit
Next Next commit
Merge branch 'master' into SPARK-35112
  • Loading branch information
AngersZhuuuu committed Apr 30, 2021
commit 6ff0522bd5e7ff712735c3de8f99bd05a69f81e0
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ object Cast {

case (StringType, CalendarIntervalType) => true
case (StringType, DayTimeIntervalType) => true
case (StringType, YearMonthIntervalType) => true

case (StringType, _: NumericType) => true
case (BooleanType, _: NumericType) => true
Expand Down Expand Up @@ -540,6 +541,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
buildCast[UTF8String](_, s => IntervalUtils.castStringToDTInterval(s))
}

private[this] def castToYearMonthInterval(from: DataType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, s => IntervalUtils.castStringToYMInterval(s))
}

// LongConverter
private[this] def castToLong(from: DataType): Any => Any = from match {
case StringType if ansiEnabled =>
Expand Down Expand Up @@ -845,6 +851,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
case TimestampType => castToTimestamp(from)
case CalendarIntervalType => castToInterval(from)
case DayTimeIntervalType => castToDayTimeInterval(from)
case YearMonthIntervalType => castToYearMonthInterval(from)
case BooleanType => castToBoolean(from)
case ByteType => castToByte(from)
case ShortType => castToShort(from)
Expand Down Expand Up @@ -904,6 +911,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
case TimestampType => castToTimestampCode(from, ctx)
case CalendarIntervalType => castToIntervalCode(from)
case DayTimeIntervalType => castToDayTimeIntervalCode(from)
case YearMonthIntervalType => castToYearMonthIntervalCode(from)
case BooleanType => castToBooleanCode(from)
case ByteType => castToByteCode(from, ctx)
case ShortType => castToShortCode(from, ctx)
Expand Down Expand Up @@ -1369,6 +1377,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
code"$evPrim = $util.castStringToDTInterval($c);"
}

private[this] def castToYearMonthIntervalCode(from: DataType): CastFunction = from match {
case StringType =>
val util = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
(c, evPrim, _) => code"$evPrim = $util.castStringToYMInterval($c);"
}

private[this] def decimalToTimestampCode(d: ExprValue): Block = {
val block = inline"new java.math.BigDecimal($MICROS_PER_SECOND)"
code"($d.toBigDecimal().bigDecimal().multiply($block)).longValue()"
Expand Down Expand Up @@ -1931,6 +1945,7 @@ object AnsiCast {

case (StringType, _: CalendarIntervalType) => true
case (StringType, DayTimeIntervalType) => true
case (StringType, YearMonthIntervalType) => true

case (StringType, DateType) => true
case (TimestampType, DateType) => true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,18 @@ object IntervalUtils {
calendarToMicros(calendar)
}

def toYMInterval(yearStr: String, monthStr: String, sign: Int): Int = {
try {
val years = toLongWithRange(YEAR, yearStr, 0, Integer.MAX_VALUE / MONTHS_PER_YEAR)
val totalMonths = sign * (years * MONTHS_PER_YEAR + toLongWithRange(MONTH, monthStr, 0, 11))
Math.toIntExact(totalMonths)
} catch {
case NonFatal(e) =>
throw new IllegalArgumentException(
s"Error parsing interval year-month string: ${e.getMessage}", e)
}
}

/**
* Parse dayTime string in form: [-]d HH:mm:ss.nnnnnnnnn and [-]HH:mm:ss.nnnnnnnnn
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1817,6 +1817,48 @@ class CastSuite extends CastSuiteBase {
checkEvaluation(cast(cast(interval, StringType), DayTimeIntervalType), duration)
}
}

test("SPARK-35111: Cast string to year-month interval") {
checkEvaluation(cast(Literal.create("INTERVAL '1-0' YEAR TO MONTH"),
YearMonthIntervalType), 12)
checkEvaluation(cast(Literal.create("INTERVAL '-1-0' YEAR TO MONTH"),
YearMonthIntervalType), -12)
checkEvaluation(cast(Literal.create("INTERVAL -'-1-0' YEAR TO MONTH"),
YearMonthIntervalType), 12)
checkEvaluation(cast(Literal.create("INTERVAL +'-1-0' YEAR TO MONTH"),
YearMonthIntervalType), -12)
checkEvaluation(cast(Literal.create("INTERVAL +'+1-0' YEAR TO MONTH"),
YearMonthIntervalType), 12)
checkEvaluation(cast(Literal.create("INTERVAL +'1-0' YEAR TO MONTH"),
YearMonthIntervalType), 12)
checkEvaluation(cast(Literal.create(" interval +'1-0' YEAR TO MONTH "),
YearMonthIntervalType), 12)
checkEvaluation(cast(Literal.create(" -1-0 "), YearMonthIntervalType), -12)
checkEvaluation(cast(Literal.create("-1-0"), YearMonthIntervalType), -12)
checkEvaluation(cast(Literal.create(null, StringType), YearMonthIntervalType), null)

Seq("0-0", "10-1", "-178956970-7", "178956970-7", "-178956970-8").foreach { interval =>
val ansiInterval = s"INTERVAL '$interval' YEAR TO MONTH"
checkEvaluation(
cast(cast(Literal.create(interval), YearMonthIntervalType), StringType), ansiInterval)
checkEvaluation(cast(cast(Literal.create(ansiInterval),
YearMonthIntervalType), StringType), ansiInterval)
}

Seq("INTERVAL '-178956970-9' YEAR TO MONTH", "INTERVAL '178956970-8' YEAR TO MONTH")
.foreach { interval =>
val e = intercept[IllegalArgumentException] {
cast(Literal.create(interval), YearMonthIntervalType).eval()
}.getMessage
assert(e.contains("Error parsing interval year-month string: integer overflow"))
}

Seq(Byte.MaxValue, Short.MaxValue, Int.MaxValue, Int.MinValue + 1, Int.MinValue)
.foreach { period =>
val interval = Literal.create(Period.ofMonths(period), YearMonthIntervalType)
checkEvaluation(cast(cast(interval, StringType), YearMonthIntervalType), period)
}
}
}

/**
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.