Skip to content
Prev Previous commit
Next Next commit
address review comments
  • Loading branch information
kiszk committed May 1, 2018
commit 04a3ae57ae755df70c2ab2141e5b3297fe229463
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,21 @@ trait ArraySortUtil extends ExpectsInputTypes {
val genericArrayData = classOf[GenericArrayData].getName
val array = ctx.freshName("array")
val c = ctx.freshName("c")
val sort = if (elementType == NullType) "" else {
val dataTypes = elementType match {
case DecimalType.Fixed(p, s) =>
s"org.apache.spark.sql.types.DataTypes.createDecimalType($p, $s)"
case ArrayType(et, cn) =>
val dt = s"org.apache.spark.sql.types.$et$$.MODULE$$"
s"org.apache.spark.sql.types.DataTypes.createArrayType($dt, $cn)"
case StructType(f) =>
"org.apache.spark.sql.types.StructType$.MODULE$." +
s"apply(new java.util.ArrayList(${f.length}))"
case _ =>
s"org.apache.spark.sql.types.$elementType$$.MODULE$$"
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still wondering whether this will work or not. What if elementType is ArrayType(ArrayType(IntegerType))?
Can't we use a reference object of elementType?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add some tests using ArrayType and StructType for elementType?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely, I added some complex test cases with nests.

if (elementType == NullType) {
s"${ev.value} = (($arrayData) $base).copy();"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do we need cast base to ArrayData?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, done

} else {
val sortOrder = ctx.freshName("sortOrder")
val o1 = ctx.freshName("o1")
val o2 = ctx.freshName("o2")
Expand All @@ -204,6 +218,7 @@ trait ArraySortUtil extends ExpectsInputTypes {
s"int $c = ${ctx.genComp(elementType, s"(($jt) $o1)", s"(($jt) $o2)")};"
}
s"""
|Object[] $array = (Object[]) (($arrayData) $base).toObjectArray($dataTypes);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

|final int $sortOrder = $order ? 1 : -1;
|java.util.Arrays.sort($array, new java.util.Comparator() {
| @Override public int compare(Object $o1, Object $o2) {
Expand All @@ -218,27 +233,9 @@ trait ArraySortUtil extends ExpectsInputTypes {
| return $sortOrder * $c;
| }
|});
|${ev.value} = new $genericArrayData($array);
""".stripMargin
}
val dataTypes = elementType match {
case DecimalType.Fixed(p, s) =>
s"org.apache.spark.sql.types.DataTypes.createDecimalType($p, $s)"
case ArrayType(et, cn) =>
s"org.apache.spark.sql.types.DataTypes.createArrayType($et, $cn)"
case MapType(kt, vt, cn) =>
s"org.apache.spark.sql.types.DataTypes.createMapType($kt, $vt, $cn)"
case StructType(f) =>
"org.apache.spark.sql.types.StructType$.MODULE$." +
s"apply(new java.util.ArrayList(${f.length}))"
case _ =>
s"org.apache.spark.sql.types.DataTypes.$elementType"
}
s"""
|Object[] $array = (Object[]) (($arrayData) $base).toArray(
| $dataTypes, scala.reflect.ClassTag$$.MODULE$$.AnyRef());
|$sort
|${ev.value} = new $genericArrayData($array);
""".stripMargin
}

}
Expand Down