-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23930][SQL] Add slice function #21040
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
5cbbf7a
367aaf2
f2784f1
b94d067
dc6cb60
72ed607
9d65570
9f0deec
e2eb21e
07604e0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -378,6 +378,138 @@ case class ArrayContains(left: Expression, right: Expression) | |
| override def prettyName: String = "array_contains" | ||
| } | ||
|
|
||
| /** | ||
| * Slices an array according to the requested start index and length | ||
| */ | ||
| // scalastyle:off line.size.limit | ||
| @ExpressionDescription( | ||
| usage = "_FUNC_(x, start, length) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", | ||
| examples = """ | ||
| Examples: | ||
| > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); | ||
| [2,3] | ||
| > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); | ||
| [3,4] | ||
| """, since = "2.4.0") | ||
| // scalastyle:on line.size.limit | ||
| case class Slice(x: Expression, start: Expression, length: Expression) | ||
| extends TernaryExpression with ImplicitCastInputTypes { | ||
|
|
||
| override def dataType: DataType = x.dataType | ||
|
|
||
| override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) | ||
|
|
||
| override def children: Seq[Expression] = Seq(x, start, length) | ||
|
|
||
| lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType | ||
|
|
||
| override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { | ||
| val startInt = startVal.asInstanceOf[Int] | ||
| val lengthInt = lengthVal.asInstanceOf[Int] | ||
| val arr = xVal.asInstanceOf[ArrayData] | ||
| val startIndex = if (startInt == 0) { | ||
| throw new RuntimeException( | ||
| s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") | ||
| } else if (startInt < 0) { | ||
| startInt + arr.numElements() | ||
| } else { | ||
| startInt - 1 | ||
| } | ||
| if (lengthInt < 0) { | ||
| throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + | ||
| "length must be greater than or equal to 0.") | ||
| } | ||
| // startIndex can be negative if start is negative and its absolute value is greater than the | ||
| // number of elements in the array | ||
| if (startIndex < 0 || startIndex >= arr.numElements()) { | ||
| return new GenericArrayData(Array.empty[AnyRef]) | ||
| } | ||
| val data = arr.toSeq[AnyRef](elementType) | ||
| new GenericArrayData(data.slice(startIndex, startIndex + lengthInt)) | ||
| } | ||
|
|
||
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| nullSafeCodeGen(ctx, ev, (x, start, length) => { | ||
| val startIdx = ctx.freshName("startIdx") | ||
| val resLength = ctx.freshName("resLength") | ||
| val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false) | ||
| s""" | ||
| |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue; | ||
| |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue; | ||
| |if ($start == 0) { | ||
| | throw new RuntimeException("Unexpected value for start in function $prettyName: " | ||
| | + "SQL array indices start at 1."); | ||
| |} else if ($start < 0) { | ||
| | $startIdx = $start + $x.numElements(); | ||
| |} else { | ||
| | // arrays in SQL are 1-based instead of 0-based | ||
| | $startIdx = $start - 1; | ||
| |} | ||
| |if ($length < 0) { | ||
| | throw new RuntimeException("Unexpected value for length in function $prettyName: " | ||
| | + "length must be greater than or equal to 0."); | ||
| |} else if ($length > $x.numElements() - $startIdx) { | ||
| | $resLength = $x.numElements() - $startIdx; | ||
| |} else { | ||
| | $resLength = $length; | ||
| |} | ||
| |${genCodeForResult(ctx, ev, x, startIdx, resLength)} | ||
| """.stripMargin | ||
| }) | ||
| } | ||
|
|
||
| def genCodeForResult( | ||
| ctx: CodegenContext, | ||
| ev: ExprCode, | ||
| inputArray: String, | ||
| startIdx: String, | ||
| resLength: String): String = { | ||
| val values = ctx.freshName("values") | ||
| val i = ctx.freshName("i") | ||
| val getValue = CodeGenerator.getValue(inputArray, elementType, s"$i + $startIdx") | ||
| if (!CodeGenerator.isPrimitiveType(elementType)) { | ||
| val arrayClass = classOf[GenericArrayData].getName | ||
| s""" | ||
| |Object[] $values; | ||
| |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { | ||
| | $values = new Object[0]; | ||
| |} else { | ||
| | $values = new Object[$resLength]; | ||
| | for (int $i = 0; $i < $resLength; $i ++) { | ||
| | $values[$i] = $getValue; | ||
| | } | ||
| |} | ||
| |${ev.value} = new $arrayClass($values); | ||
| """.stripMargin | ||
| } else { | ||
| val sizeInBytes = ctx.freshName("sizeInBytes") | ||
| val bytesArray = ctx.freshName("bytesArray") | ||
| val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) | ||
| s""" | ||
| |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { | ||
| | $resLength = 0; | ||
| |} | ||
| |${CodeGenerator.JAVA_INT} $sizeInBytes = | ||
| | UnsafeArrayData.calculateHeaderPortionInBytes($resLength) + | ||
| | ${classOf[ByteArrayMethods].getName}.roundNumberOfBytesToNearestWord( | ||
| | ${elementType.defaultSize} * $resLength); | ||
| |byte[] $bytesArray = new byte[$sizeInBytes]; | ||
|
||
| |UnsafeArrayData $values = new UnsafeArrayData(); | ||
| |Platform.putLong($bytesArray, ${Platform.BYTE_ARRAY_OFFSET}, $resLength); | ||
| |$values.pointTo($bytesArray, ${Platform.BYTE_ARRAY_OFFSET}, $sizeInBytes); | ||
| |for (int $i = 0; $i < $resLength; $i ++) { | ||
| | if ($inputArray.isNullAt($i + $startIdx)) { | ||
| | $values.setNullAt($i); | ||
| | } else { | ||
| | $values.set$primitiveValueTypeName($i, $getValue); | ||
| | } | ||
| |} | ||
| |${ev.value} = $values; | ||
| """.stripMargin | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Creates a String containing all the elements of the input array separated by the delimiter. | ||
| */ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -106,6 +106,34 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper | |
| checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) | ||
| } | ||
|
|
||
| test("Slice") { | ||
| val a0 = Literal.create(Seq(1, 2, 3, 4, 5, 6), ArrayType(IntegerType)) | ||
| val a1 = Literal.create(Seq[String]("a", "b", "c", "d"), ArrayType(StringType)) | ||
| val a2 = Literal.create(Seq[String]("", null, "a", "b"), ArrayType(StringType)) | ||
| val a3 = Literal.create(Seq(1, 2, null, 4), ArrayType(IntegerType)) | ||
|
|
||
| checkEvaluation(Slice(a0, Literal(1), Literal(2)), Seq(1, 2)) | ||
| checkEvaluation(Slice(a0, Literal(-3), Literal(2)), Seq(4, 5)) | ||
| checkEvaluation(Slice(a0, Literal(4), Literal(10)), Seq(4, 5, 6)) | ||
| checkEvaluation(Slice(a0, Literal(-1), Literal(2)), Seq(6)) | ||
| checkExceptionInExpression[RuntimeException](Slice(a0, Literal(1), Literal(-1)), | ||
| "Unexpected value for length") | ||
| checkExceptionInExpression[RuntimeException](Slice(a0, Literal(0), Literal(1)), | ||
| "Unexpected value for start") | ||
| checkEvaluation(Slice(a0, Literal(-20), Literal(1)), Seq.empty[Int]) | ||
| checkEvaluation(Slice(a1, Literal(-20), Literal(1)), Seq.empty[String]) | ||
| checkEvaluation(Slice(a0, Literal.create(null, IntegerType), Literal(2)), null) | ||
| checkEvaluation(Slice(a0, Literal(2), Literal.create(null, IntegerType)), null) | ||
| checkEvaluation(Slice(Literal.create(null, ArrayType(IntegerType)), Literal(1), Literal(2)), | ||
| null) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a case for something like
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And also can you add a case for nullable primitive array like |
||
|
|
||
| checkEvaluation(Slice(a1, Literal(1), Literal(2)), Seq("a", "b")) | ||
| checkEvaluation(Slice(a2, Literal(1), Literal(2)), Seq("", null)) | ||
| checkEvaluation(Slice(a0, Literal(10), Literal(1)), Seq.empty[Int]) | ||
| checkEvaluation(Slice(a1, Literal(10), Literal(1)), Seq.empty[String]) | ||
| checkEvaluation(Slice(a3, Literal(2), Literal(3)), Seq(2, null, 4)) | ||
| } | ||
|
|
||
| test("ArrayJoin") { | ||
| def testArrays( | ||
| arrays: Seq[Expression], | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: indent