-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-7076][SPARK-7077][SPARK-7080][SQL] Use managed memory for aggregations #5725
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
480a74a
ab68e08
f03e9c1
5d55cef
8a8f9df
1ff814d
53ba9b7
fc4c3a8
1a483c5
079f1bf
f764d13
c754ae1
ae39694
c7f0b56
62ab054
c55bf66
738fa33
c1b3813
7df6008
58ac393
d2bb986
b3eaccd
bade966
d85eeff
1f4b716
92d5a06
628f936
23a440a
765243d
b26f1d3
49aed30
29a7575
ef6b3d3
06e929d
854201a
f3dcbfe
afe8dca
a95291e
31eaabc
6ffdaa1
9c19fc0
cde4132
0925847
a8e4a3f
b45f070
162caf7
3ca84b2
529e571
ce3c565
78a5b84
a19e066
de5e001
6e4b192
70a39e4
50e9671
017b2dc
1bc36cc
81f34f8
eeee512
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,16 +21,88 @@ import org.apache.spark.sql.types._ | |
| import org.apache.spark.unsafe.PlatformDependent | ||
| import org.apache.spark.unsafe.array.ByteArrayMethods | ||
|
|
||
| /** Write a column into an UnsafeRow */ | ||
| /** | ||
| * Converts Rows into UnsafeRow format. This class is NOT thread-safe. | ||
| * | ||
| * @param fieldTypes the data types of the row's columns. | ||
| */ | ||
| class UnsafeRowConverter(fieldTypes: Array[DataType]) { | ||
|
|
||
| def this(schema: StructType) { | ||
| this(schema.fields.map(_.dataType)) | ||
| } | ||
|
|
||
| /** Re-used pointer to the unsafe row being written */ | ||
| private[this] val unsafeRow = new UnsafeRow() | ||
|
|
||
| /** Functions for encoding each column */ | ||
| private[this] val writers: Array[UnsafeColumnWriter[Any]] = { | ||
| fieldTypes.map(t => UnsafeColumnWriter.forType(t).asInstanceOf[UnsafeColumnWriter[Any]]) | ||
| } | ||
|
|
||
| /** The size, in bytes, of the fixed-length portion of the row, including the null bitmap */ | ||
| private[this] val fixedLengthSize: Int = | ||
| (8 * fieldTypes.length) + UnsafeRow.calculateBitSetWidthInBytes(fieldTypes.length) | ||
|
|
||
| /** | ||
| * Compute the amount of space, in bytes, required to encode the given row. | ||
| */ | ||
| def getSizeRequirement(row: Row): Int = { | ||
| var fieldNumber = 0 | ||
| var variableLengthFieldSize: Int = 0 | ||
| while (fieldNumber < writers.length) { | ||
| if (!row.isNullAt(fieldNumber)) { | ||
| variableLengthFieldSize += writers(fieldNumber).getSize(row(fieldNumber)) | ||
| } | ||
| fieldNumber += 1 | ||
| } | ||
| fixedLengthSize + variableLengthFieldSize | ||
| } | ||
|
|
||
| /** | ||
| * Convert the given row into UnsafeRow format. | ||
| * | ||
| * @param row the row to convert | ||
| * @param baseObject the base object of the destination address | ||
| * @param baseOffset the base offset of the destination address | ||
| * @return the number of bytes written. This should be equal to `getSizeRequirement(row)`. | ||
| */ | ||
| def writeRow(row: Row, baseObject: Object, baseOffset: Long): Long = { | ||
| unsafeRow.pointTo(baseObject, baseOffset, writers.length, null) | ||
| var fieldNumber = 0 | ||
| var appendCursor: Int = fixedLengthSize | ||
| while (fieldNumber < writers.length) { | ||
| if (row.isNullAt(fieldNumber)) { | ||
| unsafeRow.setNullAt(fieldNumber) | ||
| // TODO: type-specific null value writing? | ||
| } else { | ||
| appendCursor += writers(fieldNumber).write( | ||
| row(fieldNumber), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is pretty minor, but is there a reason we are boxing here instead of passing the row itself in, allowing a specific accessor to be used for extraction?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a really good idea; can't believe I didn't think of that... Should be an easy change, so I'll do this shortly. |
||
| fieldNumber, | ||
| unsafeRow, | ||
| baseObject, | ||
| baseOffset, | ||
| appendCursor) | ||
| } | ||
| fieldNumber += 1 | ||
| } | ||
| appendCursor | ||
| } | ||
|
|
||
| } | ||
|
|
||
| /** | ||
| * Function for writing a column into an UnsafeRow. | ||
| */ | ||
| private abstract class UnsafeColumnWriter[T] { | ||
| /** | ||
| * Write a value into an UnsafeRow. | ||
| * | ||
| * @param value the value to write | ||
| * @param columnNumber what column to write it to | ||
| * @param row a pointer to the unsafe row | ||
| * @param baseObject | ||
| * @param baseOffset | ||
| * @param baseObject the base object of the target row's address | ||
| * @param baseOffset the base offset of the target row's address | ||
| * @param appendCursor the offset from the start of the unsafe row to the end of the row; | ||
| * used for calculating where variable-length data should be written | ||
| * @return the number of variable-length bytes written | ||
|
|
@@ -50,6 +122,12 @@ private abstract class UnsafeColumnWriter[T] { | |
| } | ||
|
|
||
| private object UnsafeColumnWriter { | ||
| private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter | ||
| private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter | ||
| private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter | ||
| private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter | ||
| private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter | ||
|
|
||
| def forType(dataType: DataType): UnsafeColumnWriter[_] = { | ||
| dataType match { | ||
| case IntegerType => IntUnsafeColumnWriter | ||
|
|
@@ -63,34 +141,7 @@ private object UnsafeColumnWriter { | |
| } | ||
| } | ||
|
|
||
| private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter[UTF8String] { | ||
| def getSize(value: UTF8String): Int = { | ||
| // round to nearest word | ||
| val numBytes = value.getBytes.length | ||
| 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) | ||
| } | ||
|
|
||
| override def write( | ||
| value: UTF8String, | ||
| columnNumber: Int, | ||
| row: UnsafeRow, | ||
| baseObject: Object, | ||
| baseOffset: Long, | ||
| appendCursor: Int): Int = { | ||
| val numBytes = value.getBytes.length | ||
| PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes) | ||
| PlatformDependent.copyMemory( | ||
| value.getBytes, | ||
| PlatformDependent.BYTE_ARRAY_OFFSET, | ||
| baseObject, | ||
| baseOffset + appendCursor + 8, | ||
| numBytes | ||
| ) | ||
| row.setLong(columnNumber, appendCursor) | ||
| 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) | ||
| } | ||
| } | ||
| private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter | ||
| // ------------------------------------------------------------------------------------------------ | ||
|
|
||
| private abstract class PrimitiveUnsafeColumnWriter[T] extends UnsafeColumnWriter[T] { | ||
| def getSize(value: T): Int = 0 | ||
|
|
@@ -108,7 +159,6 @@ private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWrite | |
| 0 | ||
| } | ||
| } | ||
| private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter | ||
|
|
||
| private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter[Long] { | ||
| override def write( | ||
|
|
@@ -122,7 +172,6 @@ private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWrit | |
| 0 | ||
| } | ||
| } | ||
| private case object LongUnsafeColumnWriter extends LongUnsafeColumnWriter | ||
|
|
||
| private class FloatUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter[Float] { | ||
| override def write( | ||
|
|
@@ -136,7 +185,6 @@ private class FloatUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWri | |
| 0 | ||
| } | ||
| } | ||
| private case object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter | ||
|
|
||
| private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter[Double] { | ||
| override def write( | ||
|
|
@@ -150,55 +198,29 @@ private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWr | |
| 0 | ||
| } | ||
| } | ||
| private case object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter | ||
|
|
||
| class UnsafeRowConverter(fieldTypes: Array[DataType]) { | ||
|
|
||
| def this(schema: StructType) { | ||
| this(schema.fields.map(_.dataType)) | ||
| } | ||
|
|
||
| private[this] val unsafeRow = new UnsafeRow() | ||
|
|
||
| private[this] val writers: Array[UnsafeColumnWriter[Any]] = { | ||
| fieldTypes.map(t => UnsafeColumnWriter.forType(t).asInstanceOf[UnsafeColumnWriter[Any]]) | ||
| } | ||
|
|
||
| private[this] val fixedLengthSize: Int = | ||
| (8 * fieldTypes.length) + UnsafeRow.calculateBitSetWidthInBytes(fieldTypes.length) | ||
|
|
||
| def getSizeRequirement(row: Row): Int = { | ||
| var fieldNumber = 0 | ||
| var variableLengthFieldSize: Int = 0 | ||
| while (fieldNumber < writers.length) { | ||
| if (!row.isNullAt(fieldNumber)) { | ||
| variableLengthFieldSize += writers(fieldNumber).getSize(row(fieldNumber)) | ||
| } | ||
| fieldNumber += 1 | ||
| } | ||
| fixedLengthSize + variableLengthFieldSize | ||
| private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter[UTF8String] { | ||
| def getSize(value: UTF8String): Int = { | ||
| 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(value.getBytes.length) | ||
| } | ||
|
|
||
| def writeRow(row: Row, baseObject: Object, baseOffset: Long): Long = { | ||
| unsafeRow.pointTo(baseObject, baseOffset, writers.length, null) | ||
| var fieldNumber = 0 | ||
| var appendCursor: Int = fixedLengthSize | ||
| while (fieldNumber < writers.length) { | ||
| if (row.isNullAt(fieldNumber)) { | ||
| unsafeRow.setNullAt(fieldNumber) | ||
| // TODO: type-specific null value writing? | ||
| } else { | ||
| appendCursor += writers(fieldNumber).write( | ||
| row(fieldNumber), | ||
| fieldNumber, | ||
| unsafeRow, | ||
| baseObject, | ||
| baseOffset, | ||
| appendCursor) | ||
| } | ||
| fieldNumber += 1 | ||
| } | ||
| appendCursor | ||
| override def write( | ||
| value: UTF8String, | ||
| columnNumber: Int, | ||
| row: UnsafeRow, | ||
| baseObject: Object, | ||
| baseOffset: Long, | ||
| appendCursor: Int): Int = { | ||
| val numBytes = value.getBytes.length | ||
| PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes) | ||
| PlatformDependent.copyMemory( | ||
| value.getBytes, | ||
| PlatformDependent.BYTE_ARRAY_OFFSET, | ||
| baseObject, | ||
| baseOffset + appendCursor + 8, | ||
| numBytes | ||
| ) | ||
| row.setLong(columnNumber, appendCursor) | ||
| 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) | ||
| } | ||
|
|
||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick: if we have comments, let's add a blank line