Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ object FunctionRegistry {
expression[Explode]("explode"),
expression[Greatest]("greatest"),
expression[If]("if"),
expression[Inline]("inline"),
expression[IsNaN]("isnan"),
expression[IfNull]("ifnull"),
expression[IsNull]("isnull"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -195,3 +195,42 @@ case class Explode(child: Expression) extends ExplodeBase(child, position = fals
extended = "> SELECT _FUNC_(array(10,20));\n 0\t10\n 1\t20")
// scalastyle:on line.size.limit
case class PosExplode(child: Expression) extends ExplodeBase(child, position = true)

/**
* Explodes an array of structs into a table.
*/
@ExpressionDescription(
usage = "_FUNC_(a) - Explodes an array of structs into a table.",
extended = "> SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b')));\n [1,a]\n[2,b]")
case class Inline(child: Expression) extends UnaryExpression with Generator with CodegenFallback {

override def children: Seq[Expression] = child :: Nil

override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
case ArrayType(et, _) if et.isInstanceOf[StructType] =>
TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(
s"input to function inline should be array of struct type, not ${child.dataType}")
}

override def elementSchema: StructType = child.dataType match {
case ArrayType(et : StructType, _) =>
StructType(et.fields.zipWithIndex.map {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, so it's just et now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. Currently, our type checker ensures that homogeneous StructType array.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not return et directly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, my god. I was too naive, here.
Thank you!

case (field, index) => StructField(field.name, field.dataType, nullable = field.nullable)
})
}

private lazy val ncol = elementSchema.fields.length
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to name it numFields


override def eval(input: InternalRow): TraversableOnce[InternalRow] = child.dataType match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we pattern match here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, really it's useless. It was for NullType before. I'll remove this.

case ArrayType(et : StructType, _) =>
val inputArray = child.eval(input).asInstanceOf[ArrayData]
if (inputArray == null) {
Nil
} else {
for (i <- 0 until inputArray.numElements())
yield inputArray.getStruct(i, ncol)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class GeneratorExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
private def checkTuple(actual: ExplodeBase, expected: Seq[InternalRow]): Unit = {
assert(actual.eval(null).toSeq === expected)
private def checkTuple(actual: Expression, expected: Seq[InternalRow]): Unit = {
assert(actual.eval(null).asInstanceOf[TraversableOnce[InternalRow]].toSeq === expected)
}

private final val int_array = Seq(1, 2, 3)
Expand Down Expand Up @@ -68,4 +69,23 @@ class GeneratorExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
PosExplode(CreateArray(str_array.map(Literal(_)))),
str_correct_answer.map(InternalRow.fromSeq(_)))
}

test("inline") {
val correct_answer = Seq(
Seq(0, UTF8String.fromString("a")),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can create a row directly in test: call create_row(...)

Copy link
Contributor

@cloud-fan cloud-fan Jul 1, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and it can help us convert string to UTF8String, then we don't need to do it manually

Seq(1, UTF8String.fromString("b")),
Seq(2, UTF8String.fromString("c")))

checkTuple(
Inline(Literal.create(Array(), ArrayType(StructType(Seq(StructField("id1", LongType)))))),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we usually use new StructType().add("id", LongType) to create struct type

Seq.empty)

checkTuple(
Inline(CreateArray(Seq(
CreateStruct(Seq(Literal(0), Literal("a"))),
CreateStruct(Seq(Literal(1), Literal("b"))),
CreateStruct(Seq(Literal(2), Literal("c")))
))),
correct_answer.map(InternalRow.fromSeq(_)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType}

class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
import testImplicits._
Expand Down Expand Up @@ -89,4 +91,30 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")),
Row(3) :: Nil)
}

test("inline with empty table or empty array") {
Copy link
Contributor

@cloud-fan cloud-fan Jul 1, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the test name is misleading: we do allow empty array, the problem is array() returns an array of null type, which fails the type check.

checkAnswer(
spark.range(0).selectExpr("inline(array(struct(10, 100)))"),
Nil)

val m = intercept[AnalysisException] {
spark.range(2).selectExpr("inline(array())")
}.getMessage
assert(m.contains("data type mismatch"))
}

test("inline on literal") {
checkAnswer(
spark.range(2).selectExpr("inline(array(struct(10, 100), struct(20, 200), struct(30, 300)))"),
Row(10, 100) :: Row(20, 200) :: Row(30, 300) ::
Row(10, 100) :: Row(20, 200) :: Row(30, 300) :: Nil)
}

test("inline on column") {
val df = Seq((1, 2)).toDF("a", "b")

checkAnswer(
df.selectExpr("inline(array(struct(a, b), struct(a, b)))"),
Row(1, 2) :: Row(1, 2) :: Nil)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,6 @@ private[sql] class HiveSessionCatalog(
"map_keys", "map_values",
"parse_url", "percentile", "percentile_approx", "reflect", "sentences", "stack", "str_to_map",
"xpath", "xpath_double", "xpath_float", "xpath_int", "xpath_long",
"xpath_number", "xpath_short", "xpath_string",

// table generating function
"inline"
"xpath_number", "xpath_short", "xpath_string"
)
}