Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ object FunctionRegistry {
expression[MapFilter]("map_filter"),
expression[ArrayFilter]("filter"),
expression[ArrayAggregate]("aggregate"),
expression[TransformKeys]("transform_keys"),
CreateStruct.registryEntry,

// misc functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,3 +442,65 @@ case class ArrayAggregate(

override def prettyName: String = "aggregate"
}

/**
* Transform Keys for every entry of the map by applying the transform_keys function.
* Returns map with transformed key entries
*/
@ExpressionDescription(
usage = "_FUNC_(expr, func) - Transforms elements in a map using the function.",
examples = """
Examples:
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k, v) -> k + 1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we need one more right parenthesis after the second array(1, 2, 3)?

map(array(2, 3, 4), array(1, 2, 3))
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k, v) -> k + v);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

map(array(2, 4, 6), array(1, 2, 3))
""",
since = "2.4.0")
case class TransformKeys(
input: Expression,
function: Expression)
extends MapBasedSimpleHigherOrderFunction with CodegenFallback {

override def nullable: Boolean = input.nullable

override def dataType: DataType = {
val map = input.dataType.asInstanceOf[MapType]
MapType(function.dataType, map.valueType, map.valueContainsNull)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use valueType and valueContainsNull from the following val?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about this?

}

override def inputTypes: Seq[AbstractDataType] = Seq(MapType, expectingFunctionType)

@transient val (keyType, valueType, valueContainsNull) =
HigherOrderFunction.mapKeyValueArgumentType(input.dataType)

override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): TransformKeys = {
copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil))
}

@transient lazy val (keyVar, valueVar) = {
val LambdaFunction(
_, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function
(keyVar, valueVar)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: how about:

@transient lazy val LambdaFunction(_,
  (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I meant we don't need to surround by:

@transient lazy val (keyVar, valueVar) = {
  ...
  (keyVar, valueVar)
}

just

@transient lazy val LambdaFunction(_,
  (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function

should work.


override def nullSafeEval(inputRow: InternalRow, value: Any): Any = {
val map = value.asInstanceOf[MapData]
val f = functionForEval
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we use functionForEval directly?

val resultKeys = new GenericArrayData(new Array[Any](map.numElements))
var i = 0
while (i < map.numElements) {
keyVar.value.set(map.keyArray().get(i, keyVar.dataType))
valueVar.value.set(map.valueArray().get(i, valueVar.dataType))
val result = f.eval(inputRow)
if (result == null) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: extra space between == and null.

throw new RuntimeException("Cannot use null as map key!")
}
resultKeys.update(i, result)
i += 1
}
new ArrayBasedMapData(resultKeys, map.valueArray())
}

override def prettyName: String = "transform_keys"
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
ArrayFilter(expr, createLambda(at.elementType, at.containsNull, f))
}

def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
val valueType = expr.dataType.asInstanceOf[MapType].valueType
val keyType = expr.dataType.asInstanceOf[MapType].keyType
TransformKeys(expr, createLambda(keyType, false, valueType, true, f))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use valueContainsNull instead of true?

}

def aggregate(
expr: Expression,
zero: Expression,
Expand Down Expand Up @@ -230,4 +236,56 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
(acc, array) => coalesce(aggregate(array, acc, (acc, elem) => acc + elem), acc)),
15)
}

test("TransformKeys") {
val ai0 = Literal.create(
Map(1 -> 1, 2 -> 2, 3 -> 3),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's maybe irrelevant but WDYT about adding test cases with null values?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this!
Included test cases, both here and in DataFrameFunctionsSuite.

MapType(IntegerType, IntegerType))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add valueContainsNull explicitly?

val ai1 = Literal.create(
Map.empty[Int, Int],
MapType(IntegerType, IntegerType))
val ai2 = Literal.create(
Map(1 -> 1, 2 -> null, 3 -> 3),
MapType(IntegerType, IntegerType))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add tests for Literal.create(null, MapType(IntegerType, IntegerType))?


val plusOne: (Expression, Expression) => Expression = (k, v) => k + 1
val plusValue: (Expression, Expression) => Expression = (k, v) => k + v

checkEvaluation(transformKeys(ai0, plusOne), Map(2 -> 1, 3 -> 2, 4 -> 3))
checkEvaluation(transformKeys(ai0, plusValue), Map(2 -> 1, 4 -> 2, 6 -> 3))
checkEvaluation(
transformKeys(transformKeys(ai0, plusOne), plusValue), Map(3 -> 1, 5 -> 2, 7 -> 3))
checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int])
checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int])
checkEvaluation(
transformKeys(transformKeys(ai1, plusOne), plusValue), Map.empty[Int, Int])
checkEvaluation(transformKeys(ai2, plusOne), Map(2 -> 1, 3 -> null, 4 -> 3))
checkEvaluation(
transformKeys(transformKeys(ai2, plusOne), plusOne), Map(3 -> 1, 4 -> null, 5 -> 3))

val as0 = Literal.create(
Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), MapType(StringType, StringType))
val as1 = Literal.create(
Map("a" -> "xy", "bb" -> "yz", "ccc" -> null), MapType(StringType, StringType))
val asn = Literal.create(Map.empty[StringType, StringType], MapType(StringType, StringType))

val concatValue: (Expression, Expression) => Expression = (k, v) => Concat(Seq(k, v))
val convertKeyToKeyLength: (Expression, Expression) => Expression =
(k, v) => Length(k) + 1

checkEvaluation(
transformKeys(as0, concatValue), Map("axy" -> "xy", "bbyz" -> "yz", "ccczx" -> "zx"))
checkEvaluation(
transformKeys(transformKeys(as0, concatValue), concatValue),
Map("axyxy" -> "xy", "bbyzyz" -> "yz", "ccczxzx" -> "zx"))
checkEvaluation(transformKeys(asn, concatValue), Map.empty[String, String])
checkEvaluation(
transformKeys(transformKeys(asn, concatValue), convertKeyToKeyLength),
Map.empty[Int, String])
checkEvaluation(transformKeys(as0, convertKeyToKeyLength),
Map(2 -> "xy", 3 -> "yz", 4 -> "zx"))
checkEvaluation(transformKeys(as1, convertKeyToKeyLength),
Map(2 -> "xy", 3 -> "yz", 4 -> null))
checkEvaluation(transformKeys(asn, convertKeyToKeyLength), Map.empty[Int, String])
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,17 @@ select transform(zs, z -> aggregate(z, 1, (acc, val) -> acc * val * size(z))) as

-- Aggregate a null array
select aggregate(cast(null as array<int>), 0, (a, y) -> a + y + 1, a -> a + 2) as v;

create or replace temporary view nested as values
(1, map(1,1,2,2,3,3)),
(2, map(4,4,5,5,6,6))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

  (1, map(1, 1, 2, 2, 3, 3)),
  (2, map(4, 4, 5, 5, 6, 6))

as t(x, ys);

-- Identity Transform Keys in a map
select transform_keys(ys, (k, v) -> k) as v from nested;

-- Transform Keys in a map by adding constant
select transform_keys(ys, (k, v) -> k + 1) as v from nested;

-- Transform Keys in a map using values
select transform_keys(ys, (k, v) -> k + v) as v from nested;
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 15
-- Number of queries: 19


-- !query 0
Expand Down Expand Up @@ -145,3 +145,40 @@ select aggregate(cast(null as array<int>), 0, (a, y) -> a + y + 1, a -> a + 2) a
struct<v:int>
-- !query 14 output
NULL


-- !query 15
create or replace temporary view nested as values
(1, map(1,1,2,2,3,3)),
(2, map(4,4,5,5,6,6))
as t(x, ys)
-- !query 15 schema
struct<>
-- !query 15 output


-- !query 16
select transform_keys(ys, (k, v) -> k) as v from nested
-- !query 16 schema
struct<v:map<int,int>>
-- !query 16 output
{1:1,2:2,3:3}
{4:4,5:5,6:6}


-- !query 17
select transform_keys(ys, (k, v) -> k + 1) as v from nested
-- !query 17 schema
struct<v:map<int,int>>
-- !query 17 output
{2:1,3:2,4:3}
{5:4,6:5,7:6}


-- !query 18
select transform_keys(ys, (k, v) -> k + v) as v from nested
-- !query 18 schema
struct<v:map<int,int>>
-- !query 18 output
{10:5,12:6,8:4}
{2:1,4:2,6:3}
Loading