-
Notifications
You must be signed in to change notification settings - Fork 29k
[Spark-14138][SQL] Fix generated SpecificColumnarIterator code can exceed JVM size limit for cached DataFrames #11984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
ab67d33
fea2a52
e56406e
60f6719
226bad5
f3307a7
9346793
beb9840
16cf602
a310bfc
c1acf82
60cebd5
3a05ddf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
group a lot of calls into a method
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,8 @@ | |
|
|
||
| package org.apache.spark.sql.execution.columnar | ||
|
|
||
| import scala.collection.mutable | ||
|
|
||
| import org.apache.spark.Logging | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
|
|
@@ -68,6 +70,8 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera | |
| protected def create(columnTypes: Seq[DataType]): ColumnarIterator = { | ||
| val ctx = newCodeGenContext() | ||
| val numFields = columnTypes.size | ||
| val accessorClasses = new mutable.HashMap[String, String] | ||
| val accessorStructClasses = new mutable.HashMap[(String, DataType), (String, String)] | ||
| val (initializeAccessors, extractors) = columnTypes.zipWithIndex.map { case (dt, index) => | ||
| val accessorName = ctx.freshName("accessor") | ||
| val accessorCls = dt match { | ||
|
|
@@ -88,16 +92,20 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera | |
| case array: ArrayType => classOf[ArrayColumnAccessor].getName | ||
| case t: MapType => classOf[MapColumnAccessor].getName | ||
| } | ||
| ctx.addMutableState(accessorCls, accessorName, s"$accessorName = null;") | ||
|
|
||
| val createCode = dt match { | ||
| case t if ctx.isPrimitiveType(dt) => | ||
| s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" | ||
| case NullType | StringType | BinaryType => | ||
| s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" | ||
| case other => | ||
| s"""$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder), | ||
| (${dt.getClass.getName}) columnTypes[$index]);""" | ||
| ctx.addMutableState(accessorCls, accessorName, "") | ||
|
|
||
| val createCode = { | ||
| val shortAccCls = accessorCls.substring(accessorCls.lastIndexOf(".") + 1) | ||
| dt match { | ||
| case t if ctx.isPrimitiveType(dt) => | ||
| s"$accessorName = get${accessorClasses.getOrElseUpdate(accessorCls, shortAccCls)}($index);" | ||
| case NullType | StringType | BinaryType => | ||
| s"$accessorName = get${accessorClasses.getOrElseUpdate(accessorCls, shortAccCls)}($index);" | ||
| case other => | ||
| val shortDTCls = dt.getClass.getName.substring(dt.getClass.getName.lastIndexOf(".") + 1) | ||
| accessorStructClasses.getOrElseUpdate((accessorCls, dt), (shortAccCls, shortDTCls)) | ||
| s"$accessorName = get${shortAccCls}_${shortDTCls}($index);" | ||
| } | ||
| } | ||
|
|
||
| val extract = s"$accessorName.extractTo(mutableRow, $index);" | ||
|
|
@@ -114,6 +122,57 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera | |
| (createCode, extract + patch) | ||
| }.unzip | ||
|
|
||
| val accessorCode = accessorClasses.map { case (accessorCls, shortAccCls) => | ||
| s""" | ||
| private $accessorCls get${shortAccCls}(int idx) { | ||
| byte[] buffer = batch.buffers()[columnIndexes[idx]]; | ||
| return new $accessorCls(ByteBuffer.wrap(buffer).order(nativeOrder)); | ||
| } | ||
| """ | ||
| } | ||
| val accessorStructCode = accessorStructClasses.map { | ||
| case ((accessorCls, dt), (shortAccCls, shortDTCls)) => | ||
| s""" | ||
| private $accessorCls get${shortAccCls}_${shortDTCls}(int idx) { | ||
| byte[] buffer = batch.buffers()[columnIndexes[idx]]; | ||
| return new $accessorCls(ByteBuffer.wrap(buffer).order(nativeOrder), | ||
| (${dt.getClass.getName}) columnTypes[idx]); | ||
| } | ||
| """ | ||
| } | ||
|
|
||
| /* 4000 = 64000 bytes / 16 (up to 16 bytes per one call)) */ | ||
| val numberOfStatementsThreshold = 4000 | ||
|
||
| val (initializerAccessorFuncs, initializerAccessorCalls, extractorFuncs, extractorCalls) = | ||
|
||
| if (initializeAccessors.length < numberOfStatementsThreshold) { | ||
| ("", initializeAccessors.mkString("\n"), "", extractors.mkString("\n")) | ||
| } else { | ||
| val groupedAccessorsItr = initializeAccessors.grouped(numberOfStatementsThreshold) | ||
| var groupedAccessorsLength = 0 | ||
| val groupedExtractorsItr = extractors.grouped(numberOfStatementsThreshold) | ||
| var groupedExtractorsLength = 0 | ||
| ( | ||
| groupedAccessorsItr.zipWithIndex.map { case (body, i) => | ||
| groupedAccessorsLength += 1 | ||
| s""" | ||
| |private void accessors$i() { | ||
| | ${body.mkString("\n")} | ||
| |} | ||
| """.stripMargin | ||
| }.mkString(""), | ||
| (0 to groupedAccessorsLength - 1).map { i => s"accessors$i();" }.mkString("\n"), | ||
| groupedExtractorsItr.zipWithIndex.map { case (body, i) => | ||
| groupedExtractorsLength += 1 | ||
| s""" | ||
| |private void extractors$i() { | ||
| | ${body.mkString("\n")} | ||
| |} | ||
| """.stripMargin | ||
| }.mkString(""), | ||
| (0 to groupedExtractorsLength - 1).map { i => s"extractors$i();" }.mkString("\n") | ||
| ) | ||
| } | ||
|
|
||
| val code = s""" | ||
| import java.nio.ByteBuffer; | ||
| import java.nio.ByteOrder; | ||
|
|
@@ -130,7 +189,6 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera | |
| class SpecificColumnarIterator extends ${classOf[ColumnarIterator].getName} { | ||
|
|
||
| private ByteOrder nativeOrder = null; | ||
| private byte[][] buffers = null; | ||
| private UnsafeRow unsafeRow = new UnsafeRow(); | ||
| private BufferHolder bufferHolder = new BufferHolder(); | ||
| private UnsafeRowWriter rowWriter = new UnsafeRowWriter(); | ||
|
|
@@ -142,15 +200,13 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera | |
| private scala.collection.Iterator input = null; | ||
| private DataType[] columnTypes = null; | ||
| private int[] columnIndexes = null; | ||
| ${classOf[CachedBatch].getName} batch = null; | ||
|
|
||
| ${declareMutableStates(ctx)} | ||
|
|
||
| public SpecificColumnarIterator() { | ||
| this.nativeOrder = ByteOrder.nativeOrder(); | ||
| this.buffers = new byte[${columnTypes.length}][]; | ||
| this.mutableRow = new MutableUnsafeRow(rowWriter); | ||
|
|
||
| ${initMutableStates(ctx)} | ||
| } | ||
|
|
||
| public void initialize(Iterator input, DataType[] columnTypes, int[] columnIndexes) { | ||
|
|
@@ -159,6 +215,12 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera | |
| this.columnIndexes = columnIndexes; | ||
| } | ||
|
|
||
| ${accessorCode.mkString("\n")} | ||
| ${accessorStructCode.mkString("\n")} | ||
|
|
||
| ${initializerAccessorFuncs} | ||
| ${extractorFuncs} | ||
|
|
||
| public boolean hasNext() { | ||
| if (currentRow < numRowsInBatch) { | ||
| return true; | ||
|
|
@@ -167,13 +229,10 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera | |
| return false; | ||
| } | ||
|
|
||
| ${classOf[CachedBatch].getName} batch = (${classOf[CachedBatch].getName}) input.next(); | ||
| batch = (${classOf[CachedBatch].getName}) input.next(); | ||
| currentRow = 0; | ||
| numRowsInBatch = batch.numRows(); | ||
| for (int i = 0; i < columnIndexes.length; i ++) { | ||
| buffers[i] = batch.buffers()[columnIndexes[i]]; | ||
| } | ||
| ${initializeAccessors.mkString("\n")} | ||
| ${initializerAccessorCalls} | ||
|
|
||
| return hasNext(); | ||
| } | ||
|
|
@@ -182,7 +241,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera | |
| currentRow += 1; | ||
| bufferHolder.reset(); | ||
| rowWriter.initialize(bufferHolder, $numFields); | ||
| ${extractors.mkString("\n")} | ||
| ${extractorCalls} | ||
| unsafeRow.pointTo(bufferHolder.buffer, $numFields, bufferHolder.totalSize()); | ||
| return unsafeRow; | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ctx.addFunction ?