diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DateUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DateUtils.scala index 018f59bbfa8..7834bdeaed6 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DateUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DateUtils.scala @@ -20,6 +20,9 @@ import java.time.LocalDate import scala.collection.mutable.ListBuffer +import ai.rapids.cudf.{DType, Scalar} +import com.nvidia.spark.rapids.VersionUtils.isSpark320OrLater + import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils.localDateToDays @@ -57,7 +60,9 @@ object DateUtils { val YESTERDAY = "yesterday" val TOMORROW = "tomorrow" - def specialDatesDays: Map[String, Int] = { + def specialDatesDays: Map[String, Int] = if (isSpark320OrLater) { + Map.empty + } else { val today = currentDate() Map( EPOCH -> 0, @@ -68,7 +73,9 @@ object DateUtils { ) } - def specialDatesSeconds: Map[String, Long] = { + def specialDatesSeconds: Map[String, Long] = if (isSpark320OrLater) { + Map.empty + } else { val today = currentDate() val now = DateTimeUtils.currentTimestamp() Map( @@ -80,7 +87,9 @@ object DateUtils { ) } - def specialDatesMicros: Map[String, Long] = { + def specialDatesMicros: Map[String, Long] = if (isSpark320OrLater) { + Map.empty + } else { val today = currentDate() val now = DateTimeUtils.currentTimestamp() Map( @@ -92,6 +101,23 @@ object DateUtils { ) } + def fetchSpecialDates(unit: DType): Map[String, () => Scalar] = unit match { + case DType.TIMESTAMP_DAYS => + DateUtils.specialDatesDays.map { case (k, v) => + k -> (() => Scalar.timestampDaysFromInt(v)) + } + case DType.TIMESTAMP_SECONDS => + DateUtils.specialDatesSeconds.map { case (k, v) => + k -> (() => Scalar.timestampFromLong(unit, v)) + } + case DType.TIMESTAMP_MICROSECONDS => + DateUtils.specialDatesMicros.map { case (k, v) => + k -> (() => Scalar.timestampFromLong(unit, v)) + } + case _ => + throw new IllegalArgumentException(s"unsupported DType: $unit") + } + def currentDate(): Int = localDateToDays(LocalDate.now()) case class FormatKeywordToReplace(word: String, startIndex: Int, endIndex: Int) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index ac2d79dac0f..9209125c43f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -19,6 +19,7 @@ package com.nvidia.spark.rapids import java.text.SimpleDateFormat import java.time.DateTimeException +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, Scalar} @@ -27,6 +28,7 @@ import com.nvidia.spark.rapids.RapidsPluginImplicits._ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.{Cast, CastBase, Expression, NullIntolerant, TimeZoneAwareExpression} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.rapids.GpuToTimestamp.replaceSpecialDates import org.apache.spark.sql.rapids.RegexReplace import org.apache.spark.sql.types._ @@ -788,27 +790,6 @@ object GpuCast extends Arm { } } - /** - * Replace special date strings such as "now" with timestampDays. This method does not - * close the `input` ColumnVector. - */ - def specialDateOr( - input: ColumnVector, - special: String, - value: Int, - orColumnVector: ColumnVector): ColumnVector = { - - withResource(orColumnVector) { other => - withResource(Scalar.fromString(special)) { str => - withResource(input.equalTo(str)) { isStr => - withResource(Scalar.timestampDaysFromInt(value)) { date => - isStr.ifElse(date, other) - } - } - } - } - } - /** This method does not close the `input` ColumnVector. */ def convertDateOrNull( input: ColumnVector, @@ -884,16 +865,21 @@ object GpuCast extends Arm { */ private def castStringToDate(sanitizedInput: ColumnVector): ColumnVector = { - val specialDates = DateUtils.specialDatesDays - // convert dates that are in valid formats yyyy, yyyy-mm, yyyy-mm-dd val converted = convertDateOr(sanitizedInput, DATE_REGEX_YYYY_MM_DD, "%Y-%m-%d", convertDateOr(sanitizedInput, DATE_REGEX_YYYY_MM, "%Y-%m", convertDateOrNull(sanitizedInput, DATE_REGEX_YYYY, "%Y"))) // handle special dates like "epoch", "now", etc. - specialDates.foldLeft(converted)((prev, specialDate) => - specialDateOr(sanitizedInput, specialDate._1, specialDate._2, prev)) + closeOnExcept(converted) { tsVector => + DateUtils.fetchSpecialDates(DType.TIMESTAMP_DAYS) match { + case specialDates if specialDates.nonEmpty => + // `tsVector` will be closed in replaceSpecialDates + replaceSpecialDates(sanitizedInput, tsVector, specialDates) + case _ => + tsVector + } + } } private def castStringToDateAnsi(input: ColumnVector, ansiMode: Boolean): ColumnVector = { @@ -908,27 +894,6 @@ object GpuCast extends Arm { } } - /** - * Replace special date strings such as "now" with timestampMicros. This method does not - * close the `input` ColumnVector. - */ - private def specialTimestampOr( - input: ColumnVector, - special: String, - value: Long, - orColumnVector: ColumnVector): ColumnVector = { - - withResource(orColumnVector) { other => - withResource(Scalar.fromString(special)) { str => - withResource(input.equalTo(str)) { isStr => - withResource(Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, value)) { date => - isStr.ifElse(date, other) - } - } - } - } - } - /** This method does not close the `input` ColumnVector. */ private def convertTimestampOrNull( input: ColumnVector, @@ -1009,7 +974,6 @@ object GpuCast extends Arm { val today = DateUtils.currentDate() val todayStr = new SimpleDateFormat("yyyy-MM-dd") .format(today * DateUtils.ONE_DAY_SECONDS * 1000L) - val specialDates = DateUtils.specialDatesMicros var sanitizedInput = input.incRefCount() @@ -1027,8 +991,15 @@ object GpuCast extends Arm { convertTimestampOrNull(sanitizedInput, TIMESTAMP_REGEX_YYYY, "%Y")))) // handle special dates like "epoch", "now", etc. - val finalResult = specialDates.foldLeft(converted)((prev, specialDate) => - specialTimestampOr(sanitizedInput, specialDate._1, specialDate._2, prev)) + val finalResult = closeOnExcept(converted) { tsVector => + DateUtils.fetchSpecialDates(DType.TIMESTAMP_MICROSECONDS) match { + case specialDates if specialDates.nonEmpty => + // `tsVector` will be closed in replaceSpecialDates. + replaceSpecialDates(sanitizedInput, tsVector, specialDates) + case _ => + tsVector + } + } if (ansiMode) { // When ANSI mode is enabled, we need to throw an exception if any values could not be diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/VersionUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/VersionUtils.scala new file mode 100644 index 00000000000..ec71c1405a1 --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/VersionUtils.scala @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +object VersionUtils { + + lazy val isSpark301OrLater: Boolean = cmpSparkVersion(3, 0, 1) >= 0 + + lazy val isSpark311OrLater: Boolean = cmpSparkVersion(3, 1, 1) >= 0 + + lazy val isSpark320OrLater: Boolean = cmpSparkVersion(3, 2, 0) >= 0 + + lazy val isSpark: Boolean = { + ShimLoader.getSparkShims.getSparkShimVersion.isInstanceOf[SparkShimVersion] + } + + lazy val isDataBricks: Boolean = { + ShimLoader.getSparkShims.getSparkShimVersion.isInstanceOf[DatabricksShimVersion] + } + + lazy val isCloudera: Boolean = { + ShimLoader.getSparkShims.getSparkShimVersion.isInstanceOf[ClouderaShimVersion] + } + + lazy val isEMR: Boolean = { + ShimLoader.getSparkShims.getSparkShimVersion.isInstanceOf[EMRShimVersion] + } + + def cmpSparkVersion(major: Int, minor: Int, bugfix: Int): Int = { + val sparkShimVersion = ShimLoader.getSparkShims.getSparkShimVersion + val (sparkMajor, sparkMinor, sparkBugfix) = sparkShimVersion match { + case SparkShimVersion(a, b, c) => (a, b, c) + case DatabricksShimVersion(a, b, c) => (a, b, c) + case ClouderaShimVersion(a, b, c, _) => (a, b, c) + case EMRShimVersion(a, b, c) => (a, b, c) + } + val fullVersion = ((major.toLong * 1000) + minor) * 1000 + bugfix + val sparkFullVersion = ((sparkMajor.toLong * 1000) + sparkMinor) * 1000 + sparkBugfix + sparkFullVersion.compareTo(fullVersion) + } +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala index 29d1b6bfd85..bf7b31816cf 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala @@ -18,8 +18,10 @@ package org.apache.spark.sql.rapids import java.util.concurrent.TimeUnit +import scala.collection.mutable + import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, Scalar} -import com.nvidia.spark.rapids.{Arm, BinaryExprMeta, DataFromReplacementRule, DateUtils, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, GpuScalar, GpuUnaryExpression, RapidsConf, RapidsMeta} +import com.nvidia.spark.rapids.{Arm, BinaryExprMeta, DataFromReplacementRule, DateUtils, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, GpuScalar, GpuUnaryExpression, RapidsConf, RapidsMeta, ShimLoader} import com.nvidia.spark.rapids.DateUtils.TimestampFormatConversionException import com.nvidia.spark.rapids.GpuOverrides.{extractStringLit, getTimeParserPolicy} import com.nvidia.spark.rapids.RapidsPluginImplicits._ @@ -508,20 +510,31 @@ object GpuToTimestamp extends Arm { val REMOVE_WHITESPACE_FROM_MONTH_DAY: RegexReplace = RegexReplace(raw"(\A\d+)-([ \t]*)(\d+)-([ \t]*)(\d+)", raw"\1-\3-\5") - def daysScalarSeconds(name: String): Scalar = { - Scalar.timestampFromLong(DType.TIMESTAMP_SECONDS, DateUtils.specialDatesSeconds(name)) - } - - def daysScalarMicros(name: String): Scalar = { - Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, DateUtils.specialDatesMicros(name)) - } - def daysEqual(col: ColumnVector, name: String): ColumnVector = { withResource(Scalar.fromString(name)) { scalarName => col.equalTo(scalarName) } } + /** + * Replace special date strings such as "now" with timestampDays. This method does not + * close the `stringVector`. + */ + def replaceSpecialDates( + stringVector: ColumnVector, + chronoVector: ColumnVector, + specialDates: Map[String, () => Scalar]): ColumnVector = { + specialDates.foldLeft(chronoVector) { case (buffer, (name, scalarBuilder)) => + withResource(buffer) { bufVector => + withResource(daysEqual(stringVector, name)) { isMatch => + withResource(scalarBuilder()) { scalar => + isMatch.ifElse(scalar, bufVector) + } + } + } + } + } + def isTimestamp(col: ColumnVector, sparkFormat: String, strfFormat: String) : ColumnVector = { CORRECTED_COMPATIBLE_FORMATS.get(sparkFormat) match { case Some(fmt) => @@ -546,50 +559,27 @@ object GpuToTimestamp extends Arm { lhs: GpuColumnVector, sparkFormat: String, strfFormat: String, - dtype: DType, - daysScalar: String => Scalar, - asTimestamp: (ColumnVector, String) => ColumnVector): ColumnVector = { + dtype: DType): ColumnVector = { + + // `tsVector` will be closed in replaceSpecialDates + val tsVector = withResource(isTimestamp(lhs.getBase, sparkFormat, strfFormat)) { isTs => + withResource(Scalar.fromNull(dtype)) { nullValue => + withResource(lhs.getBase.asTimestamp(dtype, strfFormat)) { tsVec => + isTs.ifElse(tsVec, nullValue) + } + } + } // in addition to date/timestamp strings, we also need to check for special dates and null // values, since anything else is invalid and should throw an error or be converted to null // depending on the policy - withResource(isTimestamp(lhs.getBase, sparkFormat, strfFormat)) { isTimestamp => - withResource(daysEqual(lhs.getBase, DateUtils.EPOCH)) { isEpoch => - withResource(daysEqual(lhs.getBase, DateUtils.NOW)) { isNow => - withResource(daysEqual(lhs.getBase, DateUtils.TODAY)) { isToday => - withResource(daysEqual(lhs.getBase, DateUtils.YESTERDAY)) { isYesterday => - withResource(daysEqual(lhs.getBase, DateUtils.TOMORROW)) { isTomorrow => - withResource(lhs.getBase.isNull) { _ => - withResource(Scalar.fromNull(dtype)) { nullValue => - withResource(asTimestamp(lhs.getBase, strfFormat)) { converted => - withResource(daysScalar(DateUtils.EPOCH)) { epoch => - withResource(daysScalar(DateUtils.NOW)) { now => - withResource(daysScalar(DateUtils.TODAY)) { today => - withResource(daysScalar(DateUtils.YESTERDAY)) { yesterday => - withResource(daysScalar(DateUtils.TOMORROW)) { tomorrow => - withResource(isTomorrow.ifElse(tomorrow, nullValue)) { a => - withResource(isYesterday.ifElse(yesterday, a)) { b => - withResource(isToday.ifElse(today, b)) { c => - withResource(isNow.ifElse(now, c)) { d => - withResource(isEpoch.ifElse(epoch, d)) { e => - isTimestamp.ifElse(converted, e) - } - } - } - } - } - } - } - } - } - } - } - } - } - } - } - } - } + closeOnExcept(tsVector) { tsVector => + DateUtils.fetchSpecialDates(dtype) match { + case specialDates if specialDates.nonEmpty => + // `tsVector` will be closed in replaceSpecialDates + replaceSpecialDates(lhs.getBase, tsVector, specialDates) + case _ => + tsVector } } } @@ -728,9 +718,7 @@ abstract class GpuToTimestamp lhs, sparkFormat, strfFormat, - DType.TIMESTAMP_MICROSECONDS, - daysScalarMicros, - (col, strfFormat) => col.asTimestampMicroseconds(strfFormat)) + DType.TIMESTAMP_MICROSECONDS) } } else { // Timestamp or DateType lhs.getBase.asTimestampMicroseconds() @@ -779,9 +767,7 @@ abstract class GpuToTimestampImproved extends GpuToTimestamp { lhs, sparkFormat, strfFormat, - DType.TIMESTAMP_SECONDS, - daysScalarSeconds, - (col, strfFormat) => col.asTimestampSeconds(strfFormat)) + DType.TIMESTAMP_SECONDS) } } else if (lhs.dataType() == DateType){ lhs.getBase.asTimestampSeconds() diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ParseDateTimeSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ParseDateTimeSuite.scala index 2ec57f29c79..d82b790fdb7 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/ParseDateTimeSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/ParseDateTimeSuite.scala @@ -229,11 +229,14 @@ class ParseDateTimeSuite extends SparkQueryCompareTestSuite with BeforeAndAfterE .repartition(2) .withColumn("c1", unix_timestamp(col("c0"), "yyyy-MM-dd HH:mm:ss")) } - val startTimeSeconds = System.currentTimeMillis()/1000L + val startTimeSeconds = System.currentTimeMillis() / 1000L val cpuNowSeconds = withCpuSparkSession(now).collect().head.toSeq(1).asInstanceOf[Long] val gpuNowSeconds = withGpuSparkSession(now).collect().head.toSeq(1).asInstanceOf[Long] - assert(cpuNowSeconds >= startTimeSeconds) - assert(gpuNowSeconds >= startTimeSeconds) + // For Spark 3.2+, "now" will NOT be parsed as the current time + if (!VersionUtils.isSpark320OrLater) { + assert(cpuNowSeconds >= startTimeSeconds) + assert(gpuNowSeconds >= startTimeSeconds) + } // CPU ran first so cannot have a greater value than the GPU run (but could be the same second) assert(cpuNowSeconds <= gpuNowSeconds) } @@ -408,4 +411,3 @@ class ParseDateTimeSuite extends SparkQueryCompareTestSuite with BeforeAndAfterE ) } - diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ScalarSubquerySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ScalarSubquerySuite.scala index 48b04ee1e54..0522e359d66 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/ScalarSubquerySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/ScalarSubquerySuite.scala @@ -35,7 +35,7 @@ class ScalarSubquerySuite extends SparkQueryCompareTestSuite { // In Spark 3.2.0+ canonicalization for ScalarSubquery was fixed. Because it is a bug fix // we have fixed it on our end for all versions of Spark, but the canonicalization check // only works if both have the fix or both don't have it. - skipCanonicalizationCheck = isPriorToSpark320) { + skipCanonicalizationCheck = !VersionUtils.isSpark320OrLater) { frame => { frame.createOrReplaceTempView("table") val ret = frame.sparkSession.sql( diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala index f1922389ddc..e94fddca8fa 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala @@ -1837,20 +1837,16 @@ trait SparkQueryCompareTestSuite extends FunSuite with Arm { /** most of the AQE tests requires Spark 3.0.1 or later */ def assumeSpark301orLater: Assertion = - assume(cmpSparkVersion(3, 0, 1) >= 0, "Spark version not 3.0.1+") + assume(VersionUtils.isSpark301OrLater, "Spark version not 3.0.1+") def assumeSpark311orLater: Assertion = - assume(cmpSparkVersion(3, 1, 1) >= 0, "Spark version not 3.1.1+") + assume(VersionUtils.isSpark311OrLater, "Spark version not 3.1.1+") def assumePriorToSpark320: Assertion = - assume(isPriorToSpark320, "Spark version not before 3.2.0") - - def isPriorToSpark320: Boolean = cmpSparkVersion(3, 2, 0) < 0 + assume(!VersionUtils.isSpark320OrLater, "Spark version not before 3.2.0") def assumeSpark320orLater: Assertion = - assume(isSpark320OrLater, "Spark version not 3.2.0+") - - def isSpark320OrLater: Boolean = cmpSparkVersion(3, 2, 0) >= 0 + assume(VersionUtils.isSpark320OrLater, "Spark version not 3.2.0+") def cmpSparkVersion(major: Int, minor: Int, bugfix: Int): Int = { val sparkShimVersion = ShimLoader.getSparkShims.getSparkShimVersion