-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23935][SQL] Adding map_entries function #21236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
086e223
b9e2409
4739977
d05ad9b
6aa90ef
56ff20a
1bd0d5e
baa61e5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder | |
| import org.apache.spark.sql.catalyst.expressions.codegen._ | ||
| import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.unsafe.Platform | ||
| import org.apache.spark.unsafe.array.ByteArrayMethods | ||
| import org.apache.spark.unsafe.types.{ByteArray, UTF8String} | ||
|
|
||
|
|
@@ -118,6 +119,162 @@ case class MapValues(child: Expression) | |
| override def prettyName: String = "map_values" | ||
| } | ||
|
|
||
| /** | ||
| * Returns an unordered array of all entries in the given map. | ||
| */ | ||
| @ExpressionDescription( | ||
| usage = "_FUNC_(map) - Returns an unordered array of all entries in the given map.", | ||
| examples = """ | ||
| Examples: | ||
| > SELECT _FUNC_(map(1, 'a', 2, 'b')); | ||
| [(1,"a"),(2,"b")] | ||
| """, | ||
| since = "2.4.0") | ||
| case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes { | ||
|
|
||
| override def inputTypes: Seq[AbstractDataType] = Seq(MapType) | ||
|
|
||
| lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType] | ||
|
|
||
| override def dataType: DataType = { | ||
| ArrayType( | ||
| StructType( | ||
| StructField("key", childDataType.keyType, false) :: | ||
| StructField("value", childDataType.valueType, childDataType.valueContainsNull) :: | ||
| Nil), | ||
| false) | ||
| } | ||
|
|
||
| override protected def nullSafeEval(input: Any): Any = { | ||
| val childMap = input.asInstanceOf[MapData] | ||
| val keys = childMap.keyArray() | ||
| val values = childMap.valueArray() | ||
| val length = childMap.numElements() | ||
| val resultData = new Array[AnyRef](length) | ||
| var i = 0; | ||
| while (i < length) { | ||
| val key = keys.get(i, childDataType.keyType) | ||
| val value = values.get(i, childDataType.valueType) | ||
| val row = new GenericInternalRow(Array[Any](key, value)) | ||
| resultData.update(i, row) | ||
| i += 1 | ||
| } | ||
| new GenericArrayData(resultData) | ||
| } | ||
|
|
||
| override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| nullSafeCodeGen(ctx, ev, c => { | ||
| val numElements = ctx.freshName("numElements") | ||
| val keys = ctx.freshName("keys") | ||
| val values = ctx.freshName("values") | ||
| val isKeyPrimitive = CodeGenerator.isPrimitiveType(childDataType.keyType) | ||
| val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) | ||
| val code = if (isKeyPrimitive && isValuePrimitive) { | ||
| genCodeForPrimitiveElements(ctx, keys, values, ev.value, numElements) | ||
| } else { | ||
| genCodeForAnyElements(ctx, keys, values, ev.value, numElements) | ||
| } | ||
| s""" | ||
| |final int $numElements = $c.numElements(); | ||
| |final ArrayData $keys = $c.keyArray(); | ||
| |final ArrayData $values = $c.valueArray(); | ||
| |$code | ||
| """.stripMargin | ||
| }) | ||
| } | ||
|
|
||
| private def getKey(varName: String) = CodeGenerator.getValue(varName, childDataType.keyType, "z") | ||
|
|
||
| private def getValue(varName: String) = { | ||
| CodeGenerator.getValue(varName, childDataType.valueType, "z") | ||
| } | ||
|
|
||
| private def genCodeForPrimitiveElements( | ||
| ctx: CodegenContext, | ||
| keys: String, | ||
| values: String, | ||
| arrayData: String, | ||
| numElements: String): String = { | ||
| val byteArraySize = ctx.freshName("byteArraySize") | ||
| val data = ctx.freshName("byteArray") | ||
| val unsafeRow = ctx.freshName("unsafeRow") | ||
| val unsafeArrayData = ctx.freshName("unsafeArrayData") | ||
| val structsOffset = ctx.freshName("structsOffset") | ||
| val calculateArraySize = "UnsafeArrayData.calculateSizeOfUnderlyingByteArray" | ||
| val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes" | ||
|
|
||
| val baseOffset = Platform.BYTE_ARRAY_OFFSET | ||
| val longSize = LongType.defaultSize | ||
| val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + longSize * 2 | ||
| val structSizeAsLong = structSize + "L" | ||
| val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) | ||
| val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) | ||
|
|
||
| val valueAssignment = s"$unsafeRow.set$valueTypeName(1, ${getValue(values)});" | ||
| val valueAssignmentChecked = if (childDataType.valueContainsNull) { | ||
| s""" | ||
| |if ($values.isNullAt(z)) { | ||
| | $unsafeRow.setNullAt(1); | ||
| |} else { | ||
| | $valueAssignment | ||
| |} | ||
| """.stripMargin | ||
| } else { | ||
| valueAssignment | ||
| } | ||
|
|
||
| s""" | ||
| |final long $byteArraySize = $calculateArraySize($numElements, ${longSize + structSize}); | ||
| |if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | ||
| | ${genCodeForAnyElements(ctx, keys, values, arrayData, numElements)} | ||
|
||
| |} else { | ||
| | final int $structsOffset = $calculateHeader($numElements) + $numElements * $longSize; | ||
| | final byte[] $data = new byte[(int)$byteArraySize]; | ||
| | UnsafeArrayData $unsafeArrayData = new UnsafeArrayData(); | ||
| | Platform.putLong($data, $baseOffset, $numElements); | ||
| | $unsafeArrayData.pointTo($data, $baseOffset, (int)$byteArraySize); | ||
| | UnsafeRow $unsafeRow = new UnsafeRow(2); | ||
| | for (int z = 0; z < $numElements; z++) { | ||
| | long offset = $structsOffset + z * $structSizeAsLong; | ||
| | $unsafeArrayData.setLong(z, (offset << 32) + $structSizeAsLong); | ||
| | $unsafeRow.pointTo($data, $baseOffset + offset, $structSize); | ||
| | $unsafeRow.set$keyTypeName(0, ${getKey(keys)}); | ||
| | $valueAssignmentChecked | ||
| | } | ||
| | $arrayData = $unsafeArrayData; | ||
| |} | ||
| """.stripMargin | ||
| } | ||
|
|
||
| private def genCodeForAnyElements( | ||
| ctx: CodegenContext, | ||
| keys: String, | ||
| values: String, | ||
| arrayData: String, | ||
| numElements: String): String = { | ||
| val genericArrayClass = classOf[GenericArrayData].getName | ||
| val rowClass = classOf[GenericInternalRow].getName | ||
| val data = ctx.freshName("internalRowArray") | ||
|
|
||
| val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) | ||
| val getValueWithCheck = if (childDataType.valueContainsNull && isValuePrimitive) { | ||
| s"$values.isNullAt(z) ? null : (Object)${getValue(values)}" | ||
| } else { | ||
| getValue(values) | ||
| } | ||
|
||
|
|
||
| s""" | ||
| |final Object[] $data = new Object[$numElements]; | ||
| |for (int z = 0; z < $numElements; z++) { | ||
| | $data[z] = new $rowClass(new Object[]{${getKey(keys)}, $getValueWithCheck}); | ||
| |} | ||
| |$arrayData = new $genericArrayClass($data); | ||
| """.stripMargin | ||
| } | ||
|
|
||
| override def prettyName: String = "map_entries" | ||
| } | ||
|
|
||
| /** | ||
| * Common base class for [[SortArray]] and [[ArraySort]]. | ||
| */ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -98,6 +98,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { | |
| if (expected.isNaN) result.isNaN else expected == result | ||
| case (result: Float, expected: Float) => | ||
| if (expected.isNaN) result.isNaN else expected == result | ||
| case (result: UnsafeRow, expected: GenericInternalRow) => | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mn-mikke I was just looking over compiler warnings, and noticed it claims this case is never triggered. I think it's because it would always first match the (InternalRow, InternalRow) case above. Should it go before that then?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Roger that, looks like Wenchen just did so. Thanks! |
||
| val structType = exprDataType.asInstanceOf[StructType] | ||
| result.toSeq(structType) == expected.toSeq(structType) | ||
| case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema) | ||
| case _ => | ||
| result == expected | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering it is right to use
longSizehere?I know the value is
8and is same as the word size, but feel like the meaning is different?cc @gatorsmile @cloud-fan
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ueshin Really good question. I'm eager to learn about the true purpose of the
DataType.defaultSizefunction. Currently, it's used in this meaning at more places (e.g.GenArrayData.genCodeToCreateArrayDataandCodeGenerator.createUnsafeArray.)What about using
Long.BYTESfrom Java 8 instead?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMHO,
8is the better choice since it is not related to an element size oflong.To my best guess, it would be best to define a new constant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kiszk Thanks for your suggestion, but it seems to me that
LongType.defaultSizecould be used in this case. It seems that the purpose ofdefaultSizeis not only the calculation of estimated data size in statistics.GenerateUnsafeProjection.writeArrayToBuffer,InterpretedUnsafeProjection.getElementSizeand other parts utilizedefaultSizein the same way.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not for the element size of arrays. I agree with @kiszk to use
8.Maybe we need to add a constant to represent the word size in
UnsafeRowor somewhere in the future pr.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh OK, I misunderstood the comments. Thanks guys!