Skip to content
Prev Previous commit
Next Next commit
addressed review comments
  • Loading branch information
codeatri committed Aug 9, 2018
commit 1cbaf0c6adc508299d42a82628f4f0954bed7a95
Original file line number Diff line number Diff line change
Expand Up @@ -451,9 +451,9 @@ case class ArrayAggregate(
usage = "_FUNC_(expr, func) - Transforms elements in a map using the function.",
examples = """
Examples:
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k, v) -> k + 1);
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + 1);
map(array(2, 3, 4), array(1, 2, 3))
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k, v) -> k + v);
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v);
map(array(2, 4, 6), array(1, 2, 3))
""",
since = "2.4.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.types._

class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand Down Expand Up @@ -60,9 +61,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
}

def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
val valueType = expr.dataType.asInstanceOf[MapType].valueType
val keyType = expr.dataType.asInstanceOf[MapType].keyType
TransformKeys(expr, createLambda(keyType, false, valueType, true, f))
val map = expr.dataType.asInstanceOf[MapType]
TransformKeys(expr, createLambda(map.keyType, false, map.valueType, map.valueContainsNull, f))
}

def aggregate(
Expand Down Expand Up @@ -239,35 +239,45 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper

test("TransformKeys") {
val ai0 = Literal.create(
Map(1 -> 1, 2 -> 2, 3 -> 3),
MapType(IntegerType, IntegerType))
Map(1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4),
MapType(IntegerType, IntegerType, valueContainsNull = false))
val ai1 = Literal.create(
Map.empty[Int, Int],
MapType(IntegerType, IntegerType))
MapType(IntegerType, IntegerType, valueContainsNull = true))
val ai2 = Literal.create(
Map(1 -> 1, 2 -> null, 3 -> 3),
MapType(IntegerType, IntegerType))
MapType(IntegerType, IntegerType, valueContainsNull = true))
val ai3 = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false))

val plusOne: (Expression, Expression) => Expression = (k, v) => k + 1
val plusValue: (Expression, Expression) => Expression = (k, v) => k + v
val modKey: (Expression, Expression) => Expression = (k, v) => k % 3

checkEvaluation(transformKeys(ai0, plusOne), Map(2 -> 1, 3 -> 2, 4 -> 3))
checkEvaluation(transformKeys(ai0, plusValue), Map(2 -> 1, 4 -> 2, 6 -> 3))
checkEvaluation(transformKeys(ai0, plusOne), Map(2 -> 1, 3 -> 2, 4 -> 3, 5 -> 4))
checkEvaluation(transformKeys(ai0, plusValue), Map(2 -> 1, 4 -> 2, 6 -> 3, 8 -> 4))
checkEvaluation(
transformKeys(transformKeys(ai0, plusOne), plusValue), Map(3 -> 1, 5 -> 2, 7 -> 3))
transformKeys(transformKeys(ai0, plusOne), plusValue), Map(3 -> 1, 5 -> 2, 7 -> 3, 9 -> 4))
checkEvaluation(transformKeys(ai0, modKey),
ArrayBasedMapData(Array(1, 2, 0, 1), Array(1, 2, 3, 4)))
checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int])
checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int])
checkEvaluation(
transformKeys(transformKeys(ai1, plusOne), plusValue), Map.empty[Int, Int])
checkEvaluation(transformKeys(ai2, plusOne), Map(2 -> 1, 3 -> null, 4 -> 3))
checkEvaluation(
transformKeys(transformKeys(ai2, plusOne), plusOne), Map(3 -> 1, 4 -> null, 5 -> 3))
checkEvaluation(transformKeys(ai3, plusOne), null)

val as0 = Literal.create(
Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), MapType(StringType, StringType))
Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"),
MapType(StringType, StringType, valueContainsNull = false))
val as1 = Literal.create(
Map("a" -> "xy", "bb" -> "yz", "ccc" -> null), MapType(StringType, StringType))
val asn = Literal.create(Map.empty[StringType, StringType], MapType(StringType, StringType))
Map("a" -> "xy", "bb" -> "yz", "ccc" -> null),
MapType(StringType, StringType, valueContainsNull = true))
val as2 = Literal.create(null,
MapType(StringType, StringType, valueContainsNull = false))
val asn = Literal.create(Map.empty[StringType, StringType],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as3?

MapType(StringType, StringType, valueContainsNull = true))

val concatValue: (Expression, Expression) => Expression = (k, v) => Concat(Seq(k, v))
val convertKeyToKeyLength: (Expression, Expression) => Expression =
Expand All @@ -286,6 +296,13 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
Map(2 -> "xy", 3 -> "yz", 4 -> "zx"))
checkEvaluation(transformKeys(as1, convertKeyToKeyLength),
Map(2 -> "xy", 3 -> "yz", 4 -> null))
checkEvaluation(transformKeys(as2, convertKeyToKeyLength), null)
checkEvaluation(transformKeys(asn, convertKeyToKeyLength), Map.empty[Int, String])

val ax0 = Literal.create(
Map(1 -> "x", 2 -> "y", 3 -> "z"),
MapType(IntegerType, StringType, valueContainsNull = false))

checkEvaluation(transformKeys(ax0, plusOne), Map(2 -> "x", 3 -> "y", 4 -> "z"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2123,71 +2123,42 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
).toDF("i")

val dfExample2 = Seq(
Map[Int, String](1 -> "a", 2 -> "b", 3 -> "c")
).toDF("x")

val dfExample3 = Seq(
Map[String, Int]("a" -> 1, "b" -> 2, "c" -> 3)
).toDF("y")

val dfExample4 = Seq(
Map[Int, Double](1 -> 1.0E0, 2 -> 1.4E0, 3 -> 1.7E0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need E0?

).toDF("z")
).toDF("j")

val dfExample5 = Seq(
val dfExample3 = Seq(
Map[Int, Boolean](25 -> true, 26 -> false)
).toDF("a")

val dfExample6 = Seq(
Map[Int, String](25 -> "ab", 26 -> "cd")
).toDF("b")
).toDF("x")

val dfExample7 = Seq(
val dfExample4 = Seq(
Map[Array[Int], Boolean](Array(1, 2) -> false)
).toDF("c")
).toDF("y")


def testMapOfPrimitiveTypesCombination(): Unit = {
checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> k + v)"),
Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7))))

checkAnswer(dfExample2.selectExpr("transform_keys(x, (k, v) -> k + 1)"),
Seq(Row(Map(2 -> "a", 3 -> "b", 4 -> "c"))))

checkAnswer(dfExample3.selectExpr("transform_keys(y, (k, v) -> v * v)"),
Seq(Row(Map(1 -> 1, 4 -> 2, 9 -> 3))))

checkAnswer(dfExample3.selectExpr("transform_keys(y, (k, v) -> length(k) + v)"),
Seq(Row(Map(2 -> 1, 3 -> 2, 4 -> 3))))

checkAnswer(
dfExample3.selectExpr("transform_keys(y, (k, v) -> concat(k, cast(v as String)))"),
Seq(Row(Map("a1" -> 1, "b2" -> 2, "c3" -> 3))))

checkAnswer(dfExample4.selectExpr("transform_keys(z, " +
checkAnswer(dfExample2.selectExpr("transform_keys(j, " +
"(k, v) -> map_from_arrays(ARRAY(1, 2, 3), ARRAY('one', 'two', 'three'))[k])"),
Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7))))

checkAnswer(dfExample4.selectExpr("transform_keys(z, (k, v) -> CAST(v * 2 AS BIGINT) + k)"),
checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> CAST(v * 2 AS BIGINT) + k)"),
Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7))))

checkAnswer(dfExample4.selectExpr("transform_keys(z, (k, v) -> k + v)"),
checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> k + v)"),
Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7))))

checkAnswer(dfExample5.selectExpr("transform_keys(a, (k, v) -> k % 2 = 0 OR v)"),
checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR v)"),
Seq(Row(Map(true -> true, true -> false))))

checkAnswer(dfExample5.selectExpr("transform_keys(a, (k, v) -> if(v, 2 * k, 3 * k))"),
checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"),
Seq(Row(Map(50 -> true, 78 -> false))))

checkAnswer(dfExample5.selectExpr("transform_keys(a, (k, v) -> if(v, 2 * k, 3 * k))"),
checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"),
Seq(Row(Map(50 -> true, 78 -> false))))

checkAnswer(dfExample6.selectExpr(
"transform_keys(b, (k, v) -> concat(conv(k, 10, 16) , substr(v, 1, 1)))"),
Seq(Row(Map("19a" -> "ab", "1Ac" -> "cd"))))

checkAnswer(dfExample7.selectExpr("transform_keys(c, (k, v) -> array_contains(k, 3) AND v)"),
checkAnswer(dfExample4.selectExpr("transform_keys(y, (k, v) -> array_contains(k, 3) AND v)"),
Seq(Row(Map(false -> false))))
}
// Test with local relation, the Project will be evaluated without codegen
Expand All @@ -2196,52 +2167,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
dfExample2.cache()
dfExample3.cache()
dfExample4.cache()
dfExample5.cache()
dfExample6.cache()
// Test with cached relation, the Project will be evaluated with codegen
testMapOfPrimitiveTypesCombination()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have do that if the expression implements CodegenFallback?

}

test("transform keys function - test empty") {
val dfExample1 = Seq(
Map.empty[Int, Int]
).toDF("i")

val dfExample2 = Seq(
Map.empty[BigInt, String]
).toDF("j")

def testEmpty(): Unit = {
checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> NULL)"),
Seq(Row(Map.empty[Null, Null])))

checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> k)"),
Seq(Row(Map.empty[Null, Null])))

checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> v)"),
Seq(Row(Map.empty[Null, Null])))

checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> 0)"),
Seq(Row(Map.empty[Int, Null])))

checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> 'key')"),
Seq(Row(Map.empty[String, Null])))

checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> true)"),
Seq(Row(Map.empty[Boolean, Null])))

checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> k + cast(v as BIGINT))"),
Seq(Row(Map.empty[BigInt, Null])))

checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> v)"),
Seq(Row(Map())))
}
testEmpty()
dfExample1.cache()
dfExample2.cache()
testEmpty()
}

test("transform keys function - Invalid lambda functions and exceptions") {
val dfExample1 = Seq(
Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7)
Expand Down Expand Up @@ -2279,36 +2208,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
testInvalidLambdaFunctions()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need dfExample3.cache() as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ueshin I would like to ask you a generic question regarding higher-order functions. Is it necessary to perform checks with codegen paths if all the newly added functions extends from CodegenFallback? Eventually, is there a plan to add coden for these functions in future?

}

test("transform keys function - test null") {
val dfExample1 = Seq(
Map[Boolean, Integer](true -> 1, false -> null)
).toDF("a")

def testNullValues(): Unit = {
checkAnswer(dfExample1.selectExpr("transform_keys(a, (k, v) -> if(k, NOT k, v IS NULL))"),
Seq(Row(Map(false -> 1, true -> null))))
}

testNullValues()
dfExample1.cache()
testNullValues()
}

test("transform keys function - test duplicate keys") {
val dfExample1 = Seq(
Map[Int, String](1 -> "a", 2 -> "b", 3 -> "c", 4 -> "d")
).toDF("a")

def testNullValues(): Unit = {
checkAnswer(dfExample1.selectExpr("transform_keys(a, (k, v) -> k%3)"),
Seq(Row(Map(1 -> "a", 2 -> "b", 0 -> "c", 1 -> "d"))))
}

testNullValues()
dfExample1.cache()
testNullValues()
}

private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
import DataFrameFunctionsSuite.CodegenFallbackExpr
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
Expand Down