From 27733f9ad56657925c176ae394114e0429aa9a0b Mon Sep 17 00:00:00 2001 From: Chongguang LIU Date: Sun, 17 Jun 2018 20:17:15 +0200 Subject: [PATCH 1/8] array_contains function deals with Column type for the second parameter. --- .../src/main/scala/org/apache/spark/sql/functions.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index a2aae9a708ff..5cc349cb3ae6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3077,12 +3077,16 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Returns null if the array is null, true if the array contains `value`, and false otherwise. + * Returns null if the array is null, true if the array contains `value` or the content of + * `value` if it is of type Column, and false otherwise. * @group collection_funcs * @since 1.5.0 */ def array_contains(column: Column, value: Any): Column = withExpr { - ArrayContains(column.expr, Literal(value)) + value match { + case c: Column => ArrayContains(column.expr, c.expr) + case _ => ArrayContains(column.expr, Literal(value)) + } } /** From 28aa51554f4c730fae3c8090ac3c268e1ddfa4f8 Mon Sep 17 00:00:00 2001 From: Chongguang LIU Date: Sun, 17 Jun 2018 21:58:57 +0200 Subject: [PATCH 2/8] add unit test for Column type --- .../org/apache/spark/sql/DataFrameFunctionsSuite.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 59119bbbd8a2..89c9412e2021 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -558,9 +558,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("array contains function") { val df = Seq( - (Seq[Int](1, 2), "x"), - (Seq[Int](), "x") - ).toDF("a", "b") + (Seq[Int](1, 2), "x", 1), + (Seq[Int](), "x", 1) + ).toDF("a", "b", "c") // Simple test cases checkAnswer( @@ -571,6 +571,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("array_contains(a, 1)"), Seq(Row(true), Row(false)) ) + checkAnswer( + df.select(array_contains(df("a"), df("c"))), + Seq(Row(true), Row(false)) + ) // In hive, this errors because null has no type information intercept[AnalysisException] { From e65611e451bf49b38850af84a89510ac8a749cbd Mon Sep 17 00:00:00 2001 From: Chongguang LIU Date: Mon, 18 Jun 2018 14:06:44 +0200 Subject: [PATCH 3/8] array_position deals with Column Type --- .../main/scala/org/apache/spark/sql/functions.scala | 5 ++++- .../apache/spark/sql/DataFrameFunctionsSuite.scala | 11 +++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5cc349cb3ae6..dc4fd96ee931 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3150,7 +3150,10 @@ object functions { * @since 2.4.0 */ def array_position(column: Column, value: Any): Column = withExpr { - ArrayPosition(column.expr, Literal(value)) + value match { + case c: Column => ArrayPosition(column.expr, c.expr) + case _ => ArrayPosition(column.expr, Literal(value)) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 89c9412e2021..463e360c12f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -789,9 +789,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("array position function") { val df = Seq( - (Seq[Int](1, 2), "x"), - (Seq[Int](), "x") - ).toDF("a", "b") + (Seq[Int](1, 2), "x", 1), + (Seq[Int](), "x", 1) + ).toDF("a", "b", "c") checkAnswer( df.select(array_position(df("a"), 1)), @@ -801,7 +801,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("array_position(a, 1)"), Seq(Row(1L), Row(0L)) ) - + checkAnswer( + df.select(array_position(df("a"), df("c"))), + Seq(Row(1L), Row(0L)) + ) checkAnswer( df.select(array_position(df("a"), null)), Seq(Row(null), Row(null)) From b6e150c5e9101c6135ed1020ce97f7e754613035 Mon Sep 17 00:00:00 2001 From: Chongguang LIU Date: Mon, 18 Jun 2018 14:17:08 +0200 Subject: [PATCH 4/8] array_remove function deals with Column Type --- .../scala/org/apache/spark/sql/functions.scala | 5 ++++- .../spark/sql/DataFrameFunctionsSuite.scala | 16 ++++++++++++---- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index dc4fd96ee931..f6c7548015c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3182,7 +3182,10 @@ object functions { * @since 2.4.0 */ def array_remove(column: Column, element: Any): Column = withExpr { - ArrayRemove(column.expr, Literal(element)) + element match { + case c: Column => ArrayRemove(column.expr, c.expr) + case _ => ArrayRemove(column.expr, Literal(element)) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 463e360c12f1..e9804ee62936 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1119,10 +1119,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("array remove") { val df = Seq( - (Array[Int](2, 1, 2, 3), Array("a", "b", "c", "a"), Array("", "")), - (Array.empty[Int], Array.empty[String], Array.empty[String]), - (null, null, null) - ).toDF("a", "b", "c") + (Array[Int](2, 1, 2, 3), Array("a", "b", "c", "a"), Array("", ""), 2), + (Array.empty[Int], Array.empty[String], Array.empty[String], 2), + (null, null, null, 2) + ).toDF("a", "b", "c", "d") checkAnswer( df.select(array_remove($"a", 2), array_remove($"b", "a"), array_remove($"c", "")), Seq( @@ -1131,6 +1131,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null, null, null)) ) + checkAnswer( + df.select(array_remove($"a", $"d")), + Seq( + Row(Seq(1, 3)), + Row(Seq.empty[Int]), + Row(null)) + ) + checkAnswer( df.selectExpr("array_remove(a, 2)", "array_remove(b, \"a\")", "array_remove(c, \"\")"), From eabdb6ed48f39d5c201a32ab32e863e21cc6cbde Mon Sep 17 00:00:00 2001 From: Chongguang LIU Date: Mon, 18 Jun 2018 14:24:29 +0200 Subject: [PATCH 5/8] element_at function deals with Column Type --- .../main/scala/org/apache/spark/sql/functions.scala | 5 ++++- .../apache/spark/sql/DataFrameFunctionsSuite.scala | 12 ++++++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f6c7548015c6..c2a41cc76f72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3164,7 +3164,10 @@ object functions { * @since 2.4.0 */ def element_at(column: Column, value: Any): Column = withExpr { - ElementAt(column.expr, Literal(value)) + value match { + case c: Column => ElementAt(column.expr, c.expr) + case _ => ElementAt(column.expr, Literal(value)) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index e9804ee62936..f9caff67534e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -831,10 +831,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("element_at function") { val df = Seq( - (Seq[String]("1", "2", "3")), - (Seq[String](null, "")), - (Seq[String]()) - ).toDF("a") + (Seq[String]("1", "2", "3"), 1), + (Seq[String](null, ""), -1), + (Seq[String](), 2) + ).toDF("a", "b") intercept[Exception] { checkAnswer( @@ -852,6 +852,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.select(element_at(df("a"), 4)), Seq(Row(null), Row(null), Row(null)) ) + checkAnswer( + df.select(element_at(df("a"), df("b"))), + Seq(Row("1"), Row(""), Row(null)) + ) checkAnswer( df.select(element_at(df("a"), 1)), From 8440d5b08bd33321533635c4b6329ce5c7b843d2 Mon Sep 17 00:00:00 2001 From: Chongguang LIU Date: Mon, 18 Jun 2018 14:26:54 +0200 Subject: [PATCH 6/8] reset doc for array_contains function --- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c2a41cc76f72..6fbe55ea2850 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3077,8 +3077,7 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Returns null if the array is null, true if the array contains `value` or the content of - * `value` if it is of type Column, and false otherwise. + * Returns null if the array is null, true if the array contains `value`, and false otherwise. * @group collection_funcs * @since 1.5.0 */ From f8c5b43ddc3b9209d9b7972e28d237d205190136 Mon Sep 17 00:00:00 2001 From: Chongguang LIU Date: Wed, 20 Jun 2018 10:44:00 +0200 Subject: [PATCH 7/8] use lit() and add unit tests --- .../org/apache/spark/sql/functions.scala | 8 ++++---- .../spark/sql/DataFrameFunctionsSuite.scala | 20 +++++++++++++++++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6fbe55ea2850..77ce7123babe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3083,7 +3083,7 @@ object functions { */ def array_contains(column: Column, value: Any): Column = withExpr { value match { - case c: Column => ArrayContains(column.expr, c.expr) + case c: Column => ArrayContains(column.expr, lit(c).expr) case _ => ArrayContains(column.expr, Literal(value)) } } @@ -3150,7 +3150,7 @@ object functions { */ def array_position(column: Column, value: Any): Column = withExpr { value match { - case c: Column => ArrayPosition(column.expr, c.expr) + case c: Column => ArrayPosition(column.expr, lit(c).expr) case _ => ArrayPosition(column.expr, Literal(value)) } } @@ -3164,7 +3164,7 @@ object functions { */ def element_at(column: Column, value: Any): Column = withExpr { value match { - case c: Column => ElementAt(column.expr, c.expr) + case c: Column => ElementAt(column.expr, lit(c).expr) case _ => ElementAt(column.expr, Literal(value)) } } @@ -3185,7 +3185,7 @@ object functions { */ def array_remove(column: Column, element: Any): Column = withExpr { element match { - case c: Column => ArrayRemove(column.expr, c.expr) + case c: Column => ArrayRemove(column.expr, lit(c).expr) case _ => ArrayRemove(column.expr, Literal(element)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index f9caff67534e..92ceaa2cc97b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -575,6 +575,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.select(array_contains(df("a"), df("c"))), Seq(Row(true), Row(false)) ) + checkAnswer( + df.selectExpr("array_contains(a, c)"), + Seq(Row(true), Row(false)) + ) // In hive, this errors because null has no type information intercept[AnalysisException] { @@ -801,6 +805,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("array_position(a, 1)"), Seq(Row(1L), Row(0L)) ) + checkAnswer( + df.selectExpr("array_position(a, c)"), + Seq(Row(1L), Row(0L)) + ) checkAnswer( df.select(array_position(df("a"), df("c"))), Seq(Row(1L), Row(0L)) @@ -856,6 +864,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.select(element_at(df("a"), df("b"))), Seq(Row("1"), Row(""), Row(null)) ) + checkAnswer( + df.selectExpr("element_at(a, b)"), + Seq(Row("1"), Row(""), Row(null)) + ) checkAnswer( df.select(element_at(df("a"), 1)), @@ -1143,6 +1155,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row(null)) ) + checkAnswer( + df.selectExpr("array_remove(a, d)"), + Seq( + Row(Seq(1, 3)), + Row(Seq.empty[Int]), + Row(null)) + ) + checkAnswer( df.selectExpr("array_remove(a, 2)", "array_remove(b, \"a\")", "array_remove(c, \"\")"), From ddd94f7e27f3622db89abd8c4f85975fa0034fff Mon Sep 17 00:00:00 2001 From: Chongguang LIU Date: Wed, 20 Jun 2018 15:45:59 +0200 Subject: [PATCH 8/8] use lit() to unify the cases --- .../org/apache/spark/sql/functions.scala | 20 ++++--------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 77ce7123babe..b5ee405cd634 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3082,10 +3082,7 @@ object functions { * @since 1.5.0 */ def array_contains(column: Column, value: Any): Column = withExpr { - value match { - case c: Column => ArrayContains(column.expr, lit(c).expr) - case _ => ArrayContains(column.expr, Literal(value)) - } + ArrayContains(column.expr, lit(value).expr) } /** @@ -3149,10 +3146,7 @@ object functions { * @since 2.4.0 */ def array_position(column: Column, value: Any): Column = withExpr { - value match { - case c: Column => ArrayPosition(column.expr, lit(c).expr) - case _ => ArrayPosition(column.expr, Literal(value)) - } + ArrayPosition(column.expr, lit(value).expr) } /** @@ -3163,10 +3157,7 @@ object functions { * @since 2.4.0 */ def element_at(column: Column, value: Any): Column = withExpr { - value match { - case c: Column => ElementAt(column.expr, lit(c).expr) - case _ => ElementAt(column.expr, Literal(value)) - } + ElementAt(column.expr, lit(value).expr) } /** @@ -3184,10 +3175,7 @@ object functions { * @since 2.4.0 */ def array_remove(column: Column, element: Any): Column = withExpr { - element match { - case c: Column => ArrayRemove(column.expr, lit(c).expr) - case _ => ArrayRemove(column.expr, Literal(element)) - } + ArrayRemove(column.expr, lit(element).expr) } /**