Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Support Row encoding
  • Loading branch information
MaxGekk committed Sep 3, 2019
commit cee85e12eb4dca3634c6bd4f00f974edb3b84ecf
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ object DeserializerBuildHelper {
returnNullable = false)
}

def createDeserializerForLocalTime(path: Expression): Expression = {
StaticInvoke(
DateTimeUtils.getClass,
ObjectType(classOf[java.time.LocalTime]),
"microsToLocalTime",
path :: Nil,
returnNullable = false)
}

def createDeserializerForJavaBigDecimal(
path: Expression,
returnNullable: Boolean): Expression = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ object JavaTypeInference {
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
case c: Class[_] if c == classOf[java.time.Instant] => (TimestampType, true)
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
case c: Class[_] if c == classOf[java.time.LocalTime] => (TimeType, true)

case _ if typeToken.isArray =>
val (dataType, nullable) = inferDataType(typeToken.getComponentType, seenTypeSet)
Expand Down Expand Up @@ -235,6 +236,9 @@ object JavaTypeInference {
case c if c == classOf[java.sql.Timestamp] =>
createDeserializerForSqlTimestamp(path)

case c if c == classOf[java.time.LocalTime] =>
createDeserializerForLocalTime(path)

case c if c == classOf[java.lang.String] =>
createDeserializerForString(path, returnNullable = true)

Expand Down Expand Up @@ -390,6 +394,8 @@ object JavaTypeInference {

case c if c == classOf[java.sql.Date] => createSerializerForSqlDate(inputObject)

case c if c == classOf[java.time.LocalTime] => createSerializerForJavaLocalTime(inputObject)

case c if c == classOf[java.math.BigDecimal] =>
createSerializerForJavaBigDecimal(inputObject)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
createDeserializerForSqlTimestamp(path)

case t if isSubtype(t, localTypeOf[java.time.LocalTime]) =>
createDeserializerForLocalTime(path)

case t if isSubtype(t, localTypeOf[java.lang.String]) =>
createDeserializerForString(path, returnNullable = false)

Expand Down Expand Up @@ -496,6 +499,9 @@ object ScalaReflection extends ScalaReflection {

case t if isSubtype(t, localTypeOf[java.sql.Date]) => createSerializerForSqlDate(inputObject)

case t if isSubtype(t, localTypeOf[java.time.LocalTime]) =>
createSerializerForJavaLocalTime(inputObject)

case t if isSubtype(t, localTypeOf[BigDecimal]) =>
createSerializerForScalaBigDecimal(inputObject)

Expand Down Expand Up @@ -671,6 +677,7 @@ object ScalaReflection extends ScalaReflection {
Schema(TimestampType, nullable = true)
case t if isSubtype(t, localTypeOf[java.time.LocalDate]) => Schema(DateType, nullable = true)
case t if isSubtype(t, localTypeOf[java.sql.Date]) => Schema(DateType, nullable = true)
case t if isSubtype(t, localTypeOf[java.time.LocalTime]) => Schema(TimeType, nullable = true)
case t if isSubtype(t, localTypeOf[BigDecimal]) =>
Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
Expand Down Expand Up @@ -771,6 +778,7 @@ object ScalaReflection extends ScalaReflection {
StringType -> classOf[UTF8String],
DateType -> classOf[DateType.InternalType],
TimestampType -> classOf[TimestampType.InternalType],
TimeType -> classOf[TimeType.InternalType],
BinaryType -> classOf[BinaryType.InternalType],
CalendarIntervalType -> classOf[CalendarInterval]
)
Expand All @@ -784,7 +792,8 @@ object ScalaReflection extends ScalaReflection {
FloatType -> classOf[java.lang.Float],
DoubleType -> classOf[java.lang.Double],
DateType -> classOf[java.lang.Integer],
TimestampType -> classOf[java.lang.Long]
TimestampType -> classOf[java.lang.Long],
TimeType -> classOf[java.lang.Long]
)

def dataTypeJavaClass(dt: DataType): Class[_] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,15 @@ object SerializerBuildHelper {
returnNullable = false)
}

def createSerializerForJavaLocalTime(inputObject: Expression): Expression = {
StaticInvoke(
DateTimeUtils.getClass,
TimeType,
"localTimeToMicros",
inputObject :: Nil,
returnNullable = false)
}

def createSerializerForJavaBigDecimal(inputObject: Expression): Expression = {
CheckOverflow(StaticInvoke(
Decimal.getClass,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ import org.apache.spark.sql.types._
* TimestampType -> java.sql.Timestamp if spark.sql.datetime.java8API.enabled is false
* TimestampType -> java.time.Instant if spark.sql.datetime.java8API.enabled is true
*
* TimeType -> java.time.LocalTime
* BinaryType -> byte array
* ArrayType -> scala.collection.Seq or Array
* MapType -> scala.collection.Map
Expand Down Expand Up @@ -108,6 +109,8 @@ object RowEncoder {
createSerializerForSqlDate(inputObject)
}

case TimeType => createSerializerForJavaLocalTime(inputObject)

case d: DecimalType =>
CheckOverflow(StaticInvoke(
Decimal.getClass,
Expand Down Expand Up @@ -226,6 +229,7 @@ object RowEncoder {
} else {
ObjectType(classOf[java.sql.Date])
}
case TimeType => ObjectType(classOf[java.time.LocalTime])
case _: DecimalType => ObjectType(classOf[java.math.BigDecimal])
case StringType => ObjectType(classOf[java.lang.String])
case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]])
Expand Down Expand Up @@ -281,6 +285,8 @@ object RowEncoder {
createDeserializerForSqlDate(input)
}

case TimeType => createDeserializerForLocalTime(input)

case _: DecimalType => createDeserializerForJavaBigDecimal(input, returnNullable = false)

case StringType => createDeserializerForString(input, returnNullable = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1630,7 +1630,7 @@ object CodeGenerator extends Logging {
case ByteType => JAVA_BYTE
case ShortType => JAVA_SHORT
case IntegerType | DateType => JAVA_INT
case LongType | TimestampType => JAVA_LONG
case LongType | TimestampType | TimeType => JAVA_LONG
case FloatType => JAVA_FLOAT
case DoubleType => JAVA_DOUBLE
case _: DecimalType => "Decimal"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,16 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
}
}

test("encoding/decoding TimeType to/from java.time.LocalTime") {
val schema = new StructType().add("t", TimeType)
val encoder = RowEncoder(schema).resolveAndBind()
val localTime = java.time.LocalTime.parse("20:38:45.123456")
val row = encoder.toRow(Row(localTime))
assert(row.getLong(0) === DateTimeUtils.localTimeToMicros(localTime))
val readback = encoder.fromRow(row)
assert(readback.get(0).equals(localTime))
}

for {
elementType <- Seq(IntegerType, StringType)
containsNull <- Seq(true, false)
Expand Down