Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor ArrowConverters.
  • Loading branch information
ueshin committed Jul 17, 2017
commit 579def2db0a0f015760a458032d3bd916669201c
4 changes: 2 additions & 2 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2923,8 +2923,8 @@ def assertFramesEqual(self, df_with_arrow, df_without):
self.assertTrue(df_without.equals(df_with_arrow), msg=msg)

def test_unsupported_datatype(self):
schema = StructType([StructField("array", ArrayType(IntegerType(), False), True)])
df = self.spark.createDataFrame([([1, 2, 3],)], schema=schema)
schema = StructType([StructField("dt", DateType(), True)])
df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema)
with QuietTest(self.sc):
self.assertRaises(Exception, lambda: df.toPandas())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.arrow
import java.io.ByteArrayOutputStream
import java.nio.channels.Channels

import org.apache.arrow.memory.{BufferAllocator, RootAllocator}
import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.vector._
import org.apache.arrow.vector.file._
import org.apache.arrow.vector.schema.ArrowRecordBatch
Expand Down Expand Up @@ -50,19 +50,6 @@ private[sql] class ArrowPayload private[arrow] (payload: Array[Byte]) extends Se
def asPythonSerializable: Array[Byte] = payload
}

private[sql] object ArrowPayload {

/**
* Create an ArrowPayload from an ArrowRecordBatch and Spark schema.
*/
def apply(
batch: ArrowRecordBatch,
schema: StructType,
allocator: BufferAllocator): ArrowPayload = {
new ArrowPayload(ArrowConverters.batchToByteArray(batch, schema, allocator))
}
}

private[sql] object ArrowConverters {

/**
Expand All @@ -73,89 +60,45 @@ private[sql] object ArrowConverters {
rowIter: Iterator[InternalRow],
schema: StructType,
maxRecordsPerBatch: Int): Iterator[ArrowPayload] = {

new Iterator[ArrowPayload] {
private val _allocator = new RootAllocator(Long.MaxValue)
private var _nextPayload = if (rowIter.nonEmpty) convert() else null

override def hasNext: Boolean = _nextPayload != null
private val arrowSchema = ArrowUtils.toArrowSchema(schema)
private val allocator =
ArrowUtils.rootAllocator.newChildAllocator("toPayloadIterator", 0, Long.MaxValue)

private val root = VectorSchemaRoot.create(arrowSchema, allocator)
private val arrowWriter = ArrowWriter.create(root)

override def hasNext: Boolean = rowIter.hasNext || {
root.close()
allocator.close()
false
}

override def next(): ArrowPayload = {
val obj = _nextPayload
if (hasNext) {
if (rowIter.hasNext) {
_nextPayload = convert()
} else {
_allocator.close()
_nextPayload = null
val out = new ByteArrayOutputStream()
val writer = new ArrowFileWriter(root, null, Channels.newChannel(out))

Utils.tryWithSafeFinally {
var rowId = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe rowCount instead of rowId because it is a count of how many rows in the batch so far and not a unique id?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I'll update it.

while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowId < maxRecordsPerBatch)) {
val row = rowIter.next()
arrowWriter.write(row)
rowId += 1
}
arrowWriter.finish()
writer.writeBatch()
} {
arrowWriter.reset()
writer.close()
}
obj
}

private def convert(): ArrowPayload = {
val batch = internalRowIterToArrowBatch(rowIter, schema, _allocator, maxRecordsPerBatch)
ArrowPayload(batch, schema, _allocator)
new ArrowPayload(out.toByteArray)
}
}
}

/**
* Iterate over InternalRows and write to an ArrowRecordBatch, stopping when rowIter is consumed
* or the number of records in the batch equals maxRecordsInBatch. If maxRecordsPerBatch is 0,
* then rowIter will be fully consumed.
*/
private def internalRowIterToArrowBatch(
rowIter: Iterator[InternalRow],
schema: StructType,
allocator: BufferAllocator,
maxRecordsPerBatch: Int = 0): ArrowRecordBatch = {

val arrowSchema = ArrowUtils.toArrowSchema(schema)
val root = VectorSchemaRoot.create(arrowSchema, allocator)
val arrowWriter = ArrowWriter.create(root)

var recordsInBatch = 0
while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || recordsInBatch < maxRecordsPerBatch)) {
val row = rowIter.next()
arrowWriter.write(row)
recordsInBatch += 1
}
arrowWriter.finish()

Utils.tryWithSafeFinally {
val unloader = new VectorUnloader(arrowWriter.root)
unloader.getRecordBatch()
} {
root.close()
}
}

/**
* Convert an ArrowRecordBatch to a byte array and close batch to release resources. Once closed,
* the batch can no longer be used.
*/
private[arrow] def batchToByteArray(
batch: ArrowRecordBatch,
schema: StructType,
allocator: BufferAllocator): Array[Byte] = {
val arrowSchema = ArrowUtils.toArrowSchema(schema)
val root = VectorSchemaRoot.create(arrowSchema, allocator)
val out = new ByteArrayOutputStream()
val writer = new ArrowFileWriter(root, null, Channels.newChannel(out))

// Write a batch to byte stream, ensure the batch, allocator and writer are closed
Utils.tryWithSafeFinally {
val loader = new VectorLoader(root)
loader.load(batch)
writer.writeBatch() // writeBatch can throw IOException
} {
batch.close()
root.close()
writer.close()
}
out.toByteArray
}

/**
* Convert a byte array to an ArrowRecordBatch.
*/
Expand Down