Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ public static Object read(
if (handleUserDefinedType && dataType instanceof UserDefinedType) {
return obj.get(ordinal, ((UserDefinedType)dataType).sqlType());
}
if (dataType instanceof DayTimeIntervalType) {
return obj.getLong(ordinal);
}

throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,14 @@ object Encoders {
*/
def BINARY: Encoder[Array[Byte]] = ExpressionEncoder()

/**
* Creates an encoder that serializes instances of the `java.time.Duration` class
* to the internal representation of nullable Catalyst's DayTimeIntervalType.
*
* @since 3.2.0
*/
def DURATION: Encoder[java.time.Duration] = ExpressionEncoder()

/**
* Creates an encoder for Java Bean of type T.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.lang.{Iterable => JavaIterable}
import java.math.{BigDecimal => JavaBigDecimal}
import java.math.{BigInteger => JavaBigInteger}
import java.sql.{Date, Timestamp}
import java.time.{Instant, LocalDate}
import java.time.{Duration, Instant, LocalDate}
import java.util.{Map => JavaMap}
import javax.annotation.Nullable

Expand Down Expand Up @@ -74,6 +74,7 @@ object CatalystTypeConverters {
case LongType => LongConverter
case FloatType => FloatConverter
case DoubleType => DoubleConverter
case DayTimeIntervalType => DurationConverter
case dataType: DataType => IdentityConverter(dataType)
}
converter.asInstanceOf[CatalystTypeConverter[Any, Any, Any]]
Expand Down Expand Up @@ -400,6 +401,18 @@ object CatalystTypeConverters {
override def toScalaImpl(row: InternalRow, column: Int): Double = row.getDouble(column)
}

private object DurationConverter extends CatalystTypeConverter[Duration, Duration, Any] {
override def toCatalystImpl(scalaValue: Duration): Long = {
IntervalUtils.durationToMicros(scalaValue)
}
override def toScala(catalystValue: Any): Duration = {
if (catalystValue == null) null
else IntervalUtils.microsToDuration(catalystValue.asInstanceOf[Long])
}
override def toScalaImpl(row: InternalRow, column: Int): Duration =
IntervalUtils.microsToDuration(row.getLong(column))
}

/**
* Creates a converter function that will convert Scala objects to the specified Catalyst type.
* Typical use case would be converting a collection of rows that have the same schema. You will
Expand Down Expand Up @@ -465,6 +478,7 @@ object CatalystTypeConverters {
map,
(key: Any) => convertToCatalyst(key),
(value: Any) => convertToCatalyst(value))
case d: Duration => DurationConverter.toCatalyst(d)
case other => other
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, UpCast}
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, StaticInvoke}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils}
import org.apache.spark.sql.types._

object DeserializerBuildHelper {
Expand Down Expand Up @@ -143,6 +143,15 @@ object DeserializerBuildHelper {
returnNullable = false)
}

def createDeserializerForDuration(path: Expression): Expression = {
StaticInvoke(
IntervalUtils.getClass,
ObjectType(classOf[java.time.Duration]),
"microsToDuration",
path :: Nil,
returnNullable = false)
}

/**
* When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff
* and lost the required data type, which may lead to runtime error if the real type doesn't
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ object InternalRow {
case ByteType => (input, ordinal) => input.getByte(ordinal)
case ShortType => (input, ordinal) => input.getShort(ordinal)
case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal)
case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal)
case LongType | TimestampType | DayTimeIntervalType =>
(input, ordinal) => input.getLong(ordinal)
case FloatType => (input, ordinal) => input.getFloat(ordinal)
case DoubleType => (input, ordinal) => input.getDouble(ordinal)
case StringType => (input, ordinal) => input.getUTF8String(ordinal)
Expand Down Expand Up @@ -168,7 +169,8 @@ object InternalRow {
case ByteType => (input, v) => input.setByte(ordinal, v.asInstanceOf[Byte])
case ShortType => (input, v) => input.setShort(ordinal, v.asInstanceOf[Short])
case IntegerType | DateType => (input, v) => input.setInt(ordinal, v.asInstanceOf[Int])
case LongType | TimestampType => (input, v) => input.setLong(ordinal, v.asInstanceOf[Long])
case LongType | TimestampType | DayTimeIntervalType =>
(input, v) => input.setLong(ordinal, v.asInstanceOf[Long])
case FloatType => (input, v) => input.setFloat(ordinal, v.asInstanceOf[Float])
case DoubleType => (input, v) => input.setDouble(ordinal, v.asInstanceOf[Double])
case CalendarIntervalType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ object JavaTypeInference {
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
case c: Class[_] if c == classOf[java.time.Instant] => (TimestampType, true)
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
case c: Class[_] if c == classOf[java.time.Duration] => (DayTimeIntervalType, true)

case _ if typeToken.isArray =>
val (dataType, nullable) = inferDataType(typeToken.getComponentType, seenTypeSet)
Expand Down Expand Up @@ -249,6 +250,9 @@ object JavaTypeInference {
case c if c == classOf[java.sql.Timestamp] =>
createDeserializerForSqlTimestamp(path)

case c if c == classOf[java.time.Duration] =>
createDeserializerForDuration(path)

case c if c == classOf[java.lang.String] =>
createDeserializerForString(path, returnNullable = true)

Expand Down Expand Up @@ -406,6 +410,8 @@ object JavaTypeInference {

case c if c == classOf[java.sql.Date] => createSerializerForSqlDate(inputObject)

case c if c == classOf[java.time.Duration] => createSerializerForJavaDuration(inputObject)

case c if c == classOf[java.math.BigDecimal] =>
createSerializerForJavaBigDecimal(inputObject)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
createDeserializerForSqlTimestamp(path)

case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
createDeserializerForDuration(path)

case t if isSubtype(t, localTypeOf[java.lang.String]) =>
createDeserializerForString(path, returnNullable = false)

Expand Down Expand Up @@ -522,6 +525,9 @@ object ScalaReflection extends ScalaReflection {

case t if isSubtype(t, localTypeOf[java.sql.Date]) => createSerializerForSqlDate(inputObject)

case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
createSerializerForJavaDuration(inputObject)

case t if isSubtype(t, localTypeOf[BigDecimal]) =>
createSerializerForScalaBigDecimal(inputObject)

Expand Down Expand Up @@ -740,6 +746,8 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, localTypeOf[java.sql.Date]) => Schema(DateType, nullable = true)
case t if isSubtype(t, localTypeOf[CalendarInterval]) =>
Schema(CalendarIntervalType, nullable = true)
case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
Schema(DayTimeIntervalType, nullable = true)
case t if isSubtype(t, localTypeOf[BigDecimal]) =>
Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
Expand Down Expand Up @@ -837,7 +845,8 @@ object ScalaReflection extends ScalaReflection {
DateType -> classOf[DateType.InternalType],
TimestampType -> classOf[TimestampType.InternalType],
BinaryType -> classOf[BinaryType.InternalType],
CalendarIntervalType -> classOf[CalendarInterval]
CalendarIntervalType -> classOf[CalendarInterval],
DayTimeIntervalType -> classOf[DayTimeIntervalType.InternalType]
)

val typeBoxedJavaMapping = Map[DataType, Class[_]](
Expand All @@ -849,7 +858,8 @@ object ScalaReflection extends ScalaReflection {
FloatType -> classOf[java.lang.Float],
DoubleType -> classOf[java.lang.Double],
DateType -> classOf[java.lang.Integer],
TimestampType -> classOf[java.lang.Long]
TimestampType -> classOf[java.lang.Long],
DayTimeIntervalType -> classOf[java.lang.Long]
)

def dataTypeJavaClass(dt: DataType): Class[_] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst

import org.apache.spark.sql.catalyst.expressions.{CheckOverflow, CreateNamedStruct, Expression, IsNull, UnsafeArrayData}
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData, IntervalUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -104,6 +104,15 @@ object SerializerBuildHelper {
returnNullable = false)
}

def createSerializerForJavaDuration(inputObject: Expression): Expression = {
StaticInvoke(
IntervalUtils.getClass,
DayTimeIntervalType,
"durationToMicros",
inputObject :: Nil,
returnNullable = false)
}

def createSerializerForJavaBigDecimal(inputObject: Expression): Expression = {
CheckOverflow(StaticInvoke(
Decimal.getClass,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,11 @@ package object dsl {
/** Creates a new AttributeReference of type timestamp */
def timestamp: AttributeReference = AttributeReference(s, TimestampType, nullable = true)()

/** Creates a new AttributeReference of the day-time interval type */
def dayTimeInterval: AttributeReference = {
AttributeReference(s, DayTimeIntervalType, nullable = true)()
}

/** Creates a new AttributeReference of type binary */
def binary: AttributeReference = AttributeReference(s, BinaryType, nullable = true)()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ import org.apache.spark.sql.types._
* TimestampType -> java.sql.Timestamp if spark.sql.datetime.java8API.enabled is false
* TimestampType -> java.time.Instant if spark.sql.datetime.java8API.enabled is true
*
* DayTimeIntervalType -> java.time.Duration
*
* BinaryType -> byte array
* ArrayType -> scala.collection.Seq or Array
* MapType -> scala.collection.Map
Expand Down Expand Up @@ -108,6 +110,8 @@ object RowEncoder {
createSerializerForSqlDate(inputObject)
}

case DayTimeIntervalType => createSerializerForJavaDuration(inputObject)

case d: DecimalType =>
CheckOverflow(StaticInvoke(
Decimal.getClass,
Expand Down Expand Up @@ -226,6 +230,7 @@ object RowEncoder {
} else {
ObjectType(classOf[java.sql.Date])
}
case DayTimeIntervalType => ObjectType(classOf[java.time.Duration])
case _: DecimalType => ObjectType(classOf[java.math.BigDecimal])
case StringType => ObjectType(classOf[java.lang.String])
case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]])
Expand Down Expand Up @@ -281,6 +286,8 @@ object RowEncoder {
createDeserializerForSqlDate(input)
}

case DayTimeIntervalType => createDeserializerForDuration(input)

case _: DecimalType => createDeserializerForJavaBigDecimal(input, returnNullable = false)

case StringType => createDeserializerForString(input, returnNullable = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ object InterpretedUnsafeProjection {
case IntegerType | DateType =>
(v, i) => writer.write(i, v.getInt(i))

case LongType | TimestampType =>
case LongType | TimestampType | DayTimeIntervalType =>
(v, i) => writer.write(i, v.getLong(i))

case FloatType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@ final class SpecificInternalRow(val values: Array[MutableValue]) extends BaseGen
private[this] def dataTypeToMutableValue(dataType: DataType): MutableValue = dataType match {
// We use INT for DATE internally
case IntegerType | DateType => new MutableInt
// We use Long for Timestamp internally
case LongType | TimestampType => new MutableLong
// We use Long for Timestamp and DayTimeInterval internally
case LongType | TimestampType | DayTimeIntervalType => new MutableLong
case FloatType => new MutableFloat
case DoubleType => new MutableDouble
case BooleanType => new MutableBoolean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1813,7 +1813,7 @@ object CodeGenerator extends Logging {
case ByteType => JAVA_BYTE
case ShortType => JAVA_SHORT
case IntegerType | DateType => JAVA_INT
case LongType | TimestampType => JAVA_LONG
case LongType | TimestampType | DayTimeIntervalType => JAVA_LONG
case FloatType => JAVA_FLOAT
case DoubleType => JAVA_DOUBLE
case _: DecimalType => "Decimal"
Expand All @@ -1834,7 +1834,7 @@ object CodeGenerator extends Logging {
case ByteType => java.lang.Byte.TYPE
case ShortType => java.lang.Short.TYPE
case IntegerType | DateType => java.lang.Integer.TYPE
case LongType | TimestampType => java.lang.Long.TYPE
case LongType | TimestampType | DayTimeIntervalType => java.lang.Long.TYPE
case FloatType => java.lang.Float.TYPE
case DoubleType => java.lang.Double.TYPE
case _: DecimalType => classOf[Decimal]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import java.lang.{Short => JavaShort}
import java.math.{BigDecimal => JavaBigDecimal}
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
import java.time.{Instant, LocalDate}
import java.time.{Duration, Instant, LocalDate}
import java.util
import java.util.Objects
import javax.xml.bind.DatatypeConverter
Expand All @@ -43,6 +43,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, Scala
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeUtils.instantToMicros
import org.apache.spark.sql.catalyst.util.IntervalUtils.durationToMicros
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -76,6 +77,7 @@ object Literal {
case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType)
case ld: LocalDate => Literal(ld.toEpochDay.toInt, DateType)
case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType)
case d: Duration => Literal(durationToMicros(d), DayTimeIntervalType)
case a: Array[Byte] => Literal(a, BinaryType)
case a: collection.mutable.WrappedArray[_] => apply(a.array)
case a: Array[_] =>
Expand Down Expand Up @@ -111,6 +113,7 @@ object Literal {
case _ if clz == classOf[Date] => DateType
case _ if clz == classOf[Instant] => TimestampType
case _ if clz == classOf[Timestamp] => TimestampType
case _ if clz == classOf[Duration] => DayTimeIntervalType
case _ if clz == classOf[JavaBigDecimal] => DecimalType.SYSTEM_DEFAULT
case _ if clz == classOf[Array[Byte]] => BinaryType
case _ if clz == classOf[Array[Char]] => StringType
Expand Down Expand Up @@ -167,6 +170,7 @@ object Literal {
case dt: DecimalType => Literal(Decimal(0, dt.precision, dt.scale))
case DateType => create(0, DateType)
case TimestampType => create(0L, TimestampType)
case DayTimeIntervalType => create(0L, DayTimeIntervalType)
case StringType => Literal("")
case BinaryType => Literal("".getBytes(StandardCharsets.UTF_8))
case CalendarIntervalType => Literal(new CalendarInterval(0, 0, 0))
Expand All @@ -186,7 +190,7 @@ object Literal {
case ByteType => v.isInstanceOf[Byte]
case ShortType => v.isInstanceOf[Short]
case IntegerType | DateType => v.isInstanceOf[Int]
case LongType | TimestampType => v.isInstanceOf[Long]
case LongType | TimestampType | DayTimeIntervalType => v.isInstanceOf[Long]
case FloatType => v.isInstanceOf[Float]
case DoubleType => v.isInstanceOf[Double]
case _: DecimalType => v.isInstanceOf[Decimal]
Expand Down Expand Up @@ -388,7 +392,7 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression {
}
case ByteType | ShortType =>
ExprCode.forNonNullValue(JavaCode.expression(s"($javaType)$value", dataType))
case TimestampType | LongType =>
case TimestampType | LongType | DayTimeIntervalType =>
toExprCode(s"${value}L")
case _ =>
val constRef = ctx.addReferenceObj("literal", value, javaType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.util

import java.time.Duration
import java.time.temporal.ChronoUnit
import java.util.concurrent.TimeUnit

import scala.util.control.NonFatal
Expand Down Expand Up @@ -762,4 +764,31 @@ object IntervalUtils {

new CalendarInterval(totalMonths, totalDays, micros)
}

/**
* Converts this duration to the total length in microseconds.
* <p>
* If this duration is too large to fit in a [[Long]] microseconds, then an
* exception is thrown.
* <p>
* If this duration has greater than microsecond precision, then the conversion
* will drop any excess precision information as though the amount in nanoseconds
* was subject to integer division by one thousand.
*
* @return The total length of the duration in microseconds
* @throws ArithmeticException If numeric overflow occurs
*/
def durationToMicros(duration: Duration): Long = {
val us = Math.multiplyExact(duration.getSeconds, MICROS_PER_SECOND)
val result = Math.addExact(us, duration.getNano / NANOS_PER_MICROS)
result
}

/**
* Obtains a [[Duration]] representing a number of microseconds.
*
* @param micros The number of microseconds, positive or negative
* @return A [[Duration]], not null
*/
def microsToDuration(micros: Long): Duration = Duration.of(micros, ChronoUnit.MICROS)
}
Loading