Skip to content
20 changes: 20 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2273,6 +2273,26 @@ def map_values(col):
return Column(sc._jvm.functions.map_values(_to_java_column(col)))


@since(2.4)
def map_entries(col):
"""
Collection function: Returns an unordered array of all entries in the given map.

:param col: name of column or expression

>>> from pyspark.sql.functions import map_entries
>>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data")
>>> df.select(map_entries("data").alias("entries")).show()
+----------------+
| entries|
+----------------+
|[[1, a], [2, b]]|
+----------------+
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.map_entries(_to_java_column(col)))


# ---------------------------- User Defined Function ----------------------------------

class PandasUDFType(object):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ object FunctionRegistry {
expression[ElementAt]("element_at"),
expression[MapKeys]("map_keys"),
expression[MapValues]("map_values"),
expression[MapEntries]("map_entries"),
expression[Size]("size"),
expression[Size]("cardinality"),
expression[SortArray]("sort_array"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,162 @@ case class MapValues(child: Expression)
override def prettyName: String = "map_values"
}

/**
* Returns an unordered array of all entries in the given map.
*/
@ExpressionDescription(
usage = "_FUNC_(map) - Returns an unordered array of all entries in the given map.",
examples = """
Examples:
> SELECT _FUNC_(map(1, 'a', 2, 'b'));
[(1,"a"),(2,"b")]
""",
since = "2.4.0")
case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(MapType)

lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType]

override def dataType: DataType = {
ArrayType(
StructType(
StructField("key", childDataType.keyType, false) ::
StructField("value", childDataType.valueType, childDataType.valueContainsNull) ::
Nil),
false)
}

override protected def nullSafeEval(input: Any): Any = {
val childMap = input.asInstanceOf[MapData]
val keys = childMap.keyArray()
val values = childMap.valueArray()
val length = childMap.numElements()
val resultData = new Array[AnyRef](length)
var i = 0;
while (i < length) {
val key = keys.get(i, childDataType.keyType)
val value = values.get(i, childDataType.valueType)
val row = new GenericInternalRow(Array[Any](key, value))
resultData.update(i, row)
i += 1
}
new GenericArrayData(resultData)
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => {
val numElements = ctx.freshName("numElements")
val keys = ctx.freshName("keys")
val values = ctx.freshName("values")
val isKeyPrimitive = CodeGenerator.isPrimitiveType(childDataType.keyType)
val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType)
val code = if (isKeyPrimitive && isValuePrimitive) {
genCodeForPrimitiveElements(ctx, keys, values, ev.value, numElements)
} else {
genCodeForAnyElements(ctx, keys, values, ev.value, numElements)
}
s"""
|final int $numElements = $c.numElements();
|final ArrayData $keys = $c.keyArray();
|final ArrayData $values = $c.valueArray();
|$code
""".stripMargin
})
}

private def getKey(varName: String) = CodeGenerator.getValue(varName, childDataType.keyType, "z")

private def getValue(varName: String) = {
CodeGenerator.getValue(varName, childDataType.valueType, "z")
}

private def genCodeForPrimitiveElements(
ctx: CodegenContext,
keys: String,
values: String,
arrayData: String,
numElements: String): String = {
val byteArraySize = ctx.freshName("byteArraySize")
val data = ctx.freshName("byteArray")
val unsafeRow = ctx.freshName("unsafeRow")
val structSize = ctx.freshName("structSize")
val unsafeArrayData = ctx.freshName("unsafeArrayData")
val structsOffset = ctx.freshName("structsOffset")
val calculateArraySize = "UnsafeArrayData.calculateSizeOfUnderlyingByteArray"
val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes"

val baseOffset = Platform.BYTE_ARRAY_OFFSET
val longSize = LongType.defaultSize
val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType)
val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType)

val valueAssignment = s"$unsafeRow.set$valueTypeName(1, ${getValue(values)});"
val valueAssignmentChecked = if (childDataType.valueContainsNull) {
s"""
|if ($values.isNullAt(z)) {
| $unsafeRow.setNullAt(1);
|} else {
| $valueAssignment
|}
""".stripMargin
} else {
valueAssignment
}

s"""
|final int $structSize = ${UnsafeRow.calculateBitSetWidthInBytes(2) + longSize * 2};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can calculate structSize beforehand and inline it?

|final long $byteArraySize = $calculateArraySize($numElements, $longSize + $structSize);
|final int $structsOffset = $calculateHeader($numElements) + $numElements * $longSize;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can move this into else-clause?

|if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| ${genCodeForAnyElements(ctx, keys, values, arrayData, numElements)}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, should we use this idiom for other array functions? WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, I separated the logic that I can leverage for map_from_entries function. Moreover, I think it should be possible to replace UnsafeArrayData.createUnsafeArray with that logic, but will do it in a different PR.

|} else {
| final byte[] $data = new byte[(int)$byteArraySize];
| UnsafeArrayData $unsafeArrayData = new UnsafeArrayData();
| Platform.putLong($data, $baseOffset, $numElements);
| $unsafeArrayData.pointTo($data, $baseOffset, (int)$byteArraySize);
| UnsafeRow $unsafeRow = new UnsafeRow(2);
| for (int z = 0; z < $numElements; z++) {
| long offset = $structsOffset + z * $structSize;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: $structSize -> ${$structSize}L

| $unsafeArrayData.setLong(z, (offset << 32) + $structSize);
| $unsafeRow.pointTo($data, $baseOffset + offset, $structSize);
| $unsafeRow.set$keyTypeName(0, ${getKey(keys)});
| $valueAssignmentChecked
| }
| $arrayData = $unsafeArrayData;
|}
""".stripMargin
}

private def genCodeForAnyElements(
ctx: CodegenContext,
keys: String,
values: String,
arrayData: String,
numElements: String): String = {
val genericArrayClass = classOf[GenericArrayData].getName
val rowClass = classOf[GenericInternalRow].getName
val data = ctx.freshName("internalRowArray")

val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType)
val getValueWithCheck = if (childDataType.valueContainsNull && isValuePrimitive) {
s"$values.isNullAt(z) ? null : (Object)${getValue(values)}"
} else {
getValue(values)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: indent


s"""
|final Object[] $data = new Object[$numElements];
|for (int z = 0; z < $numElements; z++) {
| $data[z] = new $rowClass(new Object[]{${getKey(keys)}, $getValueWithCheck});
|}
|$arrayData = new $genericArrayClass($data);
""".stripMargin
}

override def prettyName: String = "map_entries"
}

/**
* Sorts the input array in ascending / descending order according to the natural ordering of
* the array elements and returns it.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._

class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand Down Expand Up @@ -56,6 +57,28 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(MapValues(m2), null)
}

test("MapEntries") {
def r(values: Any*): InternalRow = create_row(values: _*)

// Primitive-type keys/values
val mi0 = Literal.create(Map(1 -> 1, 2 -> null, 3 -> 2), MapType(IntegerType, IntegerType))
val mi1 = Literal.create(Map[Int, Int](), MapType(IntegerType, IntegerType))
val mi2 = Literal.create(null, MapType(IntegerType, IntegerType))

checkEvaluation(MapEntries(mi0), Seq(r(1, 1), r(2, null), r(3, 2)))
checkEvaluation(MapEntries(mi1), Seq.empty)
checkEvaluation(MapEntries(mi2), null)

// Non-primitive-type keys/values
val ms0 = Literal.create(Map("a" -> "c", "b" -> null), MapType(StringType, StringType))
val ms1 = Literal.create(Map[Int, Int](), MapType(StringType, StringType))
val ms2 = Literal.create(null, MapType(StringType, StringType))

checkEvaluation(MapEntries(ms0), Seq(r("a", "c"), r("b", null)))
checkEvaluation(MapEntries(ms1), Seq.empty)
checkEvaluation(MapEntries(ms2), null)
}

test("Sort Array") {
val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
if (expected.isNaN) result.isNaN else expected == result
case (result: Float, expected: Float) =>
if (expected.isNaN) result.isNaN else expected == result
case (result: UnsafeRow, expected: GenericInternalRow) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mn-mikke I was just looking over compiler warnings, and noticed it claims this case is never triggered. I think it's because it would always first match the (InternalRow, InternalRow) case above. Should it go before that then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @srowen,
(InternalRow, InternalRow) case was introduced later in 21838 and covers the logic of the case with UnsafeRow. So we can just remove the unreachable piece of code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Roger that, looks like Wenchen just did so. Thanks!

val structType = exprDataType.asInstanceOf[StructType]
result.toSeq(structType) == expected.toSeq(structType)
case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema)
case _ =>
result == expected
Expand Down
7 changes: 7 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3392,6 +3392,13 @@ object functions {
*/
def map_values(e: Column): Column = withExpr { MapValues(e.expr) }

/**
* Returns an unordered array of all entries in the given map.
* @group collection_funcs
* @since 2.4.0
*/
def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) }

// scalastyle:off line.size.limit
// scalastyle:off parameter.number

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,50 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
)
}

test("map_entries") {
val dummyFilter = (c: Column) => c.isNotNull || c.isNull

// Primitive-type elements
val idf = Seq(
Map[Int, Int](1 -> 100, 2 -> 200, 3 -> 300),
Map[Int, Int](),
null
).toDF("m")
val iExpected = Seq(
Row(Seq(Row(1, 100), Row(2, 200), Row(3, 300))),
Row(Seq.empty),
Row(null)
)

checkAnswer(idf.select(map_entries('m)), iExpected)
checkAnswer(idf.selectExpr("map_entries(m)"), iExpected)
checkAnswer(idf.filter(dummyFilter('m)).select(map_entries('m)), iExpected)
checkAnswer(
spark.range(1).selectExpr("map_entries(map(1, null, 2, null))"),
Seq(Row(Seq(Row(1, null), Row(2, null)))))
checkAnswer(
spark.range(1).filter(dummyFilter('id)).selectExpr("map_entries(map(1, null, 2, null))"),
Seq(Row(Seq(Row(1, null), Row(2, null)))))

// Non-primitive-type elements
val sdf = Seq(
Map[String, String]("a" -> "f", "b" -> "o", "c" -> "o"),
Map[String, String]("a" -> null, "b" -> null),
Map[String, String](),
null
).toDF("m")
val sExpected = Seq(
Row(Seq(Row("a", "f"), Row("b", "o"), Row("c", "o"))),
Row(Seq(Row("a", null), Row("b", null))),
Row(Seq.empty),
Row(null)
)

checkAnswer(sdf.select(map_entries('m)), sExpected)
checkAnswer(sdf.selectExpr("map_entries(m)"), sExpected)
checkAnswer(sdf.filter(dummyFilter('m)).select(map_entries('m)), sExpected)
}

test("array contains function") {
val df = Seq(
(Seq[Int](1, 2), "x"),
Expand Down