Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ object Encoders {
*/
def INSTANT: Encoder[java.time.Instant] = ExpressionEncoder()

/**
* Creates an encoder that serializes instances of the `java.time.Duration` class
* to the internal representation of nullable Catalyst's CalendarIntervalType.
*
* @since 3.0.0
*/
def DURATION: Encoder[java.time.Duration] = ExpressionEncoder()

/**
* An encoder for arrays of bytes.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.lang.{Iterable => JavaIterable}
import java.math.{BigDecimal => JavaBigDecimal}
import java.math.{BigInteger => JavaBigInteger}
import java.sql.{Date, Timestamp}
import java.time.{Instant, LocalDate}
import java.time.{Duration, Instant, LocalDate}
import java.util.{Map => JavaMap}
import javax.annotation.Nullable

Expand All @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

/**
* Functions to convert Scala types to Catalyst types and vice versa.
Expand Down Expand Up @@ -74,6 +74,7 @@ object CatalystTypeConverters {
case LongType => LongConverter
case FloatType => FloatConverter
case DoubleType => DoubleConverter
case CalendarIntervalType => DurationConverter
case dataType: DataType => IdentityConverter(dataType)
}
converter.asInstanceOf[CatalystTypeConverter[Any, Any, Any]]
Expand Down Expand Up @@ -341,6 +342,16 @@ object CatalystTypeConverters {
DateTimeUtils.microsToInstant(row.getLong(column))
}

private object DurationConverter extends CatalystTypeConverter[Duration, Duration, Any] {
override def toCatalystImpl(scalaValue: Duration): CalendarInterval =
DateTimeUtils.durationToInterval(scalaValue)
override def toScala(catalystValue: Any): Duration =
if (catalystValue == null) null
else DateTimeUtils.intervalToDuration(catalystValue.asInstanceOf[CalendarInterval])
override def toScalaImpl(row: InternalRow, column: Int): Duration =
DateTimeUtils.intervalToDuration(row.getInterval(column))
}

private class DecimalConverter(dataType: DecimalType)
extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] {

Expand Down Expand Up @@ -462,6 +473,7 @@ object CatalystTypeConverters {
map,
(key: Any) => convertToCatalyst(key),
(value: Any) => convertToCatalyst(value))
case d: Duration => DurationConverter.toCatalyst(d)
case other => other
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ object DeserializerBuildHelper {
returnNullable = false)
}

def createDeserializerForDuration(path: Expression): Expression = {
StaticInvoke(
DateTimeUtils.getClass,
ObjectType(classOf[java.time.Duration]),
"intervalToDuration",
path :: Nil,
returnNullable = false)
}

def createDeserializerForJavaBigDecimal(
path: Expression,
returnNullable: Boolean): Expression = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

/**
* An abstract class for row used internally in Spark SQL, which only contains the columns as
Expand Down Expand Up @@ -58,6 +58,8 @@ abstract class InternalRow extends SpecializedGetters with Serializable {
*/
def setDecimal(i: Int, value: Decimal, precision: Int): Unit = update(i, value)

def setInterval(i: Int, value: CalendarInterval): Unit = update(i, value)

/**
* Make a copy of the current [[InternalRow]] object.
*/
Expand Down Expand Up @@ -177,6 +179,8 @@ object InternalRow {
case _: StructType => (input, v) => input.update(ordinal, v.asInstanceOf[InternalRow].copy())
case _: ArrayType => (input, v) => input.update(ordinal, v.asInstanceOf[ArrayData].copy())
case _: MapType => (input, v) => input.update(ordinal, v.asInstanceOf[MapData].copy())
case _: CalendarIntervalType =>
(input, v) => input.setInterval(ordinal, v.asInstanceOf[CalendarInterval])
case _ => (input, v) => input.update(ordinal, v)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ object JavaTypeInference {
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
case c: Class[_] if c == classOf[java.time.Instant] => (TimestampType, true)
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
case c: Class[_] if c == classOf[java.time.Duration] => (CalendarIntervalType, true)

case _ if typeToken.isArray =>
val (dataType, nullable) = inferDataType(typeToken.getComponentType, seenTypeSet)
Expand Down Expand Up @@ -235,6 +236,9 @@ object JavaTypeInference {
case c if c == classOf[java.sql.Timestamp] =>
createDeserializerForSqlTimestamp(path)

case c if c == classOf[java.time.Duration] =>
createDeserializerForDuration(path)

case c if c == classOf[java.lang.String] =>
createDeserializerForString(path, returnNullable = true)

Expand Down Expand Up @@ -390,6 +394,8 @@ object JavaTypeInference {

case c if c == classOf[java.sql.Date] => createSerializerForSqlDate(inputObject)

case c if c == classOf[java.time.Duration] => createSerializerForJavaDuration(inputObject)

case c if c == classOf[java.math.BigDecimal] =>
createSerializerForJavaBigDecimal(inputObject)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ object ScalaReflection extends ScalaReflection {
*/
def isNativeType(dt: DataType): Boolean = dt match {
case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType | CalendarIntervalType => true
FloatType | DoubleType | BinaryType => true
case _ => false
}

Expand Down Expand Up @@ -230,6 +230,9 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
createDeserializerForSqlTimestamp(path)

case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
createDeserializerForDuration(path)

case t if isSubtype(t, localTypeOf[java.lang.String]) =>
createDeserializerForString(path, returnNullable = false)

Expand Down Expand Up @@ -496,6 +499,9 @@ object ScalaReflection extends ScalaReflection {

case t if isSubtype(t, localTypeOf[java.sql.Date]) => createSerializerForSqlDate(inputObject)

case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
createSerializerForJavaDuration(inputObject)

case t if isSubtype(t, localTypeOf[BigDecimal]) =>
createSerializerForScalaBigDecimal(inputObject)

Expand Down Expand Up @@ -671,6 +677,8 @@ object ScalaReflection extends ScalaReflection {
Schema(TimestampType, nullable = true)
case t if isSubtype(t, localTypeOf[java.time.LocalDate]) => Schema(DateType, nullable = true)
case t if isSubtype(t, localTypeOf[java.sql.Date]) => Schema(DateType, nullable = true)
case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
Schema(CalendarIntervalType, nullable = true)
case t if isSubtype(t, localTypeOf[BigDecimal]) =>
Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,15 @@ object SerializerBuildHelper {
returnNullable = false)
}

def createSerializerForJavaDuration(inputObject: Expression): Expression = {
StaticInvoke(
DateTimeUtils.getClass,
CalendarIntervalType,
"durationToInterval",
inputObject :: Nil,
returnNullable = false)
}

def createSerializerForJavaBigDecimal(inputObject: Expression): Expression = {
CheckOverflow(StaticInvoke(
Decimal.getClass,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ import org.apache.spark.sql.types._
* ArrayType -> scala.collection.Seq or Array
* MapType -> scala.collection.Map
* StructType -> org.apache.spark.sql.Row
* CalendarIntervalType -> java.time.Duration
* }}}
*/
object RowEncoder {
Expand Down Expand Up @@ -108,6 +109,8 @@ object RowEncoder {
createSerializerForSqlDate(inputObject)
}

case CalendarIntervalType => createSerializerForJavaDuration(inputObject)

case d: DecimalType =>
CheckOverflow(StaticInvoke(
Decimal.getClass,
Expand Down Expand Up @@ -226,6 +229,7 @@ object RowEncoder {
} else {
ObjectType(classOf[java.sql.Date])
}
case CalendarIntervalType => ObjectType(classOf[java.time.Duration])
case _: DecimalType => ObjectType(classOf[java.math.BigDecimal])
case StringType => ObjectType(classOf[java.lang.String])
case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]])
Expand Down Expand Up @@ -281,6 +285,8 @@ object RowEncoder {
createDeserializerForSqlDate(input)
}

case CalendarIntervalType => createDeserializerForDuration(input)

case _: DecimalType => createDeserializerForJavaBigDecimal(input, returnNullable = false)

case StringType => createDeserializerForString(input, returnNullable = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import java.lang.{Short => JavaShort}
import java.math.{BigDecimal => JavaBigDecimal}
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
import java.time.{Instant, LocalDate}
import java.time.{Duration, Instant, LocalDate}
import java.util
import java.util.Objects
import javax.xml.bind.DatatypeConverter
Expand All @@ -42,7 +42,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeUtils.instantToMicros
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{durationToInterval, instantToMicros}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types._
Expand Down Expand Up @@ -71,6 +71,7 @@ object Literal {
case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType)
case ld: LocalDate => Literal(ld.toEpochDay.toInt, DateType)
case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType)
case d: Duration => Literal(durationToInterval(d), CalendarIntervalType)
case a: Array[Byte] => Literal(a, BinaryType)
case a: collection.mutable.WrappedArray[_] => apply(a.array)
case a: Array[_] =>
Expand Down Expand Up @@ -120,6 +121,7 @@ object Literal {
case _ if clz == classOf[BigInt] => DecimalType.SYSTEM_DEFAULT
case _ if clz == classOf[BigDecimal] => DecimalType.SYSTEM_DEFAULT
case _ if clz == classOf[CalendarInterval] => CalendarIntervalType
case _ if clz == classOf[Duration] => CalendarIntervalType

case _ if clz.isArray => ArrayType(componentTypeToDataType(clz.getComponentType))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import java.util.concurrent.TimeUnit._
import scala.util.control.NonFatal

import org.apache.spark.sql.types.Decimal
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

/**
* Helper functions for converting between internal and external date and time representations.
Expand Down Expand Up @@ -959,4 +959,14 @@ object DateTimeUtils {
None
}
}

def durationToInterval(duration: Duration): CalendarInterval = {
val micros = duration.getSeconds * MICROS_PER_SECOND + duration.getNano / NANOS_PER_MICROS
new CalendarInterval(0, micros)
}

def intervalToDuration(interval: CalendarInterval): Duration = {
val microsDuration = Duration.ofNanos(interval.microseconds * NANOS_PER_MICROS)
microsDuration.plusSeconds(interval.months * SECONDS_PER_MONTH)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst

import java.time.{Instant, LocalDate}
import java.time.{Duration, Instant, LocalDate}

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
Expand All @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

class CatalystTypeConvertersSuite extends SparkFunSuite with SQLHelper {

Expand Down Expand Up @@ -216,4 +216,30 @@ class CatalystTypeConvertersSuite extends SparkFunSuite with SQLHelper {
}
}
}

test("converting java.time.Duration to CalendarIntervalType") {
Seq(
"P0D",
"PT0.000001S",
"PT-0.000001S",
"P1DT2H3M4.000001S",
"P-1DT2H3M4.000001S").foreach { time =>
val input = Duration.parse(time)
val result = CatalystTypeConverters.convertToCatalyst(input)
val expected = DateTimeUtils.durationToInterval(input)
assert(result === expected)
}
}

test("converting CalendarIntervalType to java.time.Duration") {
Seq(
CalendarInterval.fromString("interval 0 days"),
CalendarInterval.fromString("interval 1 month"),
CalendarInterval.fromString("interval 1 month 1 microsecond"),
CalendarInterval.fromString("interval -1 month -1 microsecond"),
CalendarInterval.fromString("interval 10000 years -1 microsecond")).foreach { i =>
val duration = DateTimeUtils.intervalToDuration(i)
assert(CatalystTypeConverters.createToScalaConverter(CalendarIntervalType)(i) === duration)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,16 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
}
}

test("encoding/decoding CalendarIntervalType to/from java.time.Duration") {
val schema = new StructType().add("i", CalendarIntervalType)
val encoder = RowEncoder(schema).resolveAndBind()
val duration = java.time.Duration.parse("P2DT3H4M")
val row = encoder.toRow(Row(duration))
assert(row.getInterval(0) === DateTimeUtils.durationToInterval(duration))
val readback = encoder.fromRow(row)
assert(readback.get(0).equals(duration))
}

for {
elementType <- Seq(IntegerType, StringType)
containsNull <- Seq(true, false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import java.nio.charset.StandardCharsets
import java.time.{Instant, LocalDate, LocalDateTime, ZoneOffset}
import java.time.{Duration, Instant, LocalDate, LocalDateTime, ZoneOffset}
import java.util.TimeZone

import scala.reflect.runtime.universe.TypeTag
Expand Down Expand Up @@ -76,6 +76,7 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Literal.default(TimestampType), Instant.ofEpochSecond(0))
}
checkEvaluation(Literal.default(CalendarIntervalType), new CalendarInterval(0, 0L))
checkEvaluation(Literal.default(CalendarIntervalType), Duration.ofSeconds(0))
checkEvaluation(Literal.default(ArrayType(StringType)), Array())
checkEvaluation(Literal.default(MapType(IntegerType, StringType)), Map())
checkEvaluation(Literal.default(StructType(StructField("a", StringType) :: Nil)), Row(""))
Expand Down Expand Up @@ -316,4 +317,22 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(literalStr === expected)
}
}

test("construct literals from java.time.Duration") {
Seq(
Duration.ofSeconds(0),
Duration.ofSeconds(1, 999999000),
Duration.ofSeconds(-1, -999999000),
Duration.ofDays(365 * 10000),
Duration.ofDays(-365 * 10000)).foreach { duration =>
checkEvaluation(Literal(duration), duration)
}
}

test("construct literals from arrays of java.time.Duration") {
val duration0 = Duration.ofMinutes(10)
checkEvaluation(Literal(Array(duration0)), Array(duration0))
val duration1 = Duration.ofHours(3)
checkEvaluation(Literal(Array(duration0, duration1)), Array(duration0, duration1))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits {
/** @since 3.0.0 */
implicit def newInstantEncoder: Encoder[java.time.Instant] = Encoders.INSTANT

/** @since 3.0.0 */
implicit def newDurationEncoder: Encoder[java.time.Duration] = Encoders.DURATION

// Boxed primitives

/** @since 2.0.0 */
Expand Down
Loading