-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23931][SQL] Adds arrays_zip function to sparksql #21045
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
7bf45dd
99848fe
27b0bc2
93826b6
a7e29f6
7130fec
d552216
1fecef4
f71151a
6b4bc94
1549928
9f7bba1
3ba2b4f
3a59201
6462fa8
8b1eb7c
2bfba80
c3b062c
d9b95c4
26bbf66
d9ad04d
f29ee1c
c58d09c
38fa996
5b3066b
759a4d4
68e69db
12b3835
643cb9b
5876082
0223960
2b88387
bbc20ee
8d3a838
d8f3dea
3d68ea9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
Signed-off-by: DylanGuedes <[email protected]>
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2395,21 +2395,20 @@ def array_repeat(col, count): | |
|
|
||
|
|
||
| @since(2.4) | ||
| def zip(col1, col2): | ||
| def zip(*cols): | ||
| """ | ||
| Merge two columns into one, such that the M-th element of the N-th argument will be | ||
| the N-th field of the M-th output element. | ||
|
|
||
| :param col1: name of the first column | ||
| :param col2: name of the second column | ||
| :param cols: columns in input | ||
|
||
|
|
||
| >>> from pyspark.sql.functions import zip | ||
| >>> df = spark.createDataFrame([(([1, 2, 3], [2, 3, 4]))], ['vals1', 'vals2']) | ||
| >>> df.select(zip(df.vals1, df.vals2).alias('zipped')).collect() | ||
| [Row(zipped=[1, 2]), Row(zipped=[2, 3]), Row(zipped=[3, 4])] | ||
| """ | ||
| sc = SparkContext._active_spark_context | ||
| return Column(sc._jvm.functions.zip(_to_java_column(col1), _to_java_column(col2))) | ||
| return Column(sc._jvm.functions.zip(_to_seq(sc, cols, _to_java_column))) | ||
|
|
||
|
|
||
| # ---------------------------- User Defined Function ---------------------------------- | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -129,106 +129,84 @@ case class MapKeys(child: Expression) | |
| } | ||
|
|
||
| @ExpressionDescription( | ||
| usage = """_FUNC_(a1, a2) - Returns a merged array matching N-th element of first | ||
| array with the N-th element of second.""", | ||
| usage = """_FUNC_(a1, a2, ...) - Returns a merged array containing in the N-th position the | ||
| N-th value of each array given.""", | ||
|
||
| examples = """ | ||
| Examples: | ||
| > SELECT _FUNC_(array(1, 2, 3), array(2, 3, 4)); | ||
| [[1, 2], [2, 3], [3, 4]] | ||
| > SELECT _FUNC_(array(1, 2), array(2, 3), array(3, 4)); | ||
| [[1, 2, 3], [2, 3, 4]] | ||
| """, | ||
| since = "2.4.0") | ||
| case class Zip(left: Expression, right: Expression) | ||
| extends BinaryExpression with ExpectsInputTypes { | ||
| case class Zip(children: Seq[Expression]) extends Expression with ExpectsInputTypes { | ||
| private[this] val childrenArray = children.toArray | ||
|
|
||
| override def inputTypes: Seq[AbstractDataType] = Seq.fill(childrenArray.length)(ArrayType) | ||
|
|
||
| def mountSchema(): StructType = { | ||
|
||
| val arrayAT = childrenArray.map(_.dataType.asInstanceOf[ArrayType]) | ||
| val n = childrenArray.length | ||
| var i = n - 1 | ||
| var myList = List[StructField]() | ||
| while (i >= 0) { | ||
| myList = StructField(s"_$i", arrayAT(i).elementType, arrayAT(i).containsNull) :: myList | ||
| i -= 1 | ||
| } | ||
|
|
||
| override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) | ||
| StructType(myList) | ||
| } | ||
|
|
||
| override def dataType: DataType = ArrayType(StructType( | ||
| StructField("_1", left.dataType.asInstanceOf[ArrayType].elementType, true) :: | ||
| StructField("_2", right.dataType.asInstanceOf[ArrayType].elementType, true) :: | ||
| Nil)) | ||
| override def dataType: DataType = ArrayType(mountSchema()) | ||
|
|
||
| override def prettyName: String = "zip" | ||
|
|
||
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| nullSafeCodeGen(ctx, ev, (arr1, arr2) => { | ||
| val genericArrayData = classOf[GenericArrayData].getName | ||
| val genericInternalRow = classOf[GenericInternalRow].getName | ||
|
|
||
| val i = ctx.freshName("i") | ||
| val values = ctx.freshName("values") | ||
| val len1 = ctx.freshName("len1") | ||
| val len2 = ctx.freshName("len2") | ||
| val pair = ctx.freshName("pair") | ||
| val getValue1 = CodeGenerator.getValue( | ||
| arr1, left.dataType.asInstanceOf[ArrayType].elementType, i) | ||
| val getValue2 = CodeGenerator.getValue( | ||
| arr2, right.dataType.asInstanceOf[ArrayType].elementType, i) | ||
|
|
||
| s""" | ||
| |int $len1 = $arr1.numElements(); | ||
| |int $len2 = $arr2.numElements(); | ||
| |Object[] $values; | ||
| |Object[] $pair; | ||
| |if ($len1 > $len2) { | ||
| | $values = new Object[$len1]; | ||
| | for (int $i = 0; $i < $len1; $i ++) { | ||
| | $pair = new Object[2]; | ||
| | $pair[0] = $getValue1; | ||
| | if ($i >= $len2) { | ||
| | $pair[1] = null; | ||
| | } else { | ||
| | $pair[1] = $getValue2; | ||
| | } | ||
| | $values[$i] = new $genericInternalRow($pair); | ||
| | } | ||
| |} else { | ||
| | $values = new Object[$len2]; | ||
| | for (int $i = 0; $i < $len2; $i ++) { | ||
| | $pair = new Object[2]; | ||
| | $pair[1] = $getValue2; | ||
| | if ($i >= $len1) { | ||
| | $pair[0] = null; | ||
| | } else { | ||
| | $pair[0] = $getValue1; | ||
| | } | ||
| | $values[$i] = new $genericInternalRow($pair); | ||
| | } | ||
| |} | ||
| |${ev.value} = new $genericArrayData($values); | ||
| """.stripMargin | ||
| }) | ||
| } | ||
| val genericArrayData = classOf[GenericArrayData].getName | ||
| val genericInternalRow = classOf[GenericInternalRow].getName | ||
|
|
||
| def extendWithNull(a1: Array[AnyRef], a2: Array[AnyRef]): | ||
| (Array[AnyRef], Array[AnyRef]) = { | ||
| val lens = (a1.length, a2.length) | ||
| val evals = children.map(_.genCode(ctx)) | ||
| val numArrs = evals.length | ||
|
|
||
| var arr1 = a1 | ||
| var arr2 = a2 | ||
| val values = children.zip(evals).map { case(child, eval) => | ||
|
|
||
| val diff = lens._1 - lens._2 | ||
| if (lens._1 > lens._2) { | ||
| arr2 = a2 ++ Array.fill(diff)(null) | ||
| } | ||
| if (lens._1 < lens._2) { | ||
| arr1 = a1 ++ Array.fill(-diff)(null) | ||
| } | ||
|
|
||
| (arr1, arr2) | ||
| ev.copy(code = | ||
| s""" | ||
| """.stripMargin) | ||
| } | ||
|
|
||
| override def nullSafeEval(a1: Any, a2: Any): Any = { | ||
| val type1 = left.dataType.asInstanceOf[ArrayType].elementType | ||
| val type2 = right.dataType.asInstanceOf[ArrayType].elementType | ||
| override def nullable: Boolean = children.forall(_.nullable) | ||
|
|
||
| val arrays = ( | ||
| a1.asInstanceOf[ArrayData].toArray[AnyRef](type1), | ||
| a2.asInstanceOf[ArrayData].toArray[AnyRef](type2) | ||
| ) | ||
| override def eval(input: InternalRow): Any = { | ||
| val inputArrays = childrenArray.map(_.eval(input).asInstanceOf[ArrayData]) | ||
| val arrayTypes = childrenArray.map(_.dataType.asInstanceOf[ArrayType].elementType) | ||
| val numberOfArrays = childrenArray.length | ||
|
|
||
| val extendedArrays = extendWithNull(arrays._1, arrays._2) | ||
| var biggestCardinality = 0 | ||
| for (e <- inputArrays) { | ||
| biggestCardinality = biggestCardinality max e.numElements() | ||
| } | ||
|
|
||
| new GenericArrayData(extendedArrays.zipped.map((a, b) => InternalRow.apply(a, b))) | ||
| var i = 0 | ||
| var j = 0 | ||
| var result = Seq[InternalRow]() | ||
| while (i < biggestCardinality) { | ||
| var myList = List[Any]() | ||
|
||
| j = numberOfArrays - 1 | ||
| while (j >= 0) { | ||
| if (inputArrays(j).numElements() > i) { | ||
| myList = inputArrays(j).get(i, arrayTypes(j)) :: myList | ||
| } else { | ||
| myList = null :: myList | ||
| } | ||
| j -= 1 | ||
| } | ||
| result = result :+ InternalRow.apply(myList: _*) | ||
| i += 1 | ||
| } | ||
| new GenericArrayData(result) | ||
| } | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done! |
||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -325,9 +325,11 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper | |
| val val2 = List(Row(9001, 4), Row(9002, 5), Row(null, 6)) | ||
| val val3 = List(Row("a", 4), Row("b", null), Row(null, null)) | ||
|
|
||
| checkEvaluation(Zip(lit1._1, lit1._2), val1) | ||
| checkEvaluation(Zip(lit2._1, lit2._2), val2) | ||
| checkEvaluation(Zip(lit3._1, lit3._2), val3) | ||
| checkEvaluation(Zip(Seq(Literal.create(Seq(1, 0)), Literal.create(Seq(1, 0)))), | ||
| List(Row(1, 0), Row(1, 0))) | ||
| checkEvaluation(Zip(Seq(lit1._1, lit1._2)), val1) | ||
| checkEvaluation(Zip(Seq(lit2._1, lit2._2)), val2) | ||
| checkEvaluation(Zip(Seq(lit3._1, lit3._2)), val3) | ||
| } | ||
|
||
|
|
||
| test("Array Min") { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3509,12 +3509,12 @@ object functions { | |
| def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) } | ||
|
|
||
| /** | ||
| * Merge two columns into a resulting one. | ||
| * Merge multiple columns into a resulting one. | ||
|
||
| * | ||
| * @group collection_funcs | ||
| * @since 2.4.0 | ||
| */ | ||
| def zip(e1: Column, e2: Column): Column = withExpr { Zip(e1.expr, e2.expr) } | ||
| def zip(e: Column*): Column = withExpr { Zip(e.map(_.expr)) } | ||
|
|
||
| ////////////////////////////////////////////////////////////////////////////////////////////// | ||
| // Mask functions | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -481,16 +481,16 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { | |
|
|
||
| test("dataframe zip function") { | ||
|
||
| val df1 = Seq((Seq(9001, 9002, 9003), Seq(4, 5, 6))).toDF("val1", "val2") | ||
| val df2 = Seq((Seq(9001, 9002), Seq(4, 5, 6))).toDF("val1", "val2") | ||
| val df2 = Seq((Seq("a", "b"), Seq(4, 5), Seq(10, 11))).toDF("val1", "val2", "val3") | ||
| val df3 = Seq((Seq("a", "b"), Seq(4, 5, 6))).toDF("val1", "val2") | ||
|
|
||
| val expectedValue1 = Row(Seq(Row(9001, 4), Row(9002, 5), Row(9003, 6))) | ||
| checkAnswer(df1.select(zip($"val1", $"val2")), expectedValue1) | ||
| checkAnswer(df1.selectExpr("zip(val1, val2)"), expectedValue1) | ||
|
|
||
| val expectedValue2 = Row(Seq(Row(9001, 4), Row(9002, 5), Row(null, 6))) | ||
| checkAnswer(df2.select(zip($"val1", $"val2")), expectedValue2) | ||
| checkAnswer(df2.selectExpr("zip(val1, val2)"), expectedValue2) | ||
| val expectedValue2 = Row(Seq(Row("a", 4, 10), Row("b", 5, 11))) | ||
| checkAnswer(df2.select(zip($"val1", $"val2", $"val3")), expectedValue2) | ||
| checkAnswer(df2.selectExpr("zip(val1, val2, val3)"), expectedValue2) | ||
|
|
||
| val expectedValue3 = Row(Seq(Row("a", 4), Row("b", 5), Row(null, 6))) | ||
| checkAnswer(df3.select(zip($"val1", $"val2")), expectedValue3) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add
Collection function:like other collection functions?