diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java index ea0648a6cb90..d1bb719aca8f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java @@ -83,6 +83,9 @@ public static Object read( if (handleUserDefinedType && dataType instanceof UserDefinedType) { return obj.get(ordinal, ((UserDefinedType)dataType).sqlType()); } + if (dataType instanceof DayTimeIntervalType) { + return obj.getLong(ordinal); + } throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString()); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index 24045b5a43a6..5e72b19ca5da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -135,6 +135,14 @@ object Encoders { */ def BINARY: Encoder[Array[Byte]] = ExpressionEncoder() + /** + * Creates an encoder that serializes instances of the `java.time.Duration` class + * to the internal representation of nullable Catalyst's DayTimeIntervalType. + * + * @since 3.2.0 + */ + def DURATION: Encoder[java.time.Duration] = ExpressionEncoder() + /** * Creates an encoder for Java Bean of type T. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 907b5877b3ac..8201fd7d8fb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -21,7 +21,7 @@ import java.lang.{Iterable => JavaIterable} import java.math.{BigDecimal => JavaBigDecimal} import java.math.{BigInteger => JavaBigInteger} import java.sql.{Date, Timestamp} -import java.time.{Instant, LocalDate} +import java.time.{Duration, Instant, LocalDate} import java.util.{Map => JavaMap} import javax.annotation.Nullable @@ -74,6 +74,7 @@ object CatalystTypeConverters { case LongType => LongConverter case FloatType => FloatConverter case DoubleType => DoubleConverter + case DayTimeIntervalType => DurationConverter case dataType: DataType => IdentityConverter(dataType) } converter.asInstanceOf[CatalystTypeConverter[Any, Any, Any]] @@ -400,6 +401,18 @@ object CatalystTypeConverters { override def toScalaImpl(row: InternalRow, column: Int): Double = row.getDouble(column) } + private object DurationConverter extends CatalystTypeConverter[Duration, Duration, Any] { + override def toCatalystImpl(scalaValue: Duration): Long = { + IntervalUtils.durationToMicros(scalaValue) + } + override def toScala(catalystValue: Any): Duration = { + if (catalystValue == null) null + else IntervalUtils.microsToDuration(catalystValue.asInstanceOf[Long]) + } + override def toScalaImpl(row: InternalRow, column: Int): Duration = + IntervalUtils.microsToDuration(row.getLong(column)) + } + /** * Creates a converter function that will convert Scala objects to the specified Catalyst type. * Typical use case would be converting a collection of rows that have the same schema. You will @@ -465,6 +478,7 @@ object CatalystTypeConverters { map, (key: Any) => convertToCatalyst(key), (value: Any) => convertToCatalyst(value)) + case d: Duration => DurationConverter.toCatalyst(d) case other => other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index 701e4e3483c0..03243b4610f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, UpCast} import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, StaticInvoke} -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils} import org.apache.spark.sql.types._ object DeserializerBuildHelper { @@ -143,6 +143,15 @@ object DeserializerBuildHelper { returnNullable = false) } + def createDeserializerForDuration(path: Expression): Expression = { + StaticInvoke( + IntervalUtils.getClass, + ObjectType(classOf[java.time.Duration]), + "microsToDuration", + path :: Nil, + returnNullable = false) + } + /** * When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff * and lost the required data type, which may lead to runtime error if the real type doesn't diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index f98b59edd422..00b2d16a2ad6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -133,7 +133,8 @@ object InternalRow { case ByteType => (input, ordinal) => input.getByte(ordinal) case ShortType => (input, ordinal) => input.getShort(ordinal) case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal) - case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal) + case LongType | TimestampType | DayTimeIntervalType => + (input, ordinal) => input.getLong(ordinal) case FloatType => (input, ordinal) => input.getFloat(ordinal) case DoubleType => (input, ordinal) => input.getDouble(ordinal) case StringType => (input, ordinal) => input.getUTF8String(ordinal) @@ -168,7 +169,8 @@ object InternalRow { case ByteType => (input, v) => input.setByte(ordinal, v.asInstanceOf[Byte]) case ShortType => (input, v) => input.setShort(ordinal, v.asInstanceOf[Short]) case IntegerType | DateType => (input, v) => input.setInt(ordinal, v.asInstanceOf[Int]) - case LongType | TimestampType => (input, v) => input.setLong(ordinal, v.asInstanceOf[Long]) + case LongType | TimestampType | DayTimeIntervalType => + (input, v) => input.setLong(ordinal, v.asInstanceOf[Long]) case FloatType => (input, v) => input.setFloat(ordinal, v.asInstanceOf[Float]) case DoubleType => (input, v) => input.setDouble(ordinal, v.asInstanceOf[Double]) case CalendarIntervalType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 2248e2eb0259..7f055a1e77bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -118,6 +118,7 @@ object JavaTypeInference { case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true) case c: Class[_] if c == classOf[java.time.Instant] => (TimestampType, true) case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) + case c: Class[_] if c == classOf[java.time.Duration] => (DayTimeIntervalType, true) case _ if typeToken.isArray => val (dataType, nullable) = inferDataType(typeToken.getComponentType, seenTypeSet) @@ -249,6 +250,9 @@ object JavaTypeInference { case c if c == classOf[java.sql.Timestamp] => createDeserializerForSqlTimestamp(path) + case c if c == classOf[java.time.Duration] => + createDeserializerForDuration(path) + case c if c == classOf[java.lang.String] => createDeserializerForString(path, returnNullable = true) @@ -406,6 +410,8 @@ object JavaTypeInference { case c if c == classOf[java.sql.Date] => createSerializerForSqlDate(inputObject) + case c if c == classOf[java.time.Duration] => createSerializerForJavaDuration(inputObject) + case c if c == classOf[java.math.BigDecimal] => createSerializerForJavaBigDecimal(inputObject) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 361c3476f594..bdb2a8ebff82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -240,6 +240,9 @@ object ScalaReflection extends ScalaReflection { case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => createDeserializerForSqlTimestamp(path) + case t if isSubtype(t, localTypeOf[java.time.Duration]) => + createDeserializerForDuration(path) + case t if isSubtype(t, localTypeOf[java.lang.String]) => createDeserializerForString(path, returnNullable = false) @@ -522,6 +525,9 @@ object ScalaReflection extends ScalaReflection { case t if isSubtype(t, localTypeOf[java.sql.Date]) => createSerializerForSqlDate(inputObject) + case t if isSubtype(t, localTypeOf[java.time.Duration]) => + createSerializerForJavaDuration(inputObject) + case t if isSubtype(t, localTypeOf[BigDecimal]) => createSerializerForScalaBigDecimal(inputObject) @@ -740,6 +746,8 @@ object ScalaReflection extends ScalaReflection { case t if isSubtype(t, localTypeOf[java.sql.Date]) => Schema(DateType, nullable = true) case t if isSubtype(t, localTypeOf[CalendarInterval]) => Schema(CalendarIntervalType, nullable = true) + case t if isSubtype(t, localTypeOf[java.time.Duration]) => + Schema(DayTimeIntervalType, nullable = true) case t if isSubtype(t, localTypeOf[BigDecimal]) => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) => @@ -837,7 +845,8 @@ object ScalaReflection extends ScalaReflection { DateType -> classOf[DateType.InternalType], TimestampType -> classOf[TimestampType.InternalType], BinaryType -> classOf[BinaryType.InternalType], - CalendarIntervalType -> classOf[CalendarInterval] + CalendarIntervalType -> classOf[CalendarInterval], + DayTimeIntervalType -> classOf[DayTimeIntervalType.InternalType] ) val typeBoxedJavaMapping = Map[DataType, Class[_]]( @@ -849,7 +858,8 @@ object ScalaReflection extends ScalaReflection { FloatType -> classOf[java.lang.Float], DoubleType -> classOf[java.lang.Double], DateType -> classOf[java.lang.Integer], - TimestampType -> classOf[java.lang.Long] + TimestampType -> classOf[java.lang.Long], + DayTimeIntervalType -> classOf[java.lang.Long] ) def dataTypeJavaClass(dt: DataType): Class[_] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala index 0554f0f76708..fcecfbe925c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions.{CheckOverflow, CreateNamedStruct, Expression, IsNull, UnsafeArrayData} import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData, IntervalUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -104,6 +104,15 @@ object SerializerBuildHelper { returnNullable = false) } + def createSerializerForJavaDuration(inputObject: Expression): Expression = { + StaticInvoke( + IntervalUtils.getClass, + DayTimeIntervalType, + "durationToMicros", + inputObject :: Nil, + returnNullable = false) + } + def createSerializerForJavaBigDecimal(inputObject: Expression): Expression = { CheckOverflow(StaticInvoke( Decimal.getClass, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index e8ec72b7c802..5d559731055f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -297,6 +297,11 @@ package object dsl { /** Creates a new AttributeReference of type timestamp */ def timestamp: AttributeReference = AttributeReference(s, TimestampType, nullable = true)() + /** Creates a new AttributeReference of the day-time interval type */ + def dayTimeInterval: AttributeReference = { + AttributeReference(s, DayTimeIntervalType, nullable = true)() + } + /** Creates a new AttributeReference of type binary */ def binary: AttributeReference = AttributeReference(s, BinaryType, nullable = true)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index ee6320955c8f..ebda55b60e66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -53,6 +53,8 @@ import org.apache.spark.sql.types._ * TimestampType -> java.sql.Timestamp if spark.sql.datetime.java8API.enabled is false * TimestampType -> java.time.Instant if spark.sql.datetime.java8API.enabled is true * + * DayTimeIntervalType -> java.time.Duration + * * BinaryType -> byte array * ArrayType -> scala.collection.Seq or Array * MapType -> scala.collection.Map @@ -108,6 +110,8 @@ object RowEncoder { createSerializerForSqlDate(inputObject) } + case DayTimeIntervalType => createSerializerForJavaDuration(inputObject) + case d: DecimalType => CheckOverflow(StaticInvoke( Decimal.getClass, @@ -226,6 +230,7 @@ object RowEncoder { } else { ObjectType(classOf[java.sql.Date]) } + case DayTimeIntervalType => ObjectType(classOf[java.time.Duration]) case _: DecimalType => ObjectType(classOf[java.math.BigDecimal]) case StringType => ObjectType(classOf[java.lang.String]) case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]]) @@ -281,6 +286,8 @@ object RowEncoder { createDeserializerForSqlDate(input) } + case DayTimeIntervalType => createDeserializerForDuration(input) + case _: DecimalType => createDeserializerForJavaBigDecimal(input, returnNullable = false) case StringType => createDeserializerForString(input, returnNullable = false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index cf1c18b1d268..00ac3d64aaeb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -160,7 +160,7 @@ object InterpretedUnsafeProjection { case IntegerType | DateType => (v, i) => writer.write(i, v.getInt(i)) - case LongType | TimestampType => + case LongType | TimestampType | DayTimeIntervalType => (v, i) => writer.write(i, v.getLong(i)) case FloatType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala index ad49b38544a1..fd22978544a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala @@ -195,8 +195,8 @@ final class SpecificInternalRow(val values: Array[MutableValue]) extends BaseGen private[this] def dataTypeToMutableValue(dataType: DataType): MutableValue = dataType match { // We use INT for DATE internally case IntegerType | DateType => new MutableInt - // We use Long for Timestamp internally - case LongType | TimestampType => new MutableLong + // We use Long for Timestamp and DayTimeInterval internally + case LongType | TimestampType | DayTimeIntervalType => new MutableLong case FloatType => new MutableFloat case DoubleType => new MutableDouble case BooleanType => new MutableBoolean diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 84c66b2d7696..67c4adfb3887 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1813,7 +1813,7 @@ object CodeGenerator extends Logging { case ByteType => JAVA_BYTE case ShortType => JAVA_SHORT case IntegerType | DateType => JAVA_INT - case LongType | TimestampType => JAVA_LONG + case LongType | TimestampType | DayTimeIntervalType => JAVA_LONG case FloatType => JAVA_FLOAT case DoubleType => JAVA_DOUBLE case _: DecimalType => "Decimal" @@ -1834,7 +1834,7 @@ object CodeGenerator extends Logging { case ByteType => java.lang.Byte.TYPE case ShortType => java.lang.Short.TYPE case IntegerType | DateType => java.lang.Integer.TYPE - case LongType | TimestampType => java.lang.Long.TYPE + case LongType | TimestampType | DayTimeIntervalType => java.lang.Long.TYPE case FloatType => java.lang.Float.TYPE case DoubleType => java.lang.Double.TYPE case _: DecimalType => classOf[Decimal] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 53c65a4f9fe5..203e98cfd737 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -28,7 +28,7 @@ import java.lang.{Short => JavaShort} import java.math.{BigDecimal => JavaBigDecimal} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} -import java.time.{Instant, LocalDate} +import java.time.{Duration, Instant, LocalDate} import java.util import java.util.Objects import javax.xml.bind.DatatypeConverter @@ -43,6 +43,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, Scala import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.DateTimeUtils.instantToMicros +import org.apache.spark.sql.catalyst.util.IntervalUtils.durationToMicros import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -76,6 +77,7 @@ object Literal { case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType) case ld: LocalDate => Literal(ld.toEpochDay.toInt, DateType) case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType) + case d: Duration => Literal(durationToMicros(d), DayTimeIntervalType) case a: Array[Byte] => Literal(a, BinaryType) case a: collection.mutable.WrappedArray[_] => apply(a.array) case a: Array[_] => @@ -111,6 +113,7 @@ object Literal { case _ if clz == classOf[Date] => DateType case _ if clz == classOf[Instant] => TimestampType case _ if clz == classOf[Timestamp] => TimestampType + case _ if clz == classOf[Duration] => DayTimeIntervalType case _ if clz == classOf[JavaBigDecimal] => DecimalType.SYSTEM_DEFAULT case _ if clz == classOf[Array[Byte]] => BinaryType case _ if clz == classOf[Array[Char]] => StringType @@ -167,6 +170,7 @@ object Literal { case dt: DecimalType => Literal(Decimal(0, dt.precision, dt.scale)) case DateType => create(0, DateType) case TimestampType => create(0L, TimestampType) + case DayTimeIntervalType => create(0L, DayTimeIntervalType) case StringType => Literal("") case BinaryType => Literal("".getBytes(StandardCharsets.UTF_8)) case CalendarIntervalType => Literal(new CalendarInterval(0, 0, 0)) @@ -186,7 +190,7 @@ object Literal { case ByteType => v.isInstanceOf[Byte] case ShortType => v.isInstanceOf[Short] case IntegerType | DateType => v.isInstanceOf[Int] - case LongType | TimestampType => v.isInstanceOf[Long] + case LongType | TimestampType | DayTimeIntervalType => v.isInstanceOf[Long] case FloatType => v.isInstanceOf[Float] case DoubleType => v.isInstanceOf[Double] case _: DecimalType => v.isInstanceOf[Decimal] @@ -388,7 +392,7 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { } case ByteType | ShortType => ExprCode.forNonNullValue(JavaCode.expression(s"($javaType)$value", dataType)) - case TimestampType | LongType => + case TimestampType | LongType | DayTimeIntervalType => toExprCode(s"${value}L") case _ => val constRef = ctx.addReferenceObj("literal", value, javaType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index f716ca17778b..6be4e9f04306 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.util +import java.time.Duration +import java.time.temporal.ChronoUnit import java.util.concurrent.TimeUnit import scala.util.control.NonFatal @@ -762,4 +764,31 @@ object IntervalUtils { new CalendarInterval(totalMonths, totalDays, micros) } + + /** + * Converts this duration to the total length in microseconds. + *

+ * If this duration is too large to fit in a [[Long]] microseconds, then an + * exception is thrown. + *

+ * If this duration has greater than microsecond precision, then the conversion + * will drop any excess precision information as though the amount in nanoseconds + * was subject to integer division by one thousand. + * + * @return The total length of the duration in microseconds + * @throws ArithmeticException If numeric overflow occurs + */ + def durationToMicros(duration: Duration): Long = { + val us = Math.multiplyExact(duration.getSeconds, MICROS_PER_SECOND) + val result = Math.addExact(us, duration.getNano / NANOS_PER_MICROS) + result + } + + /** + * Obtains a [[Duration]] representing a number of microseconds. + * + * @param micros The number of microseconds, positive or negative + * @return A [[Duration]], not null + */ + def microsToDuration(micros: Long): Duration = Duration.of(micros, ChronoUnit.MICROS) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 8a099e8ede6c..05273270de30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -170,7 +170,8 @@ object DataType { private val otherTypes = { Seq(NullType, DateType, TimestampType, BinaryType, IntegerType, BooleanType, LongType, - DoubleType, FloatType, ShortType, ByteType, StringType, CalendarIntervalType) + DoubleType, FloatType, ShortType, ByteType, StringType, CalendarIntervalType, + DayTimeIntervalType) .map(t => t.typeName -> t).toMap } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala index f4b08330e4c7..6b66af842fcc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.catalyst -import java.time.{Instant, LocalDate} +import java.time.{Duration, Instant, LocalDate} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData import org.apache.spark.sql.catalyst.plans.SQLHelper -import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData, IntervalUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -223,4 +223,39 @@ class CatalystTypeConvertersSuite extends SparkFunSuite with SQLHelper { } } } + + test("SPARK-34605: converting java.time.Duration to DayTimeIntervalType") { + Seq( + Duration.ZERO, + Duration.ofNanos(1), + Duration.ofNanos(-1), + Duration.ofSeconds(0, Long.MaxValue), + Duration.ofSeconds(0, Long.MinValue), + Duration.ofDays(106751991), + Duration.ofDays(-106751991)).foreach { input => + val result = CatalystTypeConverters.convertToCatalyst(input) + val expected = IntervalUtils.durationToMicros(input) + assert(result === expected) + } + + val errMsg = intercept[ArithmeticException] { + IntervalUtils.durationToMicros(Duration.ofSeconds(Long.MaxValue, Long.MaxValue)) + }.getMessage + assert(errMsg.contains("long overflow")) + } + + test("SPARK-34605: converting DayTimeIntervalType to java.time.Duration") { + Seq( + 0L, + 1L, + 999999, + -1000000, + Long.MaxValue).foreach { input => + Seq(1L, -1L).foreach { sign => + val us = sign * input + val duration = IntervalUtils.microsToDuration(us) + assert(CatalystTypeConverters.createToScalaConverter(DayTimeIntervalType)(us) === duration) + } + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index d20a9ba3f0f6..9ab336165d16 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -22,7 +22,7 @@ import scala.util.Random import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest -import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData, IntervalUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -342,6 +342,16 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { } } + test("SPARK-34605: encoding/decoding DayTimeIntervalType to/from java.time.Duration") { + val schema = new StructType().add("d", DayTimeIntervalType) + val encoder = RowEncoder(schema).resolveAndBind() + val duration = java.time.Duration.ofDays(1) + val row = toRow(encoder, Row(duration)) + assert(row.getLong(0) === IntervalUtils.durationToMicros(duration)) + val readback = fromRow(encoder, row) + assert(readback.get(0).equals(duration)) + } + for { elementType <- Seq(IntegerType, StringType) containsNull <- Seq(true, false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index f26e5fdc5f91..8cba46cabd91 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.nio.charset.StandardCharsets -import java.time.{Instant, LocalDate, LocalDateTime, ZoneOffset} +import java.time.{Duration, Instant, LocalDate, LocalDateTime, ZoneOffset} import java.util.TimeZone import scala.reflect.runtime.universe.TypeTag @@ -349,4 +349,22 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { assert(literalStr === "2021-02-03 17:50:03.456") } } + + test("SPARK-34605: construct literals from java.time.Duration") { + Seq( + Duration.ofNanos(0), + Duration.ofSeconds(-1), + Duration.ofNanos(123456000), + Duration.ofDays(106751991), + Duration.ofDays(-106751991)).foreach { duration => + checkEvaluation(Literal(duration), duration) + } + } + + test("SPARK-34605: construct literals from arrays of java.time.Duration") { + val duration0 = Duration.ofDays(2).plusHours(3).plusMinutes(4) + checkEvaluation(Literal(Array(duration0)), Array(duration0)) + val duration1 = Duration.ofHours(-1024) + checkEvaluation(Literal(Array(duration0, duration1)), Array(duration0, duration1)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala index c313b546873a..51fd291bba0d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.util +import java.time.Duration import java.util.concurrent.TimeUnit import org.apache.spark.SparkFunSuite @@ -379,4 +380,25 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { intercept[ArithmeticException](multiplyExact(maxMonth, 2)) intercept[ArithmeticException](divideExact(maxDay, 0.5)) } + + test("SPARK-34605: microseconds to duration") { + assert(microsToDuration(0).isZero) + assert(microsToDuration(-1).toNanos === -1000) + assert(microsToDuration(1).toNanos === 1000) + assert(microsToDuration(Long.MaxValue).toDays === 106751991) + assert(microsToDuration(Long.MinValue).toDays === -106751991) + } + + test("SPARK-34605: duration to microseconds") { + assert(durationToMicros(Duration.ZERO) === 0) + assert(durationToMicros(Duration.ofSeconds(-1)) === -1000000) + assert(durationToMicros(Duration.ofNanos(123456)) === 123) + assert(durationToMicros(Duration.ofDays(106751991)) === + (Long.MaxValue / MICROS_PER_DAY) * MICROS_PER_DAY) + + val errMsg = intercept[ArithmeticException] { + durationToMicros(Duration.ofDays(106751991 + 1)) + }.getMessage + assert(errMsg.contains("long overflow")) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 1135c8848bc2..bcc4871d3e64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -88,6 +88,9 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits { /** @since 3.0.0 */ implicit def newInstantEncoder: Encoder[java.time.Instant] = Encoders.INSTANT + /** @since 3.2.0 */ + implicit def newDurationEncoder: Encoder[java.time.Duration] = Encoders.DURATION + /** @since 3.2.0 */ implicit def newJavaEnumEncoder[A <: java.lang.Enum[_] : TypeTag]: Encoder[A] = ExpressionEncoder() diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 05c4a04b20b1..85ad80e21329 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -21,6 +21,7 @@ import java.math.BigDecimal; import java.sql.Date; import java.sql.Timestamp; +import java.time.Duration; import java.time.Instant; import java.time.LocalDate; import java.util.*; @@ -412,6 +413,14 @@ public void testLocalDateAndInstantEncoders() { Assert.assertEquals(data, ds.collectAsList()); } + @Test + public void testDurationEncoder() { + Encoder encoder = Encoders.DURATION(); + List data = Arrays.asList(Duration.ofDays(0)); + Dataset ds = spark.createDataset(data, encoder); + Assert.assertEquals(data, ds.collectAsList()); + } + public static class KryoSerializable { String value; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 2ec4c6918a24..843696e04847 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -2007,6 +2007,11 @@ class DatasetSuite extends QueryTest checkAnswer(withUDF, Row(Row(1), null, null) :: Row(Row(1), null, null) :: Nil) } + + test("SPARK-34605: implicit encoder for java.time.Duration") { + val duration = java.time.Duration.ofMinutes(10) + assert(spark.range(1).map { _ => duration }.head === duration) + } } case class Bar(a: Int)