Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ object FunctionRegistry {
expression[Rand]("rand"),
expression[Randn]("randn"),
expression[Stack]("stack"),
expression[ReplicateRows]("replicate_rows"),
expression[CaseWhen]("when"),

// math functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ object TypeCoercion {
new ImplicitTypeCasts(conf) ::
DateTimeOperations ::
WindowFrameCoercion ::
ReplicateRowsCoercion ::
Nil

// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
Expand Down Expand Up @@ -702,6 +703,21 @@ object TypeCoercion {
}
}

/**
* Coerces first argument in ReplicateRows expression and introduces a cast to Long
* if necessary.
*/
object ReplicateRowsCoercion extends TypeCoercionRule {
private val acceptedTypes = Seq(LongType, IntegerType, ShortType, ByteType)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: LongType seems not necessary be here. Can avoid re-entering the following pattern matching if it is already long type.

override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case s @ ReplicateRows(children)
if s.childrenResolved && acceptedTypes.contains(s.children.head.dataType) =>
val numRowExpr = s.children.head
val castedExpr = ImplicitTypeCasts.implicitCast(numRowExpr, LongType).getOrElse(numRowExpr)
ReplicateRows(Seq(castedExpr) ++ s.children.tail)
}
}

/**
* Coerces the types of [[Concat]] children to expected ones.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not need to introduce this breaking line.

/**
* An expression that produces zero or more rows given a single input row.
*
Expand Down Expand Up @@ -222,6 +223,51 @@ case class Stack(children: Seq[Expression]) extends Generator {
}
}

/**
* Replicate the row based N times. N is specified as the first argument to the function.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Replicate N times the row.?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, using n to match following expression description?

* {{{
* SELECT replicate_rows(2, "val1", "val2") ->
* 2 val1 val2
* 2 val1 val2
* }}}
*/
@ExpressionDescription(
usage = "_FUNC_(n, expr1, ..., exprk) - Replicates `expr1`, ..., `exprk` into `n` rows.",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replicates `n`, `expr1`, ..., `exprk` into `n` rows.?

examples = """
Examples:
> SELECT _FUNC_(2, "val1", "val2");
2 val1 val2
2 val1 val2
""")
case class ReplicateRows(children: Seq[Expression]) extends Generator with CodegenFallback {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be easily implemented in codegen so we don't need CodegenFallback. We can deal with it in follow-up if you want.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@viirya If you don't mind, i would like to do it in a follow-up.

override def checkInputDataTypes(): TypeCheckResult = {
if (children.length < 2) {
TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least 2 arguments.")
} else if (children.head.dataType != LongType) {
TypeCheckResult.TypeCheckFailure("The number of rows must be a positive long value.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about this message? The first argument type must be byte, short, int, or long, but ${children.head.dataType} found. BTW, it seems we don't reject negative values? (The current message says the number must be positive though...?)

} else {
TypeCheckResult.TypeCheckSuccess
}
}

override def elementSchema: StructType =
StructType(children.zipWithIndex.map {
case (e, index) => StructField(s"col$index", e.dataType)
})

override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
val numRows = children.head.eval(input).asInstanceOf[Long]
val values = children.map(_.eval(input)).toArray
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

children.head seems getting evaluated twice here, can we avoid it?

Range.Long(0, numRows, 1).map { i =>
val fields = new Array[Any](children.length)
for (col <- 0 until children.length) {
fields.update(col, values(col))
}
InternalRow(fields: _*)
}
}
}

/**
* Wrapper around another generator to specify outer behavior. This is used to implement functions
* such as explode_outer. This expression gets replaced during analysis.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1353,6 +1353,31 @@ class TypeCoercionSuite extends AnalysisTest {
SpecifiedWindowFrame(RangeFrame, CurrentRow, UnboundedFollowing))
)
}

test("type coercion for ReplicateRows") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this tests into sql-tests/inputs/typeCoercion/native?

val rule = TypeCoercion.ReplicateRowsCoercion
// Cast is setup to promote the first expression to Long
// for numeric types.
ruleTest(rule,
ReplicateRows(Seq(1.toShort, Literal("rowdata"))),
ReplicateRows(Seq(Cast(1.toShort, LongType), Literal("rowdata"))))
ruleTest(rule,
ReplicateRows(Seq(1, Literal("rowdata"))),
ReplicateRows(Seq(Cast(1, LongType), Literal("rowdata"))))
ruleTest(rule,
ReplicateRows(Seq(1.toByte, Literal("rowdata"))),
ReplicateRows(Seq(Cast(1.toByte, LongType), Literal("rowdata"))))

// No cast here since the expected type is Long.
ruleTest(rule,
ReplicateRows(Seq(1L, Literal("rowdata"))),
ReplicateRows(Seq(1L, Literal("rowdata"))))

// No type coercion when first expression is a non numeric type.
ruleTest(rule,
ReplicateRows(Seq(Literal("invalid"), Literal("rowdata"))),
ReplicateRows(Seq(Literal("invalid"), Literal("rowdata"))))
}
}


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES
(1, 'row1', 1.1),
(2, 'row2', 2.2),
(0, 'row3', 3.3),
(-1,'row4', 4.4),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current behaviour of the negative value case is the same with the hive one?

(null,'row5', 5.5),
(3, 'row6', null)
AS tab1(c1, c2, c3);

-- c1, c2 replicated c1 times
SELECT replicate_rows(c1, c2) FROM tab1;

-- c1, c2, c2 repeated replicated c1 times
SELECT replicate_rows(c1, c2, c2) FROM tab1;

-- c1, c2, c2, c3 replicated c1 times
SELECT replicate_rows(c1, c2, c2, c2, c3) FROM tab1;

-- Used as a derived table in FROM clause.
SELECT c2, c1
FROM (
SELECT replicate_rows(c1, c2) AS (c1, c2) FROM tab1
);

-- column expression.
SELECT replicate_rows(c1, concat(c2, '...'), c2) FROM tab1;

-- Clean-up
DROP VIEW IF EXISTS tab1;
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 7


-- !query 0
CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES
(1, 'row1', 1.1),
(2, 'row2', 2.2),
(0, 'row3', 3.3),
(-1,'row4', 4.4),
(null,'row5', 5.5),
(3, 'row6', null)
AS tab1(c1, c2, c3)
-- !query 0 schema
struct<>
-- !query 0 output



-- !query 1
SELECT replicate_rows(c1, c2) FROM tab1
-- !query 1 schema
struct<col0:bigint,col1:string>
-- !query 1 output
1 row1
2 row2
2 row2
3 row6
3 row6
3 row6


-- !query 2
SELECT replicate_rows(c1, c2, c2) FROM tab1
-- !query 2 schema
struct<col0:bigint,col1:string,col2:string>
-- !query 2 output
1 row1 row1
2 row2 row2
2 row2 row2
3 row6 row6
3 row6 row6
3 row6 row6


-- !query 3
SELECT replicate_rows(c1, c2, c2, c2, c3) FROM tab1
-- !query 3 schema
struct<col0:bigint,col1:string,col2:string,col3:string,col4:decimal(2,1)>
-- !query 3 output
1 row1 row1 row1 1.1
2 row2 row2 row2 2.2
2 row2 row2 row2 2.2
3 row6 row6 row6 NULL
3 row6 row6 row6 NULL
3 row6 row6 row6 NULL


-- !query 4
SELECT c2, c1
FROM (
SELECT replicate_rows(c1, c2) AS (c1, c2) FROM tab1
)
-- !query 4 schema
struct<c2:string,c1:bigint>
-- !query 4 output
row1 1
row2 2
row2 2
row6 3
row6 3
row6 3


-- !query 5
SELECT replicate_rows(c1, concat(c2, '...'), c2) FROM tab1
-- !query 5 schema
struct<col0:bigint,col1:string,col2:string>
-- !query 5 output
1 row1... row1
2 row2... row2
2 row2... row2
3 row6... row6
3 row6... row6
3 row6... row6


-- !query 6
DROP VIEW IF EXISTS tab1
-- !query 6 schema
struct<>
-- !query 6 output

Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,37 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
sql("select * from values 1, 2 lateral view outer empty_gen() a as b"),
Row(1, null) :: Row(2, null) :: Nil)
}

test("ReplicateRows generator") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicate tests? I feel udtf_replicate_rows.sql is enough for tests.

val df = spark.range(1)

// Empty DataFrame suppress the result generation
checkAnswer(spark.emptyDataFrame.selectExpr("replicate_rows(1, 1, 2, 3)"), Nil)

checkAnswer(df.selectExpr("replicate_rows(1, 2.5)"), Row(1, 2.5) :: Nil)
checkAnswer(df.selectExpr("replicate_rows(1, null)"), Row(1, null) :: Nil)
checkAnswer(df.selectExpr("replicate_rows(3, 'row1')"),
Row(3, "row1") :: Row(3, "row1") :: Row(3, "row1") :: Nil)
checkAnswer(df.selectExpr("replicate_rows(-1, 2.5)"), Nil)

// The data for the same column should have the same type.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This copied comment can be removed.

val msg1 = intercept[AnalysisException] {
df.selectExpr("replicate_rows(1)")
}.getMessage
assert(msg1.contains("requires at least 2 arguments"))

// The data for the same column should have the same type.
val msg2 = intercept[AnalysisException] {
df.selectExpr("replicate_rows('a', 1)")
}.getMessage
assert(msg2.contains("The number of rows must be a positive long value."))

val msg3 = intercept[AnalysisException] {
df.selectExpr("replicate_rows(null, 1)")
}.getMessage
assert(msg3.contains("The number of rows must be a positive long value."))

}
}

case class EmptyGenerator() extends Generator {
Expand Down