-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23935][SQL] Adding map_entries function #21236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
086e223
b9e2409
4739977
d05ad9b
6aa90ef
56ff20a
1bd0d5e
baa61e5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -118,6 +118,162 @@ case class MapValues(child: Expression) | |
| override def prettyName: String = "map_values" | ||
| } | ||
|
|
||
| /** | ||
| * Returns an unordered array of all entries in the given map. | ||
| */ | ||
| @ExpressionDescription( | ||
| usage = "_FUNC_(map) - Returns an unordered array of all entries in the given map.", | ||
| examples = """ | ||
| Examples: | ||
| > SELECT _FUNC_(map(1, 'a', 2, 'b')); | ||
| [(1,"a"),(2,"b")] | ||
| """, | ||
| since = "2.4.0") | ||
| case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes { | ||
|
|
||
| override def inputTypes: Seq[AbstractDataType] = Seq(MapType) | ||
|
|
||
| lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType] | ||
|
|
||
| override def dataType: DataType = { | ||
| ArrayType( | ||
| StructType( | ||
| StructField("key", childDataType.keyType, false) :: | ||
| StructField("value", childDataType.valueType, childDataType.valueContainsNull) :: | ||
| Nil), | ||
| false) | ||
| } | ||
|
|
||
| override protected def nullSafeEval(input: Any): Any = { | ||
| val childMap = input.asInstanceOf[MapData] | ||
| val keys = childMap.keyArray() | ||
| val values = childMap.valueArray() | ||
| val length = childMap.numElements() | ||
| val resultData = new Array[AnyRef](length) | ||
| var i = 0; | ||
| while (i < length) { | ||
| val key = keys.get(i, childDataType.keyType) | ||
| val value = values.get(i, childDataType.valueType) | ||
| val row = new GenericInternalRow(Array[Any](key, value)) | ||
| resultData.update(i, row) | ||
| i += 1 | ||
| } | ||
| new GenericArrayData(resultData) | ||
| } | ||
|
|
||
| override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| nullSafeCodeGen(ctx, ev, c => { | ||
| val numElements = ctx.freshName("numElements") | ||
| val keys = ctx.freshName("keys") | ||
| val values = ctx.freshName("values") | ||
| val isKeyPrimitive = CodeGenerator.isPrimitiveType(childDataType.keyType) | ||
| val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) | ||
| val code = if (isKeyPrimitive && isValuePrimitive) { | ||
| genCodeForPrimitiveElements(ctx, keys, values, ev.value, numElements) | ||
| } else { | ||
| genCodeForAnyElements(ctx, keys, values, ev.value, numElements) | ||
| } | ||
| s""" | ||
| |final int $numElements = $c.numElements(); | ||
| |final ArrayData $keys = $c.keyArray(); | ||
| |final ArrayData $values = $c.valueArray(); | ||
| |$code | ||
| """.stripMargin | ||
| }) | ||
| } | ||
|
|
||
| private def getKey(varName: String) = CodeGenerator.getValue(varName, childDataType.keyType, "z") | ||
|
|
||
| private def getValue(varName: String) = { | ||
| CodeGenerator.getValue(varName, childDataType.valueType, "z") | ||
| } | ||
|
|
||
| private def genCodeForPrimitiveElements( | ||
| ctx: CodegenContext, | ||
| keys: String, | ||
| values: String, | ||
| arrayData: String, | ||
| numElements: String): String = { | ||
| val byteArraySize = ctx.freshName("byteArraySize") | ||
| val data = ctx.freshName("byteArray") | ||
| val unsafeRow = ctx.freshName("unsafeRow") | ||
| val structSize = ctx.freshName("structSize") | ||
| val unsafeArrayData = ctx.freshName("unsafeArrayData") | ||
| val structsOffset = ctx.freshName("structsOffset") | ||
| val calculateArraySize = "UnsafeArrayData.calculateSizeOfUnderlyingByteArray" | ||
| val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes" | ||
|
|
||
| val baseOffset = Platform.BYTE_ARRAY_OFFSET | ||
| val longSize = LongType.defaultSize | ||
| val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) | ||
| val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) | ||
|
|
||
| val valueAssignment = s"$unsafeRow.set$valueTypeName(1, ${getValue(values)});" | ||
| val valueAssignmentChecked = if (childDataType.valueContainsNull) { | ||
| s""" | ||
| |if ($values.isNullAt(z)) { | ||
| | $unsafeRow.setNullAt(1); | ||
| |} else { | ||
| | $valueAssignment | ||
| |} | ||
| """.stripMargin | ||
| } else { | ||
| valueAssignment | ||
| } | ||
|
|
||
| s""" | ||
| |final int $structSize = ${UnsafeRow.calculateBitSetWidthInBytes(2) + longSize * 2}; | ||
| |final long $byteArraySize = $calculateArraySize($numElements, $longSize + $structSize); | ||
| |final int $structsOffset = $calculateHeader($numElements) + $numElements * $longSize; | ||
|
||
| |if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | ||
| | ${genCodeForAnyElements(ctx, keys, values, arrayData, numElements)} | ||
|
||
| |} else { | ||
| | final byte[] $data = new byte[(int)$byteArraySize]; | ||
| | UnsafeArrayData $unsafeArrayData = new UnsafeArrayData(); | ||
| | Platform.putLong($data, $baseOffset, $numElements); | ||
| | $unsafeArrayData.pointTo($data, $baseOffset, (int)$byteArraySize); | ||
| | UnsafeRow $unsafeRow = new UnsafeRow(2); | ||
| | for (int z = 0; z < $numElements; z++) { | ||
| | long offset = $structsOffset + z * $structSize; | ||
|
||
| | $unsafeArrayData.setLong(z, (offset << 32) + $structSize); | ||
| | $unsafeRow.pointTo($data, $baseOffset + offset, $structSize); | ||
| | $unsafeRow.set$keyTypeName(0, ${getKey(keys)}); | ||
| | $valueAssignmentChecked | ||
| | } | ||
| | $arrayData = $unsafeArrayData; | ||
| |} | ||
| """.stripMargin | ||
| } | ||
|
|
||
| private def genCodeForAnyElements( | ||
| ctx: CodegenContext, | ||
| keys: String, | ||
| values: String, | ||
| arrayData: String, | ||
| numElements: String): String = { | ||
| val genericArrayClass = classOf[GenericArrayData].getName | ||
| val rowClass = classOf[GenericInternalRow].getName | ||
| val data = ctx.freshName("internalRowArray") | ||
|
|
||
| val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) | ||
| val getValueWithCheck = if (childDataType.valueContainsNull && isValuePrimitive) { | ||
| s"$values.isNullAt(z) ? null : (Object)${getValue(values)}" | ||
| } else { | ||
| getValue(values) | ||
| } | ||
|
||
|
|
||
| s""" | ||
| |final Object[] $data = new Object[$numElements]; | ||
| |for (int z = 0; z < $numElements; z++) { | ||
| | $data[z] = new $rowClass(new Object[]{${getKey(keys)}, $getValueWithCheck}); | ||
| |} | ||
| |$arrayData = new $genericArrayClass($data); | ||
| """.stripMargin | ||
| } | ||
|
|
||
| override def prettyName: String = "map_entries" | ||
| } | ||
|
|
||
| /** | ||
| * Sorts the input array in ascending / descending order according to the natural ordering of | ||
| * the array elements and returns it. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -98,6 +98,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { | |
| if (expected.isNaN) result.isNaN else expected == result | ||
| case (result: Float, expected: Float) => | ||
| if (expected.isNaN) result.isNaN else expected == result | ||
| case (result: UnsafeRow, expected: GenericInternalRow) => | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mn-mikke I was just looking over compiler warnings, and noticed it claims this case is never triggered. I think it's because it would always first match the (InternalRow, InternalRow) case above. Should it go before that then?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Roger that, looks like Wenchen just did so. Thanks! |
||
| val structType = exprDataType.asInstanceOf[StructType] | ||
| result.toSeq(structType) == expected.toSeq(structType) | ||
| case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema) | ||
| case _ => | ||
| result == expected | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can calculate
structSizebeforehand and inline it?