diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java index d1bb719aca8f..90f340b51c3e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java @@ -86,6 +86,9 @@ public static Object read( if (dataType instanceof DayTimeIntervalType) { return obj.getLong(ordinal); } + if (dataType instanceof YearMonthIntervalType) { + return obj.getInt(ordinal); + } throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString()); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index 5e72b19ca5da..d50829578e6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -143,6 +143,14 @@ object Encoders { */ def DURATION: Encoder[java.time.Duration] = ExpressionEncoder() + /** + * Creates an encoder that serializes instances of the `java.time.Period` class + * to the internal representation of nullable Catalyst's YearMonthIntervalType. + * + * @since 3.2.0 + */ + def PERIOD: Encoder[java.time.Period] = ExpressionEncoder() + /** * Creates an encoder for Java Bean of type T. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 8201fd7d8fb5..b55d1b725f56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -21,7 +21,7 @@ import java.lang.{Iterable => JavaIterable} import java.math.{BigDecimal => JavaBigDecimal} import java.math.{BigInteger => JavaBigInteger} import java.sql.{Date, Timestamp} -import java.time.{Duration, Instant, LocalDate} +import java.time.{Duration, Instant, LocalDate, Period} import java.util.{Map => JavaMap} import javax.annotation.Nullable @@ -75,6 +75,7 @@ object CatalystTypeConverters { case FloatType => FloatConverter case DoubleType => DoubleConverter case DayTimeIntervalType => DurationConverter + case YearMonthIntervalType => PeriodConverter case dataType: DataType => IdentityConverter(dataType) } converter.asInstanceOf[CatalystTypeConverter[Any, Any, Any]] @@ -413,6 +414,18 @@ object CatalystTypeConverters { IntervalUtils.microsToDuration(row.getLong(column)) } + private object PeriodConverter extends CatalystTypeConverter[Period, Period, Any] { + override def toCatalystImpl(scalaValue: Period): Int = { + IntervalUtils.periodToMonths(scalaValue) + } + override def toScala(catalystValue: Any): Period = { + if (catalystValue == null) null + else IntervalUtils.monthsToPeriod(catalystValue.asInstanceOf[Int]) + } + override def toScalaImpl(row: InternalRow, column: Int): Period = + IntervalUtils.monthsToPeriod(row.getInt(column)) + } + /** * Creates a converter function that will convert Scala objects to the specified Catalyst type. * Typical use case would be converting a collection of rows that have the same schema. You will @@ -479,6 +492,7 @@ object CatalystTypeConverters { (key: Any) => convertToCatalyst(key), (value: Any) => convertToCatalyst(value)) case d: Duration => DurationConverter.toCatalyst(d) + case p: Period => PeriodConverter.toCatalyst(p) case other => other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index 03243b4610f6..eaa7c17bfd31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -152,6 +152,15 @@ object DeserializerBuildHelper { returnNullable = false) } + def createDeserializerForPeriod(path: Expression): Expression = { + StaticInvoke( + IntervalUtils.getClass, + ObjectType(classOf[java.time.Period]), + "monthsToPeriod", + path :: Nil, + returnNullable = false) + } + /** * When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff * and lost the required data type, which may lead to runtime error if the real type doesn't diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 00b2d16a2ad6..fd74f60c0c47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -132,7 +132,8 @@ object InternalRow { case BooleanType => (input, ordinal) => input.getBoolean(ordinal) case ByteType => (input, ordinal) => input.getByte(ordinal) case ShortType => (input, ordinal) => input.getShort(ordinal) - case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal) + case IntegerType | DateType | YearMonthIntervalType => + (input, ordinal) => input.getInt(ordinal) case LongType | TimestampType | DayTimeIntervalType => (input, ordinal) => input.getLong(ordinal) case FloatType => (input, ordinal) => input.getFloat(ordinal) @@ -168,7 +169,8 @@ object InternalRow { case BooleanType => (input, v) => input.setBoolean(ordinal, v.asInstanceOf[Boolean]) case ByteType => (input, v) => input.setByte(ordinal, v.asInstanceOf[Byte]) case ShortType => (input, v) => input.setShort(ordinal, v.asInstanceOf[Short]) - case IntegerType | DateType => (input, v) => input.setInt(ordinal, v.asInstanceOf[Int]) + case IntegerType | DateType | YearMonthIntervalType => + (input, v) => input.setInt(ordinal, v.asInstanceOf[Int]) case LongType | TimestampType | DayTimeIntervalType => (input, v) => input.setLong(ordinal, v.asInstanceOf[Long]) case FloatType => (input, v) => input.setFloat(ordinal, v.asInstanceOf[Float]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 7f055a1e77bd..541b78336ba0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -119,6 +119,7 @@ object JavaTypeInference { case c: Class[_] if c == classOf[java.time.Instant] => (TimestampType, true) case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) case c: Class[_] if c == classOf[java.time.Duration] => (DayTimeIntervalType, true) + case c: Class[_] if c == classOf[java.time.Period] => (YearMonthIntervalType, true) case _ if typeToken.isArray => val (dataType, nullable) = inferDataType(typeToken.getComponentType, seenTypeSet) @@ -253,6 +254,9 @@ object JavaTypeInference { case c if c == classOf[java.time.Duration] => createDeserializerForDuration(path) + case c if c == classOf[java.time.Period] => + createDeserializerForPeriod(path) + case c if c == classOf[java.lang.String] => createDeserializerForString(path, returnNullable = true) @@ -412,6 +416,8 @@ object JavaTypeInference { case c if c == classOf[java.time.Duration] => createSerializerForJavaDuration(inputObject) + case c if c == classOf[java.time.Period] => createSerializerForJavaPeriod(inputObject) + case c if c == classOf[java.math.BigDecimal] => createSerializerForJavaBigDecimal(inputObject) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index bdb2a8ebff82..c258cdfa767a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -243,6 +243,9 @@ object ScalaReflection extends ScalaReflection { case t if isSubtype(t, localTypeOf[java.time.Duration]) => createDeserializerForDuration(path) + case t if isSubtype(t, localTypeOf[java.time.Period]) => + createDeserializerForPeriod(path) + case t if isSubtype(t, localTypeOf[java.lang.String]) => createDeserializerForString(path, returnNullable = false) @@ -528,6 +531,9 @@ object ScalaReflection extends ScalaReflection { case t if isSubtype(t, localTypeOf[java.time.Duration]) => createSerializerForJavaDuration(inputObject) + case t if isSubtype(t, localTypeOf[java.time.Period]) => + createSerializerForJavaPeriod(inputObject) + case t if isSubtype(t, localTypeOf[BigDecimal]) => createSerializerForScalaBigDecimal(inputObject) @@ -748,6 +754,8 @@ object ScalaReflection extends ScalaReflection { Schema(CalendarIntervalType, nullable = true) case t if isSubtype(t, localTypeOf[java.time.Duration]) => Schema(DayTimeIntervalType, nullable = true) + case t if isSubtype(t, localTypeOf[java.time.Period]) => + Schema(YearMonthIntervalType, nullable = true) case t if isSubtype(t, localTypeOf[BigDecimal]) => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) => @@ -846,7 +854,8 @@ object ScalaReflection extends ScalaReflection { TimestampType -> classOf[TimestampType.InternalType], BinaryType -> classOf[BinaryType.InternalType], CalendarIntervalType -> classOf[CalendarInterval], - DayTimeIntervalType -> classOf[DayTimeIntervalType.InternalType] + DayTimeIntervalType -> classOf[DayTimeIntervalType.InternalType], + YearMonthIntervalType -> classOf[YearMonthIntervalType.InternalType] ) val typeBoxedJavaMapping = Map[DataType, Class[_]]( @@ -859,7 +868,8 @@ object ScalaReflection extends ScalaReflection { DoubleType -> classOf[java.lang.Double], DateType -> classOf[java.lang.Integer], TimestampType -> classOf[java.lang.Long], - DayTimeIntervalType -> classOf[java.lang.Long] + DayTimeIntervalType -> classOf[java.lang.Long], + YearMonthIntervalType -> classOf[java.lang.Integer] ) def dataTypeJavaClass(dt: DataType): Class[_] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala index fcecfbe925c8..f80fab573c9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala @@ -113,6 +113,15 @@ object SerializerBuildHelper { returnNullable = false) } + def createSerializerForJavaPeriod(inputObject: Expression): Expression = { + StaticInvoke( + IntervalUtils.getClass, + YearMonthIntervalType, + "periodToMonths", + inputObject :: Nil, + returnNullable = false) + } + def createSerializerForJavaBigDecimal(inputObject: Expression): Expression = { CheckOverflow(StaticInvoke( Decimal.getClass, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 5d559731055f..626ece33f157 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -302,6 +302,11 @@ package object dsl { AttributeReference(s, DayTimeIntervalType, nullable = true)() } + /** Creates a new AttributeReference of the year-month interval type */ + def yearMonthInterval: AttributeReference = { + AttributeReference(s, YearMonthIntervalType, nullable = true)() + } + /** Creates a new AttributeReference of type binary */ def binary: AttributeReference = AttributeReference(s, BinaryType, nullable = true)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index ebda55b60e66..b67f70754e89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -54,6 +54,7 @@ import org.apache.spark.sql.types._ * TimestampType -> java.time.Instant if spark.sql.datetime.java8API.enabled is true * * DayTimeIntervalType -> java.time.Duration + * YearMonthIntervalType -> java.time.Period * * BinaryType -> byte array * ArrayType -> scala.collection.Seq or Array @@ -112,6 +113,8 @@ object RowEncoder { case DayTimeIntervalType => createSerializerForJavaDuration(inputObject) + case YearMonthIntervalType => createSerializerForJavaPeriod(inputObject) + case d: DecimalType => CheckOverflow(StaticInvoke( Decimal.getClass, @@ -231,6 +234,7 @@ object RowEncoder { ObjectType(classOf[java.sql.Date]) } case DayTimeIntervalType => ObjectType(classOf[java.time.Duration]) + case YearMonthIntervalType => ObjectType(classOf[java.time.Period]) case _: DecimalType => ObjectType(classOf[java.math.BigDecimal]) case StringType => ObjectType(classOf[java.lang.String]) case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]]) @@ -288,6 +292,8 @@ object RowEncoder { case DayTimeIntervalType => createDeserializerForDuration(input) + case YearMonthIntervalType => createDeserializerForPeriod(input) + case _: DecimalType => createDeserializerForJavaBigDecimal(input, returnNullable = false) case StringType => createDeserializerForString(input, returnNullable = false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index 00ac3d64aaeb..908b73abadfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -157,7 +157,7 @@ object InterpretedUnsafeProjection { case ShortType => (v, i) => writer.write(i, v.getShort(i)) - case IntegerType | DateType => + case IntegerType | DateType | YearMonthIntervalType => (v, i) => writer.write(i, v.getInt(i)) case LongType | TimestampType | DayTimeIntervalType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala index fd22978544a7..0f2619246899 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala @@ -193,8 +193,8 @@ final class MutableAny extends MutableValue { final class SpecificInternalRow(val values: Array[MutableValue]) extends BaseGenericInternalRow { private[this] def dataTypeToMutableValue(dataType: DataType): MutableValue = dataType match { - // We use INT for DATE internally - case IntegerType | DateType => new MutableInt + // We use INT for DATE and YearMonthIntervalType internally + case IntegerType | DateType | YearMonthIntervalType => new MutableInt // We use Long for Timestamp and DayTimeInterval internally case LongType | TimestampType | DayTimeIntervalType => new MutableLong case FloatType => new MutableFloat diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 67c4adfb3887..45ee1934777f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1812,7 +1812,7 @@ object CodeGenerator extends Logging { case BooleanType => JAVA_BOOLEAN case ByteType => JAVA_BYTE case ShortType => JAVA_SHORT - case IntegerType | DateType => JAVA_INT + case IntegerType | DateType | YearMonthIntervalType => JAVA_INT case LongType | TimestampType | DayTimeIntervalType => JAVA_LONG case FloatType => JAVA_FLOAT case DoubleType => JAVA_DOUBLE @@ -1833,7 +1833,7 @@ object CodeGenerator extends Logging { case BooleanType => java.lang.Boolean.TYPE case ByteType => java.lang.Byte.TYPE case ShortType => java.lang.Short.TYPE - case IntegerType | DateType => java.lang.Integer.TYPE + case IntegerType | DateType | YearMonthIntervalType => java.lang.Integer.TYPE case LongType | TimestampType | DayTimeIntervalType => java.lang.Long.TYPE case FloatType => java.lang.Float.TYPE case DoubleType => java.lang.Double.TYPE diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 203e98cfd737..2ea73e83c743 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -28,7 +28,7 @@ import java.lang.{Short => JavaShort} import java.math.{BigDecimal => JavaBigDecimal} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} -import java.time.{Duration, Instant, LocalDate} +import java.time.{Duration, Instant, LocalDate, Period} import java.util import java.util.Objects import javax.xml.bind.DatatypeConverter @@ -43,7 +43,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, Scala import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.DateTimeUtils.instantToMicros -import org.apache.spark.sql.catalyst.util.IntervalUtils.durationToMicros +import org.apache.spark.sql.catalyst.util.IntervalUtils.{durationToMicros, periodToMonths} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -78,6 +78,7 @@ object Literal { case ld: LocalDate => Literal(ld.toEpochDay.toInt, DateType) case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType) case d: Duration => Literal(durationToMicros(d), DayTimeIntervalType) + case p: Period => Literal(periodToMonths(p), YearMonthIntervalType) case a: Array[Byte] => Literal(a, BinaryType) case a: collection.mutable.WrappedArray[_] => apply(a.array) case a: Array[_] => @@ -114,6 +115,7 @@ object Literal { case _ if clz == classOf[Instant] => TimestampType case _ if clz == classOf[Timestamp] => TimestampType case _ if clz == classOf[Duration] => DayTimeIntervalType + case _ if clz == classOf[Period] => YearMonthIntervalType case _ if clz == classOf[JavaBigDecimal] => DecimalType.SYSTEM_DEFAULT case _ if clz == classOf[Array[Byte]] => BinaryType case _ if clz == classOf[Array[Char]] => StringType @@ -171,6 +173,7 @@ object Literal { case DateType => create(0, DateType) case TimestampType => create(0L, TimestampType) case DayTimeIntervalType => create(0L, DayTimeIntervalType) + case YearMonthIntervalType => create(0, YearMonthIntervalType) case StringType => Literal("") case BinaryType => Literal("".getBytes(StandardCharsets.UTF_8)) case CalendarIntervalType => Literal(new CalendarInterval(0, 0, 0)) @@ -189,7 +192,7 @@ object Literal { case BooleanType => v.isInstanceOf[Boolean] case ByteType => v.isInstanceOf[Byte] case ShortType => v.isInstanceOf[Short] - case IntegerType | DateType => v.isInstanceOf[Int] + case IntegerType | DateType | YearMonthIntervalType => v.isInstanceOf[Int] case LongType | TimestampType | DayTimeIntervalType => v.isInstanceOf[Long] case FloatType => v.isInstanceOf[Float] case DoubleType => v.isInstanceOf[Double] @@ -366,7 +369,7 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { ExprCode.forNonNullValue(JavaCode.literal(code, dataType)) } dataType match { - case BooleanType | IntegerType | DateType => + case BooleanType | IntegerType | DateType | YearMonthIntervalType => toExprCode(value.toString) case FloatType => value.asInstanceOf[Float] match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index 6be4e9f04306..06ab4b603f02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import java.time.Duration +import java.time.{Duration, Period} import java.time.temporal.ChronoUnit import java.util.concurrent.TimeUnit @@ -791,4 +791,35 @@ object IntervalUtils { * @return A [[Duration]], not null */ def microsToDuration(micros: Long): Duration = Duration.of(micros, ChronoUnit.MICROS) + + /** + * Gets the total number of months in this period. + *

+ * This returns the total number of months in the period by multiplying the + * number of years by 12 and adding the number of months. + *

+ * + * @return The total number of months in the period, may be negative + * @throws ArithmeticException If numeric overflow occurs + */ + def periodToMonths(period: Period): Int = { + val monthsInYears = Math.multiplyExact(period.getYears, MONTHS_PER_YEAR) + Math.addExact(monthsInYears, period.getMonths) + } + + /** + * Obtains a [[Period]] representing a number of months. The days unit will be zero, and the years + * and months units will be normalized. + * + *

+ * The months unit is adjusted to have an absolute value < 12, with the years unit being adjusted + * to compensate. For example, the method returns "2 years and 3 months" for the 27 input months. + *

+ * The sign of the years and months units will be the same after normalization. + * For example, -13 months will be converted to "-1 year and -1 month". + * + * @param months The number of months, positive or negative + * @return The period of months, not null + */ + def monthsToPeriod(months: Int): Period = Period.ofMonths(months).normalized() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 05273270de30..5c5742c812e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -171,7 +171,7 @@ object DataType { private val otherTypes = { Seq(NullType, DateType, TimestampType, BinaryType, IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType, CalendarIntervalType, - DayTimeIntervalType) + DayTimeIntervalType, YearMonthIntervalType) .map(t => t.typeName -> t).toMap } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala index 6b66af842fcc..0dbae707a4a3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst -import java.time.{Duration, Instant, LocalDate} +import java.time.{Duration, Instant, LocalDate, Period} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row @@ -258,4 +258,40 @@ class CatalystTypeConvertersSuite extends SparkFunSuite with SQLHelper { } } } + + test("SPARK-34615: converting java.time.Period to YearMonthIntervalType") { + Seq( + Period.ZERO, + Period.ofMonths(1), + Period.ofMonths(-1), + Period.ofMonths(Int.MaxValue).normalized(), + Period.ofMonths(Int.MinValue).normalized(), + Period.ofYears(106751991), + Period.ofYears(-106751991)).foreach { input => + val result = CatalystTypeConverters.convertToCatalyst(input) + val expected = IntervalUtils.periodToMonths(input) + assert(result === expected) + } + + val errMsg = intercept[ArithmeticException] { + IntervalUtils.periodToMonths(Period.of(Int.MaxValue, Int.MaxValue, Int.MaxValue)) + }.getMessage + assert(errMsg.contains("integer overflow")) + } + + test("SPARK-34615: converting YearMonthIntervalType to java.time.Period") { + Seq( + 0, + 1, + 999999, + 1000000, + Int.MaxValue).foreach { input => + Seq(1, -1).foreach { sign => + val months = sign * input + val period = IntervalUtils.monthsToPeriod(months) + assert( + CatalystTypeConverters.createToScalaConverter(YearMonthIntervalType)(months) === period) + } + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 9ab336165d16..6c22c14870d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -352,6 +352,16 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { assert(readback.get(0).equals(duration)) } + test("SPARK-34615: encoding/decoding YearMonthIntervalType to/from java.time.Period") { + val schema = new StructType().add("p", YearMonthIntervalType) + val encoder = RowEncoder(schema).resolveAndBind() + val period = java.time.Period.ofMonths(1) + val row = toRow(encoder, Row(period)) + assert(row.getInt(0) === IntervalUtils.periodToMonths(period)) + val readback = fromRow(encoder, row) + assert(readback.get(0).equals(period)) + } + for { elementType <- Seq(IntegerType, StringType) containsNull <- Seq(true, false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index 8cba46cabd91..f8766f3fd27a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.nio.charset.StandardCharsets -import java.time.{Duration, Instant, LocalDate, LocalDateTime, ZoneOffset} +import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period, ZoneOffset} import java.util.TimeZone import scala.reflect.runtime.universe.TypeTag @@ -367,4 +367,22 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { val duration1 = Duration.ofHours(-1024) checkEvaluation(Literal(Array(duration0, duration1)), Array(duration0, duration1)) } + + test("SPARK-34615: construct literals from java.time.Period") { + Seq( + Period.ofYears(0), + Period.of(-1, 11, 0), + Period.of(1, -11, 0), + Period.ofMonths(Int.MaxValue), + Period.ofMonths(Int.MinValue)).foreach { period => + checkEvaluation(Literal(period), period) + } + } + + test("SPARK-34615: construct literals from arrays of java.time.Period") { + val period0 = Period.ofYears(123).withMonths(456) + checkEvaluation(Literal(Array(period0)), Array(period0)) + val period1 = Period.ofMonths(-1024) + checkEvaluation(Literal(Array(period0, period1)), Array(period0, period1)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala index 51fd291bba0d..df2656fb7aa8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import java.time.Duration +import java.time.{Duration, Period} import java.util.concurrent.TimeUnit import org.apache.spark.SparkFunSuite @@ -401,4 +401,28 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper { }.getMessage assert(errMsg.contains("long overflow")) } + + test("SPARK-34615: period to months") { + assert(periodToMonths(Period.ZERO) === 0) + assert(periodToMonths(Period.of(0, -1, 0)) === -1) + assert(periodToMonths(Period.of(-1, 0, 10)) === -12) // ignore days + assert(periodToMonths(Period.of(178956970, 7, 0)) === Int.MaxValue) + assert(periodToMonths(Period.of(-178956970, -8, 123)) === Int.MinValue) + assert(periodToMonths(Period.of(0, Int.MaxValue, Int.MaxValue)) === Int.MaxValue) + + val errMsg = intercept[ArithmeticException] { + periodToMonths(Period.of(Int.MaxValue, 0, 0)) + }.getMessage + assert(errMsg.contains("integer overflow")) + } + + test("SPARK-34615: months to period") { + assert(monthsToPeriod(0) === Period.ZERO) + assert(monthsToPeriod(-11) === Period.of(0, -11, 0)) + assert(monthsToPeriod(11) === Period.of(0, 11, 0)) + assert(monthsToPeriod(27) === Period.of(2, 3, 0)) + assert(monthsToPeriod(-13) === Period.of(-1, -1, 0)) + assert(monthsToPeriod(Int.MaxValue) === Period.ofYears(178956970).withMonths(7)) + assert(monthsToPeriod(Int.MinValue) === Period.ofYears(-178956970).withMonths(-8)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index bcc4871d3e64..90188cadfd3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -91,6 +91,9 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits { /** @since 3.2.0 */ implicit def newDurationEncoder: Encoder[java.time.Duration] = Encoders.DURATION + /** @since 3.2.0 */ + implicit def newPeriodEncoder: Encoder[java.time.Period] = Encoders.PERIOD + /** @since 3.2.0 */ implicit def newJavaEnumEncoder[A <: java.lang.Enum[_] : TypeTag]: Encoder[A] = ExpressionEncoder() diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 85ad80e21329..93566e030e21 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -24,6 +24,7 @@ import java.time.Duration; import java.time.Instant; import java.time.LocalDate; +import java.time.Period; import java.util.*; import javax.annotation.Nonnull; @@ -421,6 +422,14 @@ public void testDurationEncoder() { Assert.assertEquals(data, ds.collectAsList()); } + @Test + public void testPeriodEncoder() { + Encoder encoder = Encoders.PERIOD(); + List data = Arrays.asList(Period.ofYears(10)); + Dataset ds = spark.createDataset(data, encoder); + Assert.assertEquals(data, ds.collectAsList()); + } + public static class KryoSerializable { String value; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 843696e04847..a98bb060636e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -2012,6 +2012,11 @@ class DatasetSuite extends QueryTest val duration = java.time.Duration.ofMinutes(10) assert(spark.range(1).map { _ => duration }.head === duration) } + + test("SPARK-34615: implicit encoder for java.time.Period") { + val period = java.time.Period.ofYears(9999).withMonths(11) + assert(spark.range(1).map { _ => period }.head === period) + } } case class Bar(a: Int)