diff --git a/docs/additional-functionality/advanced_configs.md b/docs/additional-functionality/advanced_configs.md
index 298fbbfae61..313ccb4b466 100644
--- a/docs/additional-functionality/advanced_configs.md
+++ b/docs/additional-functionality/advanced_configs.md
@@ -282,6 +282,8 @@ Name | SQL Function(s) | Description | Default Value | Notes
spark.rapids.sql.expression.MapKeys|`map_keys`|Returns an unordered array containing the keys of the map|true|None|
spark.rapids.sql.expression.MapValues|`map_values`|Returns an unordered array containing the values of the map|true|None|
spark.rapids.sql.expression.Md5|`md5`|MD5 hash operator|true|None|
+spark.rapids.sql.expression.MicrosToTimestamp|`timestamp_micros`|Converts the number of microseconds from unix epoch to a timestamp|true|None|
+spark.rapids.sql.expression.MillisToTimestamp|`timestamp_millis`|Converts the number of milliseconds from unix epoch to a timestamp|true|None|
spark.rapids.sql.expression.Minute|`minute`|Returns the minute component of the string/timestamp|true|None|
spark.rapids.sql.expression.MonotonicallyIncreasingID|`monotonically_increasing_id`|Returns monotonically increasing 64-bit integers|true|None|
spark.rapids.sql.expression.Month|`month`|Returns the month from a date or timestamp|true|None|
@@ -316,6 +318,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
spark.rapids.sql.expression.RowNumber|`row_number`|Window function that returns the index for the row within the aggregation window|true|None|
spark.rapids.sql.expression.ScalaUDF| |User Defined Function, the UDF can choose to implement a RAPIDS accelerated interface to get better performance.|true|None|
spark.rapids.sql.expression.Second|`second`|Returns the second component of the string/timestamp|true|None|
+spark.rapids.sql.expression.SecondsToTimestamp|`timestamp_seconds`|Converts the number of seconds from unix epoch to a timestamp|true|None|
spark.rapids.sql.expression.Sequence|`sequence`|Sequence|true|None|
spark.rapids.sql.expression.ShiftLeft|`shiftleft`|Bitwise shift left (<<)|true|None|
spark.rapids.sql.expression.ShiftRight|`shiftright`|Bitwise shift right (>>)|true|None|
diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index 4a3de7c8070..3b3604b3613 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -9654,6 +9654,100 @@ are limited.
|
+| MicrosToTimestamp |
+`timestamp_micros` |
+Converts the number of microseconds from unix epoch to a timestamp |
+None |
+project |
+input |
+ |
+S |
+S |
+S |
+S |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+
+
+| result |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+PS UTC is only supported TZ for TIMESTAMP |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+
+
+| MillisToTimestamp |
+`timestamp_millis` |
+Converts the number of milliseconds from unix epoch to a timestamp |
+None |
+project |
+input |
+ |
+S |
+S |
+S |
+S |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+
+
+| result |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+PS UTC is only supported TZ for TIMESTAMP |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+
+
| Minute |
`minute` |
Returns the minute component of the string/timestamp |
@@ -9906,6 +10000,32 @@ are limited.
|
+| Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
| Murmur3Hash |
`hash` |
Murmur3 hash operator |
@@ -9953,32 +10073,6 @@ are limited.
|
-| Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
| NaNvl |
`nanvl` |
Evaluates to `left` iff left is not NaN, `right` otherwise |
@@ -10278,6 +10372,32 @@ are limited.
|
+| Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
| Or |
`or` |
Logical OR |
@@ -10410,32 +10530,6 @@ are limited.
|
-| Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
| PercentRank |
`percent_rank` |
Window function that returns the percent rank value within the aggregation window |
@@ -10730,6 +10824,32 @@ are limited.
|
+| Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
| PreciseTimestampConversion |
|
Expression used internally to convert the TimestampType to Long and back without losing precision, i.e. in microseconds. Used in time windowing |
@@ -10777,32 +10897,6 @@ are limited.
|
-| Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
| PromotePrecision |
|
PromotePrecision before arithmetic operations between DecimalType data |
@@ -12088,12 +12182,59 @@ are limited.
|
-| Sequence |
-`sequence` |
-Sequence |
-None |
-project |
-start |
+SecondsToTimestamp |
+`timestamp_seconds` |
+Converts the number of seconds from unix epoch to a timestamp |
+None |
+project |
+input |
+ |
+S |
+S |
+S |
+S |
+S |
+S |
+ |
+ |
+ |
+S |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+
+
+| result |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+PS UTC is only supported TZ for TIMESTAMP |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+
+
+| Sequence |
+`sequence` |
+Sequence |
+None |
+project |
+start |
|
S |
S |
@@ -12102,7 +12243,7 @@ are limited.
|
|
NS |
-NS |
+NS |
|
|
|
@@ -12313,6 +12454,32 @@ are limited.
|
+| Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
| ShiftRightUnsigned |
`shiftrightunsigned` |
Bitwise unsigned shift right (>>>) |
@@ -12381,32 +12548,6 @@ are limited.
|
-| Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
| Signum |
`sign`, `signum` |
Returns -1.0, 0.0 or 1.0 as expr is negative, 0 or positive |
@@ -12681,6 +12822,32 @@ are limited.
|
+| Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
| SortArray |
`sort_array` |
Returns a sorted array with the input array and the ascending / descending order |
@@ -12749,32 +12916,6 @@ are limited.
|
-| Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
| SortOrder |
|
Sort order |
@@ -13074,6 +13215,32 @@ are limited.
|
+| Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
| StringInstr |
`instr` |
Instr string operator |
@@ -13142,32 +13309,6 @@ are limited.
|
-| Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
| StringLPad |
`lpad` |
Pad a string on the left |
@@ -13435,6 +13576,32 @@ are limited.
|
+| Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
| StringRepeat |
`repeat` |
StringRepeat operator that repeats the given strings with numbers of times given by repeatTimes |
@@ -13503,32 +13670,6 @@ are limited.
|
-| Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
| StringReplace |
`replace` |
StringReplace operator |
@@ -13796,6 +13937,32 @@ are limited.
|
+| Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
| StringTranslate |
`translate` |
StringTranslate operator |
@@ -13885,32 +14052,6 @@ are limited.
|
-| Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
| StringTrim |
`trim` |
StringTrim operator |
@@ -14204,6 +14345,32 @@ are limited.
|
+| Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
| SubstringIndex |
`substring_index` |
substring_index operator |
@@ -14293,32 +14460,6 @@ are limited.
|
-| Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
| Subtract |
`-` |
Subtraction |
@@ -14631,6 +14772,32 @@ are limited.
|
+| Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
| TimeAdd |
|
Adds interval to timestamp |
@@ -14699,32 +14866,6 @@ are limited.
|
-| Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
| ToDegrees |
`degrees` |
Converts radians to degrees |
@@ -15023,6 +15164,32 @@ are limited.
|
+| Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
| UnaryMinus |
`negative` |
Negate a numeric value |
@@ -15113,32 +15280,6 @@ are limited.
|
-| Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
| UnaryPositive |
`positive` |
A numeric value with a + in front of it |
@@ -15396,6 +15537,32 @@ are limited.
|
+| Expression |
+SQL Functions(s) |
+Description |
+Notes |
+Context |
+Param/Output |
+BOOLEAN |
+BYTE |
+SHORT |
+INT |
+LONG |
+FLOAT |
+DOUBLE |
+DATE |
+TIMESTAMP |
+STRING |
+DECIMAL |
+NULL |
+BINARY |
+CALENDAR |
+ARRAY |
+MAP |
+STRUCT |
+UDT |
+
+
| Upper |
`upper`, `ucase` |
String uppercase operator |
@@ -15490,32 +15657,6 @@ are limited.
|
-| Expression |
-SQL Functions(s) |
-Description |
-Notes |
-Context |
-Param/Output |
-BOOLEAN |
-BYTE |
-SHORT |
-INT |
-LONG |
-FLOAT |
-DOUBLE |
-DATE |
-TIMESTAMP |
-STRING |
-DECIMAL |
-NULL |
-BINARY |
-CALENDAR |
-ARRAY |
-MAP |
-STRUCT |
-UDT |
-
-
| WindowExpression |
|
Calculates a return value for every input row of a table based on a group (or "window") of rows |
diff --git a/integration_tests/src/main/python/date_time_test.py b/integration_tests/src/main/python/date_time_test.py
index 27c5a77a73f..15865906ba4 100644
--- a/integration_tests/src/main/python/date_time_test.py
+++ b/integration_tests/src/main/python/date_time_test.py
@@ -438,3 +438,51 @@ def do_join_cast(spark):
.withColumnRenamed("filename", "r_filename")
return left.join(right, left.monthly_reporting_period == right.r_monthly_reporting_period, how='inner')
assert_gpu_and_cpu_are_equal_collect(do_join_cast)
+
+# (-62135510400, 253402214400) is the range of seconds that can be represented by timestamp_seconds
+# considering the influence of time zone.
+ts_float_gen = SetValuesGen(FloatType(), [0.0, -0.0, 1.0, -1.0, 1.234567, -1.234567, 16777215.0, float('inf'), float('-inf'), float('nan')])
+seconds_gens = [LongGen(min_val=-62135510400, max_val=253402214400), IntegerGen(), ShortGen(), ByteGen(),
+ DoubleGen(min_exp=0, max_exp=32), ts_float_gen, DecimalGen(16, 6), DecimalGen(13, 3), DecimalGen(10, 0), DecimalGen(7, -3), DecimalGen(6, 6)]
+@pytest.mark.parametrize('data_gen', seconds_gens, ids=idfn)
+def test_timestamp_seconds(data_gen):
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark : unary_op_df(spark, data_gen).selectExpr("timestamp_seconds(a)"))
+
+def test_timestamp_seconds_long_overflow():
+ assert_gpu_and_cpu_error(
+ lambda spark : unary_op_df(spark, long_gen).selectExpr("timestamp_seconds(a)").collect(),
+ conf={},
+ error_message='long overflow')
+
+@pytest.mark.parametrize('data_gen', [DecimalGen(7, 7), DecimalGen(20, 7)], ids=idfn)
+def test_timestamp_seconds_rounding_necessary(data_gen):
+ assert_gpu_and_cpu_error(
+ lambda spark : unary_op_df(spark, data_gen).selectExpr("timestamp_seconds(a)").collect(),
+ conf={},
+ error_message='Rounding necessary')
+
+@pytest.mark.parametrize('data_gen', [DecimalGen(19, 6), DecimalGen(20, 6)], ids=idfn)
+def test_timestamp_seconds_decimal_overflow(data_gen):
+ assert_gpu_and_cpu_error(
+ lambda spark : unary_op_df(spark, data_gen).selectExpr("timestamp_seconds(a)").collect(),
+ conf={},
+ error_message='Overflow')
+
+millis_gens = [LongGen(min_val=-62135510400000, max_val=253402214400000), IntegerGen(), ShortGen(), ByteGen()]
+@pytest.mark.parametrize('data_gen', millis_gens, ids=idfn)
+def test_timestamp_millis(data_gen):
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark : unary_op_df(spark, data_gen).selectExpr("timestamp_millis(a)"))
+
+def test_timestamp_millis_long_overflow():
+ assert_gpu_and_cpu_error(
+ lambda spark : unary_op_df(spark, long_gen).selectExpr("timestamp_millis(a)").collect(),
+ conf={},
+ error_message='long overflow')
+
+micros_gens = [LongGen(min_val=-62135510400000000, max_val=253402214400000000), IntegerGen(), ShortGen(), ByteGen()]
+@pytest.mark.parametrize('data_gen', micros_gens, ids=idfn)
+def test_timestamp_micros(data_gen):
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark : unary_op_df(spark, data_gen).selectExpr("timestamp_micros(a)"))
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index 709abacfcae..c4a52eee08f 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -1124,6 +1124,30 @@ object GpuOverrides extends Logging {
(a, conf, p, r) => new UnaryExprMeta[DayOfYear](a, conf, p, r) {
override def convertToGpu(child: Expression): GpuExpression = GpuDayOfYear(child)
}),
+ expr[SecondsToTimestamp](
+ "Converts the number of seconds from unix epoch to a timestamp",
+ ExprChecks.unaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP,
+ TypeSig.gpuNumeric, TypeSig.cpuNumeric),
+ (a, conf, p, r) => new UnaryExprMeta[SecondsToTimestamp](a, conf, p, r) {
+ override def convertToGpu(child: Expression): GpuExpression =
+ GpuSecondsToTimestamp(child)
+ }),
+ expr[MillisToTimestamp](
+ "Converts the number of milliseconds from unix epoch to a timestamp",
+ ExprChecks.unaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP,
+ TypeSig.integral, TypeSig.integral),
+ (a, conf, p, r) => new UnaryExprMeta[MillisToTimestamp](a, conf, p, r) {
+ override def convertToGpu(child: Expression): GpuExpression =
+ GpuMillisToTimestamp(child)
+ }),
+ expr[MicrosToTimestamp](
+ "Converts the number of microseconds from unix epoch to a timestamp",
+ ExprChecks.unaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP,
+ TypeSig.integral, TypeSig.integral),
+ (a, conf, p, r) => new UnaryExprMeta[MicrosToTimestamp](a, conf, p, r) {
+ override def convertToGpu(child: Expression): GpuExpression =
+ GpuMicrosToTimestamp(child)
+ }),
expr[Acos](
"Inverse cosine",
ExprChecks.mathUnaryWithAst,
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
index e4d1069a1d1..4d9daacb575 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
@@ -671,7 +671,7 @@ object TypeSig {
val integral: TypeSig = BYTE + SHORT + INT + LONG
/**
- * All numeric types fp + integral + DECIMAL_64
+ * All numeric types fp + integral + DECIMAL_128
*/
val gpuNumeric: TypeSig = integral + fp + DECIMAL_128
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala
index f845e5458c6..1591e15a863 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala
@@ -20,13 +20,14 @@ import java.time.ZoneId
import java.util.concurrent.TimeUnit
import ai.rapids.cudf.{BinaryOp, CaptureGroups, ColumnVector, ColumnView, DType, RegexProgram, Scalar}
-import com.nvidia.spark.rapids.{BinaryExprMeta, BoolUtils, DataFromReplacementRule, DateUtils, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuExpressionsUtils, GpuScalar, GpuUnaryExpression, RapidsConf, RapidsMeta}
+import com.nvidia.spark.rapids.{BinaryExprMeta, BoolUtils, DataFromReplacementRule, DateUtils, GpuBinaryExpression, GpuCast, GpuColumnVector, GpuExpression, GpuExpressionsUtils, GpuScalar, GpuUnaryExpression, RapidsConf, RapidsMeta}
import com.nvidia.spark.rapids.Arm._
import com.nvidia.spark.rapids.GpuOverrides.{extractStringLit, getTimeParserPolicy}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.shims.ShimBinaryExpression
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ExpectsInputTypes, Expression, FromUTCTimestamp, ImplicitCastInputTypes, NullIntolerant, TimeZoneAwareExpression}
+import org.apache.spark.sql.catalyst.util.DateTimeConstants
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -390,6 +391,176 @@ abstract class UnixTimeExprMeta[A <: BinaryExpression with TimeZoneAwareExpressi
}
}
+trait GpuNumberToTimestampUnaryExpression extends GpuUnaryExpression {
+
+ override def dataType: DataType = TimestampType
+ override def outputTypeOverride: DType = DType.TIMESTAMP_MICROSECONDS
+
+ /**
+ * Test whether if input * multiplier will cause Long-overflow. In Math.multiplyExact,
+ * if there is an integer-overflow, then it will throw an ArithmeticException "long overflow"
+ */
+ def checkLongMultiplicationOverflow(input: ColumnVector, multiplier: Long): Unit = {
+ withResource(input.max()) { maxValue =>
+ if (maxValue.isValid) {
+ Math.multiplyExact(maxValue.getLong, multiplier)
+ }
+ }
+ withResource(input.min()) { minValue =>
+ if (minValue.isValid) {
+ Math.multiplyExact(minValue.getLong, multiplier)
+ }
+ }
+ }
+
+ protected val convertTo : GpuColumnVector => ColumnVector
+
+ override def doColumnar(input: GpuColumnVector): ColumnVector = {
+ convertTo(input)
+ }
+}
+
+case class GpuSecondsToTimestamp(child: Expression) extends GpuNumberToTimestampUnaryExpression {
+
+ override def nullable: Boolean = child.dataType match {
+ case _: FloatType | _: DoubleType => true
+ case _ => child.nullable
+ }
+
+ private def checkRoundingNecessary(input: ColumnVector, dt: DecimalType): Unit = {
+ // SecondsToTimestamp supports decimals with a scale of 6 or less, which can be represented
+ // as microseconds. An exception will be thrown if the scale is more than 6.
+ val decimalTypeSupported = DType.create(DType.DTypeEnum.DECIMAL128, -6)
+ if (dt.scale > 6) {
+ // Match the behavior of `BigDecimal.longValueExact()`, if decimal is equal to the value
+ // casted to decimal128 with scale 6, then no rounding is necessary.
+ val decimalTypeAllScale = DType.create(DType.DTypeEnum.DECIMAL128, -dt.scale)
+ val castedAllScale = withResource(input.castTo(decimalTypeSupported)) { casted =>
+ casted.castTo(decimalTypeAllScale)
+ }
+ val isEqual = withResource(castedAllScale) { _ =>
+ withResource(input.castTo(decimalTypeAllScale)) { original =>
+ original.equalTo(castedAllScale)
+ }
+ }
+ val roundUnnecessity = withResource(isEqual) { _ =>
+ withResource(isEqual.all()) { all =>
+ all.isValid && all.getBoolean()
+ }
+ }
+ if (!roundUnnecessity) {
+ throw new ArithmeticException("Rounding necessary")
+ }
+ }
+ }
+
+ @transient
+ protected lazy val convertTo: GpuColumnVector => ColumnVector = child.dataType match {
+ case LongType =>
+ (input: GpuColumnVector) => {
+ checkLongMultiplicationOverflow(input.getBase, DateTimeConstants.MICROS_PER_SECOND)
+ val mul = withResource(Scalar.fromLong(DateTimeConstants.MICROS_PER_SECOND)) { scalar =>
+ input.getBase.mul(scalar)
+ }
+ withResource(mul) { _ =>
+ mul.asTimestampMicroseconds()
+ }
+ }
+ case DoubleType | FloatType =>
+ (input: GpuColumnVector) => {
+ GpuCast.doCast(input.getBase, input.dataType, TimestampType, false, false, false)
+ }
+ case dt: DecimalType =>
+ (input: GpuColumnVector) => {
+ checkRoundingNecessary(input.getBase, dt)
+ // Cast to decimal128 to avoid overflow, scala of 6 is enough after rounding check.
+ val decimalTypeSupported = DType.create(DType.DTypeEnum.DECIMAL128, -6)
+ val mul = withResource(input.getBase.castTo(decimalTypeSupported)) { decimal =>
+ withResource(Scalar.fromLong(DateTimeConstants.MICROS_PER_SECOND)) { scalar =>
+ decimal.mul(scalar, decimalTypeSupported)
+ }
+ }
+ // Match the behavior of `BigDecimal.longValueExact()`:
+ closeOnExcept(mul) { _ =>
+ val greaterThanMax = withResource(Scalar.fromLong(Long.MaxValue)) { longMax =>
+ mul.greaterThan(longMax)
+ }
+ val largerThanLongMax = withResource(greaterThanMax) { greaterThanMax =>
+ withResource(greaterThanMax.any()) { any =>
+ any.isValid && any.getBoolean()
+ }
+ }
+ lazy val smallerThanLongMin: Boolean = {
+ val lessThanMin = withResource(Scalar.fromLong(Long.MinValue)) { longMin =>
+ mul.lessThan(longMin)
+ }
+ withResource(lessThanMin) { lessThanMin =>
+ withResource(lessThanMin.any()) { any =>
+ any.isValid && any.getBoolean()
+ }
+ }
+ }
+ if (largerThanLongMax || smallerThanLongMin) {
+ throw new java.lang.ArithmeticException("Overflow")
+ }
+ }
+ val longs = withResource(mul) { _ =>
+ mul.castTo(DType.INT64)
+ }
+ withResource(longs) { _ =>
+ longs.asTimestampMicroseconds()
+ }
+ }
+ case IntegerType | ShortType | ByteType =>
+ (input: GpuColumnVector) =>
+ withResource(input.getBase.castTo(DType.INT64)) { longs =>
+ // Not possible to overflow for Int, Short and Byte
+ longs.asTimestampSeconds()
+ }
+ case _ =>
+ throw new UnsupportedOperationException(s"Unsupport type ${child.dataType} " +
+ s"for SecondsToTimestamp ")
+ }
+}
+
+case class GpuMillisToTimestamp(child: Expression) extends GpuNumberToTimestampUnaryExpression {
+ protected lazy val convertTo: GpuColumnVector => ColumnVector = child.dataType match {
+ case LongType =>
+ (input: GpuColumnVector) => {
+ checkLongMultiplicationOverflow(input.getBase, DateTimeConstants.MICROS_PER_MILLIS)
+ input.getBase.asTimestampMilliseconds()
+ }
+ case IntegerType | ShortType | ByteType =>
+ (input: GpuColumnVector) => {
+ withResource(input.getBase.castTo(DType.INT64)) { longs =>
+ checkLongMultiplicationOverflow(longs, DateTimeConstants.MICROS_PER_MILLIS)
+ longs.asTimestampMilliseconds()
+ }
+ }
+ case _ =>
+ throw new UnsupportedOperationException(s"Unsupport type ${child.dataType} " +
+ s"for MillisToTimestamp ")
+ }
+}
+
+case class GpuMicrosToTimestamp(child: Expression) extends GpuNumberToTimestampUnaryExpression {
+ protected lazy val convertTo: GpuColumnVector => ColumnVector = child.dataType match {
+ case LongType =>
+ (input: GpuColumnVector) => {
+ input.getBase.asTimestampMicroseconds()
+ }
+ case IntegerType | ShortType | ByteType =>
+ (input: GpuColumnVector) => {
+ withResource(input.getBase.castTo(DType.INT64)) { longs =>
+ longs.asTimestampMicroseconds()
+ }
+ }
+ case _ =>
+ throw new UnsupportedOperationException(s"Unsupport type ${child.dataType} " +
+ s"for MicrosToTimestamp ")
+ }
+}
+
sealed trait TimeParserPolicy extends Serializable
object LegacyTimeParserPolicy extends TimeParserPolicy
object ExceptionTimeParserPolicy extends TimeParserPolicy
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala
index 68eb6c4227c..5c4138611ab 100644
--- a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala
@@ -1294,6 +1294,66 @@ trait SparkQueryCompareTestSuite extends AnyFunSuite {
).toDF("dates")
}
+ def doubleTimestampSecondsDf(session: SparkSession): DataFrame = {
+ import session.sqlContext.implicits._
+ // some cases out of the range (-62135510400, 253402214400), which are not covered by IT
+ Seq[Double](
+ 253402214400.000001d,
+ 269999999999.999999d,
+ -62135510400.000001d,
+ -79999999999.999999d
+ ).toDF("doubles")
+ }
+
+ def decimalTimestampSecondsDf(session: SparkSession): DataFrame = {
+ import session.sqlContext.implicits._
+ Seq[BigDecimal](
+ BigDecimal("253402214400.000001"),
+ BigDecimal("269999999999.999999"),
+ BigDecimal("-62135510400.000001"),
+ BigDecimal("-79999999999.999999")
+ ).toDF("decimals")
+ }
+
+ def longTimestampSecondsDf(session: SparkSession): DataFrame = {
+ import session.sqlContext.implicits._
+ Seq[java.lang.Long](
+ 253402214401L,
+ 269999999999L,
+ -62135510401L,
+ -79999999999L
+ ).toDF("longs")
+ }
+
+ def longTimestampMillisDf(session: SparkSession): DataFrame = {
+ import session.sqlContext.implicits._
+ Seq[java.lang.Long](
+ 253402214401000L,
+ 269999999999999L,
+ -62135510401000L,
+ -79999999999999L
+ ).toDF("longs")
+ }
+
+ def longTimestampMicrosDf(session: SparkSession): DataFrame = {
+ import session.sqlContext.implicits._
+ Seq[java.lang.Long](
+ 253402214401000000L,
+ 269999999999999999L,
+ -62135510401000000L,
+ -79999999999999999L,
+ Long.MaxValue
+ ).toDF("longs")
+ }
+
+ def longTimestampMicrosLongOverflowDf(session: SparkSession): DataFrame = {
+ import session.sqlContext.implicits._
+ Seq[java.lang.Long](
+ Long.MinValue,
+ -9223183700000000000L
+ ).toDF("longs")
+ }
+
def datesPostEpochDf(session: SparkSession): DataFrame = {
import session.sqlContext.implicits._
Seq(
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/TimeOperatorsSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/TimeOperatorsSuite.scala
index 09958a2c6b4..18eae5bf70b 100644
--- a/tests/src/test/scala/com/nvidia/spark/rapids/TimeOperatorsSuite.scala
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/TimeOperatorsSuite.scala
@@ -16,6 +16,8 @@
package com.nvidia.spark.rapids
+import java.lang.RuntimeException
+
import org.apache.spark.SparkConf
import org.apache.spark.sql.functions._
@@ -34,4 +36,36 @@ class TimeOperatorsSuite extends SparkQueryCompareTestSuite {
frame => frame.select(from_unixtime(col("dates"),"dd/LL/yy HH:mm:ss.SSSSSS"))
}
+ // some cases for timestamp_seconds not covered by integration tests
+ testSparkResultsAreEqual(
+ "Test timestamp_seconds from Large Double type", doubleTimestampSecondsDf) {
+ frame => frame.select(timestamp_seconds(col("doubles")))
+ }
+
+ testSparkResultsAreEqual(
+ "Test timestamp_seconds from Large Decimal type", decimalTimestampSecondsDf) {
+ frame => frame.select(timestamp_seconds(col("decimals")))
+ }
+
+ testSparkResultsAreEqual(
+ "Test timestamp_seconds from large Long type", longTimestampSecondsDf) {
+ frame => frame.select(timestamp_seconds(col("longs")))
+ }
+
+ testSparkResultsAreEqual(
+ "Test timestamp_millis from large Long type", longTimestampMillisDf) {
+ frame => frame.selectExpr("timestamp_millis(longs)")
+ }
+
+ testSparkResultsAreEqual(
+ "Test timestamp_micros from large Long type", longTimestampMicrosDf) {
+ frame => frame.selectExpr("timestamp_micros(longs)")
+ }
+
+ testBothCpuGpuExpectedException[RuntimeException](
+ "Test timestamp_micros from long near Long.minValue: long overflow",
+ e => e.getMessage.contains("ArithmeticException"),
+ longTimestampMicrosLongOverflowDf) {
+ frame => frame.selectExpr("timestamp_micros(longs)")
+ }
}
diff --git a/tools/generated_files/operatorsScore.csv b/tools/generated_files/operatorsScore.csv
index 502d36f3ff9..532ec2d9b02 100644
--- a/tools/generated_files/operatorsScore.csv
+++ b/tools/generated_files/operatorsScore.csv
@@ -163,6 +163,8 @@ MapKeys,4
MapValues,4
Max,4
Md5,4
+MicrosToTimestamp,4
+MillisToTimestamp,4
Min,4
Minute,4
MonotonicallyIncreasingID,4
@@ -201,6 +203,7 @@ RowNumber,4
ScalaUDF,4
ScalarSubquery,4
Second,4
+SecondsToTimestamp,4
Sequence,4
ShiftLeft,4
ShiftRight,4
diff --git a/tools/generated_files/supportedExprs.csv b/tools/generated_files/supportedExprs.csv
index 09331b4f985..5bc63ac003d 100644
--- a/tools/generated_files/supportedExprs.csv
+++ b/tools/generated_files/supportedExprs.csv
@@ -337,6 +337,10 @@ MapValues,S,`map_values`,None,project,input,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,
MapValues,S,`map_values`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA
Md5,S,`md5`,None,project,input,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA
Md5,S,`md5`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA
+MicrosToTimestamp,S,`timestamp_micros`,None,project,input,NA,S,S,S,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
+MicrosToTimestamp,S,`timestamp_micros`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA
+MillisToTimestamp,S,`timestamp_millis`,None,project,input,NA,S,S,S,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
+MillisToTimestamp,S,`timestamp_millis`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA
Minute,S,`minute`,None,project,input,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA
Minute,S,`minute`,None,project,result,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
MonotonicallyIncreasingID,S,`monotonically_increasing_id`,None,project,result,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
@@ -437,6 +441,8 @@ ScalaUDF,S, ,None,project,param,S,S,S,S,S,S,S,S,PS,S,S,S,S,S,PS,PS,PS,NS
ScalaUDF,S, ,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,S,S,S,PS,PS,PS,NS
Second,S,`second`,None,project,input,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA
Second,S,`second`,None,project,result,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
+SecondsToTimestamp,S,`timestamp_seconds`,None,project,input,NA,S,S,S,S,S,S,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA
+SecondsToTimestamp,S,`timestamp_seconds`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA
Sequence,S,`sequence`,None,project,start,NA,S,S,S,S,NA,NA,NS,NS,NA,NA,NA,NA,NA,NA,NA,NA,NA
Sequence,S,`sequence`,None,project,stop,NA,S,S,S,S,NA,NA,NS,NS,NA,NA,NA,NA,NA,NA,NA,NA,NA
Sequence,S,`sequence`,None,project,step,NA,S,S,S,S,NA,NA,NA,NA,NA,NA,NA,NA,NS,NA,NA,NA,NA