Skip to content
Closed
Next Next commit
Added Support for transform_keys function
  • Loading branch information
codeatri committed Aug 6, 2018
commit 0a19cc44bf694f76f8f1be8faeaa16dc47f9bb86
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ object FunctionRegistry {
expression[ArrayTransform]("transform"),
expression[ArrayFilter]("filter"),
expression[ArrayAggregate]("aggregate"),
expression[TransformKeys]("transform_keys"),
CreateStruct.registryEntry,

// misc functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -365,3 +365,69 @@ case class ArrayAggregate(

override def prettyName: String = "aggregate"
}

/**
* Transform Keys in a map using the transform_keys function.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a better comment?

*/
@ExpressionDescription(
usage = "_FUNC_(expr, func) - Transforms elements in a map using the function.",
examples = """
Examples:
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k,v) -> k + 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: missing space -> k, v

map(array(2, 3, 4), array(1, 2, 3))
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k, v) -> k + v);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

map(array(2, 4, 6), array(1, 2, 3))
""",
since = "2.4.0")
case class TransformKeys(
input: Expression,
function: Expression)
extends ArrayBasedHigherOrderFunction with CodegenFallback {

override def nullable: Boolean = input.nullable

override def dataType: DataType = {
val valueType = input.dataType.asInstanceOf[MapType].valueType
MapType(function.dataType, valueType, input.nullable)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here input.nullable is wrong. This should indicate whether the value contains null, not whether the returned object can be null or not.

}

override def inputTypes: Seq[AbstractDataType] = Seq(MapType, expectingFunctionType)

override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction):
TransformKeys = {
val (keyElementType, valueElementType, containsNull) = input.dataType match {
case MapType(keyType, valueType, containsNullValue) =>
(keyType, valueType, containsNullValue)
case _ =>
val MapType(keyType, valueType, containsNullValue) = MapType.defaultConcreteType
(keyType, valueType, containsNullValue)
}
copy(function = f(function, (keyElementType, false) :: (valueElementType, containsNull) :: Nil))
}

@transient lazy val (keyVar, valueVar) = {
val LambdaFunction(
_, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function
(keyVar, valueVar)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: how about:

@transient lazy val LambdaFunction(_,
  (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I meant we don't need to surround by:

@transient lazy val (keyVar, valueVar) = {
  ...
  (keyVar, valueVar)
}

just

@transient lazy val LambdaFunction(_,
  (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function

should work.


override def eval(input: InternalRow): Any = {
val arr = this.input.eval(input).asInstanceOf[MapData]
if (arr == null) {
null
} else {
val f = functionForEval
val resultKeys = new GenericArrayData(new Array[Any](arr.numElements))
var i = 0
while (i < arr.numElements) {
keyVar.value.set(arr.keyArray().get(i, keyVar.dataType))
valueVar.value.set(arr.valueArray().get(i, valueVar.dataType))
resultKeys.update(i, f.eval(input))
Copy link
Contributor

@hvanhovell hvanhovell Aug 7, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes that the transformation will return a unique key right? If it doesn't you'll break the map semantics. For example: transform_keys(some_map, (k, v) -> 0)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fun of duplicated keys either, but other functions transforming maps have the same problem. See the discussions here and here.

Example:

scala> spark.range(1).selectExpr("map(0,1,0,2)").show()
+----------------+
| map(0, 1, 0, 2)|
+----------------+
|[0 -> 1, 0 -> 2]|
+----------------+

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm missing something, but couldn't f.eval(input) be evaluated to null? Keys are not allowed to benull. Other functions have usually a null check and throw RuntimeException for such cases.

i += 1
}
new ArrayBasedMapData(resultKeys, arr.valueArray())
}
}

override def prettyName: String = "transform_keys"
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
ArrayFilter(expr, createLambda(at.elementType, at.containsNull, f))
}

def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
val valueType = expr.dataType.asInstanceOf[MapType].valueType
val keyType = expr.dataType.asInstanceOf[MapType].keyType
TransformKeys(expr, createLambda(keyType, false, valueType, true, f))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use valueContainsNull instead of true?

}

def aggregate(
expr: Expression,
zero: Expression,
Expand Down Expand Up @@ -181,4 +187,46 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
(acc, array) => coalesce(aggregate(array, acc, (acc, elem) => acc + elem), acc)),
15)
}

test("TransformKeys") {
val ai0 = Literal.create(
Map(1 -> 1, 2 -> 2, 3 -> 3),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's maybe irrelevant but WDYT about adding test cases with null values?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this!
Included test cases, both here and in DataFrameFunctionsSuite.

MapType(IntegerType, IntegerType))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add valueContainsNull explicitly?

val ai1 = Literal.create(
Map.empty[Int, Int],
MapType(IntegerType, IntegerType))

val plusOne: (Expression, Expression) => Expression = (k, v) => k + 1
val plusValue: (Expression, Expression) => Expression = (k, v) => k + v

checkEvaluation(transformKeys(ai0, plusOne), Map(2 -> 1, 3 -> 2, 4 -> 3))
checkEvaluation(transformKeys(ai0, plusValue), Map(2 -> 1, 4 -> 2, 6 -> 3))
checkEvaluation(
transformKeys(transformKeys(ai0, plusOne), plusValue), Map(3 -> 1, 5 -> 2, 7 -> 3))
checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int])
checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int])
checkEvaluation(
transformKeys(transformKeys(ai1, plusOne), plusValue), Map.empty[Int, Int])

val as0 = Literal.create(
Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), MapType(StringType, StringType))
val asn = Literal.create(Map.empty[StringType, StringType], MapType(StringType, StringType))

val concatValue: (Expression, Expression) => Expression = (k, v) => Concat(Seq(k, v))
val convertKeyAndConcatValue: (Expression, Expression) => Expression =
(k, v) => Length(k) + 1

checkEvaluation(
transformKeys(as0, concatValue), Map("axy" -> "xy", "bbyz" -> "yz", "ccczx" -> "zx"))
checkEvaluation(
transformKeys(transformKeys(as0, concatValue), concatValue),
Map("axyxy" -> "xy", "bbyzyz" -> "yz", "ccczxzx" -> "zx"))
checkEvaluation(transformKeys(asn, concatValue), Map.empty[String, String])
checkEvaluation(
transformKeys(transformKeys(asn, concatValue), convertKeyAndConcatValue),
Map.empty[Int, String])
checkEvaluation(transformKeys(as0, convertKeyAndConcatValue),
Map(2 -> "xy", 3 -> "yz", 4 -> "zx"))
checkEvaluation(transformKeys(asn, convertKeyAndConcatValue), Map.empty[Int, String])
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,17 @@ select transform(zs, z -> aggregate(z, 1, (acc, val) -> acc * val * size(z))) as

-- Aggregate a null array
select aggregate(cast(null as array<int>), 0, (a, y) -> a + y + 1, a -> a + 2) as v;

create or replace temporary view nested as values
(1, map(1,1,2,2,3,3)),
(2, map(4,4,5,5,6,6))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

  (1, map(1, 1, 2, 2, 3, 3)),
  (2, map(4, 4, 5, 5, 6, 6))

as t(x, ys);

-- Identity Transform Keys in a map
select transform_keys(ys, (k, v) -> k) as v from nested;

-- Transform Keys in a map by adding constant
select transform_keys(ys, (k, v) -> k + 1) as v from nested;

-- Transform Keys in a map using values
select transform_keys(ys, (k, v) -> k + v) as v from nested;
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 15
-- Number of queries: 19


-- !query 0
Expand Down Expand Up @@ -145,3 +145,40 @@ select aggregate(cast(null as array<int>), 0, (a, y) -> a + y + 1, a -> a + 2) a
struct<v:int>
-- !query 14 output
NULL


-- !query 15
create or replace temporary view nested as values
(1, map(1,1,2,2,3,3)),
(2, map(4,4,5,5,6,6))
as t(x, ys)
-- !query 15 schema
struct<>
-- !query 15 output


-- !query 16
select transform_keys(ys, (k, v) -> k) as v from nested
-- !query 16 schema
struct<v:map<int,int>>
-- !query 16 output
{1:1,2:2,3:3}
{4:4,5:5,6:6}


-- !query 17
select transform_keys(ys, (k, v) -> k + 1) as v from nested
-- !query 17 schema
struct<v:map<int,int>>
-- !query 17 output
{2:1,3:2,4:3}
{5:4,6:5,7:6}


-- !query 18
select transform_keys(ys, (k, v) -> k + v) as v from nested
-- !query 18 schema
struct<v:map<int,int>>
-- !query 18 output
{10:5,12:6,8:4}
{2:1,4:2,6:3}
Original file line number Diff line number Diff line change
Expand Up @@ -2071,6 +2071,158 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
assert(ex4.getMessage.contains("data type mismatch: argument 3 requires int type"))
}

test("transform keys function - test various primitive data types combinations") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need so many cases here. We only need to verify the api works end to end.
Evaluation checks of the function should be in HigherOrderFunctionsSuite.

val dfExample1 = Seq(
Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7)
).toDF("i")

val dfExample2 = Seq(
Map[Int, String](1 -> "a", 2 -> "b", 3 -> "c")
).toDF("x")

val dfExample3 = Seq(
Map[String, Int]("a" -> 1, "b" -> 2, "c" -> 3)
).toDF("y")

val dfExample4 = Seq(
Map[Int, Double](1 -> 1.0E0, 2 -> 1.4E0, 3 -> 1.7E0)
).toDF("z")

val dfExample5 = Seq(
Map[Int, Boolean](25 -> true, 26 -> false)
).toDF("a")

val dfExample6 = Seq(
Map[Int, String](25 -> "ab", 26 -> "cd")
).toDF("b")

val dfExample7 = Seq(
Map[Array[Int], Boolean](Array(1, 2) -> false)
).toDF("c")


def testMapOfPrimitiveTypesCombination(): Unit = {
checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> k + v)"),
Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7))))

checkAnswer(dfExample2.selectExpr("transform_keys(x, (k, v) -> k + 1)"),
Seq(Row(Map(2 -> "a", 3 -> "b", 4 -> "c"))))

checkAnswer(dfExample3.selectExpr("transform_keys(y, (k, v) -> v * v)"),
Seq(Row(Map(1 -> 1, 4 -> 2, 9 -> 3))))

checkAnswer(dfExample3.selectExpr("transform_keys(y, (k, v) -> length(k) + v)"),
Seq(Row(Map(2 -> 1, 3 -> 2, 4 -> 3))))

checkAnswer(
dfExample3.selectExpr("transform_keys(y, (k, v) -> concat(k, cast(v as String)))"),
Seq(Row(Map("a1" -> 1, "b2" -> 2, "c3" -> 3))))

checkAnswer(dfExample4.selectExpr("transform_keys(z, " +
"(k, v) -> map_from_arrays(ARRAY(1, 2, 3), ARRAY('one', 'two', 'three'))[k])"),
Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7))))

checkAnswer(dfExample4.selectExpr("transform_keys(z, (k, v) -> CAST(v * 2 AS BIGINT) + k)"),
Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7))))

checkAnswer(dfExample4.selectExpr("transform_keys(z, (k, v) -> k + v)"),
Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7))))

checkAnswer(dfExample5.selectExpr("transform_keys(a, (k, v) -> k % 2 = 0 OR v)"),
Seq(Row(Map(true -> true, true -> false))))

checkAnswer(dfExample5.selectExpr("transform_keys(a, (k, v) -> if(v, 2 * k, 3 * k))"),
Seq(Row(Map(50 -> true, 78 -> false))))

checkAnswer(dfExample5.selectExpr("transform_keys(a, (k, v) -> if(v, 2 * k, 3 * k))"),
Seq(Row(Map(50 -> true, 78 -> false))))

checkAnswer(dfExample6.selectExpr(
"transform_keys(b, (k, v) -> concat(conv(k, 10, 16) , substr(v, 1, 1)))"),
Seq(Row(Map("19a" -> "ab", "1Ac" -> "cd"))))

checkAnswer(dfExample7.selectExpr("transform_keys(c, (k, v) -> array_contains(k, 3) AND v)"),
Seq(Row(Map(false -> false))))
}
// Test with local relation, the Project will be evaluated without codegen
testMapOfPrimitiveTypesCombination()
dfExample1.cache()
dfExample2.cache()
dfExample3.cache()
dfExample4.cache()
dfExample5.cache()
dfExample6.cache()
// Test with cached relation, the Project will be evaluated with codegen
testMapOfPrimitiveTypesCombination()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have do that if the expression implements CodegenFallback?

}

test("transform keys function - test empty") {
val dfExample1 = Seq(
Map.empty[Int, Int]
).toDF("i")

val dfExample2 = Seq(
Map.empty[BigInt, String]
).toDF("j")

def testEmpty(): Unit = {
checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> NULL)"),
Seq(Row(Map.empty[Null, Null])))

checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> k)"),
Seq(Row(Map.empty[Null, Null])))

checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> v)"),
Seq(Row(Map.empty[Null, Null])))

checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> 0)"),
Seq(Row(Map.empty[Int, Null])))

checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> 'key')"),
Seq(Row(Map.empty[String, Null])))

checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> true)"),
Seq(Row(Map.empty[Boolean, Null])))

checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> k + cast(v as BIGINT))"),
Seq(Row(Map.empty[BigInt, Null])))

checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> v)"),
Seq(Row(Map())))
}
testEmpty()
dfExample1.cache()
dfExample2.cache()
testEmpty()
}

test("transform keys function - Invalid lambda functions") {
val dfExample1 = Seq(
Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7)
).toDF("i")

val dfExample2 = Seq(
Map[String, String]("a" -> "b")
).toDF("j")

def testInvalidLambdaFunctions(): Unit = {
val ex1 = intercept[AnalysisException] {
dfExample1.selectExpr("transform_keys(i, k -> k )")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: extra space after k -> k.

}
assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match"))

val ex2 = intercept[AnalysisException] {
dfExample2.selectExpr("transform_keys(j, (k, v, x) -> k + 1)")
}
assert(ex2.getMessage.contains("The number of lambda function arguments '3' does not match"))
}

testInvalidLambdaFunctions()
dfExample1.cache()
dfExample2.cache()
testInvalidLambdaFunctions()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need dfExample3.cache() as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ueshin I would like to ask you a generic question regarding higher-order functions. Is it necessary to perform checks with codegen paths if all the newly added functions extends from CodegenFallback? Eventually, is there a plan to add coden for these functions in future?

}

private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
import DataFrameFunctionsSuite.CodegenFallbackExpr
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
Expand Down