Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[SPARK-23926][SQL] Adding more tests + fixing a bug in codegen.
  • Loading branch information
mn-mikke authored and mn-mikke committed Apr 11, 2018
commit 28ed66493fb795228d4c2b2cc445d77f6d95d161
Original file line number Diff line number Diff line change
Expand Up @@ -225,28 +225,18 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
> SELECT _FUNC_(array(2, 1, 4, 3));
[3, 4, 1, 2]
""",
since = "2.4.0")
since = "1.5.0",
note = "Reverse logic for arrays is available since 2.4.0."
)
case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {

// Input types are utilized by type coercion in ImplicitTypeCasts.
override def inputTypes: Seq[AbstractDataType] = Seq(StringType)

val allowedTypes = Seq(StringType, ArrayType)
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType))

override def dataType: DataType = child.dataType

lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType

override def checkInputDataTypes(): TypeCheckResult = {
if (allowedTypes.exists(_.acceptsType(child.dataType))) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(
s"The argument of function $prettyName should be StringType or ArrayType," +
s" but it's " + child.dataType.simpleString)
}
}

override def nullSafeEval(input: Any): Any = input match {
case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse)
case s: UTF8String => s.reverse()
Expand All @@ -266,10 +256,19 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI
private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = {
val length = ctx.freshName("length")
val javaElementType = CodeGenerator.javaType(elementType)
val getCall = (index: String) => CodeGenerator.getValue(ev.value, elementType, index)
val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType)

val initialization = if (isPrimitiveType) {
s"$childName.copy()"
} else {
s"new ${classOf[GenericArrayData].getName()}(new Object[$length])"
}

val numberOfIterations = if (isPrimitiveType) s"$length / 2" else length

val swapAssigments = if (CodeGenerator.isPrimitiveType(elementType)) {
val swapAssigments = if (isPrimitiveType) {
val setFunc = "set" + CodeGenerator.primitiveTypeName(elementType)
val getCall = (index: String) => CodeGenerator.getValue(ev.value, elementType, index)
s"""|boolean isNullAtK = ${ev.value}.isNullAt(k);
|boolean isNullAtL = ${ev.value}.isNullAt(l);
|if(!isNullAtK) {
Expand All @@ -285,19 +284,17 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI
| ${ev.value}.setNullAt(l);
|}""".stripMargin
} else {
s"""|Object el = ${getCall("k")};
|${ev.value}.update(k, ${getCall("l")});
|${ev.value}.update(l, el);""".stripMargin
s"${ev.value}.update(k, ${CodeGenerator.getValue(childName, elementType, "l")});"
}

s"""
|${ev.value} = $childName.copy();
|final int $length = ${ev.value}.numElements();
|for(int k = 0; k < $length / 2; k++) {
| int l = $length - k - 1;
| $swapAssigments
|}
""".stripMargin
|final int $length = $childName.numElements();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: more spaces

|${ev.value} = $initialization;
|for(int k = 0; k < $numberOfIterations; k++) {
| int l = $length - k - 1;
| $swapAssigments
|}
""".stripMargin
}

override def prettyName: String = "reverse"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
}

test("reverse function") {
val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codegen on

// String test cases
val oneRowDF = Seq(("Spark", 3215)).toDF("s", "i")

Expand All @@ -438,7 +440,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
Seq(Row(null))
)


// Array test cases (primitive-type elements)
val idf = Seq(
Seq(1, 9, 8, 7),
Expand All @@ -451,6 +452,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
idf.select(reverse('i)),
Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
)
checkAnswer(
idf.filter(dummyFilter('i)).select(reverse('i)),
Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
)
checkAnswer(
idf.selectExpr("reverse(i)"),
Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
Expand All @@ -459,6 +464,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
oneRowDF.selectExpr("reverse(array(1, null, 2, null))"),
Seq(Row(Seq(null, 2, null, 1)))
)
checkAnswer(
oneRowDF.filter(dummyFilter('i)).selectExpr("reverse(array(1, null, 2, null))"),
Seq(Row(Seq(null, 2, null, 1)))
)

// Array test cases (complex-type elements)
val sdf = Seq(
Expand All @@ -472,6 +481,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
sdf.select(reverse('s)),
Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null))
)
checkAnswer(
sdf.filter(dummyFilter('s)).select(reverse('s)),
Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null))
)
checkAnswer(
sdf.selectExpr("reverse(s)"),
Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null))
Expand All @@ -480,6 +493,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
oneRowDF.selectExpr("reverse(array(array(1, 2), array(3, 4)))"),
Seq(Row(Seq(Seq(3, 4), Seq(1, 2))))
)
checkAnswer(
oneRowDF.filter(dummyFilter('s)).selectExpr("reverse(array(array(1, 2), array(3, 4)))"),
Seq(Row(Seq(Seq(3, 4), Seq(1, 2))))
)

// Error test cases
intercept[AnalysisException] {
Expand Down