Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add ColumnAccessor.extract(ColumnVector, rowId)
  • Loading branch information
kiszk committed Aug 7, 2017
commit cb3e631d48d73d74c04d437b0b1ac2b263f31959
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,11 @@ public UTF8String getUTF8String(int rowId) {
}
}

/**
* Returns the UTF8String from a compressed column
*/
public UTF8String getUTF8StringFromCompressible(int rowId) { return null; }

/**
* Returns the byte array for rowId.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,15 @@ public final class OnHeapCachedBatch extends ColumnVector implements java.io.Ser
// e.g. when isNullAt() and getInt() ara called, extractTo() must be called only once
private boolean[] calledExtractTo;

// a row where the compressed data is extracted
private transient UnsafeRow unsafeRow;
private transient BufferHolder bufferHolder;
private transient UnsafeRowWriter rowWriter;
private transient MutableUnsafeRow mutableRow;

// accessor for a column
private transient ColumnAccessor columnAccessor;

// an accessor uses only column 0
private final int ORDINAL = 0;
// a row where the compressed data is extracted
private transient ColumnVector columnVector;

// an accessor uses only row 0 in columnVector
private final int ROWID = 0;


protected OnHeapCachedBatch(int capacity, DataType type) {
super(capacity, type, VectorType.Compressible, MemoryMode.ON_HEAP);
Expand All @@ -73,8 +71,8 @@ private void initialize() {
if (columnAccessor == null) {
setColumnAccessor();
}
if (mutableRow == null) {
setRowSetter();
if (columnVector == null) {
columnVector = new OnHeapColumnVector(1, type);
}
}

Expand All @@ -84,20 +82,11 @@ private void setColumnAccessor() {
calledExtractTo = new boolean[capacity];
}

private void setRowSetter() {
unsafeRow = new UnsafeRow(1);
bufferHolder = new BufferHolder(unsafeRow);
rowWriter = new UnsafeRowWriter(bufferHolder, 1);
mutableRow = new MutableUnsafeRow(rowWriter);
}

// call extractTo() before getting actual data
private void prepareRowAccess(int rowId) {
private void prepareAccess(int rowId) {
if (!calledExtractTo[rowId]) {
assert (columnAccessor.hasNext());
bufferHolder.reset();
rowWriter.zeroOutNullBytes();
columnAccessor.extractTo(mutableRow, ORDINAL);
columnAccessor.extractTo(columnVector, ROWID);
calledExtractTo[rowId] = true;
}
}
Expand Down Expand Up @@ -146,8 +135,8 @@ public void putNotNulls(int rowId, int count) {

@Override
public boolean isNullAt(int rowId) {
prepareRowAccess(rowId);
return unsafeRow.isNullAt(ORDINAL);
prepareAccess(rowId);
return columnVector.isNullAt(ROWID);
}

//
Expand All @@ -166,8 +155,8 @@ public void putBooleans(int rowId, int count, boolean value) {

@Override
public boolean getBoolean(int rowId) {
prepareRowAccess(rowId);
return unsafeRow.getBoolean(ORDINAL);
prepareAccess(rowId);
return columnVector.getBoolean(ROWID);
}

@Override
Expand Down Expand Up @@ -198,8 +187,8 @@ public void putBytes(int rowId, int count, byte[] src, int srcIndex) {

@Override
public byte getByte(int rowId) {
prepareRowAccess(rowId);
return unsafeRow.getByte(ORDINAL);
prepareAccess(rowId);
return columnVector.getByte(ROWID);
}

@Override
Expand Down Expand Up @@ -228,8 +217,8 @@ public void putShorts(int rowId, int count, short[] src, int srcIndex) {

@Override
public short getShort(int rowId) {
prepareRowAccess(rowId);
return unsafeRow.getShort(ORDINAL);
prepareAccess(rowId);
return columnVector.getShort(ROWID);
}

@Override
Expand Down Expand Up @@ -263,8 +252,8 @@ public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex)

@Override
public int getInt(int rowId) {
prepareRowAccess(rowId);
return unsafeRow.getInt(ORDINAL);
prepareAccess(rowId);
return columnVector.getInt(ROWID);
}

@Override
Expand Down Expand Up @@ -307,8 +296,8 @@ public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex)

@Override
public long getLong(int rowId) {
prepareRowAccess(rowId);
return unsafeRow.getLong(ORDINAL);
prepareAccess(rowId);
return columnVector.getLong(ROWID);
}

@Override
Expand Down Expand Up @@ -342,8 +331,8 @@ public void putFloats(int rowId, int count, byte[] src, int srcIndex) {

@Override
public float getFloat(int rowId) {
prepareRowAccess(rowId);
return unsafeRow.getFloat(ORDINAL);
prepareAccess(rowId);
return columnVector.getFloat(ROWID);
}

@Override
Expand Down Expand Up @@ -377,8 +366,8 @@ public void putDoubles(int rowId, int count, byte[] src, int srcIndex) {

@Override
public double getDouble(int rowId) {
prepareRowAccess(rowId);
return unsafeRow.getDouble(ORDINAL);
prepareAccess(rowId);
return columnVector.getDouble(ROWID);
}

@Override
Expand Down Expand Up @@ -418,9 +407,9 @@ public int putByteArray(int rowId, byte[] value, int offset, int length) {
throw new UnsupportedOperationException();
}

public final UTF8String getUTF8StringInternal(int rowId) {
prepareRowAccess(rowId);
return unsafeRow.getUTF8String(ORDINAL);
public final UTF8String getUTF8StringFromCompressible(int rowId) {
prepareAccess(rowId);
return columnVector.getUTF8String(ROWID);
}

// Spilt this function out since it is the slow path.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import scala.annotation.tailrec
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeMapData, UnsafeRow}
import org.apache.spark.sql.execution.columnar.compression.CompressibleColumnAccessor
import org.apache.spark.sql.execution.vectorized.ColumnVector
import org.apache.spark.sql.types._

/**
Expand All @@ -41,6 +42,8 @@ private[columnar] trait ColumnAccessor {

def extractTo(row: InternalRow, ordinal: Int): Unit

def extractTo(column: ColumnVector, rowId: Int): Unit

protected def underlyingBuffer: ByteBuffer
}

Expand All @@ -57,10 +60,18 @@ private[columnar] abstract class BasicColumnAccessor[JvmType](
extractSingle(row, ordinal)
}

override def extractTo(column: ColumnVector, rowId: Int): Unit = {
extractSingle(column, rowId)
}

def extractSingle(row: InternalRow, ordinal: Int): Unit = {
columnType.extract(buffer, row, ordinal)
}

def extractSingle(column: ColumnVector, rowId: Int): Unit = {
columnType.extract(buffer, column, rowId)
}

protected def underlyingBuffer = buffer
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.vectorized.ColumnVector
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -96,6 +97,10 @@ private[columnar] sealed abstract class ColumnType[JvmType] {
setField(row, ordinal, extract(buffer))
}

def extract(buffer: ByteBuffer, column: ColumnVector, rowId: Int): Unit = {
setField(column, rowId, extract(buffer))
}

/**
* Appends the given value v of type T into the given ByteBuffer.
*/
Expand Down Expand Up @@ -127,6 +132,8 @@ private[columnar] sealed abstract class ColumnType[JvmType] {
*/
def setField(row: InternalRow, ordinal: Int, value: JvmType): Unit

def setField(column: ColumnVector, rowId: Int, value: JvmType): Unit

/**
* Copies `from(fromOrdinal)` to `to(toOrdinal)`. Subclasses should override this method to avoid
* boxing/unboxing costs whenever possible.
Expand All @@ -150,6 +157,7 @@ private[columnar] object NULL extends ColumnType[Any] {
override def append(v: Any, buffer: ByteBuffer): Unit = {}
override def extract(buffer: ByteBuffer): Any = null
override def setField(row: InternalRow, ordinal: Int, value: Any): Unit = row.setNullAt(ordinal)
override def setField(column: ColumnVector, rowId: Int, value: Any): Unit = column.putNull(rowId)
override def getField(row: InternalRow, ordinal: Int): Any = null
}

Expand Down Expand Up @@ -185,6 +193,10 @@ private[columnar] object INT extends NativeColumnType(IntegerType, 4) {
row.setInt(ordinal, value)
}

override def setField(column: ColumnVector, rowId: Int, value: Int): Unit = {
column.putInt(rowId, value)
}

override def getField(row: InternalRow, ordinal: Int): Int = row.getInt(ordinal)


Expand Down Expand Up @@ -214,6 +226,10 @@ private[columnar] object LONG extends NativeColumnType(LongType, 8) {
row.setLong(ordinal, value)
}

override def setField(column: ColumnVector, rowId: Int, value: Long): Unit = {
column.putLong(rowId, value)
}

override def getField(row: InternalRow, ordinal: Int): Long = row.getLong(ordinal)

override def copyField(from: InternalRow, fromOrdinal: Int, to: InternalRow, toOrdinal: Int) {
Expand Down Expand Up @@ -242,6 +258,10 @@ private[columnar] object FLOAT extends NativeColumnType(FloatType, 4) {
row.setFloat(ordinal, value)
}

override def setField(column: ColumnVector, rowId: Int, value: Float): Unit = {
column.putFloat(rowId, value)
}

override def getField(row: InternalRow, ordinal: Int): Float = row.getFloat(ordinal)

override def copyField(from: InternalRow, fromOrdinal: Int, to: InternalRow, toOrdinal: Int) {
Expand Down Expand Up @@ -270,6 +290,10 @@ private[columnar] object DOUBLE extends NativeColumnType(DoubleType, 8) {
row.setDouble(ordinal, value)
}

override def setField(column: ColumnVector, rowId: Int, value: Double): Unit = {
column.putDouble(rowId, value)
}

override def getField(row: InternalRow, ordinal: Int): Double = row.getDouble(ordinal)

override def copyField(from: InternalRow, fromOrdinal: Int, to: InternalRow, toOrdinal: Int) {
Expand All @@ -296,6 +320,10 @@ private[columnar] object BOOLEAN extends NativeColumnType(BooleanType, 1) {
row.setBoolean(ordinal, value)
}

override def setField(column: ColumnVector, rowId: Int, value: Boolean): Unit = {
column.putBoolean(rowId, value)
}

override def getField(row: InternalRow, ordinal: Int): Boolean = row.getBoolean(ordinal)

override def copyField(from: InternalRow, fromOrdinal: Int, to: InternalRow, toOrdinal: Int) {
Expand Down Expand Up @@ -324,6 +352,10 @@ private[columnar] object BYTE extends NativeColumnType(ByteType, 1) {
row.setByte(ordinal, value)
}

override def setField(column: ColumnVector, rowId: Int, value: Byte): Unit = {
column.putByte(rowId, value)
}

override def getField(row: InternalRow, ordinal: Int): Byte = row.getByte(ordinal)

override def copyField(from: InternalRow, fromOrdinal: Int, to: InternalRow, toOrdinal: Int) {
Expand Down Expand Up @@ -352,6 +384,10 @@ private[columnar] object SHORT extends NativeColumnType(ShortType, 2) {
row.setShort(ordinal, value)
}

override def setField(column: ColumnVector, rowId: Int, value: Short): Unit = {
column.putShort(rowId, value)
}

override def getField(row: InternalRow, ordinal: Int): Short = row.getShort(ordinal)

override def copyField(from: InternalRow, fromOrdinal: Int, to: InternalRow, toOrdinal: Int) {
Expand Down Expand Up @@ -415,6 +451,10 @@ private[columnar] object STRING
}
}

override def setField(column: ColumnVector, rowId: Int, value: UTF8String): Unit = {
column.putByteArray(rowId, value.getBytes)
}

override def getField(row: InternalRow, ordinal: Int): UTF8String = {
row.getUTF8String(ordinal)
}
Expand Down Expand Up @@ -463,6 +503,10 @@ private[columnar] case class COMPACT_DECIMAL(precision: Int, scale: Int)
row.setDecimal(ordinal, value, precision)
}

override def setField(column: ColumnVector, rowId: Int, value: Decimal): Unit = {
throw new UnsupportedOperationException
}

override def copyField(from: InternalRow, fromOrdinal: Int, to: InternalRow, toOrdinal: Int) {
setField(to, toOrdinal, getField(from, fromOrdinal))
}
Expand Down Expand Up @@ -501,6 +545,10 @@ private[columnar] object BINARY extends ByteArrayColumnType[Array[Byte]](16) {
row.update(ordinal, value)
}

override def setField(column: ColumnVector, rowId: Int, value: Array[Byte]): Unit = {
throw new UnsupportedOperationException
}

override def getField(row: InternalRow, ordinal: Int): Array[Byte] = {
row.getBinary(ordinal)
}
Expand All @@ -526,6 +574,10 @@ private[columnar] case class LARGE_DECIMAL(precision: Int, scale: Int)
row.setDecimal(ordinal, value, precision)
}

override def setField(column: ColumnVector, rowId: Int, value: Decimal): Unit = {
throw new UnsupportedOperationException
}

override def actualSize(row: InternalRow, ordinal: Int): Int = {
4 + getField(row, ordinal).toJavaBigDecimal.unscaledValue().bitLength() / 8 + 1
}
Expand Down Expand Up @@ -557,6 +609,10 @@ private[columnar] case class STRUCT(dataType: StructType)
row.update(ordinal, value)
}

override def setField(column: ColumnVector, rowId: Int, value: UnsafeRow): Unit = {
throw new UnsupportedOperationException
}

override def getField(row: InternalRow, ordinal: Int): UnsafeRow = {
row.getStruct(ordinal, numOfFields).asInstanceOf[UnsafeRow]
}
Expand Down Expand Up @@ -595,6 +651,10 @@ private[columnar] case class ARRAY(dataType: ArrayType)
row.update(ordinal, value)
}

override def setField(column: ColumnVector, rowId: Int, value: UnsafeArrayData): Unit = {
throw new UnsupportedOperationException
}

override def getField(row: InternalRow, ordinal: Int): UnsafeArrayData = {
row.getArray(ordinal).asInstanceOf[UnsafeArrayData]
}
Expand Down Expand Up @@ -634,6 +694,10 @@ private[columnar] case class MAP(dataType: MapType)
row.update(ordinal, value)
}

override def setField(column: ColumnVector, rowId: Int, value: UnsafeMapData): Unit = {
throw new UnsupportedOperationException
}

override def getField(row: InternalRow, ordinal: Int): UnsafeMapData = {
row.getMap(ordinal).asInstanceOf[UnsafeMapData]
}
Expand Down
Loading