Skip to content

Commit cd10f9d

Browse files
mgaido91ueshin
authored andcommitted
[SPARK-23916][SQL] Add array_join function
## What changes were proposed in this pull request? The PR adds the SQL function `array_join`. The behavior of the function is based on Presto's one. The function accepts an `array` of `string` which is to be joined, a `string` which is the delimiter to use between the items of the first argument and optionally a `string` which is used to replace `null` values. ## How was this patch tested? added UTs Author: Marco Gaido <[email protected]> Closes #21011 from mgaido91/SPARK-23916.
1 parent 58c55cb commit cd10f9d

File tree

6 files changed

+268
-0
lines changed

6 files changed

+268
-0
lines changed

python/pyspark/sql/functions.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1834,6 +1834,27 @@ def array_contains(col, value):
18341834
return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))
18351835

18361836

1837+
@ignore_unicode_prefix
1838+
@since(2.4)
1839+
def array_join(col, delimiter, null_replacement=None):
1840+
"""
1841+
Concatenates the elements of `column` using the `delimiter`. Null values are replaced with
1842+
`null_replacement` if set, otherwise they are ignored.
1843+
1844+
>>> df = spark.createDataFrame([(["a", "b", "c"],), (["a", None],)], ['data'])
1845+
>>> df.select(array_join(df.data, ",").alias("joined")).collect()
1846+
[Row(joined=u'a,b,c'), Row(joined=u'a')]
1847+
>>> df.select(array_join(df.data, ",", "NULL").alias("joined")).collect()
1848+
[Row(joined=u'a,b,c'), Row(joined=u'a,NULL')]
1849+
"""
1850+
sc = SparkContext._active_spark_context
1851+
if null_replacement is None:
1852+
return Column(sc._jvm.functions.array_join(_to_java_column(col), delimiter))
1853+
else:
1854+
return Column(sc._jvm.functions.array_join(
1855+
_to_java_column(col), delimiter, null_replacement))
1856+
1857+
18371858
@since(1.5)
18381859
@ignore_unicode_prefix
18391860
def concat(*cols):

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ object FunctionRegistry {
401401
// collection functions
402402
expression[CreateArray]("array"),
403403
expression[ArrayContains]("array_contains"),
404+
expression[ArrayJoin]("array_join"),
404405
expression[ArrayPosition]("array_position"),
405406
expression[CreateMap]("map"),
406407
expression[CreateNamedStruct]("named_struct"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,175 @@ case class ArrayContains(left: Expression, right: Expression)
378378
override def prettyName: String = "array_contains"
379379
}
380380

381+
/**
382+
* Creates a String containing all the elements of the input array separated by the delimiter.
383+
*/
384+
@ExpressionDescription(
385+
usage = """
386+
_FUNC_(array, delimiter[, nullReplacement]) - Concatenates the elements of the given array
387+
using the delimiter and an optional string to replace nulls. If no value is set for
388+
nullReplacement, any null value is filtered.""",
389+
examples = """
390+
Examples:
391+
> SELECT _FUNC_(array('hello', 'world'), ' ');
392+
hello world
393+
> SELECT _FUNC_(array('hello', null ,'world'), ' ');
394+
hello world
395+
> SELECT _FUNC_(array('hello', null ,'world'), ' ', ',');
396+
hello , world
397+
""", since = "2.4.0")
398+
case class ArrayJoin(
399+
array: Expression,
400+
delimiter: Expression,
401+
nullReplacement: Option[Expression]) extends Expression with ExpectsInputTypes {
402+
403+
def this(array: Expression, delimiter: Expression) = this(array, delimiter, None)
404+
405+
def this(array: Expression, delimiter: Expression, nullReplacement: Expression) =
406+
this(array, delimiter, Some(nullReplacement))
407+
408+
override def inputTypes: Seq[AbstractDataType] = if (nullReplacement.isDefined) {
409+
Seq(ArrayType(StringType), StringType, StringType)
410+
} else {
411+
Seq(ArrayType(StringType), StringType)
412+
}
413+
414+
override def children: Seq[Expression] = if (nullReplacement.isDefined) {
415+
Seq(array, delimiter, nullReplacement.get)
416+
} else {
417+
Seq(array, delimiter)
418+
}
419+
420+
override def nullable: Boolean = children.exists(_.nullable)
421+
422+
override def foldable: Boolean = children.forall(_.foldable)
423+
424+
override def eval(input: InternalRow): Any = {
425+
val arrayEval = array.eval(input)
426+
if (arrayEval == null) return null
427+
val delimiterEval = delimiter.eval(input)
428+
if (delimiterEval == null) return null
429+
val nullReplacementEval = nullReplacement.map(_.eval(input))
430+
if (nullReplacementEval.contains(null)) return null
431+
432+
val buffer = new UTF8StringBuilder()
433+
var firstItem = true
434+
val nullHandling = nullReplacementEval match {
435+
case Some(rep) => (prependDelimiter: Boolean) => {
436+
if (!prependDelimiter) {
437+
buffer.append(delimiterEval.asInstanceOf[UTF8String])
438+
}
439+
buffer.append(rep.asInstanceOf[UTF8String])
440+
true
441+
}
442+
case None => (_: Boolean) => false
443+
}
444+
arrayEval.asInstanceOf[ArrayData].foreach(StringType, (_, item) => {
445+
if (item == null) {
446+
if (nullHandling(firstItem)) {
447+
firstItem = false
448+
}
449+
} else {
450+
if (!firstItem) {
451+
buffer.append(delimiterEval.asInstanceOf[UTF8String])
452+
}
453+
buffer.append(item.asInstanceOf[UTF8String])
454+
firstItem = false
455+
}
456+
})
457+
buffer.build()
458+
}
459+
460+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
461+
val code = nullReplacement match {
462+
case Some(replacement) =>
463+
val replacementGen = replacement.genCode(ctx)
464+
val nullHandling = (buffer: String, delimiter: String, firstItem: String) => {
465+
s"""
466+
|if (!$firstItem) {
467+
| $buffer.append($delimiter);
468+
|}
469+
|$buffer.append(${replacementGen.value});
470+
|$firstItem = false;
471+
""".stripMargin
472+
}
473+
val execCode = if (replacement.nullable) {
474+
ctx.nullSafeExec(replacement.nullable, replacementGen.isNull) {
475+
genCodeForArrayAndDelimiter(ctx, ev, nullHandling)
476+
}
477+
} else {
478+
genCodeForArrayAndDelimiter(ctx, ev, nullHandling)
479+
}
480+
s"""
481+
|${replacementGen.code}
482+
|$execCode
483+
""".stripMargin
484+
case None => genCodeForArrayAndDelimiter(ctx, ev,
485+
(_: String, _: String, _: String) => "// nulls are ignored")
486+
}
487+
if (nullable) {
488+
ev.copy(
489+
s"""
490+
|boolean ${ev.isNull} = true;
491+
|UTF8String ${ev.value} = null;
492+
|$code
493+
""".stripMargin)
494+
} else {
495+
ev.copy(
496+
s"""
497+
|UTF8String ${ev.value} = null;
498+
|$code
499+
""".stripMargin, FalseLiteral)
500+
}
501+
}
502+
503+
private def genCodeForArrayAndDelimiter(
504+
ctx: CodegenContext,
505+
ev: ExprCode,
506+
nullEval: (String, String, String) => String): String = {
507+
val arrayGen = array.genCode(ctx)
508+
val delimiterGen = delimiter.genCode(ctx)
509+
val buffer = ctx.freshName("buffer")
510+
val bufferClass = classOf[UTF8StringBuilder].getName
511+
val i = ctx.freshName("i")
512+
val firstItem = ctx.freshName("firstItem")
513+
val resultCode =
514+
s"""
515+
|$bufferClass $buffer = new $bufferClass();
516+
|boolean $firstItem = true;
517+
|for (int $i = 0; $i < ${arrayGen.value}.numElements(); $i ++) {
518+
| if (${arrayGen.value}.isNullAt($i)) {
519+
| ${nullEval(buffer, delimiterGen.value, firstItem)}
520+
| } else {
521+
| if (!$firstItem) {
522+
| $buffer.append(${delimiterGen.value});
523+
| }
524+
| $buffer.append(${CodeGenerator.getValue(arrayGen.value, StringType, i)});
525+
| $firstItem = false;
526+
| }
527+
|}
528+
|${ev.value} = $buffer.build();""".stripMargin
529+
530+
if (array.nullable || delimiter.nullable) {
531+
arrayGen.code + ctx.nullSafeExec(array.nullable, arrayGen.isNull) {
532+
delimiterGen.code + ctx.nullSafeExec(delimiter.nullable, delimiterGen.isNull) {
533+
s"""
534+
|${ev.isNull} = false;
535+
|$resultCode""".stripMargin
536+
}
537+
}
538+
} else {
539+
s"""
540+
|${arrayGen.code}
541+
|${delimiterGen.code}
542+
|$resultCode""".stripMargin
543+
}
544+
}
545+
546+
override def dataType: DataType = StringType
547+
548+
}
549+
381550
/**
382551
* Returns the minimum value in the array.
383552
*/

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,41 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
106106
checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
107107
}
108108

109+
test("ArrayJoin") {
110+
def testArrays(
111+
arrays: Seq[Expression],
112+
nullReplacement: Option[Expression],
113+
expected: Seq[String]): Unit = {
114+
assert(arrays.length == expected.length)
115+
arrays.zip(expected).foreach { case (arr, exp) =>
116+
checkEvaluation(ArrayJoin(arr, Literal(","), nullReplacement), exp)
117+
}
118+
}
119+
120+
val arrays = Seq(Literal.create(Seq[String]("a", "b"), ArrayType(StringType)),
121+
Literal.create(Seq[String]("a", null, "b"), ArrayType(StringType)),
122+
Literal.create(Seq[String](null), ArrayType(StringType)),
123+
Literal.create(Seq[String]("a", "b", null), ArrayType(StringType)),
124+
Literal.create(Seq[String](null, "a", "b"), ArrayType(StringType)),
125+
Literal.create(Seq[String]("a"), ArrayType(StringType)))
126+
127+
val withoutNullReplacement = Seq("a,b", "a,b", "", "a,b", "a,b", "a")
128+
val withNullReplacement = Seq("a,b", "a,NULL,b", "NULL", "a,b,NULL", "NULL,a,b", "a")
129+
testArrays(arrays, None, withoutNullReplacement)
130+
testArrays(arrays, Some(Literal("NULL")), withNullReplacement)
131+
132+
checkEvaluation(ArrayJoin(
133+
Literal.create(null, ArrayType(StringType)), Literal(","), None), null)
134+
checkEvaluation(ArrayJoin(
135+
Literal.create(Seq[String](null), ArrayType(StringType)),
136+
Literal.create(null, StringType),
137+
None), null)
138+
checkEvaluation(ArrayJoin(
139+
Literal.create(Seq[String](null), ArrayType(StringType)),
140+
Literal(","),
141+
Some(Literal.create(null, StringType))), null)
142+
}
143+
109144
test("Array Min") {
110145
checkEvaluation(ArrayMin(Literal.create(Seq(-11, 10, 2), ArrayType(IntegerType))), -11)
111146
checkEvaluation(

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3039,6 +3039,25 @@ object functions {
30393039
ArrayContains(column.expr, Literal(value))
30403040
}
30413041

3042+
/**
3043+
* Concatenates the elements of `column` using the `delimiter`. Null values are replaced with
3044+
* `nullReplacement`.
3045+
* @group collection_funcs
3046+
* @since 2.4.0
3047+
*/
3048+
def array_join(column: Column, delimiter: String, nullReplacement: String): Column = withExpr {
3049+
ArrayJoin(column.expr, Literal(delimiter), Some(Literal(nullReplacement)))
3050+
}
3051+
3052+
/**
3053+
* Concatenates the elements of `column` using the `delimiter`.
3054+
* @group collection_funcs
3055+
* @since 2.4.0
3056+
*/
3057+
def array_join(column: Column, delimiter: String): Column = withExpr {
3058+
ArrayJoin(column.expr, Literal(delimiter), None)
3059+
}
3060+
30423061
/**
30433062
* Concatenates multiple input columns together into a single column.
30443063
* The function works with strings, binary and compatible array columns.

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,29 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
413413
)
414414
}
415415

416+
test("array_join function") {
417+
val df = Seq(
418+
(Seq[String]("a", "b"), ","),
419+
(Seq[String]("a", null, "b"), ","),
420+
(Seq.empty[String], ",")
421+
).toDF("x", "delimiter")
422+
423+
checkAnswer(
424+
df.select(array_join(df("x"), ";")),
425+
Seq(Row("a;b"), Row("a;b"), Row(""))
426+
)
427+
checkAnswer(
428+
df.select(array_join(df("x"), ";", "NULL")),
429+
Seq(Row("a;b"), Row("a;NULL;b"), Row(""))
430+
)
431+
checkAnswer(
432+
df.selectExpr("array_join(x, delimiter)"),
433+
Seq(Row("a,b"), Row("a,b"), Row("")))
434+
checkAnswer(
435+
df.selectExpr("array_join(x, delimiter, 'NULL')"),
436+
Seq(Row("a,b"), Row("a,NULL,b"), Row("")))
437+
}
438+
416439
test("array_min function") {
417440
val df = Seq(
418441
Seq[Option[Int]](Some(1), Some(3), Some(2)),

0 commit comments

Comments
 (0)