diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 30317c9e9138..0fd2201213af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -73,6 +73,7 @@ object Cast { case (TimestampType, DateType) => true case (StringType, CalendarIntervalType) => true + case (StringType, YearMonthIntervalType) => true case (StringType, _: NumericType) => true case (BooleanType, _: NumericType) => true @@ -534,6 +535,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit buildCast[UTF8String](_, s => IntervalUtils.safeStringToInterval(s)) } + private[this] def castToYearMonthInterval(from: DataType): Any => Any = from match { + case StringType => + buildCast[UTF8String](_, s => IntervalUtils.castStringToYMInterval(s)) + } + // LongConverter private[this] def castToLong(from: DataType): Any => Any = from match { case StringType if ansiEnabled => @@ -838,6 +844,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit case decimal: DecimalType => castToDecimal(from, decimal) case TimestampType => castToTimestamp(from) case CalendarIntervalType => castToInterval(from) + case YearMonthIntervalType => castToYearMonthInterval(from) case BooleanType => castToBoolean(from) case ByteType => castToByte(from) case ShortType => castToShort(from) @@ -896,6 +903,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit case decimal: DecimalType => castToDecimalCode(from, decimal, ctx) case TimestampType => castToTimestampCode(from, ctx) case CalendarIntervalType => castToIntervalCode(from) + case YearMonthIntervalType => castToYearMonthIntervalCode(from) case BooleanType => castToBooleanCode(from) case ByteType => castToByteCode(from, ctx) case ShortType => castToShortCode(from, ctx) @@ -1354,6 +1362,12 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } + private[this] def castToYearMonthIntervalCode(from: DataType): CastFunction = from match { + case StringType => + val util = IntervalUtils.getClass.getCanonicalName.stripSuffix("$") + (c, evPrim, _) => code"$evPrim = $util.castStringToYMInterval($c);" + } + private[this] def decimalToTimestampCode(d: ExprValue): Block = { val block = inline"new java.math.BigDecimal($MICROS_PER_SECOND)" code"($d.toBigDecimal().bigDecimal().multiply($block)).longValue()" @@ -1915,6 +1929,7 @@ object AnsiCast { case (DateType, TimestampType) => true case (StringType, _: CalendarIntervalType) => true + case (StringType, YearMonthIntervalType) => true case (StringType, DateType) => true case (TimestampType, DateType) => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index 03b1835c7cd0..d7d18ff4fdd0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -92,6 +92,25 @@ object IntervalUtils { } private val yearMonthPattern = "^([+|-])?(\\d+)-(\\d+)$".r + private val yearMonthStringPattern = + "(?i)^(INTERVAL\\s+)([+|-])?(')([+|-])?(\\d+)-(\\d+)(')(\\s+YEAR\\s+TO\\s+MONTH)$".r + + def castStringToYMInterval(input: UTF8String): Int = { + input.trimAll().toString match { + case yearMonthPattern("-", year, month) => toYMInterval(year, month, -1) + case yearMonthPattern(_, year, month) => toYMInterval(year, month, 1) + case yearMonthStringPattern(_, firstSign, _, secondSign, year, month, _, _) => + (firstSign, secondSign) match { + case ("-", "-") => toYMInterval(year, month, 1) + case ("-", _) => toYMInterval(year, month, -1) + case (_, "-") => toYMInterval(year, month, -1) + case (_, _) => toYMInterval(year, month, 1) + } + case _ => throw new IllegalArgumentException( + s"Interval string does not match year-month format of `[+|-]y-m` " + + s"or `INTERVAL [+|-]'[+|-]y-m' YEAR TO MONTH`: ${input.toString}") + } + } /** * Parse YearMonth string in form: [+|-]YYYY-MM @@ -100,28 +119,29 @@ object IntervalUtils { */ def fromYearMonthString(input: String): CalendarInterval = { require(input != null, "Interval year-month string must be not null") - def toInterval(yearStr: String, monthStr: String, sign: Int): CalendarInterval = { - try { - val years = toLongWithRange(YEAR, yearStr, 0, Integer.MAX_VALUE / MONTHS_PER_YEAR) - val totalMonths = sign * (years * MONTHS_PER_YEAR + toLongWithRange(MONTH, monthStr, 0, 11)) - new CalendarInterval(Math.toIntExact(totalMonths), 0, 0) - } catch { - case NonFatal(e) => - throw new IllegalArgumentException( - s"Error parsing interval year-month string: ${e.getMessage}", e) - } - } input.trim match { case yearMonthPattern("-", yearStr, monthStr) => - toInterval(yearStr, monthStr, -1) + new CalendarInterval(toYMInterval(yearStr, monthStr, -1), 0, 0) case yearMonthPattern(_, yearStr, monthStr) => - toInterval(yearStr, monthStr, 1) + new CalendarInterval(toYMInterval(yearStr, monthStr, 1), 0, 0) case _ => throw new IllegalArgumentException( s"Interval string does not match year-month format of 'y-m': $input") } } + def toYMInterval(yearStr: String, monthStr: String, sign: Int): Int = { + try { + val years = toLongWithRange(YEAR, yearStr, 0, Integer.MAX_VALUE / MONTHS_PER_YEAR) + val totalMonths = sign * (years * MONTHS_PER_YEAR + toLongWithRange(MONTH, monthStr, 0, 11)) + Math.toIntExact(totalMonths) + } catch { + case NonFatal(e) => + throw new IllegalArgumentException( + s"Error parsing interval year-month string: ${e.getMessage}", e) + } + } + /** * Parse dayTime string in form: [-]d HH:mm:ss.nnnnnnnnn and [-]HH:mm:ss.nnnnnnnnn * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 0554d073d1ab..e6874c618018 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -1774,6 +1774,48 @@ class CastSuite extends CastSuiteBase { assert(e3.contains("Casting 2147483648 to int causes overflow")) } } + + test("SPARK-35111: Cast string to year-month interval") { + checkEvaluation(cast(Literal.create("INTERVAL '1-0' YEAR TO MONTH"), + YearMonthIntervalType), 12) + checkEvaluation(cast(Literal.create("INTERVAL '-1-0' YEAR TO MONTH"), + YearMonthIntervalType), -12) + checkEvaluation(cast(Literal.create("INTERVAL -'-1-0' YEAR TO MONTH"), + YearMonthIntervalType), 12) + checkEvaluation(cast(Literal.create("INTERVAL +'-1-0' YEAR TO MONTH"), + YearMonthIntervalType), -12) + checkEvaluation(cast(Literal.create("INTERVAL +'+1-0' YEAR TO MONTH"), + YearMonthIntervalType), 12) + checkEvaluation(cast(Literal.create("INTERVAL +'1-0' YEAR TO MONTH"), + YearMonthIntervalType), 12) + checkEvaluation(cast(Literal.create(" interval +'1-0' YEAR TO MONTH "), + YearMonthIntervalType), 12) + checkEvaluation(cast(Literal.create(" -1-0 "), YearMonthIntervalType), -12) + checkEvaluation(cast(Literal.create("-1-0"), YearMonthIntervalType), -12) + checkEvaluation(cast(Literal.create(null, StringType), YearMonthIntervalType), null) + + Seq("0-0", "10-1", "-178956970-7", "178956970-7", "-178956970-8").foreach { interval => + val ansiInterval = s"INTERVAL '$interval' YEAR TO MONTH" + checkEvaluation( + cast(cast(Literal.create(interval), YearMonthIntervalType), StringType), ansiInterval) + checkEvaluation(cast(cast(Literal.create(ansiInterval), + YearMonthIntervalType), StringType), ansiInterval) + } + + Seq("INTERVAL '-178956970-9' YEAR TO MONTH", "INTERVAL '178956970-8' YEAR TO MONTH") + .foreach { interval => + val e = intercept[IllegalArgumentException] { + cast(Literal.create(interval), YearMonthIntervalType).eval() + }.getMessage + assert(e.contains("Error parsing interval year-month string: integer overflow")) + } + + Seq(Byte.MaxValue, Short.MaxValue, Int.MaxValue, Int.MinValue + 1, Int.MinValue) + .foreach { period => + val interval = Literal.create(Period.ofMonths(period), YearMonthIntervalType) + checkEvaluation(cast(cast(interval, StringType), YearMonthIntervalType), period) + } + } } /**