diff --git a/docs/additional-functionality/advanced_configs.md b/docs/additional-functionality/advanced_configs.md index 298fbbfae61..313ccb4b466 100644 --- a/docs/additional-functionality/advanced_configs.md +++ b/docs/additional-functionality/advanced_configs.md @@ -282,6 +282,8 @@ Name | SQL Function(s) | Description | Default Value | Notes spark.rapids.sql.expression.MapKeys|`map_keys`|Returns an unordered array containing the keys of the map|true|None| spark.rapids.sql.expression.MapValues|`map_values`|Returns an unordered array containing the values of the map|true|None| spark.rapids.sql.expression.Md5|`md5`|MD5 hash operator|true|None| +spark.rapids.sql.expression.MicrosToTimestamp|`timestamp_micros`|Converts the number of microseconds from unix epoch to a timestamp|true|None| +spark.rapids.sql.expression.MillisToTimestamp|`timestamp_millis`|Converts the number of milliseconds from unix epoch to a timestamp|true|None| spark.rapids.sql.expression.Minute|`minute`|Returns the minute component of the string/timestamp|true|None| spark.rapids.sql.expression.MonotonicallyIncreasingID|`monotonically_increasing_id`|Returns monotonically increasing 64-bit integers|true|None| spark.rapids.sql.expression.Month|`month`|Returns the month from a date or timestamp|true|None| @@ -316,6 +318,7 @@ Name | SQL Function(s) | Description | Default Value | Notes spark.rapids.sql.expression.RowNumber|`row_number`|Window function that returns the index for the row within the aggregation window|true|None| spark.rapids.sql.expression.ScalaUDF| |User Defined Function, the UDF can choose to implement a RAPIDS accelerated interface to get better performance.|true|None| spark.rapids.sql.expression.Second|`second`|Returns the second component of the string/timestamp|true|None| +spark.rapids.sql.expression.SecondsToTimestamp|`timestamp_seconds`|Converts the number of seconds from unix epoch to a timestamp|true|None| spark.rapids.sql.expression.Sequence|`sequence`|Sequence|true|None| spark.rapids.sql.expression.ShiftLeft|`shiftleft`|Bitwise shift left (<<)|true|None| spark.rapids.sql.expression.ShiftRight|`shiftright`|Bitwise shift right (>>)|true|None| diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 4a3de7c8070..3b3604b3613 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -9654,6 +9654,100 @@ are limited. +MicrosToTimestamp +`timestamp_micros` +Converts the number of microseconds from unix epoch to a timestamp +None +project +input + +S +S +S +S + + + + + + + + + + + + + + + +result + + + + + + + + +PS
UTC is only supported TZ for TIMESTAMP
+ + + + + + + + + + + +MillisToTimestamp +`timestamp_millis` +Converts the number of milliseconds from unix epoch to a timestamp +None +project +input + +S +S +S +S + + + + + + + + + + + + + + + +result + + + + + + + + +PS
UTC is only supported TZ for TIMESTAMP
+ + + + + + + + + + + Minute `minute` Returns the minute component of the string/timestamp @@ -9906,6 +10000,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Murmur3Hash `hash` Murmur3 hash operator @@ -9953,32 +10073,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - NaNvl `nanvl` Evaluates to `left` iff left is not NaN, `right` otherwise @@ -10278,6 +10372,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Or `or` Logical OR @@ -10410,32 +10530,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - PercentRank `percent_rank` Window function that returns the percent rank value within the aggregation window @@ -10730,6 +10824,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + PreciseTimestampConversion Expression used internally to convert the TimestampType to Long and back without losing precision, i.e. in microseconds. Used in time windowing @@ -10777,32 +10897,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - PromotePrecision PromotePrecision before arithmetic operations between DecimalType data @@ -12088,12 +12182,59 @@ are limited. -Sequence -`sequence` -Sequence -None -project -start +SecondsToTimestamp +`timestamp_seconds` +Converts the number of seconds from unix epoch to a timestamp +None +project +input + +S +S +S +S +S +S + + + +S + + + + + + + + + +result + + + + + + + + +PS
UTC is only supported TZ for TIMESTAMP
+ + + + + + + + + + + +Sequence +`sequence` +Sequence +None +project +start S S @@ -12102,7 +12243,7 @@ are limited. NS -NS +NS @@ -12313,6 +12454,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + ShiftRightUnsigned `shiftrightunsigned` Bitwise unsigned shift right (>>>) @@ -12381,32 +12548,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - Signum `sign`, `signum` Returns -1.0, 0.0 or 1.0 as expr is negative, 0 or positive @@ -12681,6 +12822,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + SortArray `sort_array` Returns a sorted array with the input array and the ascending / descending order @@ -12749,32 +12916,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - SortOrder Sort order @@ -13074,6 +13215,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + StringInstr `instr` Instr string operator @@ -13142,32 +13309,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - StringLPad `lpad` Pad a string on the left @@ -13435,6 +13576,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + StringRepeat `repeat` StringRepeat operator that repeats the given strings with numbers of times given by repeatTimes @@ -13503,32 +13670,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - StringReplace `replace` StringReplace operator @@ -13796,6 +13937,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + StringTranslate `translate` StringTranslate operator @@ -13885,32 +14052,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - StringTrim `trim` StringTrim operator @@ -14204,6 +14345,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + SubstringIndex `substring_index` substring_index operator @@ -14293,32 +14460,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - Subtract `-` Subtraction @@ -14631,6 +14772,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + TimeAdd Adds interval to timestamp @@ -14699,32 +14866,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - ToDegrees `degrees` Converts radians to degrees @@ -15023,6 +15164,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + UnaryMinus `negative` Negate a numeric value @@ -15113,32 +15280,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - UnaryPositive `positive` A numeric value with a + in front of it @@ -15396,6 +15537,32 @@ are limited. +Expression +SQL Functions(s) +Description +Notes +Context +Param/Output +BOOLEAN +BYTE +SHORT +INT +LONG +FLOAT +DOUBLE +DATE +TIMESTAMP +STRING +DECIMAL +NULL +BINARY +CALENDAR +ARRAY +MAP +STRUCT +UDT + + Upper `upper`, `ucase` String uppercase operator @@ -15490,32 +15657,6 @@ are limited. -Expression -SQL Functions(s) -Description -Notes -Context -Param/Output -BOOLEAN -BYTE -SHORT -INT -LONG -FLOAT -DOUBLE -DATE -TIMESTAMP -STRING -DECIMAL -NULL -BINARY -CALENDAR -ARRAY -MAP -STRUCT -UDT - - WindowExpression Calculates a return value for every input row of a table based on a group (or "window") of rows diff --git a/integration_tests/src/main/python/date_time_test.py b/integration_tests/src/main/python/date_time_test.py index 27c5a77a73f..15865906ba4 100644 --- a/integration_tests/src/main/python/date_time_test.py +++ b/integration_tests/src/main/python/date_time_test.py @@ -438,3 +438,51 @@ def do_join_cast(spark): .withColumnRenamed("filename", "r_filename") return left.join(right, left.monthly_reporting_period == right.r_monthly_reporting_period, how='inner') assert_gpu_and_cpu_are_equal_collect(do_join_cast) + +# (-62135510400, 253402214400) is the range of seconds that can be represented by timestamp_seconds +# considering the influence of time zone. +ts_float_gen = SetValuesGen(FloatType(), [0.0, -0.0, 1.0, -1.0, 1.234567, -1.234567, 16777215.0, float('inf'), float('-inf'), float('nan')]) +seconds_gens = [LongGen(min_val=-62135510400, max_val=253402214400), IntegerGen(), ShortGen(), ByteGen(), + DoubleGen(min_exp=0, max_exp=32), ts_float_gen, DecimalGen(16, 6), DecimalGen(13, 3), DecimalGen(10, 0), DecimalGen(7, -3), DecimalGen(6, 6)] +@pytest.mark.parametrize('data_gen', seconds_gens, ids=idfn) +def test_timestamp_seconds(data_gen): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr("timestamp_seconds(a)")) + +def test_timestamp_seconds_long_overflow(): + assert_gpu_and_cpu_error( + lambda spark : unary_op_df(spark, long_gen).selectExpr("timestamp_seconds(a)").collect(), + conf={}, + error_message='long overflow') + +@pytest.mark.parametrize('data_gen', [DecimalGen(7, 7), DecimalGen(20, 7)], ids=idfn) +def test_timestamp_seconds_rounding_necessary(data_gen): + assert_gpu_and_cpu_error( + lambda spark : unary_op_df(spark, data_gen).selectExpr("timestamp_seconds(a)").collect(), + conf={}, + error_message='Rounding necessary') + +@pytest.mark.parametrize('data_gen', [DecimalGen(19, 6), DecimalGen(20, 6)], ids=idfn) +def test_timestamp_seconds_decimal_overflow(data_gen): + assert_gpu_and_cpu_error( + lambda spark : unary_op_df(spark, data_gen).selectExpr("timestamp_seconds(a)").collect(), + conf={}, + error_message='Overflow') + +millis_gens = [LongGen(min_val=-62135510400000, max_val=253402214400000), IntegerGen(), ShortGen(), ByteGen()] +@pytest.mark.parametrize('data_gen', millis_gens, ids=idfn) +def test_timestamp_millis(data_gen): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr("timestamp_millis(a)")) + +def test_timestamp_millis_long_overflow(): + assert_gpu_and_cpu_error( + lambda spark : unary_op_df(spark, long_gen).selectExpr("timestamp_millis(a)").collect(), + conf={}, + error_message='long overflow') + +micros_gens = [LongGen(min_val=-62135510400000000, max_val=253402214400000000), IntegerGen(), ShortGen(), ByteGen()] +@pytest.mark.parametrize('data_gen', micros_gens, ids=idfn) +def test_timestamp_micros(data_gen): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).selectExpr("timestamp_micros(a)")) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 709abacfcae..c4a52eee08f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -1124,6 +1124,30 @@ object GpuOverrides extends Logging { (a, conf, p, r) => new UnaryExprMeta[DayOfYear](a, conf, p, r) { override def convertToGpu(child: Expression): GpuExpression = GpuDayOfYear(child) }), + expr[SecondsToTimestamp]( + "Converts the number of seconds from unix epoch to a timestamp", + ExprChecks.unaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP, + TypeSig.gpuNumeric, TypeSig.cpuNumeric), + (a, conf, p, r) => new UnaryExprMeta[SecondsToTimestamp](a, conf, p, r) { + override def convertToGpu(child: Expression): GpuExpression = + GpuSecondsToTimestamp(child) + }), + expr[MillisToTimestamp]( + "Converts the number of milliseconds from unix epoch to a timestamp", + ExprChecks.unaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP, + TypeSig.integral, TypeSig.integral), + (a, conf, p, r) => new UnaryExprMeta[MillisToTimestamp](a, conf, p, r) { + override def convertToGpu(child: Expression): GpuExpression = + GpuMillisToTimestamp(child) + }), + expr[MicrosToTimestamp]( + "Converts the number of microseconds from unix epoch to a timestamp", + ExprChecks.unaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP, + TypeSig.integral, TypeSig.integral), + (a, conf, p, r) => new UnaryExprMeta[MicrosToTimestamp](a, conf, p, r) { + override def convertToGpu(child: Expression): GpuExpression = + GpuMicrosToTimestamp(child) + }), expr[Acos]( "Inverse cosine", ExprChecks.mathUnaryWithAst, diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index e4d1069a1d1..4d9daacb575 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -671,7 +671,7 @@ object TypeSig { val integral: TypeSig = BYTE + SHORT + INT + LONG /** - * All numeric types fp + integral + DECIMAL_64 + * All numeric types fp + integral + DECIMAL_128 */ val gpuNumeric: TypeSig = integral + fp + DECIMAL_128 diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala index f845e5458c6..1591e15a863 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala @@ -20,13 +20,14 @@ import java.time.ZoneId import java.util.concurrent.TimeUnit import ai.rapids.cudf.{BinaryOp, CaptureGroups, ColumnVector, ColumnView, DType, RegexProgram, Scalar} -import com.nvidia.spark.rapids.{BinaryExprMeta, BoolUtils, DataFromReplacementRule, DateUtils, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuExpressionsUtils, GpuScalar, GpuUnaryExpression, RapidsConf, RapidsMeta} +import com.nvidia.spark.rapids.{BinaryExprMeta, BoolUtils, DataFromReplacementRule, DateUtils, GpuBinaryExpression, GpuCast, GpuColumnVector, GpuExpression, GpuExpressionsUtils, GpuScalar, GpuUnaryExpression, RapidsConf, RapidsMeta} import com.nvidia.spark.rapids.Arm._ import com.nvidia.spark.rapids.GpuOverrides.{extractStringLit, getTimeParserPolicy} import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.shims.ShimBinaryExpression import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ExpectsInputTypes, Expression, FromUTCTimestamp, ImplicitCastInputTypes, NullIntolerant, TimeZoneAwareExpression} +import org.apache.spark.sql.catalyst.util.DateTimeConstants import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch @@ -390,6 +391,176 @@ abstract class UnixTimeExprMeta[A <: BinaryExpression with TimeZoneAwareExpressi } } +trait GpuNumberToTimestampUnaryExpression extends GpuUnaryExpression { + + override def dataType: DataType = TimestampType + override def outputTypeOverride: DType = DType.TIMESTAMP_MICROSECONDS + + /** + * Test whether if input * multiplier will cause Long-overflow. In Math.multiplyExact, + * if there is an integer-overflow, then it will throw an ArithmeticException "long overflow" + */ + def checkLongMultiplicationOverflow(input: ColumnVector, multiplier: Long): Unit = { + withResource(input.max()) { maxValue => + if (maxValue.isValid) { + Math.multiplyExact(maxValue.getLong, multiplier) + } + } + withResource(input.min()) { minValue => + if (minValue.isValid) { + Math.multiplyExact(minValue.getLong, multiplier) + } + } + } + + protected val convertTo : GpuColumnVector => ColumnVector + + override def doColumnar(input: GpuColumnVector): ColumnVector = { + convertTo(input) + } +} + +case class GpuSecondsToTimestamp(child: Expression) extends GpuNumberToTimestampUnaryExpression { + + override def nullable: Boolean = child.dataType match { + case _: FloatType | _: DoubleType => true + case _ => child.nullable + } + + private def checkRoundingNecessary(input: ColumnVector, dt: DecimalType): Unit = { + // SecondsToTimestamp supports decimals with a scale of 6 or less, which can be represented + // as microseconds. An exception will be thrown if the scale is more than 6. + val decimalTypeSupported = DType.create(DType.DTypeEnum.DECIMAL128, -6) + if (dt.scale > 6) { + // Match the behavior of `BigDecimal.longValueExact()`, if decimal is equal to the value + // casted to decimal128 with scale 6, then no rounding is necessary. + val decimalTypeAllScale = DType.create(DType.DTypeEnum.DECIMAL128, -dt.scale) + val castedAllScale = withResource(input.castTo(decimalTypeSupported)) { casted => + casted.castTo(decimalTypeAllScale) + } + val isEqual = withResource(castedAllScale) { _ => + withResource(input.castTo(decimalTypeAllScale)) { original => + original.equalTo(castedAllScale) + } + } + val roundUnnecessity = withResource(isEqual) { _ => + withResource(isEqual.all()) { all => + all.isValid && all.getBoolean() + } + } + if (!roundUnnecessity) { + throw new ArithmeticException("Rounding necessary") + } + } + } + + @transient + protected lazy val convertTo: GpuColumnVector => ColumnVector = child.dataType match { + case LongType => + (input: GpuColumnVector) => { + checkLongMultiplicationOverflow(input.getBase, DateTimeConstants.MICROS_PER_SECOND) + val mul = withResource(Scalar.fromLong(DateTimeConstants.MICROS_PER_SECOND)) { scalar => + input.getBase.mul(scalar) + } + withResource(mul) { _ => + mul.asTimestampMicroseconds() + } + } + case DoubleType | FloatType => + (input: GpuColumnVector) => { + GpuCast.doCast(input.getBase, input.dataType, TimestampType, false, false, false) + } + case dt: DecimalType => + (input: GpuColumnVector) => { + checkRoundingNecessary(input.getBase, dt) + // Cast to decimal128 to avoid overflow, scala of 6 is enough after rounding check. + val decimalTypeSupported = DType.create(DType.DTypeEnum.DECIMAL128, -6) + val mul = withResource(input.getBase.castTo(decimalTypeSupported)) { decimal => + withResource(Scalar.fromLong(DateTimeConstants.MICROS_PER_SECOND)) { scalar => + decimal.mul(scalar, decimalTypeSupported) + } + } + // Match the behavior of `BigDecimal.longValueExact()`: + closeOnExcept(mul) { _ => + val greaterThanMax = withResource(Scalar.fromLong(Long.MaxValue)) { longMax => + mul.greaterThan(longMax) + } + val largerThanLongMax = withResource(greaterThanMax) { greaterThanMax => + withResource(greaterThanMax.any()) { any => + any.isValid && any.getBoolean() + } + } + lazy val smallerThanLongMin: Boolean = { + val lessThanMin = withResource(Scalar.fromLong(Long.MinValue)) { longMin => + mul.lessThan(longMin) + } + withResource(lessThanMin) { lessThanMin => + withResource(lessThanMin.any()) { any => + any.isValid && any.getBoolean() + } + } + } + if (largerThanLongMax || smallerThanLongMin) { + throw new java.lang.ArithmeticException("Overflow") + } + } + val longs = withResource(mul) { _ => + mul.castTo(DType.INT64) + } + withResource(longs) { _ => + longs.asTimestampMicroseconds() + } + } + case IntegerType | ShortType | ByteType => + (input: GpuColumnVector) => + withResource(input.getBase.castTo(DType.INT64)) { longs => + // Not possible to overflow for Int, Short and Byte + longs.asTimestampSeconds() + } + case _ => + throw new UnsupportedOperationException(s"Unsupport type ${child.dataType} " + + s"for SecondsToTimestamp ") + } +} + +case class GpuMillisToTimestamp(child: Expression) extends GpuNumberToTimestampUnaryExpression { + protected lazy val convertTo: GpuColumnVector => ColumnVector = child.dataType match { + case LongType => + (input: GpuColumnVector) => { + checkLongMultiplicationOverflow(input.getBase, DateTimeConstants.MICROS_PER_MILLIS) + input.getBase.asTimestampMilliseconds() + } + case IntegerType | ShortType | ByteType => + (input: GpuColumnVector) => { + withResource(input.getBase.castTo(DType.INT64)) { longs => + checkLongMultiplicationOverflow(longs, DateTimeConstants.MICROS_PER_MILLIS) + longs.asTimestampMilliseconds() + } + } + case _ => + throw new UnsupportedOperationException(s"Unsupport type ${child.dataType} " + + s"for MillisToTimestamp ") + } +} + +case class GpuMicrosToTimestamp(child: Expression) extends GpuNumberToTimestampUnaryExpression { + protected lazy val convertTo: GpuColumnVector => ColumnVector = child.dataType match { + case LongType => + (input: GpuColumnVector) => { + input.getBase.asTimestampMicroseconds() + } + case IntegerType | ShortType | ByteType => + (input: GpuColumnVector) => { + withResource(input.getBase.castTo(DType.INT64)) { longs => + longs.asTimestampMicroseconds() + } + } + case _ => + throw new UnsupportedOperationException(s"Unsupport type ${child.dataType} " + + s"for MicrosToTimestamp ") + } +} + sealed trait TimeParserPolicy extends Serializable object LegacyTimeParserPolicy extends TimeParserPolicy object ExceptionTimeParserPolicy extends TimeParserPolicy diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala index 68eb6c4227c..5c4138611ab 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala @@ -1294,6 +1294,66 @@ trait SparkQueryCompareTestSuite extends AnyFunSuite { ).toDF("dates") } + def doubleTimestampSecondsDf(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + // some cases out of the range (-62135510400, 253402214400), which are not covered by IT + Seq[Double]( + 253402214400.000001d, + 269999999999.999999d, + -62135510400.000001d, + -79999999999.999999d + ).toDF("doubles") + } + + def decimalTimestampSecondsDf(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq[BigDecimal]( + BigDecimal("253402214400.000001"), + BigDecimal("269999999999.999999"), + BigDecimal("-62135510400.000001"), + BigDecimal("-79999999999.999999") + ).toDF("decimals") + } + + def longTimestampSecondsDf(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq[java.lang.Long]( + 253402214401L, + 269999999999L, + -62135510401L, + -79999999999L + ).toDF("longs") + } + + def longTimestampMillisDf(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq[java.lang.Long]( + 253402214401000L, + 269999999999999L, + -62135510401000L, + -79999999999999L + ).toDF("longs") + } + + def longTimestampMicrosDf(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq[java.lang.Long]( + 253402214401000000L, + 269999999999999999L, + -62135510401000000L, + -79999999999999999L, + Long.MaxValue + ).toDF("longs") + } + + def longTimestampMicrosLongOverflowDf(session: SparkSession): DataFrame = { + import session.sqlContext.implicits._ + Seq[java.lang.Long]( + Long.MinValue, + -9223183700000000000L + ).toDF("longs") + } + def datesPostEpochDf(session: SparkSession): DataFrame = { import session.sqlContext.implicits._ Seq( diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/TimeOperatorsSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/TimeOperatorsSuite.scala index 09958a2c6b4..18eae5bf70b 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/TimeOperatorsSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/TimeOperatorsSuite.scala @@ -16,6 +16,8 @@ package com.nvidia.spark.rapids +import java.lang.RuntimeException + import org.apache.spark.SparkConf import org.apache.spark.sql.functions._ @@ -34,4 +36,36 @@ class TimeOperatorsSuite extends SparkQueryCompareTestSuite { frame => frame.select(from_unixtime(col("dates"),"dd/LL/yy HH:mm:ss.SSSSSS")) } + // some cases for timestamp_seconds not covered by integration tests + testSparkResultsAreEqual( + "Test timestamp_seconds from Large Double type", doubleTimestampSecondsDf) { + frame => frame.select(timestamp_seconds(col("doubles"))) + } + + testSparkResultsAreEqual( + "Test timestamp_seconds from Large Decimal type", decimalTimestampSecondsDf) { + frame => frame.select(timestamp_seconds(col("decimals"))) + } + + testSparkResultsAreEqual( + "Test timestamp_seconds from large Long type", longTimestampSecondsDf) { + frame => frame.select(timestamp_seconds(col("longs"))) + } + + testSparkResultsAreEqual( + "Test timestamp_millis from large Long type", longTimestampMillisDf) { + frame => frame.selectExpr("timestamp_millis(longs)") + } + + testSparkResultsAreEqual( + "Test timestamp_micros from large Long type", longTimestampMicrosDf) { + frame => frame.selectExpr("timestamp_micros(longs)") + } + + testBothCpuGpuExpectedException[RuntimeException]( + "Test timestamp_micros from long near Long.minValue: long overflow", + e => e.getMessage.contains("ArithmeticException"), + longTimestampMicrosLongOverflowDf) { + frame => frame.selectExpr("timestamp_micros(longs)") + } } diff --git a/tools/generated_files/operatorsScore.csv b/tools/generated_files/operatorsScore.csv index 502d36f3ff9..532ec2d9b02 100644 --- a/tools/generated_files/operatorsScore.csv +++ b/tools/generated_files/operatorsScore.csv @@ -163,6 +163,8 @@ MapKeys,4 MapValues,4 Max,4 Md5,4 +MicrosToTimestamp,4 +MillisToTimestamp,4 Min,4 Minute,4 MonotonicallyIncreasingID,4 @@ -201,6 +203,7 @@ RowNumber,4 ScalaUDF,4 ScalarSubquery,4 Second,4 +SecondsToTimestamp,4 Sequence,4 ShiftLeft,4 ShiftRight,4 diff --git a/tools/generated_files/supportedExprs.csv b/tools/generated_files/supportedExprs.csv index 09331b4f985..5bc63ac003d 100644 --- a/tools/generated_files/supportedExprs.csv +++ b/tools/generated_files/supportedExprs.csv @@ -337,6 +337,10 @@ MapValues,S,`map_values`,None,project,input,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA, MapValues,S,`map_values`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA Md5,S,`md5`,None,project,input,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA Md5,S,`md5`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA +MicrosToTimestamp,S,`timestamp_micros`,None,project,input,NA,S,S,S,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +MicrosToTimestamp,S,`timestamp_micros`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA +MillisToTimestamp,S,`timestamp_millis`,None,project,input,NA,S,S,S,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +MillisToTimestamp,S,`timestamp_millis`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA Minute,S,`minute`,None,project,input,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA Minute,S,`minute`,None,project,result,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA MonotonicallyIncreasingID,S,`monotonically_increasing_id`,None,project,result,NA,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA @@ -437,6 +441,8 @@ ScalaUDF,S, ,None,project,param,S,S,S,S,S,S,S,S,PS,S,S,S,S,S,PS,PS,PS,NS ScalaUDF,S, ,None,project,result,S,S,S,S,S,S,S,S,PS,S,S,S,S,S,PS,PS,PS,NS Second,S,`second`,None,project,input,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA Second,S,`second`,None,project,result,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA +SecondsToTimestamp,S,`timestamp_seconds`,None,project,input,NA,S,S,S,S,S,S,NA,NA,NA,S,NA,NA,NA,NA,NA,NA,NA +SecondsToTimestamp,S,`timestamp_seconds`,None,project,result,NA,NA,NA,NA,NA,NA,NA,NA,PS,NA,NA,NA,NA,NA,NA,NA,NA,NA Sequence,S,`sequence`,None,project,start,NA,S,S,S,S,NA,NA,NS,NS,NA,NA,NA,NA,NA,NA,NA,NA,NA Sequence,S,`sequence`,None,project,stop,NA,S,S,S,S,NA,NA,NS,NS,NA,NA,NA,NA,NA,NA,NA,NA,NA Sequence,S,`sequence`,None,project,step,NA,S,S,S,S,NA,NA,NA,NA,NA,NA,NA,NA,NS,NA,NA,NA,NA