Skip to content
Closed
Prev Previous commit
Next Next commit
address review comments
  • Loading branch information
kiszk committed Jun 8, 2018
commit 228fcc66e2b85b957833da739a20229867d51cbc
4 changes: 2 additions & 2 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1821,14 +1821,14 @@ def create_map(*cols):

@ignore_unicode_prefix
@since(2.4)
def create_map_from_arrays(col1, col2):
def map_from_arrays(col1, col2):
"""Creates a new map from two arrays.

:param col1: name of column containing a set of keys. All elements should not be null
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and duplicated?

:param col2: name of column containing a set of values

>>> df = spark.createDataFrame([([2, 5], ["Alice", "Bob"])], ['k', 'v'])
>>> df.select(create_map_from_arrays(df.k, df.v).alias("map")).collect()
>>> df.select(map_from_arrays(df.k, df.v).alias("map")).collect()
[Row(map={2: u'Alice', 5: u'Bob'})]
"""
sc = SparkContext._active_spark_context
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -415,9 +415,9 @@ object FunctionRegistry {
expression[ArrayPosition]("array_position"),
expression[ArraySort]("array_sort"),
expression[CreateMap]("map"),
expression[CreateMapFromArrays]("map_from_arrays"),
expression[CreateNamedStruct]("named_struct"),
expression[ElementAt]("element_at"),
expression[MapFromArrays]("map_from_arrays"),
expression[MapKeys]("map_keys"),
expression[MapValues]("map_values"),
expression[MapEntries]("map_entries"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,29 +248,18 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
> SELECT _FUNC_([1.0, 3.0], ['2', '4']);
{1.0:"2",3.0:"4"}
""", since = "2.4.0")
case class CreateMapFromArrays(left: Expression, right: Expression)
case class MapFromArrays(left: Expression, right: Expression)
extends BinaryExpression with ExpectsInputTypes {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: indent


override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType)

override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
case (ArrayType(_, _), ArrayType(_, _)) =>
TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure("The given two arguments should be an array")
}
}

override def dataType: DataType = {
MapType(
keyType = left.dataType.asInstanceOf[ArrayType].elementType,
valueType = right.dataType.asInstanceOf[ArrayType].elementType,
valueContainsNull = right.dataType.asInstanceOf[ArrayType].containsNull)
}

override def nullable: Boolean = left.nullable || right.nullable

override def nullSafeEval(keyArray: Any, valueArray: Any): Any = {
val keyArrayData = keyArray.asInstanceOf[ArrayData]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't you detect duplicities first?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please let us know where this specification is described or is derived from? It is not written here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although it's not specified, duplicated key can lead to non-determinism of returned values in future. Currently, GetMapValueUtil.getValueEval returns a value for the first key in the map, but there is TODO to change O(n) algorithm. So I'm wondering how it would behave if some hashing was introduced.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. According to current Spark implementation, for example, CreateMap allows us to have duplicated key.
It would be good to discuss such a behavior change in another PR. WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we don't have to change it now. But I would like to agree on a consistent approach for the new functions, since this is also related to SPARK-23934 and SPARK-23936.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to err on the safe side here. CreateMap should be fixed IMO.

val valueArrayData = valueArray.asInstanceOf[ArrayData]
Expand All @@ -279,8 +268,12 @@ case class CreateMapFromArrays(left: Expression, right: Expression)
}
val leftArrayType = left.dataType.asInstanceOf[ArrayType]
if (leftArrayType.containsNull) {
if (keyArrayData.toArray(leftArrayType.elementType).contains(null)) {
throw new RuntimeException("Cannot use null as map key!")
var i = 0
while (i < keyArrayData.numElements) {
if (keyArrayData.isNullAt(i)) {
throw new RuntimeException("Cannot use null as map key!")
}
i += 1
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use loop to null-check without converting to object array?

}
new ArrayBasedMapData(keyArrayData.copy(), valueArrayData.copy())
Expand All @@ -291,13 +284,10 @@ case class CreateMapFromArrays(left: Expression, right: Expression)
val arrayBasedMapData = classOf[ArrayBasedMapData].getName
val leftArrayType = left.dataType.asInstanceOf[ArrayType]
val keyArrayElemNullCheck = if (!leftArrayType.containsNull) "" else {
val leftArrayTypeTerm = ctx.addReferenceObj("leftArrayType", leftArrayType.elementType)
val array = ctx.freshName("array")
val i = ctx.freshName("i")
s"""
|Object[] $array = $keyArrayData.toObjectArray($leftArrayTypeTerm);
|for (int $i = 0; $i < $array.length; $i++) {
| if ($array[$i] == null) {
|for (int $i = 0; $i < $keyArrayData.numElements(); $i++) {
| if ($keyArrayData.isNullAt($i)) {
| throw new RuntimeException("Cannot use null as map key!");
| }
|}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can null-check without converting to object array.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, thanks

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, I realized we have to evaluate each element as CreateMap does. I think that we have to update eval and codegen.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry, but I couldn't get it. I might miss something, but I thought we can simply do like:

for (int $i = 0; $i < $keyArrayData.numElements(); $i++) {
  if ($keyArrayData.isNullAt($i)) {
    throw new RuntimeException("Cannot use null as map key!");
  }
}

Doesn't this work?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code should work if we evaluate each element to make isNullAt() valid.

I think that my mistake is not to currently evaluate each element in keyArrayData and valueArrayData.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. An array has been evaluated.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}

test("CreateMapFromArrays") {
test("MapFromArrays") {
def createMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = {
// catalyst map is order-sensitive, so we create ListMap here to preserve the elements order.
scala.collection.immutable.ListMap(keys.zip(values): _*)
Expand All @@ -209,24 +209,24 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {

val nullArray = Literal.create(null, ArrayType(StringType, false))

checkEvaluation(CreateMapFromArrays(intArray, longArray), createMap(intSeq, longSeq))
checkEvaluation(CreateMapFromArrays(intArray, strArray), createMap(intSeq, strSeq))
checkEvaluation(CreateMapFromArrays(integerArray, strArray), createMap(integerSeq, strSeq))
checkEvaluation(MapFromArrays(intArray, longArray), createMap(intSeq, longSeq))
checkEvaluation(MapFromArrays(intArray, strArray), createMap(intSeq, strSeq))
checkEvaluation(MapFromArrays(integerArray, strArray), createMap(integerSeq, strSeq))

checkEvaluation(
CreateMapFromArrays(strArray, intwithNullArray), createMap(strSeq, intWithNullSeq))
MapFromArrays(strArray, intwithNullArray), createMap(strSeq, intWithNullSeq))
checkEvaluation(
CreateMapFromArrays(strArray, longwithNullArray), createMap(strSeq, longWithNullSeq))
MapFromArrays(strArray, longwithNullArray), createMap(strSeq, longWithNullSeq))
checkEvaluation(
CreateMapFromArrays(strArray, longwithNullArray), createMap(strSeq, longWithNullSeq))
checkEvaluation(CreateMapFromArrays(nullArray, nullArray), null)
MapFromArrays(strArray, longwithNullArray), createMap(strSeq, longWithNullSeq))
checkEvaluation(MapFromArrays(nullArray, nullArray), null)

intercept[RuntimeException] {
checkEvaluation(CreateMapFromArrays(intwithNullArray, strArray), null)
checkEvaluation(MapFromArrays(intwithNullArray, strArray), null)
}
intercept[RuntimeException] {
checkEvaluation(
CreateMapFromArrays(intArray, Literal.create(Seq(1), ArrayType(IntegerType))), null)
MapFromArrays(intArray, Literal.create(Seq(1), ArrayType(IntegerType))), null)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1078,7 +1078,7 @@ object functions {
* @since 2.4
*/
def map_from_arrays(keys: Column, values: Column): Column = withExpr {
CreateMapFromArrays(keys.expr, values.expr)
MapFromArrays(keys.expr, values.expr)
}

/**
Expand Down