From 0a19cc44bf694f76f8f1be8faeaa16dc47f9bb86 Mon Sep 17 00:00:00 2001 From: codeatri Date: Mon, 6 Aug 2018 11:32:47 -0700 Subject: [PATCH 01/13] Added Support for transform_keys function --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/higherOrderFunctions.scala | 68 +++++++- .../HigherOrderFunctionsSuite.scala | 48 ++++++ .../inputs/higher-order-functions.sql | 14 ++ .../results/higher-order-functions.sql.out | 39 ++++- .../spark/sql/DataFrameFunctionsSuite.scala | 152 ++++++++++++++++++ 6 files changed, 320 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ed2f67da6f2b..2d3b7fedc10b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -444,6 +444,7 @@ object FunctionRegistry { expression[ArrayTransform]("transform"), expression[ArrayFilter]("filter"), expression[ArrayAggregate]("aggregate"), + expression[TransformKeys]("transform_keys"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 20c7f7d43b9d..ee98089233bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ /** @@ -365,3 +365,69 @@ case class ArrayAggregate( override def prettyName: String = "aggregate" } + +/** + * Transform Keys in a map using the transform_keys function. + */ +@ExpressionDescription( + usage = "_FUNC_(expr, func) - Transforms elements in a map using the function.", + examples = """ + Examples: + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k,v) -> k + 1); + map(array(2, 3, 4), array(1, 2, 3)) + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k, v) -> k + v); + map(array(2, 4, 6), array(1, 2, 3)) + """, + since = "2.4.0") +case class TransformKeys( + input: Expression, + function: Expression) + extends ArrayBasedHigherOrderFunction with CodegenFallback { + + override def nullable: Boolean = input.nullable + + override def dataType: DataType = { + val valueType = input.dataType.asInstanceOf[MapType].valueType + MapType(function.dataType, valueType, input.nullable) + } + + override def inputTypes: Seq[AbstractDataType] = Seq(MapType, expectingFunctionType) + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): + TransformKeys = { + val (keyElementType, valueElementType, containsNull) = input.dataType match { + case MapType(keyType, valueType, containsNullValue) => + (keyType, valueType, containsNullValue) + case _ => + val MapType(keyType, valueType, containsNullValue) = MapType.defaultConcreteType + (keyType, valueType, containsNullValue) + } + copy(function = f(function, (keyElementType, false) :: (valueElementType, containsNull) :: Nil)) + } + + @transient lazy val (keyVar, valueVar) = { + val LambdaFunction( + _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function + (keyVar, valueVar) + } + + override def eval(input: InternalRow): Any = { + val arr = this.input.eval(input).asInstanceOf[MapData] + if (arr == null) { + null + } else { + val f = functionForEval + val resultKeys = new GenericArrayData(new Array[Any](arr.numElements)) + var i = 0 + while (i < arr.numElements) { + keyVar.value.set(arr.keyArray().get(i, keyVar.dataType)) + valueVar.value.set(arr.valueArray().get(i, valueVar.dataType)) + resultKeys.update(i, f.eval(input)) + i += 1 + } + new ArrayBasedMapData(resultKeys, arr.valueArray()) + } + } + + override def prettyName: String = "transform_keys" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 40cfc0ccc7c0..a8d3edd9746d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -59,6 +59,12 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper ArrayFilter(expr, createLambda(at.elementType, at.containsNull, f)) } + def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val valueType = expr.dataType.asInstanceOf[MapType].valueType + val keyType = expr.dataType.asInstanceOf[MapType].keyType + TransformKeys(expr, createLambda(keyType, false, valueType, true, f)) + } + def aggregate( expr: Expression, zero: Expression, @@ -181,4 +187,46 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper (acc, array) => coalesce(aggregate(array, acc, (acc, elem) => acc + elem), acc)), 15) } + + test("TransformKeys") { + val ai0 = Literal.create( + Map(1 -> 1, 2 -> 2, 3 -> 3), + MapType(IntegerType, IntegerType)) + val ai1 = Literal.create( + Map.empty[Int, Int], + MapType(IntegerType, IntegerType)) + + val plusOne: (Expression, Expression) => Expression = (k, v) => k + 1 + val plusValue: (Expression, Expression) => Expression = (k, v) => k + v + + checkEvaluation(transformKeys(ai0, plusOne), Map(2 -> 1, 3 -> 2, 4 -> 3)) + checkEvaluation(transformKeys(ai0, plusValue), Map(2 -> 1, 4 -> 2, 6 -> 3)) + checkEvaluation( + transformKeys(transformKeys(ai0, plusOne), plusValue), Map(3 -> 1, 5 -> 2, 7 -> 3)) + checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int]) + checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int]) + checkEvaluation( + transformKeys(transformKeys(ai1, plusOne), plusValue), Map.empty[Int, Int]) + + val as0 = Literal.create( + Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), MapType(StringType, StringType)) + val asn = Literal.create(Map.empty[StringType, StringType], MapType(StringType, StringType)) + + val concatValue: (Expression, Expression) => Expression = (k, v) => Concat(Seq(k, v)) + val convertKeyAndConcatValue: (Expression, Expression) => Expression = + (k, v) => Length(k) + 1 + + checkEvaluation( + transformKeys(as0, concatValue), Map("axy" -> "xy", "bbyz" -> "yz", "ccczx" -> "zx")) + checkEvaluation( + transformKeys(transformKeys(as0, concatValue), concatValue), + Map("axyxy" -> "xy", "bbyzyz" -> "yz", "ccczxzx" -> "zx")) + checkEvaluation(transformKeys(asn, concatValue), Map.empty[String, String]) + checkEvaluation( + transformKeys(transformKeys(asn, concatValue), convertKeyAndConcatValue), + Map.empty[Int, String]) + checkEvaluation(transformKeys(as0, convertKeyAndConcatValue), + Map(2 -> "xy", 3 -> "yz", 4 -> "zx")) + checkEvaluation(transformKeys(asn, convertKeyAndConcatValue), Map.empty[Int, String]) + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql index 136396d9553d..16e75e2565fc 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql @@ -45,3 +45,17 @@ select transform(zs, z -> aggregate(z, 1, (acc, val) -> acc * val * size(z))) as -- Aggregate a null array select aggregate(cast(null as array), 0, (a, y) -> a + y + 1, a -> a + 2) as v; + +create or replace temporary view nested as values + (1, map(1,1,2,2,3,3)), + (2, map(4,4,5,5,6,6)) + as t(x, ys); + +-- Identity Transform Keys in a map +select transform_keys(ys, (k, v) -> k) as v from nested; + +-- Transform Keys in a map by adding constant +select transform_keys(ys, (k, v) -> k + 1) as v from nested; + +-- Transform Keys in a map using values +select transform_keys(ys, (k, v) -> k + v) as v from nested; diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out index e6f62f2e1bb6..23628a9ac549 100644 --- a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 15 +-- Number of queries: 19 -- !query 0 @@ -145,3 +145,40 @@ select aggregate(cast(null as array), 0, (a, y) -> a + y + 1, a -> a + 2) a struct -- !query 14 output NULL + + +-- !query 15 +create or replace temporary view nested as values + (1, map(1,1,2,2,3,3)), + (2, map(4,4,5,5,6,6)) + as t(x, ys) +-- !query 15 schema +struct<> +-- !query 15 output + + +-- !query 16 +select transform_keys(ys, (k, v) -> k) as v from nested +-- !query 16 schema +struct> +-- !query 16 output +{1:1,2:2,3:3} +{4:4,5:5,6:6} + + +-- !query 17 +select transform_keys(ys, (k, v) -> k + 1) as v from nested +-- !query 17 schema +struct> +-- !query 17 output +{2:1,3:2,4:3} +{5:4,6:5,7:6} + + +-- !query 18 +select transform_keys(ys, (k, v) -> k + v) as v from nested +-- !query 18 schema +struct> +-- !query 18 output +{10:5,12:6,8:4} +{2:1,4:2,6:3} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 3c5831f33b23..ac7cf0fc5216 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2071,6 +2071,158 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex4.getMessage.contains("data type mismatch: argument 3 requires int type")) } + test("transform keys function - test various primitive data types combinations") { + val dfExample1 = Seq( + Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7) + ).toDF("i") + + val dfExample2 = Seq( + Map[Int, String](1 -> "a", 2 -> "b", 3 -> "c") + ).toDF("x") + + val dfExample3 = Seq( + Map[String, Int]("a" -> 1, "b" -> 2, "c" -> 3) + ).toDF("y") + + val dfExample4 = Seq( + Map[Int, Double](1 -> 1.0E0, 2 -> 1.4E0, 3 -> 1.7E0) + ).toDF("z") + + val dfExample5 = Seq( + Map[Int, Boolean](25 -> true, 26 -> false) + ).toDF("a") + + val dfExample6 = Seq( + Map[Int, String](25 -> "ab", 26 -> "cd") + ).toDF("b") + + val dfExample7 = Seq( + Map[Array[Int], Boolean](Array(1, 2) -> false) + ).toDF("c") + + + def testMapOfPrimitiveTypesCombination(): Unit = { + checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> k + v)"), + Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7)))) + + checkAnswer(dfExample2.selectExpr("transform_keys(x, (k, v) -> k + 1)"), + Seq(Row(Map(2 -> "a", 3 -> "b", 4 -> "c")))) + + checkAnswer(dfExample3.selectExpr("transform_keys(y, (k, v) -> v * v)"), + Seq(Row(Map(1 -> 1, 4 -> 2, 9 -> 3)))) + + checkAnswer(dfExample3.selectExpr("transform_keys(y, (k, v) -> length(k) + v)"), + Seq(Row(Map(2 -> 1, 3 -> 2, 4 -> 3)))) + + checkAnswer( + dfExample3.selectExpr("transform_keys(y, (k, v) -> concat(k, cast(v as String)))"), + Seq(Row(Map("a1" -> 1, "b2" -> 2, "c3" -> 3)))) + + checkAnswer(dfExample4.selectExpr("transform_keys(z, " + + "(k, v) -> map_from_arrays(ARRAY(1, 2, 3), ARRAY('one', 'two', 'three'))[k])"), + Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7)))) + + checkAnswer(dfExample4.selectExpr("transform_keys(z, (k, v) -> CAST(v * 2 AS BIGINT) + k)"), + Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7)))) + + checkAnswer(dfExample4.selectExpr("transform_keys(z, (k, v) -> k + v)"), + Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7)))) + + checkAnswer(dfExample5.selectExpr("transform_keys(a, (k, v) -> k % 2 = 0 OR v)"), + Seq(Row(Map(true -> true, true -> false)))) + + checkAnswer(dfExample5.selectExpr("transform_keys(a, (k, v) -> if(v, 2 * k, 3 * k))"), + Seq(Row(Map(50 -> true, 78 -> false)))) + + checkAnswer(dfExample5.selectExpr("transform_keys(a, (k, v) -> if(v, 2 * k, 3 * k))"), + Seq(Row(Map(50 -> true, 78 -> false)))) + + checkAnswer(dfExample6.selectExpr( + "transform_keys(b, (k, v) -> concat(conv(k, 10, 16) , substr(v, 1, 1)))"), + Seq(Row(Map("19a" -> "ab", "1Ac" -> "cd")))) + + checkAnswer(dfExample7.selectExpr("transform_keys(c, (k, v) -> array_contains(k, 3) AND v)"), + Seq(Row(Map(false -> false)))) + } + // Test with local relation, the Project will be evaluated without codegen + testMapOfPrimitiveTypesCombination() + dfExample1.cache() + dfExample2.cache() + dfExample3.cache() + dfExample4.cache() + dfExample5.cache() + dfExample6.cache() + // Test with cached relation, the Project will be evaluated with codegen + testMapOfPrimitiveTypesCombination() + } + + test("transform keys function - test empty") { + val dfExample1 = Seq( + Map.empty[Int, Int] + ).toDF("i") + + val dfExample2 = Seq( + Map.empty[BigInt, String] + ).toDF("j") + + def testEmpty(): Unit = { + checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> NULL)"), + Seq(Row(Map.empty[Null, Null]))) + + checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> k)"), + Seq(Row(Map.empty[Null, Null]))) + + checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> v)"), + Seq(Row(Map.empty[Null, Null]))) + + checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> 0)"), + Seq(Row(Map.empty[Int, Null]))) + + checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> 'key')"), + Seq(Row(Map.empty[String, Null]))) + + checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> true)"), + Seq(Row(Map.empty[Boolean, Null]))) + + checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> k + cast(v as BIGINT))"), + Seq(Row(Map.empty[BigInt, Null]))) + + checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> v)"), + Seq(Row(Map()))) + } + testEmpty() + dfExample1.cache() + dfExample2.cache() + testEmpty() + } + + test("transform keys function - Invalid lambda functions") { + val dfExample1 = Seq( + Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7) + ).toDF("i") + + val dfExample2 = Seq( + Map[String, String]("a" -> "b") + ).toDF("j") + + def testInvalidLambdaFunctions(): Unit = { + val ex1 = intercept[AnalysisException] { + dfExample1.selectExpr("transform_keys(i, k -> k )") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match")) + + val ex2 = intercept[AnalysisException] { + dfExample2.selectExpr("transform_keys(j, (k, v, x) -> k + 1)") + } + assert(ex2.getMessage.contains("The number of lambda function arguments '3' does not match")) + } + + testInvalidLambdaFunctions() + dfExample1.cache() + dfExample2.cache() + testInvalidLambdaFunctions() + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 5806ac46707772fd1e4befa445157ed0f9c75084 Mon Sep 17 00:00:00 2001 From: codeatri Date: Tue, 7 Aug 2018 20:06:11 -0700 Subject: [PATCH 02/13] Addressed Review Commenst --- .../expressions/higherOrderFunctions.scala | 54 +++++++++---------- .../HigherOrderFunctionsSuite.scala | 8 +-- .../spark/sql/DataFrameFunctionsSuite.scala | 39 +++++++++++++- 3 files changed, 66 insertions(+), 35 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index f616edc9d8c2..0c33b868af85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -444,13 +444,14 @@ case class ArrayAggregate( } /** - * Transform Keys in a map using the transform_keys function. + * Transform Keys for every entry of the map by applying the transform_keys function. + * Returns map with transformed key entries */ @ExpressionDescription( usage = "_FUNC_(expr, func) - Transforms elements in a map using the function.", examples = """ Examples: - > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k,v) -> k + 1); + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k, v) -> k + 1); map(array(2, 3, 4), array(1, 2, 3)) > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k, v) -> k + v); map(array(2, 4, 6), array(1, 2, 3)) @@ -459,27 +460,22 @@ case class ArrayAggregate( case class TransformKeys( input: Expression, function: Expression) - extends ArrayBasedHigherOrderFunction with CodegenFallback { + extends MapBasedSimpleHigherOrderFunction with CodegenFallback { override def nullable: Boolean = input.nullable override def dataType: DataType = { - val valueType = input.dataType.asInstanceOf[MapType].valueType - MapType(function.dataType, valueType, input.nullable) + val map = input.dataType.asInstanceOf[MapType] + MapType(function.dataType, map.valueType, map.valueContainsNull) } override def inputTypes: Seq[AbstractDataType] = Seq(MapType, expectingFunctionType) - override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): - TransformKeys = { - val (keyElementType, valueElementType, containsNull) = input.dataType match { - case MapType(keyType, valueType, containsNullValue) => - (keyType, valueType, containsNullValue) - case _ => - val MapType(keyType, valueType, containsNullValue) = MapType.defaultConcreteType - (keyType, valueType, containsNullValue) - } - copy(function = f(function, (keyElementType, false) :: (valueElementType, containsNull) :: Nil)) + @transient val (keyType, valueType, valueContainsNull) = + HigherOrderFunction.mapKeyValueArgumentType(input.dataType) + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): TransformKeys = { + copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) } @transient lazy val (keyVar, valueVar) = { @@ -488,22 +484,22 @@ case class TransformKeys( (keyVar, valueVar) } - override def eval(input: InternalRow): Any = { - val arr = this.input.eval(input).asInstanceOf[MapData] - if (arr == null) { - null - } else { - val f = functionForEval - val resultKeys = new GenericArrayData(new Array[Any](arr.numElements)) - var i = 0 - while (i < arr.numElements) { - keyVar.value.set(arr.keyArray().get(i, keyVar.dataType)) - valueVar.value.set(arr.valueArray().get(i, valueVar.dataType)) - resultKeys.update(i, f.eval(input)) - i += 1 + override def nullSafeEval(inputRow: InternalRow, value: Any): Any = { + val map = value.asInstanceOf[MapData] + val f = functionForEval + val resultKeys = new GenericArrayData(new Array[Any](map.numElements)) + var i = 0 + while (i < map.numElements) { + keyVar.value.set(map.keyArray().get(i, keyVar.dataType)) + valueVar.value.set(map.valueArray().get(i, valueVar.dataType)) + val result = f.eval(inputRow) + if (result == null) { + throw new RuntimeException("Cannot use null as map key!") } - new ArrayBasedMapData(resultKeys, arr.valueArray()) + resultKeys.update(i, f.eval(inputRow)) + i += 1 } + new ArrayBasedMapData(resultKeys, map.valueArray()) } override def prettyName: String = "transform_keys" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index fbc59c947452..245670c1f6a9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -262,7 +262,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val asn = Literal.create(Map.empty[StringType, StringType], MapType(StringType, StringType)) val concatValue: (Expression, Expression) => Expression = (k, v) => Concat(Seq(k, v)) - val convertKeyAndConcatValue: (Expression, Expression) => Expression = + val convertKeyToKeyLength: (Expression, Expression) => Expression = (k, v) => Length(k) + 1 checkEvaluation( @@ -272,10 +272,10 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper Map("axyxy" -> "xy", "bbyzyz" -> "yz", "ccczxzx" -> "zx")) checkEvaluation(transformKeys(asn, concatValue), Map.empty[String, String]) checkEvaluation( - transformKeys(transformKeys(asn, concatValue), convertKeyAndConcatValue), + transformKeys(transformKeys(asn, concatValue), convertKeyToKeyLength), Map.empty[Int, String]) - checkEvaluation(transformKeys(as0, convertKeyAndConcatValue), + checkEvaluation(transformKeys(as0, convertKeyToKeyLength), Map(2 -> "xy", 3 -> "yz", 4 -> "zx")) - checkEvaluation(transformKeys(asn, convertKeyAndConcatValue), Map.empty[Int, String]) + checkEvaluation(transformKeys(asn, convertKeyToKeyLength), Map.empty[Int, String]) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 50e1384ae540..272c185dc2f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2242,7 +2242,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { testEmpty() } - test("transform keys function - Invalid lambda functions") { + test("transform keys function - Invalid lambda functions and exceptions") { val dfExample1 = Seq( Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7) ).toDF("i") @@ -2251,6 +2251,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Map[String, String]("a" -> "b") ).toDF("j") + val dfExample3 = Seq( + Map[String, String]("a" -> null) + ).toDF("x") + def testInvalidLambdaFunctions(): Unit = { val ex1 = intercept[AnalysisException] { dfExample1.selectExpr("transform_keys(i, k -> k )") @@ -2260,7 +2264,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { val ex2 = intercept[AnalysisException] { dfExample2.selectExpr("transform_keys(j, (k, v, x) -> k + 1)") } - assert(ex2.getMessage.contains("The number of lambda function arguments '3' does not match")) + assert(ex2.getMessage.contains( + "The number of lambda function arguments '3' does not match")) } testInvalidLambdaFunctions() @@ -2269,6 +2274,36 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { testInvalidLambdaFunctions() } + test("transform keys function - test null") { + val dfExample1 = Seq( + Map[Boolean, Integer](true -> 1, false -> null) + ).toDF("a") + + def testNullValues(): Unit = { + checkAnswer(dfExample1.selectExpr("transform_keys(a, (k, v) -> if(k, NOT k, v IS NULL))"), + Seq(Row(Map(false -> 1, true -> null)))) + } + + testNullValues() + dfExample1.cache() + testNullValues() + } + + test("transform keys function - test duplicate keys") { + val dfExample1 = Seq( + Map[Int, String](1 -> "a", 2 -> "b", 3 -> "c", 4 -> "d") + ).toDF("a") + + def testNullValues(): Unit = { + checkAnswer(dfExample1.selectExpr("transform_keys(a, (k, v) -> k%3)"), + Seq(Row(Map(1 -> "a", 2 -> "b", 0 -> "c", 1 -> "d")))) + } + + testNullValues() + dfExample1.cache() + testNullValues() + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 150a6a5c405c78e7a5f7dd9b3f3c72f95290ec71 Mon Sep 17 00:00:00 2001 From: codeatri Date: Tue, 7 Aug 2018 20:26:15 -0700 Subject: [PATCH 03/13] Additional Tests containing nulls --- .../expressions/HigherOrderFunctionsSuite.scala | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 245670c1f6a9..6977299bddda 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -244,6 +244,9 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val ai1 = Literal.create( Map.empty[Int, Int], MapType(IntegerType, IntegerType)) + val ai2 = Literal.create( + Map(1 -> 1, 2 -> null, 3 -> 3), + MapType(IntegerType, IntegerType)) val plusOne: (Expression, Expression) => Expression = (k, v) => k + 1 val plusValue: (Expression, Expression) => Expression = (k, v) => k + v @@ -256,9 +259,14 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int]) checkEvaluation( transformKeys(transformKeys(ai1, plusOne), plusValue), Map.empty[Int, Int]) + checkEvaluation(transformKeys(ai2, plusOne), Map(2 -> 1, 3 -> null, 4 -> 3)) + checkEvaluation( + transformKeys(transformKeys(ai2, plusOne), plusOne), Map(3 -> 1, 4 -> null, 5 -> 3)) val as0 = Literal.create( Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), MapType(StringType, StringType)) + val as1 = Literal.create( + Map("a" -> "xy", "bb" -> "yz", "ccc" -> null), MapType(StringType, StringType)) val asn = Literal.create(Map.empty[StringType, StringType], MapType(StringType, StringType)) val concatValue: (Expression, Expression) => Expression = (k, v) => Concat(Seq(k, v)) @@ -276,6 +284,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper Map.empty[Int, String]) checkEvaluation(transformKeys(as0, convertKeyToKeyLength), Map(2 -> "xy", 3 -> "yz", 4 -> "zx")) + checkEvaluation(transformKeys(as1, convertKeyToKeyLength), + Map(2 -> "xy", 3 -> "yz", 4 -> null)) checkEvaluation(transformKeys(asn, convertKeyToKeyLength), Map.empty[Int, String]) } } From 9f6a8abae75b70c5be89c6bbccf3a574bd7fb17d Mon Sep 17 00:00:00 2001 From: codeatri Date: Wed, 8 Aug 2018 11:57:39 -0700 Subject: [PATCH 04/13] nit: style --- .../spark/sql/catalyst/expressions/higherOrderFunctions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 0c33b868af85..ea3ecda15919 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -496,7 +496,7 @@ case class TransformKeys( if (result == null) { throw new RuntimeException("Cannot use null as map key!") } - resultKeys.update(i, f.eval(inputRow)) + resultKeys.update(i, result) i += 1 } new ArrayBasedMapData(resultKeys, map.valueArray()) From 652663077e383f8b188743c4494d697e34d5d02c Mon Sep 17 00:00:00 2001 From: codeatri Date: Wed, 8 Aug 2018 13:01:22 -0700 Subject: [PATCH 05/13] added additional unit test --- .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 272c185dc2f0..847b5cd18fed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2266,6 +2266,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } assert(ex2.getMessage.contains( "The number of lambda function arguments '3' does not match")) + + val ex3 = intercept[RuntimeException] { + dfExample3.selectExpr("transform_keys(x, (k, v) -> v)").show() + } + assert(ex3.getMessage.contains("Cannot use null as map key!")) + + } testInvalidLambdaFunctions() From f7fd2313dddfea3555bda61fc96339c24afb71b0 Mon Sep 17 00:00:00 2001 From: codeatri Date: Wed, 8 Aug 2018 13:03:45 -0700 Subject: [PATCH 06/13] nit:space --- .../scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 847b5cd18fed..a4c92155a88f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2271,8 +2271,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { dfExample3.selectExpr("transform_keys(x, (k, v) -> v)").show() } assert(ex3.getMessage.contains("Cannot use null as map key!")) - - } testInvalidLambdaFunctions() From 1cbaf0c6adc508299d42a82628f4f0954bed7a95 Mon Sep 17 00:00:00 2001 From: codeatri Date: Thu, 9 Aug 2018 11:32:59 -0700 Subject: [PATCH 07/13] addressed review comments --- .../expressions/higherOrderFunctions.scala | 4 +- .../HigherOrderFunctionsSuite.scala | 43 ++++-- .../spark/sql/DataFrameFunctionsSuite.scala | 125 ++---------------- 3 files changed, 44 insertions(+), 128 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index ea3ecda15919..9d8dd13ea6b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -451,9 +451,9 @@ case class ArrayAggregate( usage = "_FUNC_(expr, func) - Transforms elements in a map using the function.", examples = """ Examples: - > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k, v) -> k + 1); + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + 1); map(array(2, 3, 4), array(1, 2, 3)) - > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k, v) -> k + v); + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); map(array(2, 4, 6), array(1, 2, 3)) """, since = "2.4.0") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 6977299bddda..982a08bed198 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData import org.apache.spark.sql.types._ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -60,9 +61,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper } def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = { - val valueType = expr.dataType.asInstanceOf[MapType].valueType - val keyType = expr.dataType.asInstanceOf[MapType].keyType - TransformKeys(expr, createLambda(keyType, false, valueType, true, f)) + val map = expr.dataType.asInstanceOf[MapType] + TransformKeys(expr, createLambda(map.keyType, false, map.valueType, map.valueContainsNull, f)) } def aggregate( @@ -239,22 +239,26 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper test("TransformKeys") { val ai0 = Literal.create( - Map(1 -> 1, 2 -> 2, 3 -> 3), - MapType(IntegerType, IntegerType)) + Map(1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4), + MapType(IntegerType, IntegerType, valueContainsNull = false)) val ai1 = Literal.create( Map.empty[Int, Int], - MapType(IntegerType, IntegerType)) + MapType(IntegerType, IntegerType, valueContainsNull = true)) val ai2 = Literal.create( Map(1 -> 1, 2 -> null, 3 -> 3), - MapType(IntegerType, IntegerType)) + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val ai3 = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false)) val plusOne: (Expression, Expression) => Expression = (k, v) => k + 1 val plusValue: (Expression, Expression) => Expression = (k, v) => k + v + val modKey: (Expression, Expression) => Expression = (k, v) => k % 3 - checkEvaluation(transformKeys(ai0, plusOne), Map(2 -> 1, 3 -> 2, 4 -> 3)) - checkEvaluation(transformKeys(ai0, plusValue), Map(2 -> 1, 4 -> 2, 6 -> 3)) + checkEvaluation(transformKeys(ai0, plusOne), Map(2 -> 1, 3 -> 2, 4 -> 3, 5 -> 4)) + checkEvaluation(transformKeys(ai0, plusValue), Map(2 -> 1, 4 -> 2, 6 -> 3, 8 -> 4)) checkEvaluation( - transformKeys(transformKeys(ai0, plusOne), plusValue), Map(3 -> 1, 5 -> 2, 7 -> 3)) + transformKeys(transformKeys(ai0, plusOne), plusValue), Map(3 -> 1, 5 -> 2, 7 -> 3, 9 -> 4)) + checkEvaluation(transformKeys(ai0, modKey), + ArrayBasedMapData(Array(1, 2, 0, 1), Array(1, 2, 3, 4))) checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int]) checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int]) checkEvaluation( @@ -262,12 +266,18 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(transformKeys(ai2, plusOne), Map(2 -> 1, 3 -> null, 4 -> 3)) checkEvaluation( transformKeys(transformKeys(ai2, plusOne), plusOne), Map(3 -> 1, 4 -> null, 5 -> 3)) + checkEvaluation(transformKeys(ai3, plusOne), null) val as0 = Literal.create( - Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), MapType(StringType, StringType)) + Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), + MapType(StringType, StringType, valueContainsNull = false)) val as1 = Literal.create( - Map("a" -> "xy", "bb" -> "yz", "ccc" -> null), MapType(StringType, StringType)) - val asn = Literal.create(Map.empty[StringType, StringType], MapType(StringType, StringType)) + Map("a" -> "xy", "bb" -> "yz", "ccc" -> null), + MapType(StringType, StringType, valueContainsNull = true)) + val as2 = Literal.create(null, + MapType(StringType, StringType, valueContainsNull = false)) + val asn = Literal.create(Map.empty[StringType, StringType], + MapType(StringType, StringType, valueContainsNull = true)) val concatValue: (Expression, Expression) => Expression = (k, v) => Concat(Seq(k, v)) val convertKeyToKeyLength: (Expression, Expression) => Expression = @@ -286,6 +296,13 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper Map(2 -> "xy", 3 -> "yz", 4 -> "zx")) checkEvaluation(transformKeys(as1, convertKeyToKeyLength), Map(2 -> "xy", 3 -> "yz", 4 -> null)) + checkEvaluation(transformKeys(as2, convertKeyToKeyLength), null) checkEvaluation(transformKeys(asn, convertKeyToKeyLength), Map.empty[Int, String]) + + val ax0 = Literal.create( + Map(1 -> "x", 2 -> "y", 3 -> "z"), + MapType(IntegerType, StringType, valueContainsNull = false)) + + checkEvaluation(transformKeys(ax0, plusOne), Map(2 -> "x", 3 -> "y", 4 -> "z")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index a4c92155a88f..d4a0dd7f6090 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2123,71 +2123,42 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ).toDF("i") val dfExample2 = Seq( - Map[Int, String](1 -> "a", 2 -> "b", 3 -> "c") - ).toDF("x") - - val dfExample3 = Seq( - Map[String, Int]("a" -> 1, "b" -> 2, "c" -> 3) - ).toDF("y") - - val dfExample4 = Seq( Map[Int, Double](1 -> 1.0E0, 2 -> 1.4E0, 3 -> 1.7E0) - ).toDF("z") + ).toDF("j") - val dfExample5 = Seq( + val dfExample3 = Seq( Map[Int, Boolean](25 -> true, 26 -> false) - ).toDF("a") - - val dfExample6 = Seq( - Map[Int, String](25 -> "ab", 26 -> "cd") - ).toDF("b") + ).toDF("x") - val dfExample7 = Seq( + val dfExample4 = Seq( Map[Array[Int], Boolean](Array(1, 2) -> false) - ).toDF("c") + ).toDF("y") def testMapOfPrimitiveTypesCombination(): Unit = { checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> k + v)"), Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7)))) - checkAnswer(dfExample2.selectExpr("transform_keys(x, (k, v) -> k + 1)"), - Seq(Row(Map(2 -> "a", 3 -> "b", 4 -> "c")))) - - checkAnswer(dfExample3.selectExpr("transform_keys(y, (k, v) -> v * v)"), - Seq(Row(Map(1 -> 1, 4 -> 2, 9 -> 3)))) - - checkAnswer(dfExample3.selectExpr("transform_keys(y, (k, v) -> length(k) + v)"), - Seq(Row(Map(2 -> 1, 3 -> 2, 4 -> 3)))) - - checkAnswer( - dfExample3.selectExpr("transform_keys(y, (k, v) -> concat(k, cast(v as String)))"), - Seq(Row(Map("a1" -> 1, "b2" -> 2, "c3" -> 3)))) - - checkAnswer(dfExample4.selectExpr("transform_keys(z, " + + checkAnswer(dfExample2.selectExpr("transform_keys(j, " + "(k, v) -> map_from_arrays(ARRAY(1, 2, 3), ARRAY('one', 'two', 'three'))[k])"), Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7)))) - checkAnswer(dfExample4.selectExpr("transform_keys(z, (k, v) -> CAST(v * 2 AS BIGINT) + k)"), + checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> CAST(v * 2 AS BIGINT) + k)"), Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7)))) - checkAnswer(dfExample4.selectExpr("transform_keys(z, (k, v) -> k + v)"), + checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> k + v)"), Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7)))) - checkAnswer(dfExample5.selectExpr("transform_keys(a, (k, v) -> k % 2 = 0 OR v)"), + checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR v)"), Seq(Row(Map(true -> true, true -> false)))) - checkAnswer(dfExample5.selectExpr("transform_keys(a, (k, v) -> if(v, 2 * k, 3 * k))"), + checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"), Seq(Row(Map(50 -> true, 78 -> false)))) - checkAnswer(dfExample5.selectExpr("transform_keys(a, (k, v) -> if(v, 2 * k, 3 * k))"), + checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"), Seq(Row(Map(50 -> true, 78 -> false)))) - checkAnswer(dfExample6.selectExpr( - "transform_keys(b, (k, v) -> concat(conv(k, 10, 16) , substr(v, 1, 1)))"), - Seq(Row(Map("19a" -> "ab", "1Ac" -> "cd")))) - - checkAnswer(dfExample7.selectExpr("transform_keys(c, (k, v) -> array_contains(k, 3) AND v)"), + checkAnswer(dfExample4.selectExpr("transform_keys(y, (k, v) -> array_contains(k, 3) AND v)"), Seq(Row(Map(false -> false)))) } // Test with local relation, the Project will be evaluated without codegen @@ -2196,52 +2167,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { dfExample2.cache() dfExample3.cache() dfExample4.cache() - dfExample5.cache() - dfExample6.cache() // Test with cached relation, the Project will be evaluated with codegen testMapOfPrimitiveTypesCombination() } - test("transform keys function - test empty") { - val dfExample1 = Seq( - Map.empty[Int, Int] - ).toDF("i") - - val dfExample2 = Seq( - Map.empty[BigInt, String] - ).toDF("j") - - def testEmpty(): Unit = { - checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> NULL)"), - Seq(Row(Map.empty[Null, Null]))) - - checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> k)"), - Seq(Row(Map.empty[Null, Null]))) - - checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> v)"), - Seq(Row(Map.empty[Null, Null]))) - - checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> 0)"), - Seq(Row(Map.empty[Int, Null]))) - - checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> 'key')"), - Seq(Row(Map.empty[String, Null]))) - - checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> true)"), - Seq(Row(Map.empty[Boolean, Null]))) - - checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> k + cast(v as BIGINT))"), - Seq(Row(Map.empty[BigInt, Null]))) - - checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> v)"), - Seq(Row(Map()))) - } - testEmpty() - dfExample1.cache() - dfExample2.cache() - testEmpty() - } - test("transform keys function - Invalid lambda functions and exceptions") { val dfExample1 = Seq( Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7) @@ -2279,36 +2208,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { testInvalidLambdaFunctions() } - test("transform keys function - test null") { - val dfExample1 = Seq( - Map[Boolean, Integer](true -> 1, false -> null) - ).toDF("a") - - def testNullValues(): Unit = { - checkAnswer(dfExample1.selectExpr("transform_keys(a, (k, v) -> if(k, NOT k, v IS NULL))"), - Seq(Row(Map(false -> 1, true -> null)))) - } - - testNullValues() - dfExample1.cache() - testNullValues() - } - - test("transform keys function - test duplicate keys") { - val dfExample1 = Seq( - Map[Int, String](1 -> "a", 2 -> "b", 3 -> "c", 4 -> "d") - ).toDF("a") - - def testNullValues(): Unit = { - checkAnswer(dfExample1.selectExpr("transform_keys(a, (k, v) -> k%3)"), - Seq(Row(Map(1 -> "a", 2 -> "b", 0 -> "c", 1 -> "d")))) - } - - testNullValues() - dfExample1.cache() - testNullValues() - } - private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From 621213dd1658fbc8cb19e15dd77c9c389653d4db Mon Sep 17 00:00:00 2001 From: codeatri Date: Tue, 14 Aug 2018 12:05:40 -0700 Subject: [PATCH 08/13] fix style --- .../sql/catalyst/expressions/HigherOrderFunctionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 325685041405..92004742fb32 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -357,7 +357,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(transformKeys(ax0, plusOne), Map(2 -> "x", 3 -> "y", 4 -> "z")) } - + test("MapZipWith") { def map_zip_with( left: Expression, From 5db526be7bad0fa38dc9743c919014b475cf8aeb Mon Sep 17 00:00:00 2001 From: codeatri Date: Tue, 14 Aug 2018 16:15:32 -0700 Subject: [PATCH 09/13] Merge master Refactoring changes --- .../expressions/higherOrderFunctions.scala | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index d3bb9a56f709..43f3858dbc4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -512,21 +512,18 @@ case class ArrayAggregate( """, since = "2.4.0") case class TransformKeys( - input: Expression, + argument: Expression, function: Expression) extends MapBasedSimpleHigherOrderFunction with CodegenFallback { - override def nullable: Boolean = input.nullable + override def nullable: Boolean = argument.nullable override def dataType: DataType = { - val map = input.dataType.asInstanceOf[MapType] + val map = argument.dataType.asInstanceOf[MapType] MapType(function.dataType, map.valueType, map.valueContainsNull) } - override def inputTypes: Seq[AbstractDataType] = Seq(MapType, expectingFunctionType) - - @transient val (keyType, valueType, valueContainsNull) = - HigherOrderFunction.mapKeyValueArgumentType(input.dataType) + @transient val MapType(keyType, valueType, valueContainsNull) = argument.dataType override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): TransformKeys = { copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) @@ -538,8 +535,8 @@ case class TransformKeys( (keyVar, valueVar) } - override def nullSafeEval(inputRow: InternalRow, value: Any): Any = { - val map = value.asInstanceOf[MapData] + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val map = argumentValue.asInstanceOf[MapData] val f = functionForEval val resultKeys = new GenericArrayData(new Array[Any](map.numElements)) var i = 0 From e5d9b051b027cf86fbcd82701f54e50f1aeac7f6 Mon Sep 17 00:00:00 2001 From: codeatri Date: Tue, 14 Aug 2018 23:36:53 -0700 Subject: [PATCH 10/13] review comments --- .../expressions/higherOrderFunctions.scala | 8 +++--- .../HigherOrderFunctionsSuite.scala | 8 +++--- .../inputs/higher-order-functions.sql | 4 +-- .../results/higher-order-functions.sql.out | 4 +-- .../spark/sql/DataFrameFunctionsSuite.scala | 27 ++++++++++--------- 5 files changed, 27 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 43f3858dbc4f..338331881d10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -523,14 +523,14 @@ case class TransformKeys( MapType(function.dataType, map.valueType, map.valueContainsNull) } - @transient val MapType(keyType, valueType, valueContainsNull) = argument.dataType + @transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): TransformKeys = { copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) } @transient lazy val (keyVar, valueVar) = { - val LambdaFunction( + @transient lazy val LambdaFunction( _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function (keyVar, valueVar) } @@ -544,7 +544,7 @@ case class TransformKeys( keyVar.value.set(map.keyArray().get(i, keyVar.dataType)) valueVar.value.set(map.valueArray().get(i, valueVar.dataType)) val result = f.eval(inputRow) - if (result == null) { + if (result == null) { throw new RuntimeException("Cannot use null as map key!") } resultKeys.update(i, result) @@ -554,7 +554,7 @@ case class TransformKeys( } override def prettyName: String = "transform_keys" - } +} /** * Merges two given maps into a single map by applying function to the pair of values with diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 92004742fb32..12ef01816835 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -328,7 +328,7 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper MapType(StringType, StringType, valueContainsNull = true)) val as2 = Literal.create(null, MapType(StringType, StringType, valueContainsNull = false)) - val asn = Literal.create(Map.empty[StringType, StringType], + val as3 = Literal.create(Map.empty[StringType, StringType], MapType(StringType, StringType, valueContainsNull = true)) val concatValue: (Expression, Expression) => Expression = (k, v) => Concat(Seq(k, v)) @@ -340,16 +340,16 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation( transformKeys(transformKeys(as0, concatValue), concatValue), Map("axyxy" -> "xy", "bbyzyz" -> "yz", "ccczxzx" -> "zx")) - checkEvaluation(transformKeys(asn, concatValue), Map.empty[String, String]) + checkEvaluation(transformKeys(as3, concatValue), Map.empty[String, String]) checkEvaluation( - transformKeys(transformKeys(asn, concatValue), convertKeyToKeyLength), + transformKeys(transformKeys(as3, concatValue), convertKeyToKeyLength), Map.empty[Int, String]) checkEvaluation(transformKeys(as0, convertKeyToKeyLength), Map(2 -> "xy", 3 -> "yz", 4 -> "zx")) checkEvaluation(transformKeys(as1, convertKeyToKeyLength), Map(2 -> "xy", 3 -> "yz", 4 -> null)) checkEvaluation(transformKeys(as2, convertKeyToKeyLength), null) - checkEvaluation(transformKeys(asn, convertKeyToKeyLength), Map.empty[Int, String]) + checkEvaluation(transformKeys(as3, convertKeyToKeyLength), Map.empty[Int, String]) val ax0 = Literal.create( Map(1 -> "x", 2 -> "y", 3 -> "z"), diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql index f23505a33747..9a8454455ae7 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql @@ -53,8 +53,8 @@ select exists(ys, y -> y > 30) as v from nested; select exists(cast(null as array), y -> y > 30) as v; create or replace temporary view nested as values - (1, map(1,1,2,2,3,3)), - (2, map(4,4,5,5,6,6)) + (1, map(1, 1, 2, 2, 3, 3)), + (2, map(4, 4, 5, 5, 6, 6)) as t(x, ys); -- Identity Transform Keys in a map diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out index 1a5896983133..b77bda7bb267 100644 --- a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -167,8 +167,8 @@ NULL -- !query 17 create or replace temporary view nested as values - (1, map(1,1,2,2,3,3)), - (2, map(4,4,5,5,6,6)) + (1, map(1, 1, 2, 2, 3, 3)), + (2, map(4, 4, 5, 5, 6, 6)) as t(x, ys) -- !query 17 schema struct<> diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 12fad0e4db09..f5e9983340b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2302,13 +2302,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex5.getMessage.contains("function map_zip_with does not support ordering on type map")) } - test("transform keys function - test various primitive data types combinations") { + test("transform keys function - primitive data types") { val dfExample1 = Seq( Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7) ).toDF("i") val dfExample2 = Seq( - Map[Int, Double](1 -> 1.0E0, 2 -> 1.4E0, 3 -> 1.7E0) + Map[Int, Double](1 -> 1.0, 2 -> 1.40, 3 -> 1.70) ).toDF("j") val dfExample3 = Seq( @@ -2357,34 +2357,37 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } test("transform keys function - Invalid lambda functions and exceptions") { + val dfExample1 = Seq( - Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7) + Map[String, String]("a" -> null) ).toDF("i") val dfExample2 = Seq( - Map[String, String]("a" -> "b") + Seq(1, 2, 3, 4) ).toDF("j") - val dfExample3 = Seq( - Map[String, String]("a" -> null) - ).toDF("x") - def testInvalidLambdaFunctions(): Unit = { val ex1 = intercept[AnalysisException] { - dfExample1.selectExpr("transform_keys(i, k -> k )") + dfExample1.selectExpr("transform_keys(i, k -> k)") } assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match")) val ex2 = intercept[AnalysisException] { - dfExample2.selectExpr("transform_keys(j, (k, v, x) -> k + 1)") + dfExample1.selectExpr("transform_keys(i, (k, v, x) -> k + 1)") } assert(ex2.getMessage.contains( - "The number of lambda function arguments '3' does not match")) + "The number of lambda function arguments '3' does not match")) val ex3 = intercept[RuntimeException] { - dfExample3.selectExpr("transform_keys(x, (k, v) -> v)").show() + dfExample1.selectExpr("transform_keys(i, (k, v) -> v)").show() } assert(ex3.getMessage.contains("Cannot use null as map key!")) + + val ex4 = intercept[AnalysisException] { + dfExample2.selectExpr("transform_keys(j, (k, v) -> k + 1)") + } + assert(ex4.getMessage.contains( + "data type mismatch: argument 1 requires map type")) } testInvalidLambdaFunctions() From fb885f4797e72d0c2cbfa23980199c71e0c5aaee Mon Sep 17 00:00:00 2001 From: codeatri Date: Wed, 15 Aug 2018 11:12:38 -0700 Subject: [PATCH 11/13] review comments --- .../expressions/higherOrderFunctions.scala | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 338331881d10..a05afe31acc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -518,22 +518,19 @@ case class TransformKeys( override def nullable: Boolean = argument.nullable + @transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType + override def dataType: DataType = { - val map = argument.dataType.asInstanceOf[MapType] - MapType(function.dataType, map.valueType, map.valueContainsNull) + MapType(function.dataType, valueType, valueContainsNull) } - @transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType - override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): TransformKeys = { copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) } - @transient lazy val (keyVar, valueVar) = { - @transient lazy val LambdaFunction( - _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function - (keyVar, valueVar) - } + @transient lazy val LambdaFunction( + _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { val map = argumentValue.asInstanceOf[MapData] From 58b60b2f851fb1464743257fe1cca075a1e77ba9 Mon Sep 17 00:00:00 2001 From: codeatri Date: Wed, 15 Aug 2018 12:43:28 -0700 Subject: [PATCH 12/13] fix unit test --- .../spark/sql/DataFrameFunctionsSuite.scala | 41 ++++++++----------- 1 file changed, 17 insertions(+), 24 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index f5e9983340b8..22f191209f87 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2366,34 +2366,27 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Seq(1, 2, 3, 4) ).toDF("j") - def testInvalidLambdaFunctions(): Unit = { - val ex1 = intercept[AnalysisException] { - dfExample1.selectExpr("transform_keys(i, k -> k)") - } - assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match")) - - val ex2 = intercept[AnalysisException] { - dfExample1.selectExpr("transform_keys(i, (k, v, x) -> k + 1)") - } - assert(ex2.getMessage.contains( - "The number of lambda function arguments '3' does not match")) + val ex1 = intercept[AnalysisException] { + dfExample1.selectExpr("transform_keys(i, k -> k)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match")) - val ex3 = intercept[RuntimeException] { - dfExample1.selectExpr("transform_keys(i, (k, v) -> v)").show() - } - assert(ex3.getMessage.contains("Cannot use null as map key!")) + val ex2 = intercept[AnalysisException] { + dfExample1.selectExpr("transform_keys(i, (k, v, x) -> k + 1)") + } + assert(ex2.getMessage.contains( + "The number of lambda function arguments '3' does not match")) - val ex4 = intercept[AnalysisException] { - dfExample2.selectExpr("transform_keys(j, (k, v) -> k + 1)") - } - assert(ex4.getMessage.contains( - "data type mismatch: argument 1 requires map type")) + val ex3 = intercept[RuntimeException] { + dfExample1.selectExpr("transform_keys(i, (k, v) -> v)").show() } + assert(ex3.getMessage.contains("Cannot use null as map key!")) - testInvalidLambdaFunctions() - dfExample1.cache() - dfExample2.cache() - testInvalidLambdaFunctions() + val ex4 = intercept[AnalysisException] { + dfExample2.selectExpr("transform_keys(j, (k, v) -> k + 1)") + } + assert(ex4.getMessage.contains( + "data type mismatch: argument 1 requires map type")) } private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { From 2f4943f3cec0705c296b2988c415ac3372b7ea86 Mon Sep 17 00:00:00 2001 From: codeatri Date: Wed, 15 Aug 2018 12:51:39 -0700 Subject: [PATCH 13/13] review --- .../sql/catalyst/expressions/higherOrderFunctions.scala | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index a05afe31acc7..a305a05add7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -520,9 +520,7 @@ case class TransformKeys( @transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType - override def dataType: DataType = { - MapType(function.dataType, valueType, valueContainsNull) - } + override def dataType: DataType = MapType(function.dataType, valueType, valueContainsNull) override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): TransformKeys = { copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) @@ -534,13 +532,12 @@ case class TransformKeys( override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { val map = argumentValue.asInstanceOf[MapData] - val f = functionForEval val resultKeys = new GenericArrayData(new Array[Any](map.numElements)) var i = 0 while (i < map.numElements) { keyVar.value.set(map.keyArray().get(i, keyVar.dataType)) valueVar.value.set(map.valueArray().get(i, valueVar.dataType)) - val result = f.eval(inputRow) + val result = functionForEval.eval(inputRow) if (result == null) { throw new RuntimeException("Cannot use null as map key!") }