Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1834,6 +1834,19 @@ def array_contains(col, value):
return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))


@since(2.4)
def slice(x, start, length):
"""
Collection function: returns an array containing all the elements in `x` from index `start`
(or starting from the end if `start` is negative) with the specified `length`.
>>> df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x'])
>>> df.select(slice(df.x, 2, 2).alias("sliced")).collect()
[Row(sliced=[2, 3]), Row(sliced=[5])]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.slice(_to_java_column(x), start, length))


@ignore_unicode_prefix
@since(2.4)
def array_join(col, delimiter, null_replacement=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ object FunctionRegistry {
expression[MapKeys]("map_keys"),
expression[MapValues]("map_values"),
expression[Size]("size"),
expression[Slice]("slice"),
expression[Size]("cardinality"),
expression[SortArray]("sort_array"),
expression[ArrayMin]("array_min"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types._
import org.apache.spark.util.{ParentClassLoader, Utils}

Expand Down Expand Up @@ -730,6 +731,39 @@ class CodegenContext {
""".stripMargin
}

/**
* Generates code creating a [[UnsafeArrayData]].
*
* @param arrayName name of the array to create
* @param numElements code representing the number of elements the array should contain
* @param elementType data type of the elements in the array
* @param additionalErrorMessage string to include in the error message
*/
def createUnsafeArray(
arrayName: String,
numElements: String,
elementType: DataType,
additionalErrorMessage: String): String = {
val arraySize = freshName("size")
val arrayBytes = freshName("arrayBytes")

s"""
|long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
| $numElements,
| ${elementType.defaultSize});
|if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| throw new RuntimeException("Unsuccessful try create array with " + $arraySize +
| " bytes of data due to exceeding the limit " +
| "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} bytes for UnsafeArrayData." +
| "$additionalErrorMessage");
|}
|byte[] $arrayBytes = new byte[(int)$arraySize];
|UnsafeArrayData $arrayName = new UnsafeArrayData();
|Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements);
|$arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize);
""".stripMargin
}

/**
* Generates code to do null safe execution, i.e. only execute the code when the input is not
* null by adding null check if necessary.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}

Expand Down Expand Up @@ -378,6 +377,129 @@ case class ArrayContains(left: Expression, right: Expression)
override def prettyName: String = "array_contains"
}

/**
* Slices an array according to the requested start index and length
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(x, start, length) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3, 4), 2, 2);
[2,3]
> SELECT _FUNC_(array(1, 2, 3, 4), -2, 2);
[3,4]
""", since = "2.4.0")
// scalastyle:on line.size.limit
case class Slice(x: Expression, start: Expression, length: Expression)
extends TernaryExpression with ImplicitCastInputTypes {

override def dataType: DataType = x.dataType

override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType)

override def children: Seq[Expression] = Seq(x, start, length)

lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType

override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = {
val startInt = startVal.asInstanceOf[Int]
val lengthInt = lengthVal.asInstanceOf[Int]
val arr = xVal.asInstanceOf[ArrayData]
val startIndex = if (startInt == 0) {
throw new RuntimeException(
s"Unexpected value for start in function $prettyName: SQL array indices start at 1.")
} else if (startInt < 0) {
startInt + arr.numElements()
} else {
startInt - 1
}
if (lengthInt < 0) {
throw new RuntimeException(s"Unexpected value for length in function $prettyName: " +
"length must be greater than or equal to 0.")
}
// startIndex can be negative if start is negative and its absolute value is greater than the
// number of elements in the array
if (startIndex < 0 || startIndex >= arr.numElements()) {
return new GenericArrayData(Array.empty[AnyRef])
}
val data = arr.toSeq[AnyRef](elementType)
new GenericArrayData(data.slice(startIndex, startIndex + lengthInt))
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (x, start, length) => {
val startIdx = ctx.freshName("startIdx")
val resLength = ctx.freshName("resLength")
val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false)
s"""
|${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue;
|${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue;
|if ($start == 0) {
| throw new RuntimeException("Unexpected value for start in function $prettyName: "
| + "SQL array indices start at 1.");
|} else if ($start < 0) {
| $startIdx = $start + $x.numElements();
|} else {
| // arrays in SQL are 1-based instead of 0-based
| $startIdx = $start - 1;
|}
|if ($length < 0) {
| throw new RuntimeException("Unexpected value for length in function $prettyName: "
| + "length must be greater than or equal to 0.");
|} else if ($length > $x.numElements() - $startIdx) {
| $resLength = $x.numElements() - $startIdx;
|} else {
| $resLength = $length;
|}
|${genCodeForResult(ctx, ev, x, startIdx, resLength)}
""".stripMargin
})
}

def genCodeForResult(
ctx: CodegenContext,
ev: ExprCode,
inputArray: String,
startIdx: String,
resLength: String): String = {
val values = ctx.freshName("values")
val i = ctx.freshName("i")
val getValue = CodeGenerator.getValue(inputArray, elementType, s"$i + $startIdx")
if (!CodeGenerator.isPrimitiveType(elementType)) {
val arrayClass = classOf[GenericArrayData].getName
s"""
|Object[] $values;
|if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) {
| $values = new Object[0];
|} else {
| $values = new Object[$resLength];
| for (int $i = 0; $i < $resLength; $i ++) {
| $values[$i] = $getValue;
| }
|}
|${ev.value} = new $arrayClass($values);
""".stripMargin
} else {
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
s"""
|if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) {
| $resLength = 0;
|}
|${ctx.createUnsafeArray(values, resLength, elementType, s" $prettyName failed.")}
|for (int $i = 0; $i < $resLength; $i ++) {
| if ($inputArray.isNullAt($i + $startIdx)) {
| $values.setNullAt($i);
| } else {
| $values.set$primitiveValueTypeName($i, $getValue);
| }
|}
|${ev.value} = $values;
""".stripMargin
}
}
}

/**
* Creates a String containing all the elements of the input array separated by the delimiter.
*/
Expand Down Expand Up @@ -975,36 +1097,19 @@ case class Concat(children: Seq[Expression]) extends Expression {
}

private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
val arrayName = ctx.freshName("array")
val arraySizeName = ctx.freshName("size")
val counter = ctx.freshName("counter")
val arrayData = ctx.freshName("arrayData")

val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)

val unsafeArraySizeInBytes = s"""
|long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
| $numElemName,
| ${elementType.defaultSize});
|if ($arraySizeName > $MAX_ARRAY_LENGTH) {
| throw new RuntimeException("Unsuccessful try to concat arrays with " + $arraySizeName +
| " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes" +
| " for UnsafeArrayData.");
|}
""".stripMargin
val baseOffset = Platform.BYTE_ARRAY_OFFSET
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)

s"""
|new Object() {
| public ArrayData concat($javaType[] args) {
| ${nullArgumentProtection()}
| $numElemCode
| $unsafeArraySizeInBytes
| byte[] $arrayName = new byte[(int)$arraySizeName];
| UnsafeArrayData $arrayData = new UnsafeArrayData();
| Platform.putLong($arrayName, $baseOffset, $numElemName);
| $arrayData.pointTo($arrayName, $baseOffset, (int)$arraySizeName);
| ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")}
| int $counter = 0;
| for (int y = 0; y < ${children.length}; y++) {
| for (int z = 0; z < args[y].numElements(); z++) {
Expand Down Expand Up @@ -1156,34 +1261,16 @@ case class Flatten(child: Expression) extends UnaryExpression {
ctx: CodegenContext,
childVariableName: String,
arrayDataName: String): String = {
val arrayName = ctx.freshName("array")
val arraySizeName = ctx.freshName("size")
val counter = ctx.freshName("counter")
val tempArrayDataName = ctx.freshName("tempArrayData")

val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName)

val unsafeArraySizeInBytes = s"""
|long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
| $numElemName,
| ${elementType.defaultSize});
|if ($arraySizeName > $MAX_ARRAY_LENGTH) {
| throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " +
| $arraySizeName + " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH" +
| " bytes for UnsafeArrayData.");
|}
""".stripMargin
val baseOffset = Platform.BYTE_ARRAY_OFFSET

val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)

s"""
|$numElemCode
|$unsafeArraySizeInBytes
|byte[] $arrayName = new byte[(int)$arraySizeName];
|UnsafeArrayData $tempArrayDataName = new UnsafeArrayData();
|Platform.putLong($arrayName, $baseOffset, $numElemName);
|$tempArrayDataName.pointTo($arrayName, $baseOffset, (int)$arraySizeName);
|${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, s" $prettyName failed.")}
|int $counter = 0;
|for (int k = 0; k < $childVariableName.numElements(); k++) {
| ArrayData arr = $childVariableName.getArray(k);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,34 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
}

test("Slice") {
val a0 = Literal.create(Seq(1, 2, 3, 4, 5, 6), ArrayType(IntegerType))
val a1 = Literal.create(Seq[String]("a", "b", "c", "d"), ArrayType(StringType))
val a2 = Literal.create(Seq[String]("", null, "a", "b"), ArrayType(StringType))
val a3 = Literal.create(Seq(1, 2, null, 4), ArrayType(IntegerType))

checkEvaluation(Slice(a0, Literal(1), Literal(2)), Seq(1, 2))
checkEvaluation(Slice(a0, Literal(-3), Literal(2)), Seq(4, 5))
checkEvaluation(Slice(a0, Literal(4), Literal(10)), Seq(4, 5, 6))
checkEvaluation(Slice(a0, Literal(-1), Literal(2)), Seq(6))
checkExceptionInExpression[RuntimeException](Slice(a0, Literal(1), Literal(-1)),
"Unexpected value for length")
checkExceptionInExpression[RuntimeException](Slice(a0, Literal(0), Literal(1)),
"Unexpected value for start")
checkEvaluation(Slice(a0, Literal(-20), Literal(1)), Seq.empty[Int])
checkEvaluation(Slice(a1, Literal(-20), Literal(1)), Seq.empty[String])
checkEvaluation(Slice(a0, Literal.create(null, IntegerType), Literal(2)), null)
checkEvaluation(Slice(a0, Literal(2), Literal.create(null, IntegerType)), null)
checkEvaluation(Slice(Literal.create(null, ArrayType(IntegerType)), Literal(1), Literal(2)),
null)

checkEvaluation(Slice(a1, Literal(1), Literal(2)), Seq("a", "b"))
checkEvaluation(Slice(a2, Literal(1), Literal(2)), Seq("", null))
checkEvaluation(Slice(a0, Literal(10), Literal(1)), Seq.empty[Int])
checkEvaluation(Slice(a1, Literal(10), Literal(1)), Seq.empty[String])
checkEvaluation(Slice(a3, Literal(2), Literal(3)), Seq(2, null, 4))
}

test("ArrayJoin") {
def testArrays(
arrays: Seq[Expression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
}
}

protected def checkExceptionInExpression[T <: Throwable : ClassTag](
expression: => Expression,
expectedErrMsg: String): Unit = {
checkExceptionInExpression[T](expression, InternalRow.empty, expectedErrMsg)
}

protected def checkExceptionInExpression[T <: Throwable : ClassTag](
expression: => Expression,
inputRow: InternalRow,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Literal.fromObject(new java.util.LinkedList[Int]),
Map("nonexisting" -> Literal(1)))
checkExceptionInExpression[Exception](initializeWithNonexistingMethod,
InternalRow.fromSeq(Seq()),
"""A method named "nonexisting" is not declared in any enclosing class """ +
"nor any supertype")

Expand Down
10 changes: 10 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3039,6 +3039,16 @@ object functions {
ArrayContains(column.expr, Literal(value))
}

/**
* Returns an array containing all the elements in `x` from index `start` (or starting from the
* end if `start` is negative) with the specified `length`.
* @group collection_funcs
* @since 2.4.0
*/
def slice(x: Column, start: Int, length: Int): Column = withExpr {
Slice(x.expr, Literal(start), Literal(length))
}

/**
* Concatenates the elements of `column` using the `delimiter`. Null values are replaced with
* `nullReplacement`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
)
}

test("slice function") {
val df = Seq(
Seq(1, 2, 3),
Seq(4, 5)
).toDF("x")

val answer = Seq(Row(Seq(2, 3)), Row(Seq(5)))

checkAnswer(df.select(slice(df("x"), 2, 2)), answer)
checkAnswer(df.selectExpr("slice(x, 2, 2)"), answer)

val answerNegative = Seq(Row(Seq(3)), Row(Seq(5)))
checkAnswer(df.select(slice(df("x"), -1, 1)), answerNegative)
checkAnswer(df.selectExpr("slice(x, -1, 1)"), answerNegative)
}

test("array_join function") {
val df = Seq(
(Seq[String]("a", "b"), ","),
Expand Down