diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java index d1bb719aca8f..90f340b51c3e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java @@ -86,6 +86,9 @@ public static Object read( if (dataType instanceof DayTimeIntervalType) { return obj.getLong(ordinal); } + if (dataType instanceof YearMonthIntervalType) { + return obj.getInt(ordinal); + } throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString()); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index 5e72b19ca5da..d50829578e6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -143,6 +143,14 @@ object Encoders { */ def DURATION: Encoder[java.time.Duration] = ExpressionEncoder() + /** + * Creates an encoder that serializes instances of the `java.time.Period` class + * to the internal representation of nullable Catalyst's YearMonthIntervalType. + * + * @since 3.2.0 + */ + def PERIOD: Encoder[java.time.Period] = ExpressionEncoder() + /** * Creates an encoder for Java Bean of type T. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 8201fd7d8fb5..b55d1b725f56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -21,7 +21,7 @@ import java.lang.{Iterable => JavaIterable} import java.math.{BigDecimal => JavaBigDecimal} import java.math.{BigInteger => JavaBigInteger} import java.sql.{Date, Timestamp} -import java.time.{Duration, Instant, LocalDate} +import java.time.{Duration, Instant, LocalDate, Period} import java.util.{Map => JavaMap} import javax.annotation.Nullable @@ -75,6 +75,7 @@ object CatalystTypeConverters { case FloatType => FloatConverter case DoubleType => DoubleConverter case DayTimeIntervalType => DurationConverter + case YearMonthIntervalType => PeriodConverter case dataType: DataType => IdentityConverter(dataType) } converter.asInstanceOf[CatalystTypeConverter[Any, Any, Any]] @@ -413,6 +414,18 @@ object CatalystTypeConverters { IntervalUtils.microsToDuration(row.getLong(column)) } + private object PeriodConverter extends CatalystTypeConverter[Period, Period, Any] { + override def toCatalystImpl(scalaValue: Period): Int = { + IntervalUtils.periodToMonths(scalaValue) + } + override def toScala(catalystValue: Any): Period = { + if (catalystValue == null) null + else IntervalUtils.monthsToPeriod(catalystValue.asInstanceOf[Int]) + } + override def toScalaImpl(row: InternalRow, column: Int): Period = + IntervalUtils.monthsToPeriod(row.getInt(column)) + } + /** * Creates a converter function that will convert Scala objects to the specified Catalyst type. * Typical use case would be converting a collection of rows that have the same schema. You will @@ -479,6 +492,7 @@ object CatalystTypeConverters { (key: Any) => convertToCatalyst(key), (value: Any) => convertToCatalyst(value)) case d: Duration => DurationConverter.toCatalyst(d) + case p: Period => PeriodConverter.toCatalyst(p) case other => other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index 03243b4610f6..eaa7c17bfd31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -152,6 +152,15 @@ object DeserializerBuildHelper { returnNullable = false) } + def createDeserializerForPeriod(path: Expression): Expression = { + StaticInvoke( + IntervalUtils.getClass, + ObjectType(classOf[java.time.Period]), + "monthsToPeriod", + path :: Nil, + returnNullable = false) + } + /** * When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff * and lost the required data type, which may lead to runtime error if the real type doesn't diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 00b2d16a2ad6..fd74f60c0c47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -132,7 +132,8 @@ object InternalRow { case BooleanType => (input, ordinal) => input.getBoolean(ordinal) case ByteType => (input, ordinal) => input.getByte(ordinal) case ShortType => (input, ordinal) => input.getShort(ordinal) - case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal) + case IntegerType | DateType | YearMonthIntervalType => + (input, ordinal) => input.getInt(ordinal) case LongType | TimestampType | DayTimeIntervalType => (input, ordinal) => input.getLong(ordinal) case FloatType => (input, ordinal) => input.getFloat(ordinal) @@ -168,7 +169,8 @@ object InternalRow { case BooleanType => (input, v) => input.setBoolean(ordinal, v.asInstanceOf[Boolean]) case ByteType => (input, v) => input.setByte(ordinal, v.asInstanceOf[Byte]) case ShortType => (input, v) => input.setShort(ordinal, v.asInstanceOf[Short]) - case IntegerType | DateType => (input, v) => input.setInt(ordinal, v.asInstanceOf[Int]) + case IntegerType | DateType | YearMonthIntervalType => + (input, v) => input.setInt(ordinal, v.asInstanceOf[Int]) case LongType | TimestampType | DayTimeIntervalType => (input, v) => input.setLong(ordinal, v.asInstanceOf[Long]) case FloatType => (input, v) => input.setFloat(ordinal, v.asInstanceOf[Float]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 7f055a1e77bd..541b78336ba0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -119,6 +119,7 @@ object JavaTypeInference { case c: Class[_] if c == classOf[java.time.Instant] => (TimestampType, true) case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) case c: Class[_] if c == classOf[java.time.Duration] => (DayTimeIntervalType, true) + case c: Class[_] if c == classOf[java.time.Period] => (YearMonthIntervalType, true) case _ if typeToken.isArray => val (dataType, nullable) = inferDataType(typeToken.getComponentType, seenTypeSet) @@ -253,6 +254,9 @@ object JavaTypeInference { case c if c == classOf[java.time.Duration] => createDeserializerForDuration(path) + case c if c == classOf[java.time.Period] => + createDeserializerForPeriod(path) + case c if c == classOf[java.lang.String] => createDeserializerForString(path, returnNullable = true) @@ -412,6 +416,8 @@ object JavaTypeInference { case c if c == classOf[java.time.Duration] => createSerializerForJavaDuration(inputObject) + case c if c == classOf[java.time.Period] => createSerializerForJavaPeriod(inputObject) + case c if c == classOf[java.math.BigDecimal] => createSerializerForJavaBigDecimal(inputObject) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index bdb2a8ebff82..c258cdfa767a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -243,6 +243,9 @@ object ScalaReflection extends ScalaReflection { case t if isSubtype(t, localTypeOf[java.time.Duration]) => createDeserializerForDuration(path) + case t if isSubtype(t, localTypeOf[java.time.Period]) => + createDeserializerForPeriod(path) + case t if isSubtype(t, localTypeOf[java.lang.String]) => createDeserializerForString(path, returnNullable = false) @@ -528,6 +531,9 @@ object ScalaReflection extends ScalaReflection { case t if isSubtype(t, localTypeOf[java.time.Duration]) => createSerializerForJavaDuration(inputObject) + case t if isSubtype(t, localTypeOf[java.time.Period]) => + createSerializerForJavaPeriod(inputObject) + case t if isSubtype(t, localTypeOf[BigDecimal]) => createSerializerForScalaBigDecimal(inputObject) @@ -748,6 +754,8 @@ object ScalaReflection extends ScalaReflection { Schema(CalendarIntervalType, nullable = true) case t if isSubtype(t, localTypeOf[java.time.Duration]) => Schema(DayTimeIntervalType, nullable = true) + case t if isSubtype(t, localTypeOf[java.time.Period]) => + Schema(YearMonthIntervalType, nullable = true) case t if isSubtype(t, localTypeOf[BigDecimal]) => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) => @@ -846,7 +854,8 @@ object ScalaReflection extends ScalaReflection { TimestampType -> classOf[TimestampType.InternalType], BinaryType -> classOf[BinaryType.InternalType], CalendarIntervalType -> classOf[CalendarInterval], - DayTimeIntervalType -> classOf[DayTimeIntervalType.InternalType] + DayTimeIntervalType -> classOf[DayTimeIntervalType.InternalType], + YearMonthIntervalType -> classOf[YearMonthIntervalType.InternalType] ) val typeBoxedJavaMapping = Map[DataType, Class[_]]( @@ -859,7 +868,8 @@ object ScalaReflection extends ScalaReflection { DoubleType -> classOf[java.lang.Double], DateType -> classOf[java.lang.Integer], TimestampType -> classOf[java.lang.Long], - DayTimeIntervalType -> classOf[java.lang.Long] + DayTimeIntervalType -> classOf[java.lang.Long], + YearMonthIntervalType -> classOf[java.lang.Integer] ) def dataTypeJavaClass(dt: DataType): Class[_] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala index fcecfbe925c8..f80fab573c9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala @@ -113,6 +113,15 @@ object SerializerBuildHelper { returnNullable = false) } + def createSerializerForJavaPeriod(inputObject: Expression): Expression = { + StaticInvoke( + IntervalUtils.getClass, + YearMonthIntervalType, + "periodToMonths", + inputObject :: Nil, + returnNullable = false) + } + def createSerializerForJavaBigDecimal(inputObject: Expression): Expression = { CheckOverflow(StaticInvoke( Decimal.getClass, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 5d559731055f..626ece33f157 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -302,6 +302,11 @@ package object dsl { AttributeReference(s, DayTimeIntervalType, nullable = true)() } + /** Creates a new AttributeReference of the year-month interval type */ + def yearMonthInterval: AttributeReference = { + AttributeReference(s, YearMonthIntervalType, nullable = true)() + } + /** Creates a new AttributeReference of type binary */ def binary: AttributeReference = AttributeReference(s, BinaryType, nullable = true)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index ebda55b60e66..b67f70754e89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -54,6 +54,7 @@ import org.apache.spark.sql.types._ * TimestampType -> java.time.Instant if spark.sql.datetime.java8API.enabled is true * * DayTimeIntervalType -> java.time.Duration + * YearMonthIntervalType -> java.time.Period * * BinaryType -> byte array * ArrayType -> scala.collection.Seq or Array @@ -112,6 +113,8 @@ object RowEncoder { case DayTimeIntervalType => createSerializerForJavaDuration(inputObject) + case YearMonthIntervalType => createSerializerForJavaPeriod(inputObject) + case d: DecimalType => CheckOverflow(StaticInvoke( Decimal.getClass, @@ -231,6 +234,7 @@ object RowEncoder { ObjectType(classOf[java.sql.Date]) } case DayTimeIntervalType => ObjectType(classOf[java.time.Duration]) + case YearMonthIntervalType => ObjectType(classOf[java.time.Period]) case _: DecimalType => ObjectType(classOf[java.math.BigDecimal]) case StringType => ObjectType(classOf[java.lang.String]) case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]]) @@ -288,6 +292,8 @@ object RowEncoder { case DayTimeIntervalType => createDeserializerForDuration(input) + case YearMonthIntervalType => createDeserializerForPeriod(input) + case _: DecimalType => createDeserializerForJavaBigDecimal(input, returnNullable = false) case StringType => createDeserializerForString(input, returnNullable = false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index 00ac3d64aaeb..908b73abadfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -157,7 +157,7 @@ object InterpretedUnsafeProjection { case ShortType => (v, i) => writer.write(i, v.getShort(i)) - case IntegerType | DateType => + case IntegerType | DateType | YearMonthIntervalType => (v, i) => writer.write(i, v.getInt(i)) case LongType | TimestampType | DayTimeIntervalType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala index fd22978544a7..0f2619246899 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala @@ -193,8 +193,8 @@ final class MutableAny extends MutableValue { final class SpecificInternalRow(val values: Array[MutableValue]) extends BaseGenericInternalRow { private[this] def dataTypeToMutableValue(dataType: DataType): MutableValue = dataType match { - // We use INT for DATE internally - case IntegerType | DateType => new MutableInt + // We use INT for DATE and YearMonthIntervalType internally + case IntegerType | DateType | YearMonthIntervalType => new MutableInt // We use Long for Timestamp and DayTimeInterval internally case LongType | TimestampType | DayTimeIntervalType => new MutableLong case FloatType => new MutableFloat diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 67c4adfb3887..45ee1934777f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1812,7 +1812,7 @@ object CodeGenerator extends Logging { case BooleanType => JAVA_BOOLEAN case ByteType => JAVA_BYTE case ShortType => JAVA_SHORT - case IntegerType | DateType => JAVA_INT + case IntegerType | DateType | YearMonthIntervalType => JAVA_INT case LongType | TimestampType | DayTimeIntervalType => JAVA_LONG case FloatType => JAVA_FLOAT case DoubleType => JAVA_DOUBLE @@ -1833,7 +1833,7 @@ object CodeGenerator extends Logging { case BooleanType => java.lang.Boolean.TYPE case ByteType => java.lang.Byte.TYPE case ShortType => java.lang.Short.TYPE - case IntegerType | DateType => java.lang.Integer.TYPE + case IntegerType | DateType | YearMonthIntervalType => java.lang.Integer.TYPE case LongType | TimestampType | DayTimeIntervalType => java.lang.Long.TYPE case FloatType => java.lang.Float.TYPE case DoubleType => java.lang.Double.TYPE diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 203e98cfd737..2ea73e83c743 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -28,7 +28,7 @@ import java.lang.{Short => JavaShort} import java.math.{BigDecimal => JavaBigDecimal} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} -import java.time.{Duration, Instant, LocalDate} +import java.time.{Duration, Instant, LocalDate, Period} import java.util import java.util.Objects import javax.xml.bind.DatatypeConverter @@ -43,7 +43,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, Scala import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.DateTimeUtils.instantToMicros -import org.apache.spark.sql.catalyst.util.IntervalUtils.durationToMicros +import org.apache.spark.sql.catalyst.util.IntervalUtils.{durationToMicros, periodToMonths} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -78,6 +78,7 @@ object Literal { case ld: LocalDate => Literal(ld.toEpochDay.toInt, DateType) case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType) case d: Duration => Literal(durationToMicros(d), DayTimeIntervalType) + case p: Period => Literal(periodToMonths(p), YearMonthIntervalType) case a: Array[Byte] => Literal(a, BinaryType) case a: collection.mutable.WrappedArray[_] => apply(a.array) case a: Array[_] => @@ -114,6 +115,7 @@ object Literal { case _ if clz == classOf[Instant] => TimestampType case _ if clz == classOf[Timestamp] => TimestampType case _ if clz == classOf[Duration] => DayTimeIntervalType + case _ if clz == classOf[Period] => YearMonthIntervalType case _ if clz == classOf[JavaBigDecimal] => DecimalType.SYSTEM_DEFAULT case _ if clz == classOf[Array[Byte]] => BinaryType case _ if clz == classOf[Array[Char]] => StringType @@ -171,6 +173,7 @@ object Literal { case DateType => create(0, DateType) case TimestampType => create(0L, TimestampType) case DayTimeIntervalType => create(0L, DayTimeIntervalType) + case YearMonthIntervalType => create(0, YearMonthIntervalType) case StringType => Literal("") case BinaryType => Literal("".getBytes(StandardCharsets.UTF_8)) case CalendarIntervalType => Literal(new CalendarInterval(0, 0, 0)) @@ -189,7 +192,7 @@ object Literal { case BooleanType => v.isInstanceOf[Boolean] case ByteType => v.isInstanceOf[Byte] case ShortType => v.isInstanceOf[Short] - case IntegerType | DateType => v.isInstanceOf[Int] + case IntegerType | DateType | YearMonthIntervalType => v.isInstanceOf[Int] case LongType | TimestampType | DayTimeIntervalType => v.isInstanceOf[Long] case FloatType => v.isInstanceOf[Float] case DoubleType => v.isInstanceOf[Double] @@ -366,7 +369,7 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { ExprCode.forNonNullValue(JavaCode.literal(code, dataType)) } dataType match { - case BooleanType | IntegerType | DateType => + case BooleanType | IntegerType | DateType | YearMonthIntervalType => toExprCode(value.toString) case FloatType => value.asInstanceOf[Float] match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index 6be4e9f04306..06ab4b603f02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import java.time.Duration +import java.time.{Duration, Period} import java.time.temporal.ChronoUnit import java.util.concurrent.TimeUnit @@ -791,4 +791,35 @@ object IntervalUtils { * @return A [[Duration]], not null */ def microsToDuration(micros: Long): Duration = Duration.of(micros, ChronoUnit.MICROS) + + /** + * Gets the total number of months in this period. + *
+ * This returns the total number of months in the period by multiplying the + * number of years by 12 and adding the number of months. + *
+ * + * @return The total number of months in the period, may be negative + * @throws ArithmeticException If numeric overflow occurs + */ + def periodToMonths(period: Period): Int = { + val monthsInYears = Math.multiplyExact(period.getYears, MONTHS_PER_YEAR) + Math.addExact(monthsInYears, period.getMonths) + } + + /** + * Obtains a [[Period]] representing a number of months. The days unit will be zero, and the years + * and months units will be normalized. + * + *
+ * The months unit is adjusted to have an absolute value < 12, with the years unit being adjusted + * to compensate. For example, the method returns "2 years and 3 months" for the 27 input months. + *
+ * The sign of the years and months units will be the same after normalization.
+ * For example, -13 months will be converted to "-1 year and -1 month".
+ *
+ * @param months The number of months, positive or negative
+ * @return The period of months, not null
+ */
+ def monthsToPeriod(months: Int): Period = Period.ofMonths(months).normalized()
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index 05273270de30..5c5742c812e8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -171,7 +171,7 @@ object DataType {
private val otherTypes = {
Seq(NullType, DateType, TimestampType, BinaryType, IntegerType, BooleanType, LongType,
DoubleType, FloatType, ShortType, ByteType, StringType, CalendarIntervalType,
- DayTimeIntervalType)
+ DayTimeIntervalType, YearMonthIntervalType)
.map(t => t.typeName -> t).toMap
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
index 6b66af842fcc..0dbae707a4a3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst
-import java.time.{Duration, Instant, LocalDate}
+import java.time.{Duration, Instant, LocalDate, Period}
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
@@ -258,4 +258,40 @@ class CatalystTypeConvertersSuite extends SparkFunSuite with SQLHelper {
}
}
}
+
+ test("SPARK-34615: converting java.time.Period to YearMonthIntervalType") {
+ Seq(
+ Period.ZERO,
+ Period.ofMonths(1),
+ Period.ofMonths(-1),
+ Period.ofMonths(Int.MaxValue).normalized(),
+ Period.ofMonths(Int.MinValue).normalized(),
+ Period.ofYears(106751991),
+ Period.ofYears(-106751991)).foreach { input =>
+ val result = CatalystTypeConverters.convertToCatalyst(input)
+ val expected = IntervalUtils.periodToMonths(input)
+ assert(result === expected)
+ }
+
+ val errMsg = intercept[ArithmeticException] {
+ IntervalUtils.periodToMonths(Period.of(Int.MaxValue, Int.MaxValue, Int.MaxValue))
+ }.getMessage
+ assert(errMsg.contains("integer overflow"))
+ }
+
+ test("SPARK-34615: converting YearMonthIntervalType to java.time.Period") {
+ Seq(
+ 0,
+ 1,
+ 999999,
+ 1000000,
+ Int.MaxValue).foreach { input =>
+ Seq(1, -1).foreach { sign =>
+ val months = sign * input
+ val period = IntervalUtils.monthsToPeriod(months)
+ assert(
+ CatalystTypeConverters.createToScalaConverter(YearMonthIntervalType)(months) === period)
+ }
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 9ab336165d16..6c22c14870d6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -352,6 +352,16 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
assert(readback.get(0).equals(duration))
}
+ test("SPARK-34615: encoding/decoding YearMonthIntervalType to/from java.time.Period") {
+ val schema = new StructType().add("p", YearMonthIntervalType)
+ val encoder = RowEncoder(schema).resolveAndBind()
+ val period = java.time.Period.ofMonths(1)
+ val row = toRow(encoder, Row(period))
+ assert(row.getInt(0) === IntervalUtils.periodToMonths(period))
+ val readback = fromRow(encoder, row)
+ assert(readback.get(0).equals(period))
+ }
+
for {
elementType <- Seq(IntegerType, StringType)
containsNull <- Seq(true, false)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
index 8cba46cabd91..f8766f3fd27a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import java.nio.charset.StandardCharsets
-import java.time.{Duration, Instant, LocalDate, LocalDateTime, ZoneOffset}
+import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period, ZoneOffset}
import java.util.TimeZone
import scala.reflect.runtime.universe.TypeTag
@@ -367,4 +367,22 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
val duration1 = Duration.ofHours(-1024)
checkEvaluation(Literal(Array(duration0, duration1)), Array(duration0, duration1))
}
+
+ test("SPARK-34615: construct literals from java.time.Period") {
+ Seq(
+ Period.ofYears(0),
+ Period.of(-1, 11, 0),
+ Period.of(1, -11, 0),
+ Period.ofMonths(Int.MaxValue),
+ Period.ofMonths(Int.MinValue)).foreach { period =>
+ checkEvaluation(Literal(period), period)
+ }
+ }
+
+ test("SPARK-34615: construct literals from arrays of java.time.Period") {
+ val period0 = Period.ofYears(123).withMonths(456)
+ checkEvaluation(Literal(Array(period0)), Array(period0))
+ val period1 = Period.ofMonths(-1024)
+ checkEvaluation(Literal(Array(period0, period1)), Array(period0, period1))
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala
index 51fd291bba0d..df2656fb7aa8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.util
-import java.time.Duration
+import java.time.{Duration, Period}
import java.util.concurrent.TimeUnit
import org.apache.spark.SparkFunSuite
@@ -401,4 +401,28 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper {
}.getMessage
assert(errMsg.contains("long overflow"))
}
+
+ test("SPARK-34615: period to months") {
+ assert(periodToMonths(Period.ZERO) === 0)
+ assert(periodToMonths(Period.of(0, -1, 0)) === -1)
+ assert(periodToMonths(Period.of(-1, 0, 10)) === -12) // ignore days
+ assert(periodToMonths(Period.of(178956970, 7, 0)) === Int.MaxValue)
+ assert(periodToMonths(Period.of(-178956970, -8, 123)) === Int.MinValue)
+ assert(periodToMonths(Period.of(0, Int.MaxValue, Int.MaxValue)) === Int.MaxValue)
+
+ val errMsg = intercept[ArithmeticException] {
+ periodToMonths(Period.of(Int.MaxValue, 0, 0))
+ }.getMessage
+ assert(errMsg.contains("integer overflow"))
+ }
+
+ test("SPARK-34615: months to period") {
+ assert(monthsToPeriod(0) === Period.ZERO)
+ assert(monthsToPeriod(-11) === Period.of(0, -11, 0))
+ assert(monthsToPeriod(11) === Period.of(0, 11, 0))
+ assert(monthsToPeriod(27) === Period.of(2, 3, 0))
+ assert(monthsToPeriod(-13) === Period.of(-1, -1, 0))
+ assert(monthsToPeriod(Int.MaxValue) === Period.ofYears(178956970).withMonths(7))
+ assert(monthsToPeriod(Int.MinValue) === Period.ofYears(-178956970).withMonths(-8))
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index bcc4871d3e64..90188cadfd3c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -91,6 +91,9 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits {
/** @since 3.2.0 */
implicit def newDurationEncoder: Encoder[java.time.Duration] = Encoders.DURATION
+ /** @since 3.2.0 */
+ implicit def newPeriodEncoder: Encoder[java.time.Period] = Encoders.PERIOD
+
/** @since 3.2.0 */
implicit def newJavaEnumEncoder[A <: java.lang.Enum[_] : TypeTag]: Encoder[A] =
ExpressionEncoder()
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 85ad80e21329..93566e030e21 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -24,6 +24,7 @@
import java.time.Duration;
import java.time.Instant;
import java.time.LocalDate;
+import java.time.Period;
import java.util.*;
import javax.annotation.Nonnull;
@@ -421,6 +422,14 @@ public void testDurationEncoder() {
Assert.assertEquals(data, ds.collectAsList());
}
+ @Test
+ public void testPeriodEncoder() {
+ Encoder