Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ object FunctionRegistry {
expression[Rand]("rand"),
expression[Randn]("randn"),
expression[Stack]("stack"),
expression[ReplicateRows]("replicate_rows"),
expression[CaseWhen]("when"),

// math functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ object TypeCoercion {
new ImplicitTypeCasts(conf) ::
DateTimeOperations ::
WindowFrameCoercion ::
ReplicateRowsCoercion ::
Nil

// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
Expand Down Expand Up @@ -702,6 +703,20 @@ object TypeCoercion {
}
}

/**
* Coerces first argument in ReplicateRows expression and introduces a cast to Long
* if necessary.
*/
object ReplicateRowsCoercion extends TypeCoercionRule {
private val acceptedTypes = Seq(IntegerType, ShortType, ByteType)
override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case s @ ReplicateRows(children) if s.children.nonEmpty && s.childrenResolved &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

children is not used. How about this?

      case s @ ReplicateRows(children) if children.nonEmpty && s.childrenResolved &&
          acceptedTypes.contains(children.head.dataType) =>
        ReplicateRows(Cast(children.head, LongType) +: children.tail)

acceptedTypes.contains(s.children.head.dataType) =>
val castedExpr = Cast(s.children.head, LongType)
ReplicateRows(Seq(castedExpr) ++ s.children.tail)
}
}

/**
* Coerces the types of [[Concat]] children to expected ones.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,54 @@ case class Stack(children: Seq[Expression]) extends Generator {
}
}

/**
* Replicate the row N times. N is specified as the first argument to the function.
* {{{
* SELECT replicate_rows(2, "val1", "val2") ->
* 2 val1 val2
* 2 val1 val2
* }}}
*/
@ExpressionDescription(
usage = "_FUNC_(n, expr1, ..., exprk) - Replicates `n`, `expr1`, ..., `exprk` into `n` rows.",
Copy link
Member

@viirya viirya May 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the design doc for INTERSECT ALL and EXCEPT ALL. Looks like the n is always stripped and useless after Generate operation. So why we need to keep n in ReplicateRows outputs? Can we do it like:

> SELECT _FUNC_(2, "val1", "val2");
  val1  val2
  val1  val2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@viirya I did think about it Simon. But then, i decided to match the output with Hive.

examples = """
Examples:
> SELECT _FUNC_(2, "val1", "val2");
2 val1 val2
2 val1 val2
""")
case class ReplicateRows(children: Seq[Expression]) extends Generator with CodegenFallback {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be easily implemented in codegen so we don't need CodegenFallback. We can deal with it in follow-up if you want.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@viirya If you don't mind, i would like to do it in a follow-up.

private lazy val numColumns = children.length

override def checkInputDataTypes(): TypeCheckResult = {
if (numColumns < 2) {
TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least 2 arguments.")
} else if (children.head.dataType != LongType) {
TypeCheckResult.TypeCheckFailure("The number of rows must be a positive long value.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about this message? The first argument type must be byte, short, int, or long, but ${children.head.dataType} found. BTW, it seems we don't reject negative values? (The current message says the number must be positive though...?)

} else {
TypeCheckResult.TypeCheckSuccess
}
}

override def elementSchema: StructType =
StructType(children.zipWithIndex.map {
case (e, index) => StructField(s"col$index", e.dataType)
})

override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
val numRows = children.head.eval(input).asInstanceOf[Long]
val values = children.tail.map(_.eval(input)).toArray
Range.Long(0, numRows, 1).map { i =>
val fields = new Array[Any](numColumns)
fields.update(0, numRows)
for (col <- 1 until numColumns) {
fields.update(col, values(col - 1))
}
InternalRow(fields: _*)
}
}
}

/**
* Wrapper around another generator to specify outer behavior. This is used to implement functions
* such as explode_outer. This expression gets replaced during analysis.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
SELECT replicate_rows(CAST(1 AS BYTE), 1);

SELECT replicate_rows(CAST(1 AS INT), 1);

SELECT replicate_rows(CAST(1 AS LONG), 1);

SELECT replicate_rows(CAST(1 AS SHORT), 1);

SELECT replicate_rows("abcd", 1);
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES
(1, 'row1', 1.1),
(2, 'row2', 2.2),
(0, 'row3', 3.3),
(-1,'row4', 4.4),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current behaviour of the negative value case is the same with the hive one?

(null,'row5', 5.5),
(3, 'row6', null)
AS tab1(c1, c2, c3);

-- Requires 2 arguments at minimum.
SELECT replicate_rows() FROM tab1;

-- Requires 2 arguments at minimum.
SELECT replicate_rows(c1) FROM tab1;

-- First argument should be a numeric type.
Copy link
Member

@maropu maropu May 8, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think numeric generally includes float and double, too. integral type?

SELECT replicate_rows("abcd", c2) FROM tab1;

-- untyped null first argument
SELECT replicate_rows(null, c2) FROM tab1;

-- c1, c2 replicated c1 times
SELECT replicate_rows(c1, c2) FROM tab1;

-- c1, c2, c2 repeated replicated c1 times
SELECT replicate_rows(c1, c2, c2) FROM tab1;

-- c1, c2, c2, c3 replicated c1 times
SELECT replicate_rows(c1, c2, c2, c2, c3) FROM tab1;

-- Used as a derived table in FROM clause.
SELECT c2, c1
FROM (
SELECT replicate_rows(c1, c2) AS (c1, c2) FROM tab1
);

-- column expression.
SELECT replicate_rows(c1, concat(c2, '...'), c2) FROM tab1;

-- Clean-up
DROP VIEW IF EXISTS tab1;
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 5


-- !query 0
SELECT replicate_rows(CAST(1 AS BYTE), 1)
-- !query 0 schema
struct<col0:bigint,col1:int>
-- !query 0 output
1 1


-- !query 1
SELECT replicate_rows(CAST(1 AS INT), 1)
-- !query 1 schema
struct<col0:bigint,col1:int>
-- !query 1 output
1 1


-- !query 2
SELECT replicate_rows(CAST(1 AS LONG), 1)
-- !query 2 schema
struct<col0:bigint,col1:int>
-- !query 2 output
1 1


-- !query 3
SELECT replicate_rows(CAST(1 AS SHORT), 1)
-- !query 3 schema
struct<col0:bigint,col1:int>
-- !query 3 output
1 1


-- !query 4
SELECT replicate_rows("abcd", 1)
-- !query 4 schema
struct<>
-- !query 4 output
org.apache.spark.sql.AnalysisException
cannot resolve 'replicaterows('abcd', 1)' due to data type mismatch: The number of rows must be a positive long value.; line 1 pos 7
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 11


-- !query 0
CREATE TEMPORARY VIEW tab1 AS SELECT * FROM VALUES
(1, 'row1', 1.1),
(2, 'row2', 2.2),
(0, 'row3', 3.3),
(-1,'row4', 4.4),
(null,'row5', 5.5),
(3, 'row6', null)
AS tab1(c1, c2, c3)
-- !query 0 schema
struct<>
-- !query 0 output



-- !query 1
SELECT replicate_rows() FROM tab1
-- !query 1 schema
struct<>
-- !query 1 output
org.apache.spark.sql.AnalysisException
cannot resolve 'replicaterows()' due to data type mismatch: replicaterows requires at least 2 arguments.; line 1 pos 7


-- !query 2
SELECT replicate_rows(c1) FROM tab1
-- !query 2 schema
struct<>
-- !query 2 output
org.apache.spark.sql.AnalysisException
cannot resolve 'replicaterows(CAST(tab1.`c1` AS BIGINT))' due to data type mismatch: replicaterows requires at least 2 arguments.; line 1 pos 7


-- !query 3
SELECT replicate_rows("abcd", c2) FROM tab1
-- !query 3 schema
struct<>
-- !query 3 output
org.apache.spark.sql.AnalysisException
cannot resolve 'replicaterows('abcd', tab1.`c2`)' due to data type mismatch: The number of rows must be a positive long value.; line 1 pos 7


-- !query 4
SELECT replicate_rows(null, c2) FROM tab1
-- !query 4 schema
struct<>
-- !query 4 output
org.apache.spark.sql.AnalysisException
cannot resolve 'replicaterows(NULL, tab1.`c2`)' due to data type mismatch: The number of rows must be a positive long value.; line 1 pos 7


-- !query 5
SELECT replicate_rows(c1, c2) FROM tab1
-- !query 5 schema
struct<col0:bigint,col1:string>
-- !query 5 output
1 row1
2 row2
2 row2
3 row6
3 row6
3 row6


-- !query 6
SELECT replicate_rows(c1, c2, c2) FROM tab1
-- !query 6 schema
struct<col0:bigint,col1:string,col2:string>
-- !query 6 output
1 row1 row1
2 row2 row2
2 row2 row2
3 row6 row6
3 row6 row6
3 row6 row6


-- !query 7
SELECT replicate_rows(c1, c2, c2, c2, c3) FROM tab1
-- !query 7 schema
struct<col0:bigint,col1:string,col2:string,col3:string,col4:decimal(2,1)>
-- !query 7 output
1 row1 row1 row1 1.1
2 row2 row2 row2 2.2
2 row2 row2 row2 2.2
3 row6 row6 row6 NULL
3 row6 row6 row6 NULL
3 row6 row6 row6 NULL


-- !query 8
SELECT c2, c1
FROM (
SELECT replicate_rows(c1, c2) AS (c1, c2) FROM tab1
)
-- !query 8 schema
struct<c2:string,c1:bigint>
-- !query 8 output
row1 1
row2 2
row2 2
row6 3
row6 3
row6 3


-- !query 9
SELECT replicate_rows(c1, concat(c2, '...'), c2) FROM tab1
-- !query 9 schema
struct<col0:bigint,col1:string,col2:string>
-- !query 9 output
1 row1... row1
2 row2... row2
2 row2... row2
3 row6... row6
3 row6... row6
3 row6... row6


-- !query 10
DROP VIEW IF EXISTS tab1
-- !query 10 schema
struct<>
-- !query 10 output