Skip to content
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,15 @@ public static int calculateBitSetWidthInBytes(int numFields) {
IntegerType,
LongType,
FloatType,
DoubleType
DoubleType,
DateType
})));

// We support get() on a superset of the types for which we support set():
final Set<DataType> _readableFieldTypes = new HashSet<DataType>(
Arrays.asList(new DataType[]{
StringType
StringType,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the last comma

TimestampType
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since timestamps are now represented as longs, we can support updates to timestamps, so we can move this into the settableFieldTypes list.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I will update this later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping

}));
_readableFieldTypes.addAll(settableFieldTypes);
readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes);
Expand Down Expand Up @@ -331,8 +333,6 @@ public String getString(int i) {
return getUTF8String(i).toString();
}



@Override
public InternalRow copy() {
throw new UnsupportedOperationException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.array.ByteArrayMethods
Expand Down Expand Up @@ -120,6 +122,8 @@ private object UnsafeColumnWriter {
case FloatType => FloatUnsafeColumnWriter
case DoubleType => DoubleUnsafeColumnWriter
case StringType => StringUnsafeColumnWriter
case DateType => DateUnsafeColumnWriter
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible that we just use IntUnsafeColumnWriter for DateType? same to TimestampType

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Updated now.

case TimestampType => TimestampUnsafeColumnWriter
case t =>
throw new UnsupportedOperationException(s"Do not know how to write columns of type $t")
}
Expand All @@ -137,6 +141,8 @@ private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter
private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter
private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter
private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter
private object DateUnsafeColumnWriter extends DateUnsafeColumnWriter
private object TimestampUnsafeColumnWriter extends TimestampUnsafeColumnWriter

private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter {
// Primitives don't write to the variable-length region:
Expand Down Expand Up @@ -258,3 +264,33 @@ private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter {
8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
}
}

private class DateUnsafeColumnWriter private() extends UnsafeColumnWriter {
def getSize(source: InternalRow, column: Int): Int = {
0
}

override def write(
source: InternalRow,
target: UnsafeRow,
column: Int,
appendCursor: Int): Int = {
target.setInt(column, source.getInt(column))
0
}
}

private class TimestampUnsafeColumnWriter private() extends UnsafeColumnWriter {
def getSize(source: InternalRow, column: Int): Int = {
0
}

override def write(
source: InternalRow,
target: UnsafeRow,
column: Int,
appendCursor: Int): Int = {
target.setLong(column, source.getLong(column))
0
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ trait MutableRow extends InternalRow {
def setByte(ordinal: Int, value: Byte)
def setFloat(ordinal: Int, value: Float)
def setString(ordinal: Int, value: String)
// TODO(davies): add setDecimal()
}

/**
Expand Down Expand Up @@ -197,9 +198,10 @@ class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow {
override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value }
override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value }
override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value }
override def setString(ordinal: Int, value: String) {
override def setString(ordinal: Int, value: String): Unit = {
values(ordinal) = UTF8String.fromString(value)
}

override def setNullAt(i: Int): Unit = { values(i) = null }

override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@

package org.apache.spark.sql.catalyst.expressions

import java.sql.{Date, Timestamp}
import java.util.Arrays

import org.scalatest.Matchers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.array.ByteArrayMethods

Expand Down Expand Up @@ -74,6 +76,34 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
unsafeRow.getString(2) should be ("World")
}

test("basic conversion with primitive, string, date and timestamp types") {
val fieldTypes: Array[DataType] = Array(LongType, StringType, DateType, TimestampType)
val converter = new UnsafeRowConverter(fieldTypes)

val row = new SpecificMutableRow(fieldTypes)
row.setLong(0, 0)
row.setString(1, "Hello")
row.update(2, DateUtils.fromJavaDate(Date.valueOf("1970-01-01")))
row.update(3, DateUtils.fromJavaTimestamp(Timestamp.valueOf("2015-05-08 08:10:25")))

val sizeRequired: Int = converter.getSizeRequirement(row)
sizeRequired should be (8 + (8 * 4) +
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8))
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
numBytesWritten should be (sizeRequired)

val unsafeRow = new UnsafeRow()
unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
unsafeRow.getLong(0) should be (0)
unsafeRow.getString(1) should be ("Hello")
// Date is represented as Int in unsafeRow
DateUtils.toJavaDate(unsafeRow.getInt(2)) should be (Date.valueOf("1970-01-01"))
// Timestamp is represented as Long in unsafeRow
DateUtils.toJavaTimestamp(unsafeRow.getLong(3)) should be
(Timestamp.valueOf("2015-05-08 08:10:25"))
}

test("null handling") {
val fieldTypes: Array[DataType] = Array(
NullType,
Expand Down