Skip to content

Commit 17601e0

Browse files
MaxGekkcloud-fan
authored andcommitted
[SPARK-34605][SQL] Support java.time.Duration as an external type of the day-time interval type
### What changes were proposed in this pull request? In the PR, I propose to extend Spark SQL API to accept [`java.time.Duration`](https://docs.oracle.com/javase/8/docs/api/java/time/Duration.html) as an external type of recently added new Catalyst type - `DayTimeIntervalType` (see #31614). The Java class `java.time.Duration` has similar semantic to ANSI SQL day-time interval type, and it is the most suitable to be an external type for `DayTimeIntervalType`. In more details: 1. Added `DurationConverter` which converts `java.time.Duration` instances to/from internal representation of the Catalyst type `DayTimeIntervalType` (to `Long` type). The `DurationConverter` object uses new methods of `IntervalUtils`: - `durationToMicros()` converts the input duration to the total length in microseconds. If this duration is too large to fit `Long`, the method throws the exception `ArithmeticException`. **Note:** _the input duration has nanosecond precision, the method casts the nanos part to microseconds by dividing by 1000._ - `microsToDuration()` obtains a `java.time.Duration` representing a number of microseconds. 2. Support new type `DayTimeIntervalType` in `RowEncoder` via the methods `createDeserializerForDuration()` and `createSerializerForJavaDuration()`. 3. Extended the Literal API to construct literals from `java.time.Duration` instances. ### Why are the changes needed? 1. To allow users parallelization of `java.time.Duration` collections, and construct day-time interval columns. Also to collect such columns back to the driver side. 2. This will allow to write tests in other sub-tasks of SPARK-27790. ### Does this PR introduce _any_ user-facing change? The PR extends existing functionality. So, users can parallelize instances of the `java.time.Duration` class and collect them back: ```Scala scala> val ds = Seq(java.time.Duration.ofDays(10)).toDS ds: org.apache.spark.sql.Dataset[java.time.Duration] = [value: daytimeinterval] scala> ds.collect res0: Array[java.time.Duration] = Array(PT240H) ``` ### How was this patch tested? - Added a few tests to `CatalystTypeConvertersSuite` to check conversion from/to `java.time.Duration`. - Checking row encoding by new tests in `RowEncoderSuite`. - Making literals of `DayTimeIntervalType` are tested in `LiteralExpressionSuite` - Check collecting by `DatasetSuite` and `JavaDatasetSuite`. Closes #31729 from MaxGekk/java-time-duration. Authored-by: Max Gekk <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent e7e0161 commit 17601e0

File tree

23 files changed

+229
-20
lines changed

23 files changed

+229
-20
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ public static Object read(
8383
if (handleUserDefinedType && dataType instanceof UserDefinedType) {
8484
return obj.get(ordinal, ((UserDefinedType)dataType).sqlType());
8585
}
86+
if (dataType instanceof DayTimeIntervalType) {
87+
return obj.getLong(ordinal);
88+
}
8689

8790
throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString());
8891
}

sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,14 @@ object Encoders {
135135
*/
136136
def BINARY: Encoder[Array[Byte]] = ExpressionEncoder()
137137

138+
/**
139+
* Creates an encoder that serializes instances of the `java.time.Duration` class
140+
* to the internal representation of nullable Catalyst's DayTimeIntervalType.
141+
*
142+
* @since 3.2.0
143+
*/
144+
def DURATION: Encoder[java.time.Duration] = ExpressionEncoder()
145+
138146
/**
139147
* Creates an encoder for Java Bean of type T.
140148
*

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import java.lang.{Iterable => JavaIterable}
2121
import java.math.{BigDecimal => JavaBigDecimal}
2222
import java.math.{BigInteger => JavaBigInteger}
2323
import java.sql.{Date, Timestamp}
24-
import java.time.{Instant, LocalDate}
24+
import java.time.{Duration, Instant, LocalDate}
2525
import java.util.{Map => JavaMap}
2626
import javax.annotation.Nullable
2727

@@ -74,6 +74,7 @@ object CatalystTypeConverters {
7474
case LongType => LongConverter
7575
case FloatType => FloatConverter
7676
case DoubleType => DoubleConverter
77+
case DayTimeIntervalType => DurationConverter
7778
case dataType: DataType => IdentityConverter(dataType)
7879
}
7980
converter.asInstanceOf[CatalystTypeConverter[Any, Any, Any]]
@@ -400,6 +401,18 @@ object CatalystTypeConverters {
400401
override def toScalaImpl(row: InternalRow, column: Int): Double = row.getDouble(column)
401402
}
402403

404+
private object DurationConverter extends CatalystTypeConverter[Duration, Duration, Any] {
405+
override def toCatalystImpl(scalaValue: Duration): Long = {
406+
IntervalUtils.durationToMicros(scalaValue)
407+
}
408+
override def toScala(catalystValue: Any): Duration = {
409+
if (catalystValue == null) null
410+
else IntervalUtils.microsToDuration(catalystValue.asInstanceOf[Long])
411+
}
412+
override def toScalaImpl(row: InternalRow, column: Int): Duration =
413+
IntervalUtils.microsToDuration(row.getLong(column))
414+
}
415+
403416
/**
404417
* Creates a converter function that will convert Scala objects to the specified Catalyst type.
405418
* Typical use case would be converting a collection of rows that have the same schema. You will
@@ -465,6 +478,7 @@ object CatalystTypeConverters {
465478
map,
466479
(key: Any) => convertToCatalyst(key),
467480
(value: Any) => convertToCatalyst(value))
481+
case d: Duration => DurationConverter.toCatalyst(d)
468482
case other => other
469483
}
470484

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst
2020
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
2121
import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, UpCast}
2222
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, StaticInvoke}
23-
import org.apache.spark.sql.catalyst.util.DateTimeUtils
23+
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils}
2424
import org.apache.spark.sql.types._
2525

2626
object DeserializerBuildHelper {
@@ -143,6 +143,15 @@ object DeserializerBuildHelper {
143143
returnNullable = false)
144144
}
145145

146+
def createDeserializerForDuration(path: Expression): Expression = {
147+
StaticInvoke(
148+
IntervalUtils.getClass,
149+
ObjectType(classOf[java.time.Duration]),
150+
"microsToDuration",
151+
path :: Nil,
152+
returnNullable = false)
153+
}
154+
146155
/**
147156
* When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff
148157
* and lost the required data type, which may lead to runtime error if the real type doesn't

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,8 @@ object InternalRow {
133133
case ByteType => (input, ordinal) => input.getByte(ordinal)
134134
case ShortType => (input, ordinal) => input.getShort(ordinal)
135135
case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal)
136-
case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal)
136+
case LongType | TimestampType | DayTimeIntervalType =>
137+
(input, ordinal) => input.getLong(ordinal)
137138
case FloatType => (input, ordinal) => input.getFloat(ordinal)
138139
case DoubleType => (input, ordinal) => input.getDouble(ordinal)
139140
case StringType => (input, ordinal) => input.getUTF8String(ordinal)
@@ -168,7 +169,8 @@ object InternalRow {
168169
case ByteType => (input, v) => input.setByte(ordinal, v.asInstanceOf[Byte])
169170
case ShortType => (input, v) => input.setShort(ordinal, v.asInstanceOf[Short])
170171
case IntegerType | DateType => (input, v) => input.setInt(ordinal, v.asInstanceOf[Int])
171-
case LongType | TimestampType => (input, v) => input.setLong(ordinal, v.asInstanceOf[Long])
172+
case LongType | TimestampType | DayTimeIntervalType =>
173+
(input, v) => input.setLong(ordinal, v.asInstanceOf[Long])
172174
case FloatType => (input, v) => input.setFloat(ordinal, v.asInstanceOf[Float])
173175
case DoubleType => (input, v) => input.setDouble(ordinal, v.asInstanceOf[Double])
174176
case CalendarIntervalType =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ object JavaTypeInference {
118118
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
119119
case c: Class[_] if c == classOf[java.time.Instant] => (TimestampType, true)
120120
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
121+
case c: Class[_] if c == classOf[java.time.Duration] => (DayTimeIntervalType, true)
121122

122123
case _ if typeToken.isArray =>
123124
val (dataType, nullable) = inferDataType(typeToken.getComponentType, seenTypeSet)
@@ -249,6 +250,9 @@ object JavaTypeInference {
249250
case c if c == classOf[java.sql.Timestamp] =>
250251
createDeserializerForSqlTimestamp(path)
251252

253+
case c if c == classOf[java.time.Duration] =>
254+
createDeserializerForDuration(path)
255+
252256
case c if c == classOf[java.lang.String] =>
253257
createDeserializerForString(path, returnNullable = true)
254258

@@ -406,6 +410,8 @@ object JavaTypeInference {
406410

407411
case c if c == classOf[java.sql.Date] => createSerializerForSqlDate(inputObject)
408412

413+
case c if c == classOf[java.time.Duration] => createSerializerForJavaDuration(inputObject)
414+
409415
case c if c == classOf[java.math.BigDecimal] =>
410416
createSerializerForJavaBigDecimal(inputObject)
411417

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,9 @@ object ScalaReflection extends ScalaReflection {
240240
case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
241241
createDeserializerForSqlTimestamp(path)
242242

243+
case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
244+
createDeserializerForDuration(path)
245+
243246
case t if isSubtype(t, localTypeOf[java.lang.String]) =>
244247
createDeserializerForString(path, returnNullable = false)
245248

@@ -522,6 +525,9 @@ object ScalaReflection extends ScalaReflection {
522525

523526
case t if isSubtype(t, localTypeOf[java.sql.Date]) => createSerializerForSqlDate(inputObject)
524527

528+
case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
529+
createSerializerForJavaDuration(inputObject)
530+
525531
case t if isSubtype(t, localTypeOf[BigDecimal]) =>
526532
createSerializerForScalaBigDecimal(inputObject)
527533

@@ -740,6 +746,8 @@ object ScalaReflection extends ScalaReflection {
740746
case t if isSubtype(t, localTypeOf[java.sql.Date]) => Schema(DateType, nullable = true)
741747
case t if isSubtype(t, localTypeOf[CalendarInterval]) =>
742748
Schema(CalendarIntervalType, nullable = true)
749+
case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
750+
Schema(DayTimeIntervalType, nullable = true)
743751
case t if isSubtype(t, localTypeOf[BigDecimal]) =>
744752
Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
745753
case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
@@ -837,7 +845,8 @@ object ScalaReflection extends ScalaReflection {
837845
DateType -> classOf[DateType.InternalType],
838846
TimestampType -> classOf[TimestampType.InternalType],
839847
BinaryType -> classOf[BinaryType.InternalType],
840-
CalendarIntervalType -> classOf[CalendarInterval]
848+
CalendarIntervalType -> classOf[CalendarInterval],
849+
DayTimeIntervalType -> classOf[DayTimeIntervalType.InternalType]
841850
)
842851

843852
val typeBoxedJavaMapping = Map[DataType, Class[_]](
@@ -849,7 +858,8 @@ object ScalaReflection extends ScalaReflection {
849858
FloatType -> classOf[java.lang.Float],
850859
DoubleType -> classOf[java.lang.Double],
851860
DateType -> classOf[java.lang.Integer],
852-
TimestampType -> classOf[java.lang.Long]
861+
TimestampType -> classOf[java.lang.Long],
862+
DayTimeIntervalType -> classOf[java.lang.Long]
853863
)
854864

855865
def dataTypeJavaClass(dt: DataType): Class[_] = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst
1919

2020
import org.apache.spark.sql.catalyst.expressions.{CheckOverflow, CreateNamedStruct, Expression, IsNull, UnsafeArrayData}
2121
import org.apache.spark.sql.catalyst.expressions.objects._
22-
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
22+
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData, IntervalUtils}
2323
import org.apache.spark.sql.internal.SQLConf
2424
import org.apache.spark.sql.types._
2525
import org.apache.spark.unsafe.types.UTF8String
@@ -104,6 +104,15 @@ object SerializerBuildHelper {
104104
returnNullable = false)
105105
}
106106

107+
def createSerializerForJavaDuration(inputObject: Expression): Expression = {
108+
StaticInvoke(
109+
IntervalUtils.getClass,
110+
DayTimeIntervalType,
111+
"durationToMicros",
112+
inputObject :: Nil,
113+
returnNullable = false)
114+
}
115+
107116
def createSerializerForJavaBigDecimal(inputObject: Expression): Expression = {
108117
CheckOverflow(StaticInvoke(
109118
Decimal.getClass,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,11 @@ package object dsl {
297297
/** Creates a new AttributeReference of type timestamp */
298298
def timestamp: AttributeReference = AttributeReference(s, TimestampType, nullable = true)()
299299

300+
/** Creates a new AttributeReference of the day-time interval type */
301+
def dayTimeInterval: AttributeReference = {
302+
AttributeReference(s, DayTimeIntervalType, nullable = true)()
303+
}
304+
300305
/** Creates a new AttributeReference of type binary */
301306
def binary: AttributeReference = AttributeReference(s, BinaryType, nullable = true)()
302307

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ import org.apache.spark.sql.types._
5353
* TimestampType -> java.sql.Timestamp if spark.sql.datetime.java8API.enabled is false
5454
* TimestampType -> java.time.Instant if spark.sql.datetime.java8API.enabled is true
5555
*
56+
* DayTimeIntervalType -> java.time.Duration
57+
*
5658
* BinaryType -> byte array
5759
* ArrayType -> scala.collection.Seq or Array
5860
* MapType -> scala.collection.Map
@@ -108,6 +110,8 @@ object RowEncoder {
108110
createSerializerForSqlDate(inputObject)
109111
}
110112

113+
case DayTimeIntervalType => createSerializerForJavaDuration(inputObject)
114+
111115
case d: DecimalType =>
112116
CheckOverflow(StaticInvoke(
113117
Decimal.getClass,
@@ -226,6 +230,7 @@ object RowEncoder {
226230
} else {
227231
ObjectType(classOf[java.sql.Date])
228232
}
233+
case DayTimeIntervalType => ObjectType(classOf[java.time.Duration])
229234
case _: DecimalType => ObjectType(classOf[java.math.BigDecimal])
230235
case StringType => ObjectType(classOf[java.lang.String])
231236
case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]])
@@ -281,6 +286,8 @@ object RowEncoder {
281286
createDeserializerForSqlDate(input)
282287
}
283288

289+
case DayTimeIntervalType => createDeserializerForDuration(input)
290+
284291
case _: DecimalType => createDeserializerForJavaBigDecimal(input, returnNullable = false)
285292

286293
case StringType => createDeserializerForString(input, returnNullable = false)

0 commit comments

Comments
 (0)