Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1819,6 +1819,25 @@ def create_map(*cols):
return Column(jc)


@since(2.4)
def map_from_arrays(col1, col2):
"""Creates a new map from two arrays.

:param col1: name of column containing a set of keys. All elements should not be null
:param col2: name of column containing a set of values

>>> df = spark.createDataFrame([([2, 5], ['a', 'b'])], ['k', 'v'])
>>> df.select(map_from_arrays(df.k, df.v).alias("map")).show()
+----------------+
| map|
+----------------+
|[2 -> a, 5 -> b]|
+----------------+
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.map_from_arrays(_to_java_column(col1), _to_java_column(col2)))


@since(1.4)
def array(*cols):
"""Creates a new array column.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ object FunctionRegistry {
expression[CreateMap]("map"),
expression[CreateNamedStruct]("named_struct"),
expression[ElementAt]("element_at"),
expression[MapFromArrays]("map_from_arrays"),
expression[MapKeys]("map_keys"),
expression[MapValues]("map_values"),
expression[MapEntries]("map_entries"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.array.ByteArrayMethods
Expand Down Expand Up @@ -236,6 +236,76 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
override def prettyName: String = "map"
}

/**
* Returns a catalyst Map containing the two arrays in children expressions as keys and values.
*/
@ExpressionDescription(
usage = """
_FUNC_(keys, values) - Creates a map with a pair of the given key/value arrays. All elements
in keys should not be null""",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and duplicated.

examples = """
Examples:
> SELECT _FUNC_([1.0, 3.0], ['2', '4']);
{1.0:"2",3.0:"4"}
""", since = "2.4.0")
case class MapFromArrays(left: Expression, right: Expression)
extends BinaryExpression with ExpectsInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType)

override def dataType: DataType = {
MapType(
keyType = left.dataType.asInstanceOf[ArrayType].elementType,
valueType = right.dataType.asInstanceOf[ArrayType].elementType,
valueContainsNull = right.dataType.asInstanceOf[ArrayType].containsNull)
}

override def nullSafeEval(keyArray: Any, valueArray: Any): Any = {
val keyArrayData = keyArray.asInstanceOf[ArrayData]
val valueArrayData = valueArray.asInstanceOf[ArrayData]
if (keyArrayData.numElements != valueArrayData.numElements) {
throw new RuntimeException("The given two arrays should have the same length")
}
val leftArrayType = left.dataType.asInstanceOf[ArrayType]
if (leftArrayType.containsNull) {
var i = 0
while (i < keyArrayData.numElements) {
if (keyArrayData.isNullAt(i)) {
throw new RuntimeException("Cannot use null as map key!")
}
i += 1
}
}
new ArrayBasedMapData(keyArrayData.copy(), valueArrayData.copy())
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (keyArrayData, valueArrayData) => {
val arrayBasedMapData = classOf[ArrayBasedMapData].getName
val leftArrayType = left.dataType.asInstanceOf[ArrayType]
val keyArrayElemNullCheck = if (!leftArrayType.containsNull) "" else {
val i = ctx.freshName("i")
s"""
|for (int $i = 0; $i < $keyArrayData.numElements(); $i++) {
| if ($keyArrayData.isNullAt($i)) {
| throw new RuntimeException("Cannot use null as map key!");
| }
|}
""".stripMargin
}
s"""
|if ($keyArrayData.numElements() != $valueArrayData.numElements()) {
| throw new RuntimeException("The given two arrays should have the same length");
|}
|$keyArrayElemNullCheck
|${ev.value} = new $arrayBasedMapData($keyArrayData.copy(), $valueArrayData.copy());
""".stripMargin
})
}

override def prettyName: String = "map_from_arrays"
}

/**
* An expression representing a not yet available attribute name. This expression is unevaluable
* and as its name suggests it is a temporary place holder until we're able to determine the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,50 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}

test("MapFromArrays") {
def createMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = {
// catalyst map is order-sensitive, so we create ListMap here to preserve the elements order.
scala.collection.immutable.ListMap(keys.zip(values): _*)
}

val intSeq = Seq(5, 10, 15, 20, 25)
val longSeq = intSeq.map(_.toLong)
val strSeq = intSeq.map(_.toString)
val integerSeq = Seq[java.lang.Integer](5, 10, 15, 20, 25)
val intWithNullSeq = Seq[java.lang.Integer](5, 10, null, 20, 25)
val longWithNullSeq = intSeq.map(java.lang.Long.valueOf(_))

val intArray = Literal.create(intSeq, ArrayType(IntegerType, false))
val longArray = Literal.create(longSeq, ArrayType(LongType, false))
val strArray = Literal.create(strSeq, ArrayType(StringType, false))

val integerArray = Literal.create(integerSeq, ArrayType(IntegerType, true))
val intWithNullArray = Literal.create(intWithNullSeq, ArrayType(IntegerType, true))
val longWithNullArray = Literal.create(longWithNullSeq, ArrayType(LongType, true))

val nullArray = Literal.create(null, ArrayType(StringType, false))

checkEvaluation(MapFromArrays(intArray, longArray), createMap(intSeq, longSeq))
checkEvaluation(MapFromArrays(intArray, strArray), createMap(intSeq, strSeq))
checkEvaluation(MapFromArrays(integerArray, strArray), createMap(integerSeq, strSeq))

checkEvaluation(
MapFromArrays(strArray, intWithNullArray), createMap(strSeq, intWithNullSeq))
checkEvaluation(
MapFromArrays(strArray, longWithNullArray), createMap(strSeq, longWithNullSeq))
checkEvaluation(
MapFromArrays(strArray, longWithNullArray), createMap(strSeq, longWithNullSeq))
checkEvaluation(MapFromArrays(nullArray, nullArray), null)

intercept[RuntimeException] {
checkEvaluation(MapFromArrays(intWithNullArray, strArray), null)
}
intercept[RuntimeException] {
checkEvaluation(
MapFromArrays(intArray, Literal.create(Seq(1), ArrayType(IntegerType))), null)
}
}

test("CreateStruct") {
val row = create_row(1, 2, 3)
val c1 = 'a.int.at(0)
Expand Down
11 changes: 11 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,17 @@ object functions {
@scala.annotation.varargs
def map(cols: Column*): Column = withExpr { CreateMap(cols.map(_.expr)) }

/**
* Creates a new map column. The array in the first column is used for keys. The array in the
* second column is used for values. All elements in the array for key should not be null.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and duplicated

*
* @group normal_funcs
* @since 2.4
*/
def map_from_arrays(keys: Column, values: Column): Column = withExpr {
MapFromArrays(keys.expr, values.expr)
}

/**
* Marks a DataFrame as small enough for use in broadcast joins.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,36 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
assert(row.getMap[Int, String](0) === Map(2 -> "a"))
}

test("map with arrays") {
val df1 = Seq((Seq(1, 2), Seq("a", "b"))).toDF("k", "v")
val expectedType = MapType(IntegerType, StringType, valueContainsNull = true)
val row = df1.select(map_from_arrays($"k", $"v")).first()
assert(row.schema(0).dataType === expectedType)
assert(row.getMap[Int, String](0) === Map(1 -> "a", 2 -> "b"))
checkAnswer(df1.select(map_from_arrays($"k", $"v")), Seq(Row(Map(1 -> "a", 2 -> "b"))))

val df2 = Seq((Seq(1, 2), Seq(null, "b"))).toDF("k", "v")
checkAnswer(df2.select(map_from_arrays($"k", $"v")), Seq(Row(Map(1 -> null, 2 -> "b"))))

val df3 = Seq((null, null)).toDF("k", "v")
checkAnswer(df3.select(map_from_arrays($"k", $"v")), Seq(Row(null)))

val df4 = Seq((1, "a")).toDF("k", "v")
intercept[AnalysisException] {
df4.select(map_from_arrays($"k", $"v"))
}

val df5 = Seq((Seq("a", null), Seq(1, 2))).toDF("k", "v")
intercept[RuntimeException] {
df5.select(map_from_arrays($"k", $"v")).collect
}

val df6 = Seq((Seq(1, 2), Seq("a"))).toDF("k", "v")
intercept[RuntimeException] {
df6.select(map_from_arrays($"k", $"v")).collect
}
}

test("struct with column name") {
val df = Seq((1, "str")).toDF("a", "b")
val row = df.select(struct("a", "b")).first()
Expand Down