Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ public static Object read(
if (handleUserDefinedType && dataType instanceof UserDefinedType) {
return obj.get(ordinal, ((UserDefinedType)dataType).sqlType());
}
if (dataType instanceof TimeType) {
return obj.getLong(ordinal);
}

throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ object Encoders {
*/
def INSTANT: Encoder[java.time.Instant] = ExpressionEncoder()

/**
* Creates an encoder that serializes instances of the `java.time.LocalTime` class
* to the internal representation of nullable Catalyst's TimeType.
*
* @since 3.0.0
*/
def LOCALTIME: Encoder[java.time.LocalTime] = ExpressionEncoder()

/**
* An encoder for arrays of bytes.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.lang.{Iterable => JavaIterable}
import java.math.{BigDecimal => JavaBigDecimal}
import java.math.{BigInteger => JavaBigInteger}
import java.sql.{Date, Timestamp}
import java.time.{Instant, LocalDate}
import java.time.{Instant, LocalDate, LocalTime}
import java.util.{Map => JavaMap}
import javax.annotation.Nullable

Expand Down Expand Up @@ -66,6 +66,7 @@ object CatalystTypeConverters {
case DateType => DateConverter
case TimestampType if SQLConf.get.datetimeJava8ApiEnabled => InstantConverter
case TimestampType => TimestampConverter
case TimeType => LocalTimeConverter
case dt: DecimalType => new DecimalConverter(dt)
case BooleanType => BooleanConverter
case ByteType => ByteConverter
Expand Down Expand Up @@ -341,6 +342,18 @@ object CatalystTypeConverters {
DateTimeUtils.microsToInstant(row.getLong(column))
}

private object LocalTimeConverter extends CatalystTypeConverter[LocalTime, LocalTime, Any] {
override def toCatalystImpl(scalaValue: LocalTime): Long = {
DateTimeUtils.localTimeToMicros(scalaValue)
}
override def toScala(catalystValue: Any): LocalTime = {
if (catalystValue == null) null
else DateTimeUtils.microsToLocalTime(catalystValue.asInstanceOf[Long])
}
override def toScalaImpl(row: InternalRow, column: Int): LocalTime =
DateTimeUtils.microsToLocalTime(row.getLong(column))
}

private class DecimalConverter(dataType: DecimalType)
extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] {

Expand Down Expand Up @@ -452,6 +465,7 @@ object CatalystTypeConverters {
case ld: LocalDate => LocalDateConverter.toCatalyst(ld)
case t: Timestamp => TimestampConverter.toCatalyst(t)
case i: Instant => InstantConverter.toCatalyst(i)
case t: LocalTime => LocalTimeConverter.toCatalyst(t)
case d: BigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d)
case d: JavaBigDecimal => new DecimalConverter(DecimalType(d.precision, d.scale)).toCatalyst(d)
case seq: Seq[Any] => new GenericArrayData(seq.map(convertToCatalyst).toArray)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ object DeserializerBuildHelper {
returnNullable = false)
}

def createDeserializerForLocalTime(path: Expression): Expression = {
StaticInvoke(
DateTimeUtils.getClass,
ObjectType(classOf[java.time.LocalTime]),
"microsToLocalTime",
path :: Nil,
returnNullable = false)
}

def createDeserializerForJavaBigDecimal(
path: Expression,
returnNullable: Boolean): Expression = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ object InternalRow {
case ByteType => (input, ordinal) => input.getByte(ordinal)
case ShortType => (input, ordinal) => input.getShort(ordinal)
case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal)
case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal)
case LongType | TimestampType | TimeType => (input, ordinal) => input.getLong(ordinal)
case FloatType => (input, ordinal) => input.getFloat(ordinal)
case DoubleType => (input, ordinal) => input.getDouble(ordinal)
case StringType => (input, ordinal) => input.getUTF8String(ordinal)
Expand Down Expand Up @@ -166,7 +166,8 @@ object InternalRow {
case ByteType => (input, v) => input.setByte(ordinal, v.asInstanceOf[Byte])
case ShortType => (input, v) => input.setShort(ordinal, v.asInstanceOf[Short])
case IntegerType | DateType => (input, v) => input.setInt(ordinal, v.asInstanceOf[Int])
case LongType | TimestampType => (input, v) => input.setLong(ordinal, v.asInstanceOf[Long])
case LongType | TimestampType | TimeType =>
(input, v) => input.setLong(ordinal, v.asInstanceOf[Long])
case FloatType => (input, v) => input.setFloat(ordinal, v.asInstanceOf[Float])
case DoubleType => (input, v) => input.setDouble(ordinal, v.asInstanceOf[Double])
case DecimalType.Fixed(precision, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ object JavaTypeInference {
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
case c: Class[_] if c == classOf[java.time.Instant] => (TimestampType, true)
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
case c: Class[_] if c == classOf[java.time.LocalTime] => (TimeType, true)

case _ if typeToken.isArray =>
val (dataType, nullable) = inferDataType(typeToken.getComponentType, seenTypeSet)
Expand Down Expand Up @@ -235,6 +236,9 @@ object JavaTypeInference {
case c if c == classOf[java.sql.Timestamp] =>
createDeserializerForSqlTimestamp(path)

case c if c == classOf[java.time.LocalTime] =>
createDeserializerForLocalTime(path)

case c if c == classOf[java.lang.String] =>
createDeserializerForString(path, returnNullable = true)

Expand Down Expand Up @@ -390,6 +394,8 @@ object JavaTypeInference {

case c if c == classOf[java.sql.Date] => createSerializerForSqlDate(inputObject)

case c if c == classOf[java.time.LocalTime] => createSerializerForJavaLocalTime(inputObject)

case c if c == classOf[java.math.BigDecimal] =>
createSerializerForJavaBigDecimal(inputObject)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
createDeserializerForSqlTimestamp(path)

case t if isSubtype(t, localTypeOf[java.time.LocalTime]) =>
createDeserializerForLocalTime(path)

case t if isSubtype(t, localTypeOf[java.lang.String]) =>
createDeserializerForString(path, returnNullable = false)

Expand Down Expand Up @@ -496,6 +499,9 @@ object ScalaReflection extends ScalaReflection {

case t if isSubtype(t, localTypeOf[java.sql.Date]) => createSerializerForSqlDate(inputObject)

case t if isSubtype(t, localTypeOf[java.time.LocalTime]) =>
createSerializerForJavaLocalTime(inputObject)

case t if isSubtype(t, localTypeOf[BigDecimal]) =>
createSerializerForScalaBigDecimal(inputObject)

Expand Down Expand Up @@ -671,6 +677,7 @@ object ScalaReflection extends ScalaReflection {
Schema(TimestampType, nullable = true)
case t if isSubtype(t, localTypeOf[java.time.LocalDate]) => Schema(DateType, nullable = true)
case t if isSubtype(t, localTypeOf[java.sql.Date]) => Schema(DateType, nullable = true)
case t if isSubtype(t, localTypeOf[java.time.LocalTime]) => Schema(TimeType, nullable = true)
case t if isSubtype(t, localTypeOf[BigDecimal]) =>
Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
Expand Down Expand Up @@ -771,6 +778,7 @@ object ScalaReflection extends ScalaReflection {
StringType -> classOf[UTF8String],
DateType -> classOf[DateType.InternalType],
TimestampType -> classOf[TimestampType.InternalType],
TimeType -> classOf[TimeType.InternalType],
BinaryType -> classOf[BinaryType.InternalType],
CalendarIntervalType -> classOf[CalendarInterval]
)
Expand All @@ -784,7 +792,8 @@ object ScalaReflection extends ScalaReflection {
FloatType -> classOf[java.lang.Float],
DoubleType -> classOf[java.lang.Double],
DateType -> classOf[java.lang.Integer],
TimestampType -> classOf[java.lang.Long]
TimestampType -> classOf[java.lang.Long],
TimeType -> classOf[java.lang.Long]
)

def dataTypeJavaClass(dt: DataType): Class[_] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,15 @@ object SerializerBuildHelper {
returnNullable = false)
}

def createSerializerForJavaLocalTime(inputObject: Expression): Expression = {
StaticInvoke(
DateTimeUtils.getClass,
TimeType,
"localTimeToMicros",
inputObject :: Nil,
returnNullable = false)
}

def createSerializerForJavaBigDecimal(inputObject: Expression): Expression = {
CheckOverflow(StaticInvoke(
Decimal.getClass,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ package object dsl {
/** Creates a new AttributeReference of type timestamp */
def timestamp: AttributeReference = AttributeReference(s, TimestampType, nullable = true)()

/** Creates a new AttributeReference of type date */
def time: AttributeReference = AttributeReference(s, TimeType, nullable = true)()

/** Creates a new AttributeReference of type binary */
def binary: AttributeReference = AttributeReference(s, BinaryType, nullable = true)()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ import org.apache.spark.sql.types._
* TimestampType -> java.sql.Timestamp if spark.sql.datetime.java8API.enabled is false
* TimestampType -> java.time.Instant if spark.sql.datetime.java8API.enabled is true
*
* TimeType -> java.time.LocalTime
* BinaryType -> byte array
* ArrayType -> scala.collection.Seq or Array
* MapType -> scala.collection.Map
Expand Down Expand Up @@ -108,6 +109,8 @@ object RowEncoder {
createSerializerForSqlDate(inputObject)
}

case TimeType => createSerializerForJavaLocalTime(inputObject)

case d: DecimalType =>
CheckOverflow(StaticInvoke(
Decimal.getClass,
Expand Down Expand Up @@ -226,6 +229,7 @@ object RowEncoder {
} else {
ObjectType(classOf[java.sql.Date])
}
case TimeType => ObjectType(classOf[java.time.LocalTime])
case _: DecimalType => ObjectType(classOf[java.math.BigDecimal])
case StringType => ObjectType(classOf[java.lang.String])
case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]])
Expand Down Expand Up @@ -281,6 +285,8 @@ object RowEncoder {
createDeserializerForSqlDate(input)
}

case TimeType => createDeserializerForLocalTime(input)

case _: DecimalType => createDeserializerForJavaBigDecimal(input, returnNullable = false)

case StringType => createDeserializerForString(input, returnNullable = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ object InterpretedUnsafeProjection {
case IntegerType | DateType =>
(v, i) => writer.write(i, v.getInt(i))

case LongType | TimestampType =>
case LongType | TimestampType | TimeType =>
(v, i) => writer.write(i, v.getLong(i))

case FloatType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ final class SpecificInternalRow(val values: Array[MutableValue]) extends BaseGen
case ShortType => new MutableShort
// We use INT for DATE internally
case IntegerType | DateType => new MutableInt
// We use Long for Timestamp internally
case LongType | TimestampType => new MutableLong
// We use Long for Timestamp and Time internally
case LongType | TimestampType | TimeType => new MutableLong
case FloatType => new MutableFloat
case DoubleType => new MutableDouble
case _ => new MutableAny
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1630,7 +1630,7 @@ object CodeGenerator extends Logging {
case ByteType => JAVA_BYTE
case ShortType => JAVA_SHORT
case IntegerType | DateType => JAVA_INT
case LongType | TimestampType => JAVA_LONG
case LongType | TimestampType | TimeType => JAVA_LONG
case FloatType => JAVA_FLOAT
case DoubleType => JAVA_DOUBLE
case _: DecimalType => "Decimal"
Expand All @@ -1651,7 +1651,7 @@ object CodeGenerator extends Logging {
case ByteType => java.lang.Byte.TYPE
case ShortType => java.lang.Short.TYPE
case IntegerType | DateType => java.lang.Integer.TYPE
case LongType | TimestampType => java.lang.Long.TYPE
case LongType | TimestampType | TimeType => java.lang.Long.TYPE
case FloatType => java.lang.Float.TYPE
case DoubleType => java.lang.Double.TYPE
case _: DecimalType => classOf[Decimal]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import java.lang.{Short => JavaShort}
import java.math.{BigDecimal => JavaBigDecimal}
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
import java.time.{Instant, LocalDate}
import java.time.{Instant, LocalDate, LocalTime}
import java.util
import java.util.Objects
import javax.xml.bind.DatatypeConverter
Expand All @@ -42,7 +42,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeUtils.instantToMicros
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localTimeToMicros}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types._
Expand Down Expand Up @@ -71,6 +71,7 @@ object Literal {
case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType)
case ld: LocalDate => Literal(ld.toEpochDay.toInt, DateType)
case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType)
case lt: LocalTime => Literal(localTimeToMicros(lt), TimeType)
case a: Array[Byte] => Literal(a, BinaryType)
case a: collection.mutable.WrappedArray[_] => apply(a.array)
case a: Array[_] =>
Expand Down Expand Up @@ -105,6 +106,7 @@ object Literal {
case _ if clz == classOf[Date] => DateType
case _ if clz == classOf[Instant] => TimestampType
case _ if clz == classOf[Timestamp] => TimestampType
case _ if clz == classOf[LocalTime] => TimeType
case _ if clz == classOf[JavaBigDecimal] => DecimalType.SYSTEM_DEFAULT
case _ if clz == classOf[Array[Byte]] => BinaryType
case _ if clz == classOf[JavaShort] => ShortType
Expand Down Expand Up @@ -160,6 +162,7 @@ object Literal {
case dt: DecimalType => Literal(Decimal(0, dt.precision, dt.scale))
case DateType => create(0, DateType)
case TimestampType => create(0L, TimestampType)
case TimeType => create(0L, TimeType)
case StringType => Literal("")
case BinaryType => Literal("".getBytes(StandardCharsets.UTF_8))
case CalendarIntervalType => Literal(new CalendarInterval(0, 0))
Expand All @@ -179,7 +182,7 @@ object Literal {
case ByteType => v.isInstanceOf[Byte]
case ShortType => v.isInstanceOf[Short]
case IntegerType | DateType => v.isInstanceOf[Int]
case LongType | TimestampType => v.isInstanceOf[Long]
case LongType | TimestampType | TimeType => v.isInstanceOf[Long]
case FloatType => v.isInstanceOf[Float]
case DoubleType => v.isInstanceOf[Double]
case _: DecimalType => v.isInstanceOf[Decimal]
Expand Down Expand Up @@ -336,7 +339,7 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression {
}
case ByteType | ShortType =>
ExprCode.forNonNullValue(JavaCode.expression(s"($javaType)$value", dataType))
case TimestampType | LongType =>
case TimestampType | LongType | TimeType =>
toExprCode(s"${value}L")
case _ =>
val constRef = ctx.addReferenceObj("literal", value, javaType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.util
import java.sql.{Date, Timestamp}
import java.time._
import java.time.temporal.{ChronoField, ChronoUnit, IsoFields}
import java.time.temporal.ChronoField.MICRO_OF_DAY
import java.util.{Locale, TimeZone}
import java.util.concurrent.TimeUnit._

Expand Down Expand Up @@ -363,6 +364,10 @@ object DateTimeUtils {

def daysToLocalDate(days: Int): LocalDate = LocalDate.ofEpochDay(days)

def localTimeToMicros(localTime: LocalTime): Long = localTime.getLong(MICRO_OF_DAY)

def microsToLocalTime(us: Long): LocalTime = LocalTime.ofNanoOfDay(us * NANOS_PER_MICROS)

/**
* Trim and parse a given UTF8 date string to a corresponding [[Int]] value.
* The return type is [[Option]] in order to distinguish between 0 and null. The following
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ object DataType {

private val nonDecimalNameToType = {
Seq(NullType, DateType, TimestampType, BinaryType, IntegerType, BooleanType, LongType,
DoubleType, FloatType, ShortType, ByteType, StringType, CalendarIntervalType)
DoubleType, FloatType, ShortType, ByteType, StringType, CalendarIntervalType, TimeType)
.map(t => t.typeName -> t).toMap
}

Expand Down
Loading