From 44d56b11eb588843ffb52f460457e59939dc8708 Mon Sep 17 00:00:00 2001 From: PengLei <18066542445@189.cn> Date: Sun, 25 Apr 2021 18:05:05 +0800 Subject: [PATCH] Add SPARK-35139 demo --- .../sql/vectorized/ArrowColumnVector.java | 43 +++++++++++++++++++ .../apache/spark/sql/util/ArrowUtils.scala | 6 ++- .../sql/execution/arrow/ArrowWriter.scala | 28 +++++++++++- .../execution/arrow/ArrowWriterSuite.scala | 8 ++++ 4 files changed, 82 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index f149c9bb0c6f..f20bf7a5d299 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -19,12 +19,16 @@ import org.apache.arrow.vector.*; import org.apache.arrow.vector.complex.*; +import org.apache.arrow.vector.holders.NullableIntervalDayHolder; import org.apache.arrow.vector.holders.NullableVarCharHolder; import org.apache.spark.sql.util.ArrowUtils; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.UTF8String; +import static org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_DAY; +import static org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_MILLIS; + /** * A column vector backed by Apache Arrow. Currently calendar interval type and map type are not * supported. @@ -172,6 +176,10 @@ public ArrowColumnVector(ValueVector vector) { } } else if (vector instanceof NullVector) { accessor = new NullAccessor((NullVector) vector); + } else if (vector instanceof IntervalYearVector) { + accessor = new IntervalYearAccessor((IntervalYearVector) vector); + } else if (vector instanceof IntervalDayVector) { + accessor = new IntervalDayAccessor((IntervalDayVector) vector); } else { throw new UnsupportedOperationException(); } @@ -508,4 +516,39 @@ private static class NullAccessor extends ArrowVectorAccessor { super(vector); } } + + private static class IntervalYearAccessor extends ArrowVectorAccessor { + + private final IntervalYearVector accessor; + + IntervalYearAccessor(IntervalYearVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + int getInt(int rowId) { + int months = accessor.get(rowId); + return months; + } + } + + private static class IntervalDayAccessor extends ArrowVectorAccessor { + + private final IntervalDayVector accessor; + private final NullableIntervalDayHolder intervalDayHolder = new NullableIntervalDayHolder(); + + IntervalDayAccessor(IntervalDayVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + long getLong(int rowId) { + accessor.get(rowId, intervalDayHolder); + final long microseconds = intervalDayHolder.days * MICROS_PER_DAY + + (long)intervalDayHolder.milliseconds * MICROS_PER_MILLIS; + return microseconds; + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index 5d5da795a5b0..860dde83d6b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.complex.MapVector -import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit} +import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, IntervalUnit, TimeUnit, Types} import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} import org.apache.spark.sql.internal.SQLConf @@ -54,6 +54,8 @@ private[sql] object ArrowUtils { new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) } case NullType => ArrowType.Null.INSTANCE + case YearMonthIntervalType => Types.MinorType.INTERVALYEAR.getType + case DayTimeIntervalType => Types.MinorType.INTERVALDAY.getType case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dt.catalogString}") } @@ -74,6 +76,8 @@ private[sql] object ArrowUtils { case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType case ArrowType.Null.INSTANCE => NullType + case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH => YearMonthIntervalType + case di: ArrowType.Interval if di.getUnit == IntervalUnit.DAY_TIME => DayTimeIntervalType case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dt") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index d3a02f245197..c8deb6247d14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -18,12 +18,11 @@ package org.apache.spark.sql.execution.arrow import scala.collection.JavaConverters._ - import org.apache.arrow.vector._ import org.apache.arrow.vector.complex._ - import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.catalyst.util.DateTimeConstants.{MICROS_PER_DAY, MICROS_PER_MILLIS} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ArrowUtils @@ -74,6 +73,8 @@ object ArrowWriter { } new StructWriter(vector, children.toArray) case (NullType, vector: NullVector) => new NullWriter(vector) + case (YearMonthIntervalType, vector: IntervalYearVector) => new IntervalYearWriter(vector) + case (DayTimeIntervalType, vector: IntervalDayVector) => new IntervalDayWriter(vector) case (dt, _) => throw QueryExecutionErrors.unsupportedDataTypeError(dt) } @@ -394,3 +395,26 @@ private[arrow] class NullWriter(val valueVector: NullVector) extends ArrowFieldW override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { } } + +private[arrow] class IntervalYearWriter(val valueVector: IntervalYearVector) extends ArrowFieldWriter { + override def setNull(): Unit = { + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueVector.setSafe(count, input.getInt(ordinal)); + } +} + +private[arrow] class IntervalDayWriter(val valueVector: IntervalDayVector) extends ArrowFieldWriter { + override def setNull(): Unit = { + valueVector.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val totalMicroseconds = input.getLong(ordinal); + val days = totalMicroseconds / MICROS_PER_DAY + val millis = (totalMicroseconds - days * MICROS_PER_DAY) / MICROS_PER_MILLIS + valueVector.set(count, days.toInt, millis.toInt); + } +} \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index f58d3246f5f5..9aecc64db3ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -54,6 +54,8 @@ class ArrowWriterSuite extends SparkFunSuite { case BinaryType => reader.getBinary(rowId) case DateType => reader.getInt(rowId) case TimestampType => reader.getLong(rowId) + case YearMonthIntervalType => reader.getInt(rowId) + case DayTimeIntervalType => reader.getLong(rowId) } assert(value === datum) } @@ -73,6 +75,8 @@ class ArrowWriterSuite extends SparkFunSuite { check(DateType, Seq(0, 1, 2, null, 4)) check(TimestampType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong), "America/Los_Angeles") check(NullType, Seq(null, null, null)) + check(YearMonthIntervalType, Seq(null, 0, 1, -1, scala.Int.MaxValue, scala.Int.MinValue)) + check(DayTimeIntervalType,Seq(null, 0L, 1000L, -1000L, (scala.Long.MaxValue - 807L), (scala.Long.MinValue + 808L))) } test("get multiple") { @@ -97,6 +101,8 @@ class ArrowWriterSuite extends SparkFunSuite { case DoubleType => reader.getDoubles(0, data.size) case DateType => reader.getInts(0, data.size) case TimestampType => reader.getLongs(0, data.size) + case YearMonthIntervalType => reader.getInts(0, data.size) + case DayTimeIntervalType => reader.getLongs(0, data.size) } assert(values === data) @@ -111,6 +117,8 @@ class ArrowWriterSuite extends SparkFunSuite { check(DoubleType, (0 until 10).map(_.toDouble)) check(DateType, (0 until 10)) check(TimestampType, (0 until 10).map(_ * 4.32e10.toLong), "America/Los_Angeles") + check(YearMonthIntervalType, (0 until 10)) + check(DayTimeIntervalType, (-10 until 10).map(_ * 1000.toLong)) } test("array") {