Skip to content
26 changes: 22 additions & 4 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2154,20 +2154,38 @@ def array_max(col):
def sort_array(col, asc=True):
"""
Collection function: sorts the input array in ascending or descending order according
to the natural ordering of the array elements.
to the natural ordering of the array elements. Null elements will be placed at the beginning
of the returned array in ascending order or at the end of the returned array in descending
order.

:param col: name of column or expression

>>> df = spark.createDataFrame([([2, 1, 3],),([1],),([],)], ['data'])
>>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data'])
>>> df.select(sort_array(df.data).alias('r')).collect()
[Row(r=[1, 2, 3]), Row(r=[1]), Row(r=[])]
[Row(r=[None, 1, 2, 3]), Row(r=[1]), Row(r=[])]
>>> df.select(sort_array(df.data, asc=False).alias('r')).collect()
[Row(r=[3, 2, 1]), Row(r=[1]), Row(r=[])]
[Row(r=[3, 2, 1, None]), Row(r=[1]), Row(r=[])]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc))


@since(2.4)
def array_sort(col):
"""
Collection function: sorts the input array in ascending order. The elements of the input array
must be orderable. Null elements will be placed at the end of the returned array.

:param col: name of column or expression

>>> df = spark.createDataFrame([([2, 1, None, 3],),([1],),([],)], ['data'])
>>> df.select(array_sort(df.data).alias('r')).collect()
[Row(r=[1, 2, 3, None]), Row(r=[1]), Row(r=[])]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.array_sort(_to_java_column(col)))


@since(1.5)
@ignore_unicode_prefix
def reverse(col):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ object FunctionRegistry {
expression[CreateArray]("array"),
expression[ArrayContains]("array_contains"),
expression[ArrayPosition]("array_position"),
expression[ArraySort]("array_sort"),
expression[CreateMap]("map"),
expression[CreateNamedStruct]("named_struct"),
expression[ElementAt]("element_at"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import java.util.Comparator

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -117,47 +118,16 @@ case class MapValues(child: Expression)
}

/**
* Sorts the input array in ascending / descending order according to the natural ordering of
* the array elements and returns it.
* Common base class for [[SortArray]] and [[ArraySort]].
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order according to the natural ordering of the array elements.",
examples = """
Examples:
> SELECT _FUNC_(array('b', 'd', 'c', 'a'), true);
["a","b","c","d"]
""")
// scalastyle:on line.size.limit
case class SortArray(base: Expression, ascendingOrder: Expression)
extends BinaryExpression with ExpectsInputTypes with CodegenFallback {

def this(e: Expression) = this(e, Literal(true))

override def left: Expression = base
override def right: Expression = ascendingOrder
override def dataType: DataType = base.dataType
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType)
trait ArraySortLike extends ExpectsInputTypes {
protected def arrayExpression: Expression

override def checkInputDataTypes(): TypeCheckResult = base.dataType match {
case ArrayType(dt, _) if RowOrdering.isOrderable(dt) =>
ascendingOrder match {
case Literal(_: Boolean, BooleanType) =>
TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(
"Sort order in second argument requires a boolean literal.")
}
case ArrayType(dt, _) =>
TypeCheckResult.TypeCheckFailure(
s"$prettyName does not support sorting array of type ${dt.simpleString}")
case _ =>
TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.")
}
protected def nullOrder: NullOrder

@transient
private lazy val lt: Comparator[Any] = {
val ordering = base.dataType match {
val ordering = arrayExpression.dataType match {
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
Expand All @@ -168,9 +138,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
if (o1 == null && o2 == null) {
0
} else if (o1 == null) {
-1
nullOrder
} else if (o2 == null) {
1
-nullOrder
} else {
ordering.compare(o1, o2)
}
Expand All @@ -180,7 +150,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression)

@transient
private lazy val gt: Comparator[Any] = {
val ordering = base.dataType match {
val ordering = arrayExpression.dataType match {
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
Expand All @@ -191,28 +161,210 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
if (o1 == null && o2 == null) {
0
} else if (o1 == null) {
1
-nullOrder
} else if (o2 == null) {
-1
nullOrder
} else {
-ordering.compare(o1, o2)
}
}
}
}

override def nullSafeEval(array: Any, ascending: Any): Any = {
val elementType = base.dataType.asInstanceOf[ArrayType].elementType
def elementType: DataType = arrayExpression.dataType.asInstanceOf[ArrayType].elementType
def containsNull: Boolean = arrayExpression.dataType.asInstanceOf[ArrayType].containsNull

def sortEval(array: Any, ascending: Boolean): Any = {
val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
if (elementType != NullType) {
java.util.Arrays.sort(data, if (ascending.asInstanceOf[Boolean]) lt else gt)
java.util.Arrays.sort(data, if (ascending) lt else gt)
}
new GenericArrayData(data.asInstanceOf[Array[Any]])
}

def sortCodegen(ctx: CodegenContext, ev: ExprCode, base: String, order: String): String = {
val arrayData = classOf[ArrayData].getName
val genericArrayData = classOf[GenericArrayData].getName
val unsafeArrayData = classOf[UnsafeArrayData].getName
val array = ctx.freshName("array")
val c = ctx.freshName("c")
if (elementType == NullType) {
s"${ev.value} = $base.copy();"
} else {
val elementTypeTerm = ctx.addReferenceObj("elementTypeTerm", elementType)
val sortOrder = ctx.freshName("sortOrder")
val o1 = ctx.freshName("o1")
val o2 = ctx.freshName("o2")
val jt = CodeGenerator.javaType(elementType)
val comp = if (CodeGenerator.isPrimitiveType(elementType)) {
val bt = CodeGenerator.boxedType(elementType)
val v1 = ctx.freshName("v1")
val v2 = ctx.freshName("v2")
s"""
|$jt $v1 = (($bt) $o1).${jt}Value();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to enforce the boxing? An why do we need to cast to the java type in the non primitive scenario?

Copy link
Member Author

@kiszk kiszk May 4, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this is because compare() in java.util.Arrays.sort accepts two Object arguments. Thus, we do boxing here.
Now, I realized java.util.Arrays.sort has sort() method only for ascending. Let me use them for ascending and non-null.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I see, thanks.

|$jt $v2 = (($bt) $o2).${jt}Value();
|int $c = ${ctx.genComp(elementType, v1, v2)};
""".stripMargin
} else {
s"int $c = ${ctx.genComp(elementType, s"(($jt) $o1)", s"(($jt) $o2)")};"
}
val nonNullPrimitiveAscendingSort =
if (CodeGenerator.isPrimitiveType(elementType) && !containsNull) {
val javaType = CodeGenerator.javaType(elementType)
val primitiveTypeName = CodeGenerator.primitiveTypeName(elementType)
s"""
|if ($order) {
| $javaType[] $array = $base.to${primitiveTypeName}Array();
| java.util.Arrays.sort($array);
| ${ev.value} = $unsafeArrayData.fromPrimitiveArray($array);
|} else
""".stripMargin
} else {
""
}
s"""
|$nonNullPrimitiveAscendingSort
|{
| Object[] $array = $base.toObjectArray($elementTypeTerm);
| final int $sortOrder = $order ? 1 : -1;
| java.util.Arrays.sort($array, new java.util.Comparator() {
| @Override public int compare(Object $o1, Object $o2) {
| if ($o1 == null && $o2 == null) {
| return 0;
| } else if ($o1 == null) {
| return $sortOrder * $nullOrder;
| } else if ($o2 == null) {
| return -$sortOrder * $nullOrder;
| }
| $comp
| return $sortOrder * $c;
| }
| });
| ${ev.value} = new $genericArrayData($array);
|}
""".stripMargin
}
}

}

object ArraySortLike {
type NullOrder = Int
// Least: place null element at the first of the array for ascending order
// Greatest: place null element at the end of the array for ascending order
object NullOrder {
val Least: NullOrder = -1
val Greatest: NullOrder = 1
}
}

/**
* Sorts the input array in ascending / descending order according to the natural ordering of
* the array elements and returns it.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order
according to the natural ordering of the array elements. Null elements will be placed
at the beginning of the returned array in ascending order or at the end of the returned
array in descending order.
""",
examples = """
Examples:
> SELECT _FUNC_(array('b', 'd', null, 'c', 'a'), true);
[null,"a","b","c","d"]
""")
// scalastyle:on line.size.limit
case class SortArray(base: Expression, ascendingOrder: Expression)
extends BinaryExpression with ArraySortLike {

def this(e: Expression) = this(e, Literal(true))

override def left: Expression = base
override def right: Expression = ascendingOrder
override def dataType: DataType = base.dataType
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType)

override def arrayExpression: Expression = base
override def nullOrder: NullOrder = NullOrder.Least

override def checkInputDataTypes(): TypeCheckResult = base.dataType match {
case ArrayType(dt, _) if RowOrdering.isOrderable(dt) =>
ascendingOrder match {
case Literal(_: Boolean, BooleanType) =>
TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(
"Sort order in second argument requires a boolean literal.")
}
case ArrayType(dt, _) =>
val dtSimple = dt.simpleString
TypeCheckResult.TypeCheckFailure(
s"$prettyName does not support sorting array of type $dtSimple which is not orderable")
case _ =>
TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.")
}

override def nullSafeEval(array: Any, ascending: Any): Any = {
sortEval(array, ascending.asInstanceOf[Boolean])
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (b, order) => sortCodegen(ctx, ev, b, order))
}

override def prettyName: String = "sort_array"
}


/**
* Sorts the input array in ascending order according to the natural ordering of
* the array elements and returns it.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(array) - Sorts the input array in ascending order. The elements of the input array must
be orderable. Null elements will be placed at the end of the returned array.
""",
examples = """
Examples:
> SELECT _FUNC_(array('b', 'd', null, 'c', 'a'));
["a","b","c","d",null]
""",
since = "2.4.0")
// scalastyle:on line.size.limit
case class ArraySort(child: Expression) extends UnaryExpression with ArraySortLike {

override def dataType: DataType = child.dataType
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)

override def arrayExpression: Expression = child
override def nullOrder: NullOrder = NullOrder.Greatest

override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
case ArrayType(dt, _) if RowOrdering.isOrderable(dt) =>
TypeCheckResult.TypeCheckSuccess
case ArrayType(dt, _) =>
val dtSimple = dt.simpleString
TypeCheckResult.TypeCheckFailure(
s"$prettyName does not support sorting array of type $dtSimple which is not orderable")
case _ =>
TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.")
}

override def nullSafeEval(array: Any): Any = {
sortEval(array, true)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => sortCodegen(ctx, ev, c, "true"))
}

override def prettyName: String = "array_sort"
}

/**
* Returns a reversed string or an array with reverse order of elements.
*/
Expand Down
Loading