Skip to content

Commit 700efa4

Browse files
committed
update
1 parent c2c7021 commit 700efa4

File tree

3 files changed

+37
-16
lines changed

3 files changed

+37
-16
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
536536

537537
private[this] def castToDayTimeInterval(from: DataType): Any => Any = from match {
538538
case StringType =>
539-
buildCast[UTF8String](_, s => IntervalUtils.castStringToDTInterval(s).microseconds)
539+
buildCast[UTF8String](_, s => IntervalUtils.castStringToDTInterval(s))
540540
}
541541

542542
// LongConverter
@@ -1365,7 +1365,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
13651365
case StringType =>
13661366
val util = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
13671367
(c, evPrim, evNull) =>
1368-
code"$evPrim = $util.castStringToDTInterval($c).microseconds;"
1368+
code"$evPrim = $util.castStringToDTInterval($c);"
13691369
}
13701370

13711371
private[this] def decimalToTimestampCode(d: ExprValue): Block = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,14 @@ object IntervalUtils {
127127
"([+|-])?(\\d+) (\\d{1,2}):(\\d{1,2}):(\\d{1,2})(\\.\\d{1,9})?(')(\\s+DAY\\s+TO\\s+SECOND)$").r
128128
private val daySecondPattern = "^([+|-])?(\\d+) (\\d{1,2}):(\\d{1,2}):(\\d{1,2})(\\.\\d{1,9})?$".r
129129

130-
def castStringToDTInterval(input: UTF8String): CalendarInterval = {
130+
def castStringToDTInterval(input: UTF8String): Long = {
131+
def calendarToMicros(calendar: CalendarInterval): Long = {
132+
getDuration(calendar, TimeUnit.MICROSECONDS)
133+
}
131134
val intervalStr = input.trimAll().toString
132135
val ansiDaySecondPattern =
133136
"([+|-])?(\\d+) (\\d{1,2}):(\\d{1,2}):(\\d{1,2})(\\.\\d{1,9})?".r
134-
intervalStr match {
137+
val calendar = intervalStr match {
135138
case daySecondPattern(_, _, _, _, _, _) => fromDayTimeString(intervalStr, DAY, SECOND)
136139
case daySecondStringPattern(_, prefixSign, _, suffixSign, _, _, _, _, _, _, _) =>
137140
val dtStr =
@@ -148,6 +151,7 @@ object IntervalUtils {
148151
s"or `INTERVAL [+|-]'[+|-]d h:m:s.n' DAY TO SECOND`: ${input.toString}, " +
149152
s"$fallbackNotice")
150153
}
154+
calendarToMicros(calendar)
151155
}
152156

153157
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion.numericPrecedence
3232
import org.apache.spark.sql.catalyst.analysis.TypeCoercionSuite
3333
import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectList, CollectSet}
3434
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
35+
import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, IntervalUtils}
3536
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
36-
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
3737
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
3838
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
3939
import org.apache.spark.sql.catalyst.util.IntervalUtils.microsToDuration
@@ -1776,28 +1776,45 @@ class CastSuite extends CastSuiteBase {
17761776
}
17771777

17781778
test("SPARK-35112: Cast string to day-time interval") {
1779+
val interval = IntervalUtils.fromDayTimeString("106751991 04:00:54.775807")
17791780
checkEvaluation(cast(Literal.create("0 0:0:0"), DayTimeIntervalType), 0L)
17801781
checkEvaluation(cast(Literal.create("INTERVAL '0 0:0:0' DAY TO SECOND"),
17811782
DayTimeIntervalType), 0L)
17821783
checkEvaluation(cast(Literal.create("INTERVAL '1 2:03:04' DAY TO SECOND"),
1783-
DayTimeIntervalType), 7384000000L)
1784+
DayTimeIntervalType), 93784000000L)
17841785
checkEvaluation(cast(Literal.create("INTERVAL '1 03:04:00' DAY TO SECOND"),
1785-
DayTimeIntervalType), 11040000000L)
1786+
DayTimeIntervalType), 97440000000L)
17861787
checkEvaluation(cast(Literal.create("INTERVAL '1 03:04:00.0000' DAY TO SECOND"),
1787-
DayTimeIntervalType), 11040000000L)
1788-
checkEvaluation(cast(Literal.create("1 2:03:04"), DayTimeIntervalType), 7384000000L)
1788+
DayTimeIntervalType), 97440000000L)
1789+
checkEvaluation(cast(Literal.create("1 2:03:04"), DayTimeIntervalType), 93784000000L)
17891790
checkEvaluation(cast(Literal.create("INTERVAL '-10 2:03:04' DAY TO SECOND"),
1790-
DayTimeIntervalType), -7384000000L)
1791-
checkEvaluation(cast(Literal.create("-10 2:03:04"), DayTimeIntervalType), -7384000000L)
1791+
DayTimeIntervalType), -871384000000L)
1792+
checkEvaluation(cast(Literal.create("-10 2:03:04"), DayTimeIntervalType), -871384000000L)
17921793
checkEvaluation(cast(Literal.create("-106751991 04:00:54.775808"), DayTimeIntervalType),
1793-
-14454775808L)
1794+
Long.MinValue)
17941795
checkEvaluation(cast(Literal.create("106751991 04:00:54.775807"), DayTimeIntervalType),
1795-
14454775807L)
1796+
Long.MaxValue)
1797+
1798+
Seq("-106751991 04:00:54.775808", "106751991 04:00:54.775807").foreach { interval =>
1799+
val ansiInterval = s"INTERVAL '$interval' DAY TO SECOND"
1800+
checkEvaluation(
1801+
cast(cast(Literal.create(interval), DayTimeIntervalType), StringType), ansiInterval)
1802+
checkEvaluation(cast(cast(Literal.create(ansiInterval),
1803+
DayTimeIntervalType), StringType), ansiInterval)
1804+
}
1805+
1806+
Seq("INTERVAL '-106751991 04:00:54.775809' YEAR TO MONTH",
1807+
"INTERVAL '106751991 04:00:54.775808' YEAR TO MONTH").foreach { interval =>
1808+
val e = intercept[IllegalArgumentException] {
1809+
cast(Literal.create(interval), DayTimeIntervalType).eval()
1810+
}.getMessage
1811+
assert(e.contains("Interval string must match day-time format of"))
1812+
}
17961813

17971814
Seq(Byte.MaxValue, Short.MaxValue, Int.MaxValue, Long.MaxValue, Long.MinValue + 1,
1798-
Long.MinValue).foreach { period =>
1799-
val interval = Literal.create(Duration.of(period, ChronoUnit.MICROS), DayTimeIntervalType)
1800-
checkEvaluation(cast(cast(interval, StringType), DayTimeIntervalType), period)
1815+
Long.MinValue).foreach { duration =>
1816+
val interval = Literal.create(Duration.of(duration, ChronoUnit.MICROS), DayTimeIntervalType)
1817+
checkEvaluation(cast(cast(interval, StringType), DayTimeIntervalType), duration)
18011818
}
18021819
}
18031820
}

0 commit comments

Comments
 (0)