Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@

import org.apache.arrow.vector.*;
import org.apache.arrow.vector.complex.*;
import org.apache.arrow.vector.holders.NullableIntervalDayHolder;
import org.apache.arrow.vector.holders.NullableVarCharHolder;

import org.apache.spark.sql.util.ArrowUtils;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.types.UTF8String;

import static org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_DAY;
import static org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_MILLIS;

/**
* A column vector backed by Apache Arrow. Currently calendar interval type and map type are not
* supported.
Expand Down Expand Up @@ -172,6 +176,10 @@ public ArrowColumnVector(ValueVector vector) {
}
} else if (vector instanceof NullVector) {
accessor = new NullAccessor((NullVector) vector);
} else if (vector instanceof IntervalYearVector) {
accessor = new IntervalYearAccessor((IntervalYearVector) vector);
} else if (vector instanceof IntervalDayVector) {
accessor = new IntervalDayAccessor((IntervalDayVector) vector);
Copy link
Member

@HyukjinKwon HyukjinKwon Nov 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, there's something wrong here. We mapped Spark's DayTimeIntervalType to Java (Scala)'s java.time.Duration in Java but we map it here to Arrow's IntervalType that represents a calendar instance (see also https://github.com/apache/arrow/blob/master/format/Schema.fbs).

I think we should map it to Arrow's DurationType (Python's datetime.timedelta). I am working on SPARK-37277 to support this in Arrow conversion at PySpark but this became a blocker to me. I am preparing a PR to change this but please let me know if you guys have different thoughts.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not quite sure why DayTimeIntervalType map Arrow's IntervalType here. just according ArrowUtils.scala#L60. I try to learn about Arrow types. it's sql style. And in hive INTERVAL_DAY_TIME map arrow's IntervalType with IntervalUnit.DAY_TIME unit. If we map DayTimeIntervalType to Arrow's DurationType . Then which type YearMonthIntervalType to match?

Copy link
Member

@HyukjinKwon HyukjinKwon Nov 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the very least Duration cannot be mapped to YearMonthIntervalType. For DayTimeIntervalType , Arrow-wise, mapping to IntervalType makes sense but it makes less sense in Spark SQL because we're already mapping Duration.

I am not saying either way is 100% correct but I would pick the one to make it coherent in Spark's perspective if I have to pick one of both.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and, YearMonthIntervalType is mapped to java.time.Period which is a calendar instance:

A date-based amount of time in the ISO-8601 calendar system, such as '2 years, 3 months and 4 days'.

So YearMonthIntervalType seems fine.

} else {
throw new UnsupportedOperationException();
}
Expand Down Expand Up @@ -508,4 +516,37 @@ private static class NullAccessor extends ArrowVectorAccessor {
super(vector);
}
}

private static class IntervalYearAccessor extends ArrowVectorAccessor {

private final IntervalYearVector accessor;

IntervalYearAccessor(IntervalYearVector vector) {
super(vector);
this.accessor = vector;
}

@Override
int getInt(int rowId) {
return accessor.get(rowId);
}
}

private static class IntervalDayAccessor extends ArrowVectorAccessor {

private final IntervalDayVector accessor;
private final NullableIntervalDayHolder intervalDayHolder = new NullableIntervalDayHolder();

IntervalDayAccessor(IntervalDayVector vector) {
super(vector);
this.accessor = vector;
}

@Override
long getLong(int rowId) {
accessor.get(rowId, intervalDayHolder);
return Math.addExact(Math.multiplyExact(intervalDayHolder.days, MICROS_PER_DAY),
intervalDayHolder.milliseconds * MICROS_PER_MILLIS);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._

import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.complex.MapVector
import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit}
import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, IntervalUnit, TimeUnit}
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}

import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -54,6 +54,8 @@ private[sql] object ArrowUtils {
new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId)
}
case NullType => ArrowType.Null.INSTANCE
case YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH)
case DayTimeIntervalType => new ArrowType.Interval(IntervalUnit.DAY_TIME)
case _ =>
throw new UnsupportedOperationException(s"Unsupported data type: ${dt.catalogString}")
}
Expand All @@ -74,6 +76,8 @@ private[sql] object ArrowUtils {
case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType
case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType
case ArrowType.Null.INSTANCE => NullType
case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH => YearMonthIntervalType
case di: ArrowType.Interval if di.getUnit == IntervalUnit.DAY_TIME => DayTimeIntervalType
case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dt")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class ArrowUtilsSuite extends SparkFunSuite {
roundtrip(BinaryType)
roundtrip(DecimalType.SYSTEM_DEFAULT)
roundtrip(DateType)
roundtrip(YearMonthIntervalType)
roundtrip(DayTimeIntervalType)
val tsExMsg = intercept[UnsupportedOperationException] {
roundtrip(TimestampType)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.arrow.vector.complex._

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
import org.apache.spark.sql.catalyst.util.DateTimeConstants.{MICROS_PER_DAY, MICROS_PER_MILLIS}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils
Expand Down Expand Up @@ -74,6 +75,8 @@ object ArrowWriter {
}
new StructWriter(vector, children.toArray)
case (NullType, vector: NullVector) => new NullWriter(vector)
case (YearMonthIntervalType, vector: IntervalYearVector) => new IntervalYearWriter(vector)
case (DayTimeIntervalType, vector: IntervalDayVector) => new IntervalDayWriter(vector)
case (dt, _) =>
throw QueryExecutionErrors.unsupportedDataTypeError(dt)
}
Expand Down Expand Up @@ -394,3 +397,28 @@ private[arrow] class NullWriter(val valueVector: NullVector) extends ArrowFieldW
override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
}
}

private[arrow] class IntervalYearWriter(val valueVector: IntervalYearVector)
extends ArrowFieldWriter {
override def setNull(): Unit = {
valueVector.setNull(count)
}

override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
valueVector.setSafe(count, input.getInt(ordinal));
}
}

private[arrow] class IntervalDayWriter(val valueVector: IntervalDayVector)
extends ArrowFieldWriter {
override def setNull(): Unit = {
valueVector.setNull(count)
}

override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
val totalMicroseconds = input.getLong(ordinal)
val days = totalMicroseconds / MICROS_PER_DAY
val millis = (totalMicroseconds % MICROS_PER_DAY) / MICROS_PER_MILLIS
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, do we lose micro seconds part? I think this is another reason to use duration.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. we lose micro seconds part, end with millisecond. It's inconsistent with that convert java.time.Duration to DayTimeIntervalType that drop any excess presision that greater than microsecond precision.

valueVector.set(count, days.toInt, millis.toInt)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.execution.arrow

import org.apache.arrow.vector.IntervalDayVector

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util._
Expand Down Expand Up @@ -54,6 +56,8 @@ class ArrowWriterSuite extends SparkFunSuite {
case BinaryType => reader.getBinary(rowId)
case DateType => reader.getInt(rowId)
case TimestampType => reader.getLong(rowId)
case YearMonthIntervalType => reader.getInt(rowId)
case DayTimeIntervalType => reader.getLong(rowId)
}
assert(value === datum)
}
Expand All @@ -73,6 +77,33 @@ class ArrowWriterSuite extends SparkFunSuite {
check(DateType, Seq(0, 1, 2, null, 4))
check(TimestampType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong), "America/Los_Angeles")
check(NullType, Seq(null, null, null))
check(YearMonthIntervalType, Seq(null, 0, 1, -1, Int.MaxValue, Int.MinValue))
check(DayTimeIntervalType, Seq(null, 0L, 1000L, -1000L, (Long.MaxValue - 807L),
(Long.MinValue + 808L)))
}

test("long overflow for DayTimeIntervalType")
{
val schema = new StructType().add("value", DayTimeIntervalType, nullable = true)
val writer = ArrowWriter.create(schema, null)
val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0))
val valueVector = writer.root.getFieldVectors().get(0).asInstanceOf[IntervalDayVector]

valueVector.set(0, 106751992, 0)
valueVector.set(1, 106751991, Int.MaxValue)

// first long overflow for test Math.multiplyExact()
val msg = intercept[java.lang.ArithmeticException] {
reader.getLong(0)
}.getMessage
assert(msg.equals("long overflow"))

// second long overflow for test Math.addExact()
val msg1 = intercept[java.lang.ArithmeticException] {
reader.getLong(1)
}.getMessage
assert(msg1.equals("long overflow"))
writer.root.close()
}

test("get multiple") {
Expand All @@ -97,6 +128,8 @@ class ArrowWriterSuite extends SparkFunSuite {
case DoubleType => reader.getDoubles(0, data.size)
case DateType => reader.getInts(0, data.size)
case TimestampType => reader.getLongs(0, data.size)
case YearMonthIntervalType => reader.getInts(0, data.size)
case DayTimeIntervalType => reader.getLongs(0, data.size)
}
assert(values === data)

Expand All @@ -111,6 +144,8 @@ class ArrowWriterSuite extends SparkFunSuite {
check(DoubleType, (0 until 10).map(_.toDouble))
check(DateType, (0 until 10))
check(TimestampType, (0 until 10).map(_ * 4.32e10.toLong), "America/Los_Angeles")
check(YearMonthIntervalType, (0 until 10))
check(DayTimeIntervalType, (-10 until 10).map(_ * 1000.toLong))
}

test("array") {
Expand Down