Skip to content
26 changes: 22 additions & 4 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2154,20 +2154,38 @@ def array_max(col):
def sort_array(col, asc=True):
"""
Collection function: sorts the input array in ascending or descending order according
to the natural ordering of the array elements.
to the natural ordering of the array elements. Null elements will be placed at the beginning
of the returned array in ascending order or at the end of the returned array in descending
order.

:param col: name of column or expression

>>> df = spark.createDataFrame([([2, 1, 3],),([1],),([],)], ['data'])
>>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data'])
>>> df.select(sort_array(df.data).alias('r')).collect()
[Row(r=[1, 2, 3]), Row(r=[1]), Row(r=[])]
[Row(r=[None, 1, 2, 3]), Row(r=[1]), Row(r=[])]
>>> df.select(sort_array(df.data, asc=False).alias('r')).collect()
[Row(r=[3, 2, 1]), Row(r=[1]), Row(r=[])]
[Row(r=[3, 2, 1, None]), Row(r=[1]), Row(r=[])]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc))


@since(2.4)
def array_sort(col):
"""
Collection function: sorts the input array in ascending order. The elements of the input array
must be orderable. Null elements will be placed at the end of the returned array.

:param col: name of column or expression

>>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data'])
>>> df.select(array_sort(df.data).alias('r')).collect()
[Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.array_sort(_to_java_column(col)))


@since(1.5)
@ignore_unicode_prefix
def reverse(col):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ object FunctionRegistry {
expression[CreateArray]("array"),
expression[ArrayContains]("array_contains"),
expression[ArrayPosition]("array_position"),
expression[ArraySort]("array_sort"),
expression[CreateMap]("map"),
expression[CreateNamedStruct]("named_struct"),
expression[ElementAt]("element_at"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import java.util.Comparator

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.ArraySortUtil.NullOrder
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -117,47 +118,16 @@ case class MapValues(child: Expression)
}

/**
* Sorts the input array in ascending / descending order according to the natural ordering of
* the array elements and returns it.
* Common base class for [[SortArray]] and [[ArraySort]].
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order according to the natural ordering of the array elements.",
examples = """
Examples:
> SELECT _FUNC_(array('b', 'd', 'c', 'a'), true);
["a","b","c","d"]
""")
// scalastyle:on line.size.limit
case class SortArray(base: Expression, ascendingOrder: Expression)
extends BinaryExpression with ExpectsInputTypes with CodegenFallback {

def this(e: Expression) = this(e, Literal(true))
trait ArraySortUtil extends ExpectsInputTypes {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about ArraySortLike?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, thank you for your review while it is a long-holiday week in Japan.

protected def arrayExpression: Expression

override def left: Expression = base
override def right: Expression = ascendingOrder
override def dataType: DataType = base.dataType
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType)

override def checkInputDataTypes(): TypeCheckResult = base.dataType match {
case ArrayType(dt, _) if RowOrdering.isOrderable(dt) =>
ascendingOrder match {
case Literal(_: Boolean, BooleanType) =>
TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(
"Sort order in second argument requires a boolean literal.")
}
case ArrayType(dt, _) =>
TypeCheckResult.TypeCheckFailure(
s"$prettyName does not support sorting array of type ${dt.simpleString}")
case _ =>
TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.")
}
protected def nullOrder: NullOrder

@transient
private lazy val lt: Comparator[Any] = {
val ordering = base.dataType match {
val ordering = arrayExpression.dataType match {
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
Expand All @@ -168,9 +138,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
if (o1 == null && o2 == null) {
0
} else if (o1 == null) {
-1
nullOrder
} else if (o2 == null) {
1
-nullOrder
} else {
ordering.compare(o1, o2)
}
Expand All @@ -180,7 +150,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression)

@transient
private lazy val gt: Comparator[Any] = {
val ordering = base.dataType match {
val ordering = arrayExpression.dataType match {
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
Expand All @@ -191,28 +161,205 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
if (o1 == null && o2 == null) {
0
} else if (o1 == null) {
1
-nullOrder
} else if (o2 == null) {
-1
nullOrder
} else {
-ordering.compare(o1, o2)
}
}
}
}

override def nullSafeEval(array: Any, ascending: Any): Any = {
val elementType = base.dataType.asInstanceOf[ArrayType].elementType
def elementType: DataType = arrayExpression.dataType.asInstanceOf[ArrayType].elementType

def sortEval(array: Any, ascending: Boolean): Any = {
val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
if (elementType != NullType) {
java.util.Arrays.sort(data, if (ascending.asInstanceOf[Boolean]) lt else gt)
java.util.Arrays.sort(data, if (ascending) lt else gt)
}
new GenericArrayData(data.asInstanceOf[Array[Any]])
}

def sortCodegen(ctx: CodegenContext, ev: ExprCode, base: String, order: String): String = {
val arrayData = classOf[ArrayData].getName
val genericArrayData = classOf[GenericArrayData].getName
val array = ctx.freshName("array")
val c = ctx.freshName("c")
val sort = if (elementType == NullType) "" else {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about just copying the original array if elementType == NullType?

val sortOrder = ctx.freshName("sortOrder")
val o1 = ctx.freshName("o1")
val o2 = ctx.freshName("o2")
val jt = CodeGenerator.javaType(elementType)
val comp = if (CodeGenerator.isPrimitiveType(elementType)) {
val bt = CodeGenerator.boxedType(elementType)
val v1 = ctx.freshName("v1")
val v2 = ctx.freshName("v2")
s"""
|$jt $v1 = (($bt) $o1).${jt}Value();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to enforce the boxing? An why do we need to cast to the java type in the non primitive scenario?

Copy link
Member Author

@kiszk kiszk May 4, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this is because compare() in java.util.Arrays.sort accepts two Object arguments. Thus, we do boxing here.
Now, I realized java.util.Arrays.sort has sort() method only for ascending. Let me use them for ascending and non-null.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I see, thanks.

|$jt $v2 = (($bt) $o2).${jt}Value();
|int $c = ${ctx.genComp(elementType, v1, v2)};
""".stripMargin
} else {
s"int $c = ${ctx.genComp(elementType, s"(($jt) $o1)", s"(($jt) $o2)")};"
}
s"""
|final int $sortOrder = $order ? 1 : -1;
|java.util.Arrays.sort($array, new java.util.Comparator() {
| @Override public int compare(Object $o1, Object $o2) {
| if ($o1 == null && $o2 == null) {
| return 0;
| } else if ($o1 == null) {
| return $sortOrder * $nullOrder;
| } else if ($o2 == null) {
| return -$sortOrder * $nullOrder;
| }
| $comp
| return $sortOrder * $c;
| }
|});
""".stripMargin
}
val dataTypes = elementType match {
case DecimalType.Fixed(p, s) =>
s"org.apache.spark.sql.types.DataTypes.createDecimalType($p, $s)"
case ArrayType(et, cn) =>
s"org.apache.spark.sql.types.DataTypes.createArrayType($et, $cn)"
case MapType(kt, vt, cn) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need for MapType because MapType is not orderable?

s"org.apache.spark.sql.types.DataTypes.createMapType($kt, $vt, $cn)"
case StructType(f) =>
"org.apache.spark.sql.types.StructType$.MODULE$." +
s"apply(new java.util.ArrayList(${f.length}))"
case _ =>
s"org.apache.spark.sql.types.DataTypes.$elementType"
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if this will work for all complex types, e.g. ArrayType(ArrayType(IntegerType))?
How about using reference object of elementType?

s"""
|Object[] $array = (Object[]) (($arrayData) $base).toArray(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using toObjectArray which doesn't need ClassTag?

| $dataTypes, scala.reflect.ClassTag$$.MODULE$$.AnyRef());
|$sort
|${ev.value} = new $genericArrayData($array);
""".stripMargin
}

}

object ArraySortUtil {
type NullOrder = Int
// Least: place null element at the first of the array for ascending order
// Greatest: place null element at the end of the array for ascending order
object NullOrder {
val Least: NullOrder = -1
val Greatest: NullOrder = 1
}
}

/**
* Sorts the input array in ascending / descending order according to the natural ordering of
* the array elements and returns it.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order
according to the natural ordering of the array elements. Null elements will be placed
at the beginning of the returned array in ascending order or at the end of the returned
array in descending order.
""",
examples = """
Examples:
> SELECT _FUNC_(array('b', 'd', null, 'c', 'a'), true);
[null,"a","b","c","d"]
""")
// scalastyle:on line.size.limit
case class SortArray(base: Expression, ascendingOrder: Expression)
extends BinaryExpression with ArraySortUtil {

def this(e: Expression) = this(e, Literal(true))

override def left: Expression = base
override def right: Expression = ascendingOrder
override def dataType: DataType = base.dataType
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType)

override def arrayExpression: Expression = base
override def nullOrder: NullOrder = NullOrder.Least

override def checkInputDataTypes(): TypeCheckResult = base.dataType match {
case ArrayType(dt, _) if RowOrdering.isOrderable(dt) =>
ascendingOrder match {
case Literal(_: Boolean, BooleanType) =>
TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(
"Sort order in second argument requires a boolean literal.")
}
case ArrayType(dt, _) =>
val dtSimple = dt.simpleString
TypeCheckResult.TypeCheckFailure(
s"$prettyName does not support sorting array of type $dtSimple which is not orderable")
case _ =>
TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.")
}

override def nullSafeEval(array: Any, ascending: Any): Any = {
sortEval(array, ascending.asInstanceOf[Boolean])
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (b, order) => sortCodegen(ctx, ev, b, order))
}

override def prettyName: String = "sort_array"
}


/**
* Sorts the input array in ascending order according to the natural ordering of
* the array elements and returns it.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(array) - Sorts the input array in ascending order. The elements of the input array must
be orderable. Null elements will be placed at the end of the returned array.
""",
examples = """
Examples:
> SELECT _FUNC_(array('b', 'd', null, 'c', 'a'));
["a","b","c","d",null]
""",
since = "2.4.0")
// scalastyle:on line.size.limit
case class ArraySort(child: Expression) extends UnaryExpression with ArraySortUtil {

override def dataType: DataType = child.dataType
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)

override def arrayExpression: Expression = child
override def nullOrder: NullOrder = NullOrder.Greatest

override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
case ArrayType(dt, _) if RowOrdering.isOrderable(dt) =>
TypeCheckResult.TypeCheckSuccess
case ArrayType(dt, _) =>
val dtSimple = dt.simpleString
TypeCheckResult.TypeCheckFailure(
s"$prettyName does not support sorting array of type $dtSimple which is not orderable")
case _ =>
TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.")
}

override def nullSafeEval(array: Any): Any = {
sortEval(array, true)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => sortCodegen(ctx, ev, c, "true"))
}

override def prettyName: String = "array_sort"
}

/**
* Returns a reversed string or an array with reverse order of elements.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,28 +61,42 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType))
val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType))
val a4 = Literal.create(Seq(null, null), ArrayType(NullType))
val d1 = new Decimal().set(10)
val d2 = new Decimal().set(100)
val a4 = Literal.create(Seq(d2, d1), ArrayType(DecimalType(10, 0)))
val a5 = Literal.create(Seq(null, null), ArrayType(NullType))

checkEvaluation(new SortArray(a0), Seq(1, 2, 3))
checkEvaluation(new SortArray(a1), Seq[Integer]())
checkEvaluation(new SortArray(a2), Seq("a", "b"))
checkEvaluation(new SortArray(a3), Seq(null, "a", "b"))
checkEvaluation(new SortArray(a4), Seq(d1, d2))
checkEvaluation(SortArray(a0, Literal(true)), Seq(1, 2, 3))
checkEvaluation(SortArray(a1, Literal(true)), Seq[Integer]())
checkEvaluation(SortArray(a2, Literal(true)), Seq("a", "b"))
checkEvaluation(new SortArray(a3, Literal(true)), Seq(null, "a", "b"))
checkEvaluation(SortArray(a4, Literal(true)), Seq(d1, d2))
checkEvaluation(SortArray(a0, Literal(false)), Seq(3, 2, 1))
checkEvaluation(SortArray(a1, Literal(false)), Seq[Integer]())
checkEvaluation(SortArray(a2, Literal(false)), Seq("b", "a"))
checkEvaluation(new SortArray(a3, Literal(false)), Seq("b", "a", null))
checkEvaluation(SortArray(a4, Literal(false)), Seq(d2, d1))

checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
checkEvaluation(new SortArray(a4), Seq(null, null))
checkEvaluation(new SortArray(a5), Seq(null, null))

val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
val arrayStruct = Literal.create(Seq(create_row(2), create_row(1)), typeAS)

checkEvaluation(new SortArray(arrayStruct), Seq(create_row(1), create_row(2)))

checkEvaluation(ArraySort(a0), Seq(1, 2, 3))
checkEvaluation(ArraySort(a1), Seq[Integer]())
checkEvaluation(ArraySort(a2), Seq("a", "b"))
checkEvaluation(ArraySort(a3), Seq("a", "b", null))
checkEvaluation(ArraySort(a4), Seq(d1, d2))
checkEvaluation(ArraySort(a5), Seq(null, null))
checkEvaluation(ArraySort(arrayStruct), Seq(create_row(1), create_row(2)))
}

test("Array contains") {
Expand Down
Loading