From 33637dd5b1c0b13b5829828654242204a6ffd266 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Fri, 10 Sep 2021 10:55:15 +0800 Subject: [PATCH 1/9] init Signed-off-by: sperlingxx --- .../com/nvidia/spark/rapids/GpuCast.scala | 86 +++++++++++++------ .../sql/rapids/datetimeExpressions.scala | 23 +++-- .../spark/rapids/ParseDateTimeSuite.scala | 5 +- 3 files changed, 78 insertions(+), 36 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index 6b2e43140ef..70d0ce492c8 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -27,6 +27,7 @@ import com.nvidia.spark.rapids.RapidsPluginImplicits._ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.{Cast, CastBase, Expression, NullIntolerant, TimeZoneAwareExpression} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.rapids.GpuToTimestamp.{daysEqual, daysScalarDays, daysScalarMicros} import org.apache.spark.sql.rapids.RegexReplace import org.apache.spark.sql.types._ @@ -805,27 +806,6 @@ object GpuCast extends Arm { } } - /** - * Replace special date strings such as "now" with timestampDays. This method does not - * close the `input` ColumnVector. - */ - def specialDateOr( - input: ColumnVector, - special: String, - value: Int, - orColumnVector: ColumnVector): ColumnVector = { - - withResource(orColumnVector) { other => - withResource(Scalar.fromString(special)) { str => - withResource(input.equalTo(str)) { isStr => - withResource(Scalar.timestampDaysFromInt(value)) { date => - isStr.ifElse(date, other) - } - } - } - } - } - /** * Parse dates that match the provided length and format. This method does not * close the `input` ColumnVector. @@ -928,8 +908,6 @@ object GpuCast extends Arm { */ private def castStringToDate(input: ColumnVector): ColumnVector = { - val specialDates = DateUtils.specialDatesDays - withResource(sanitizeStringToDate(input)) { sanitizedInput => // convert dates that are in valid formats yyyy, yyyy-mm, yyyy-mm-dd @@ -938,8 +916,35 @@ object GpuCast extends Arm { convertFixedLenDateOrNull(sanitizedInput, 4, "%Y"))) // handle special dates like "epoch", "now", etc. - specialDates.foldLeft(converted)((prev, specialDate) => - specialDateOr(sanitizedInput, specialDate._1, specialDate._2, prev)) + withResource(daysEqual(sanitizedInput, DateUtils.EPOCH)) { isEpoch => + withResource(daysEqual(sanitizedInput, DateUtils.NOW)) { isNow => + withResource(daysEqual(sanitizedInput, DateUtils.TODAY)) { isToday => + withResource(daysEqual(sanitizedInput, DateUtils.YESTERDAY)) { isYesterday => + withResource(daysEqual(sanitizedInput, DateUtils.TOMORROW)) { isTomorrow => + withResource(daysScalarDays(DateUtils.EPOCH)) { epoch => + withResource(daysScalarDays(DateUtils.NOW)) { now => + withResource(daysScalarDays(DateUtils.TODAY)) { today => + withResource(daysScalarDays(DateUtils.YESTERDAY)) { yesterday => + withResource(daysScalarDays(DateUtils.TOMORROW)) { tomorrow => + withResource(isTomorrow.ifElse(tomorrow, converted)) { a => + withResource(isYesterday.ifElse(yesterday, a)) { b => + withResource(isToday.ifElse(today, b)) { c => + withResource(isNow.ifElse(now, c)) { d => + isEpoch.ifElse(epoch, d) + } + } + } + } + } + } + } + } + } + } + } + } + } + } } } @@ -1093,8 +1098,35 @@ object GpuCast extends Arm { convertFixedLenTimestampOrNull(sanitizedInput, 4, "%Y")))) // handle special dates like "epoch", "now", etc. - val finalResult = specialDates.foldLeft(converted)((prev, specialDate) => - specialTimestampOr(sanitizedInput, specialDate._1, specialDate._2, prev)) + val finalResult = withResource(daysEqual(sanitizedInput, DateUtils.EPOCH)) { isEpoch => + withResource(daysEqual(sanitizedInput, DateUtils.NOW)) { isNow => + withResource(daysEqual(sanitizedInput, DateUtils.TODAY)) { isToday => + withResource(daysEqual(sanitizedInput, DateUtils.YESTERDAY)) { isYesterday => + withResource(daysEqual(sanitizedInput, DateUtils.TOMORROW)) { isTomorrow => + withResource(daysScalarMicros(DateUtils.EPOCH)) { epoch => + withResource(daysScalarMicros(DateUtils.NOW)) { now => + withResource(daysScalarMicros(DateUtils.TODAY)) { today => + withResource(daysScalarMicros(DateUtils.YESTERDAY)) { yesterday => + withResource(daysScalarMicros(DateUtils.TOMORROW)) { tomorrow => + withResource(isTomorrow.ifElse(tomorrow, converted)) { a => + withResource(isYesterday.ifElse(yesterday, a)) { b => + withResource(isToday.ifElse(today, b)) { c => + withResource(isNow.ifElse(now, c)) { d => + isEpoch.ifElse(epoch, d) + } + } + } + } + } + } + } + } + } + } + } + } + } + } if (ansiMode) { // When ANSI mode is enabled, we need to throw an exception if any values could not be diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala index e86116d59db..d7cd1138c62 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.rapids import java.util.concurrent.TimeUnit import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, Scalar} -import com.nvidia.spark.rapids.{Arm, BinaryExprMeta, DataFromReplacementRule, DateUtils, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, GpuScalar, GpuUnaryExpression, RapidsConf, RapidsMeta} +import com.nvidia.spark.rapids.{Arm, BinaryExprMeta, DataFromReplacementRule, DateUtils, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, GpuScalar, GpuUnaryExpression, RapidsConf, RapidsMeta, ShimLoader} import com.nvidia.spark.rapids.DateUtils.TimestampFormatConversionException import com.nvidia.spark.rapids.GpuOverrides.{extractStringLit, getTimeParserPolicy} import com.nvidia.spark.rapids.RapidsPluginImplicits._ @@ -531,12 +531,25 @@ object GpuToTimestamp extends Arm { FIX_SINGLE_DIGIT_SECOND ) - def daysScalarSeconds(name: String): Scalar = { - Scalar.timestampFromLong(DType.TIMESTAMP_SECONDS, DateUtils.specialDatesSeconds(name)) + def daysScalarDays(name: String): Scalar = ShimLoader.getSparkVersion match { + case version if version.startsWith("3.2") => + Scalar.fromNull(DType.TIMESTAMP_DAYS) + case _ => + Scalar.timestampFromLong(DType.TIMESTAMP_DAYS, DateUtils.specialDatesDays(name)) } - def daysScalarMicros(name: String): Scalar = { - Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, DateUtils.specialDatesMicros(name)) + def daysScalarSeconds(name: String): Scalar = ShimLoader.getSparkVersion match { + case version if version.startsWith("3.2") => + Scalar.fromNull(DType.TIMESTAMP_SECONDS) + case _ => + Scalar.timestampFromLong(DType.TIMESTAMP_SECONDS, DateUtils.specialDatesSeconds(name)) + } + + def daysScalarMicros(name: String): Scalar = ShimLoader.getSparkVersion match { + case version if version.startsWith("3.2") => + Scalar.fromNull(DType.TIMESTAMP_MICROSECONDS) + case _ => + Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, DateUtils.specialDatesMicros(name)) } def daysEqual(col: ColumnVector, name: String): ColumnVector = { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ParseDateTimeSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ParseDateTimeSuite.scala index 63d2c8d6b6f..c203a6e0892 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/ParseDateTimeSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/ParseDateTimeSuite.scala @@ -221,18 +221,15 @@ class ParseDateTimeSuite extends SparkQueryCompareTestSuite with BeforeAndAfterE assert(!planStr.contains(RapidsConf.INCOMPATIBLE_DATE_FORMATS.key)) } - test("parse now") { + test("parse now", org.scalatest.Tag("111")) { def now(spark: SparkSession) = { import spark.implicits._ Seq("now").toDF("c0") .repartition(2) .withColumn("c1", unix_timestamp(col("c0"), "yyyy-MM-dd HH:mm:ss")) } - val startTimeSeconds = System.currentTimeMillis()/1000L val cpuNowSeconds = withCpuSparkSession(now).collect().head.toSeq(1).asInstanceOf[Long] val gpuNowSeconds = withGpuSparkSession(now).collect().head.toSeq(1).asInstanceOf[Long] - assert(cpuNowSeconds >= startTimeSeconds) - assert(gpuNowSeconds >= startTimeSeconds) // CPU ran first so cannot have a greater value than the GPU run (but could be the same second) assert(cpuNowSeconds <= gpuNowSeconds) } From 000ce393fd549dd5a09fed679ac82dc039902b90 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Fri, 10 Sep 2021 13:34:57 +0800 Subject: [PATCH 2/9] update Signed-off-by: sperlingxx --- .../spark/sql/rapids/datetimeExpressions.scala | 12 +++++++++--- .../com/nvidia/spark/rapids/ParseDateTimeSuite.scala | 3 +-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala index d7cd1138c62..914262ec506 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala @@ -532,21 +532,27 @@ object GpuToTimestamp extends Arm { ) def daysScalarDays(name: String): Scalar = ShimLoader.getSparkVersion match { - case version if version.startsWith("3.2") => + // In Spark 3.2, special datetime values such as `epoch`, `today`, `yesterday`, `tomorrow`, + // and `now` are supported in typed literals only + case version if version >= "3.2" => Scalar.fromNull(DType.TIMESTAMP_DAYS) case _ => Scalar.timestampFromLong(DType.TIMESTAMP_DAYS, DateUtils.specialDatesDays(name)) } def daysScalarSeconds(name: String): Scalar = ShimLoader.getSparkVersion match { - case version if version.startsWith("3.2") => + // In Spark 3.2, special datetime values such as `epoch`, `today`, `yesterday`, `tomorrow`, + // and `now` are supported in typed literals only + case version if version >= "3.2" => Scalar.fromNull(DType.TIMESTAMP_SECONDS) case _ => Scalar.timestampFromLong(DType.TIMESTAMP_SECONDS, DateUtils.specialDatesSeconds(name)) } def daysScalarMicros(name: String): Scalar = ShimLoader.getSparkVersion match { - case version if version.startsWith("3.2") => + // In Spark 3.2, special datetime values such as `epoch`, `today`, `yesterday`, `tomorrow`, + // and `now` are supported in typed literals only + case version if version >= "3.2" => Scalar.fromNull(DType.TIMESTAMP_MICROSECONDS) case _ => Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, DateUtils.specialDatesMicros(name)) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ParseDateTimeSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ParseDateTimeSuite.scala index c203a6e0892..85d80c66eee 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/ParseDateTimeSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/ParseDateTimeSuite.scala @@ -221,7 +221,7 @@ class ParseDateTimeSuite extends SparkQueryCompareTestSuite with BeforeAndAfterE assert(!planStr.contains(RapidsConf.INCOMPATIBLE_DATE_FORMATS.key)) } - test("parse now", org.scalatest.Tag("111")) { + test("parse now") { def now(spark: SparkSession) = { import spark.implicits._ Seq("now").toDF("c0") @@ -518,4 +518,3 @@ class ParseDateTimeSuite extends SparkQueryCompareTestSuite with BeforeAndAfterE ) } - From d87804eec7e26b843569f021f7fc64de4fba2c00 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Fri, 10 Sep 2021 13:37:25 +0800 Subject: [PATCH 3/9] update Signed-off-by: sperlingxx --- .../com/nvidia/spark/rapids/GpuCast.scala | 21 ------------------- 1 file changed, 21 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index 70d0ce492c8..43070f25e86 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -960,27 +960,6 @@ object GpuCast extends Arm { } } - /** - * Replace special date strings such as "now" with timestampMicros. This method does not - * close the `input` ColumnVector. - */ - private def specialTimestampOr( - input: ColumnVector, - special: String, - value: Long, - orColumnVector: ColumnVector): ColumnVector = { - - withResource(orColumnVector) { other => - withResource(Scalar.fromString(special)) { str => - withResource(input.equalTo(str)) { isStr => - withResource(Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, value)) { date => - isStr.ifElse(date, other) - } - } - } - } - } - /** * Parse dates that match the the provided regex. This method does not close the `input` * ColumnVector. From a8fc23456f0ccbffe479c67d1f562b2a80f2fd35 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Tue, 14 Sep 2021 15:59:15 +0800 Subject: [PATCH 4/9] refactored Signed-off-by: sperlingxx --- .../spark/rapids/shims/v2/Spark30XShims.scala | 14 ++- .../spark/rapids/shims/v2/Spark32XShims.scala | 7 ++ .../com/nvidia/spark/rapids/GpuCast.scala | 79 ++++-------- .../com/nvidia/spark/rapids/SparkShims.scala | 3 + .../sql/rapids/datetimeExpressions.scala | 114 +++++++----------- 5 files changed, 87 insertions(+), 130 deletions(-) diff --git a/sql-plugin/src/main/301+-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala b/sql-plugin/src/main/301+-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala index 884e63e0eb5..e32d6f58a2b 100644 --- a/sql-plugin/src/main/301+-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala +++ b/sql-plugin/src/main/301+-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala @@ -16,7 +16,8 @@ package com.nvidia.spark.rapids.shims.v2 -import com.nvidia.spark.rapids.{ExecChecks, ExecRule, GpuExec, SparkPlanMeta, SparkShims, TypeSig} +import ai.rapids.cudf.{DType, Scalar} +import com.nvidia.spark.rapids.{DateUtils, ExecChecks, ExecRule, GpuExec, SparkPlanMeta, SparkShims, TypeSig} import com.nvidia.spark.rapids.GpuOverrides.exec import org.apache.hadoop.fs.FileStatus @@ -103,4 +104,15 @@ trait Spark30XShims extends SparkShims { ss.sparkContext.defaultParallelism } + override def getSpecialDate(name: String, unit: DType): Scalar = unit match { + case DType.TIMESTAMP_DAYS => + Scalar.timestampDaysFromInt(DateUtils.specialDatesDays(name)) + case DType.TIMESTAMP_SECONDS => + Scalar.timestampFromLong(DType.TIMESTAMP_SECONDS, DateUtils.specialDatesSeconds(name)) + case DType.TIMESTAMP_MICROSECONDS => + Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, DateUtils.specialDatesMicros(name)) + case _ => + throw new IllegalArgumentException(s"unsupported DType: $unit") + } + } diff --git a/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala b/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala index 3e5d74f5ee0..77e838621ed 100644 --- a/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala +++ b/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala @@ -16,6 +16,7 @@ package com.nvidia.spark.rapids.shims.v2 +import ai.rapids.cudf.{DType, Scalar} import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.GpuOverrides.exec import org.apache.hadoop.fs.FileStatus @@ -137,6 +138,12 @@ trait Spark32XShims extends SparkShims { Spark32XShimsUtils.leafNodeDefaultParallelism(ss) } + override def getSpecialDate(name: String, unit: DType): Scalar = unit match { + case DType.TIMESTAMP_DAYS => Scalar.fromNull(DType.TIMESTAMP_DAYS) + case DType.TIMESTAMP_SECONDS => Scalar.fromNull(DType.TIMESTAMP_SECONDS) + case DType.TIMESTAMP_MICROSECONDS => Scalar.fromNull(DType.TIMESTAMP_MICROSECONDS) + case _ => throw new IllegalArgumentException(s"unsupported DType: $unit") + } } // TODO dedupe utils inside shims diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index 414e9c8523a..861215027e2 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -19,6 +19,7 @@ package com.nvidia.spark.rapids import java.text.SimpleDateFormat import java.time.DateTimeException +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, Scalar} @@ -27,7 +28,7 @@ import com.nvidia.spark.rapids.RapidsPluginImplicits._ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.{Cast, CastBase, Expression, NullIntolerant, TimeZoneAwareExpression} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.rapids.GpuToTimestamp.{daysEqual, daysScalarDays, daysScalarMicros} +import org.apache.spark.sql.rapids.GpuToTimestamp.replaceSpecialDates import org.apache.spark.sql.rapids.RegexReplace import org.apache.spark.sql.types._ @@ -916,33 +917,16 @@ object GpuCast extends Arm { convertFixedLenDateOrNull(sanitizedInput, 4, "%Y"))) // handle special dates like "epoch", "now", etc. - withResource(daysEqual(sanitizedInput, DateUtils.EPOCH)) { isEpoch => - withResource(daysEqual(sanitizedInput, DateUtils.NOW)) { isNow => - withResource(daysEqual(sanitizedInput, DateUtils.TODAY)) { isToday => - withResource(daysEqual(sanitizedInput, DateUtils.YESTERDAY)) { isYesterday => - withResource(daysEqual(sanitizedInput, DateUtils.TOMORROW)) { isTomorrow => - withResource(daysScalarDays(DateUtils.EPOCH)) { epoch => - withResource(daysScalarDays(DateUtils.NOW)) { now => - withResource(daysScalarDays(DateUtils.TODAY)) { today => - withResource(daysScalarDays(DateUtils.YESTERDAY)) { yesterday => - withResource(daysScalarDays(DateUtils.TOMORROW)) { tomorrow => - withResource(isTomorrow.ifElse(tomorrow, converted)) { a => - withResource(isYesterday.ifElse(yesterday, a)) { b => - withResource(isToday.ifElse(today, b)) { c => - withResource(isNow.ifElse(now, c)) { d => - isEpoch.ifElse(epoch, d) - } - } - } - } - } - } - } - } - } - } - } - } + // `converted` will be closed in replaceSpecialDates. We wrap it with closeOnExcept in case + // of exception before replaceSpecialDates. + closeOnExcept(converted) { timeStampVector => + val specialDates = Seq(DateUtils.EPOCH, DateUtils.NOW, DateUtils.TODAY, + DateUtils.YESTERDAY, DateUtils.TOMORROW) + val specialValues = mutable.ListBuffer.empty[Scalar] + withResource(specialValues) { _ => + specialDates.foreach( + specialValues += ShimLoader.getSparkShims.getSpecialDate(_, DType.TIMESTAMP_DAYS)) + replaceSpecialDates(sanitizedInput, timeStampVector, specialDates, specialValues) } } } @@ -1043,7 +1027,6 @@ object GpuCast extends Arm { val today = DateUtils.currentDate() val todayStr = new SimpleDateFormat("yyyy-MM-dd") .format(today * DateUtils.ONE_DAY_SECONDS * 1000L) - val specialDates = DateUtils.specialDatesMicros var sanitizedInput = input.incRefCount() @@ -1077,33 +1060,17 @@ object GpuCast extends Arm { convertFixedLenTimestampOrNull(sanitizedInput, 4, "%Y")))) // handle special dates like "epoch", "now", etc. - val finalResult = withResource(daysEqual(sanitizedInput, DateUtils.EPOCH)) { isEpoch => - withResource(daysEqual(sanitizedInput, DateUtils.NOW)) { isNow => - withResource(daysEqual(sanitizedInput, DateUtils.TODAY)) { isToday => - withResource(daysEqual(sanitizedInput, DateUtils.YESTERDAY)) { isYesterday => - withResource(daysEqual(sanitizedInput, DateUtils.TOMORROW)) { isTomorrow => - withResource(daysScalarMicros(DateUtils.EPOCH)) { epoch => - withResource(daysScalarMicros(DateUtils.NOW)) { now => - withResource(daysScalarMicros(DateUtils.TODAY)) { today => - withResource(daysScalarMicros(DateUtils.YESTERDAY)) { yesterday => - withResource(daysScalarMicros(DateUtils.TOMORROW)) { tomorrow => - withResource(isTomorrow.ifElse(tomorrow, converted)) { a => - withResource(isYesterday.ifElse(yesterday, a)) { b => - withResource(isToday.ifElse(today, b)) { c => - withResource(isNow.ifElse(now, c)) { d => - isEpoch.ifElse(epoch, d) - } - } - } - } - } - } - } - } - } - } - } - } + // `converted` will be closed in replaceSpecialDates. We wrap it with closeOnExcept in case + // of exception before replaceSpecialDates. + val finalResult = closeOnExcept(converted) { timeStampVector => + val specialDates = Seq(DateUtils.EPOCH, DateUtils.NOW, DateUtils.TODAY, + DateUtils.YESTERDAY, DateUtils.TOMORROW) + val specialValues = mutable.ListBuffer.empty[Scalar] + withResource(specialValues) { _ => + specialDates.foreach( + specialValues += + ShimLoader.getSparkShims.getSpecialDate(_, DType.TIMESTAMP_MICROSECONDS)) + replaceSpecialDates(sanitizedInput, timeStampVector, specialDates, specialValues) } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala index 5c751883a17..262a90fc971 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala @@ -19,6 +19,7 @@ package com.nvidia.spark.rapids import java.net.URI import java.nio.ByteBuffer +import ai.rapids.cudf.{DType, Scalar} import org.apache.arrow.memory.ReferenceManager import org.apache.arrow.vector.ValueVector import org.apache.hadoop.fs.{FileStatus, Path} @@ -256,6 +257,8 @@ trait SparkShims { def aqeShuffleReaderExec: ExecRule[_ <: SparkPlan] def leafNodeDefaultParallelism(ss: SparkSession): Int + + def getSpecialDate(name: String, unit: DType): Scalar } abstract class SparkCommonShims extends SparkShims { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala index 914262ec506..c6db354d1e5 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.rapids import java.util.concurrent.TimeUnit +import scala.collection.mutable + import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, Scalar} import com.nvidia.spark.rapids.{Arm, BinaryExprMeta, DataFromReplacementRule, DateUtils, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, GpuScalar, GpuUnaryExpression, RapidsConf, RapidsMeta, ShimLoader} import com.nvidia.spark.rapids.DateUtils.TimestampFormatConversionException @@ -531,39 +533,30 @@ object GpuToTimestamp extends Arm { FIX_SINGLE_DIGIT_SECOND ) - def daysScalarDays(name: String): Scalar = ShimLoader.getSparkVersion match { - // In Spark 3.2, special datetime values such as `epoch`, `today`, `yesterday`, `tomorrow`, - // and `now` are supported in typed literals only - case version if version >= "3.2" => - Scalar.fromNull(DType.TIMESTAMP_DAYS) - case _ => - Scalar.timestampFromLong(DType.TIMESTAMP_DAYS, DateUtils.specialDatesDays(name)) - } - - def daysScalarSeconds(name: String): Scalar = ShimLoader.getSparkVersion match { - // In Spark 3.2, special datetime values such as `epoch`, `today`, `yesterday`, `tomorrow`, - // and `now` are supported in typed literals only - case version if version >= "3.2" => - Scalar.fromNull(DType.TIMESTAMP_SECONDS) - case _ => - Scalar.timestampFromLong(DType.TIMESTAMP_SECONDS, DateUtils.specialDatesSeconds(name)) - } - - def daysScalarMicros(name: String): Scalar = ShimLoader.getSparkVersion match { - // In Spark 3.2, special datetime values such as `epoch`, `today`, `yesterday`, `tomorrow`, - // and `now` are supported in typed literals only - case version if version >= "3.2" => - Scalar.fromNull(DType.TIMESTAMP_MICROSECONDS) - case _ => - Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, DateUtils.specialDatesMicros(name)) - } - def daysEqual(col: ColumnVector, name: String): ColumnVector = { withResource(Scalar.fromString(name)) { scalarName => col.equalTo(scalarName) } } + /** + * Replace special date strings such as "now" with timestampDays. This method does not + * close the `stringVector` and `specialValues`. + */ + def replaceSpecialDates( + stringVector: ColumnVector, + chronoVector: ColumnVector, + specialNames: Seq[String], + specialValues: Seq[Scalar]): ColumnVector = { + specialValues.zip(specialNames).foldLeft(chronoVector) { case (buffer, (scalar, name)) => + withResource(buffer) { bufVector => + withResource(daysEqual(stringVector, name)) { isMatch => + isMatch.ifElse(scalar, bufVector) + } + } + } + } + def isTimestamp(col: ColumnVector, sparkFormat: String, strfFormat: String) : ColumnVector = { if (CORRECTED_COMPATIBLE_FORMATS.contains(sparkFormat)) { // the cuDF `is_timestamp` function is less restrictive than Spark's behavior for UnixTime @@ -591,50 +584,29 @@ object GpuToTimestamp extends Arm { lhs: GpuColumnVector, sparkFormat: String, strfFormat: String, - dtype: DType, - daysScalar: String => Scalar, - asTimestamp: (ColumnVector, String) => ColumnVector): ColumnVector = { + dtype: DType): ColumnVector = { + + // `tsVector` will be closed in replaceSpecialDates + val tsVector = withResource(isTimestamp(lhs.getBase, sparkFormat, strfFormat)) { isTs => + withResource(Scalar.fromNull(dtype)) { nullValue => + withResource(lhs.getBase.asTimestamp(dtype, strfFormat)) { tsVec => + isTs.ifElse(tsVec, nullValue) + } + } + } // in addition to date/timestamp strings, we also need to check for special dates and null // values, since anything else is invalid and should throw an error or be converted to null // depending on the policy - withResource(isTimestamp(lhs.getBase, sparkFormat, strfFormat)) { isTimestamp => - withResource(daysEqual(lhs.getBase, DateUtils.EPOCH)) { isEpoch => - withResource(daysEqual(lhs.getBase, DateUtils.NOW)) { isNow => - withResource(daysEqual(lhs.getBase, DateUtils.TODAY)) { isToday => - withResource(daysEqual(lhs.getBase, DateUtils.YESTERDAY)) { isYesterday => - withResource(daysEqual(lhs.getBase, DateUtils.TOMORROW)) { isTomorrow => - withResource(lhs.getBase.isNull) { _ => - withResource(Scalar.fromNull(dtype)) { nullValue => - withResource(asTimestamp(lhs.getBase, strfFormat)) { converted => - withResource(daysScalar(DateUtils.EPOCH)) { epoch => - withResource(daysScalar(DateUtils.NOW)) { now => - withResource(daysScalar(DateUtils.TODAY)) { today => - withResource(daysScalar(DateUtils.YESTERDAY)) { yesterday => - withResource(daysScalar(DateUtils.TOMORROW)) { tomorrow => - withResource(isTomorrow.ifElse(tomorrow, nullValue)) { a => - withResource(isYesterday.ifElse(yesterday, a)) { b => - withResource(isToday.ifElse(today, b)) { c => - withResource(isNow.ifElse(now, c)) { d => - withResource(isEpoch.ifElse(epoch, d)) { e => - isTimestamp.ifElse(converted, e) - } - } - } - } - } - } - } - } - } - } - } - } - } - } - } - } - } + val specialDates = Seq(DateUtils.EPOCH, DateUtils.NOW, DateUtils.TODAY, + DateUtils.YESTERDAY, DateUtils.TOMORROW) + val specialValues = mutable.ListBuffer.empty[Scalar] + + withResource(specialValues) { _ => + closeOnExcept(tsVector) { _ => + specialDates.foreach( + specialValues += ShimLoader.getSparkShims.getSpecialDate(_, dtype)) + replaceSpecialDates(lhs.getBase, tsVector, specialDates, specialValues) } } } @@ -779,9 +751,7 @@ abstract class GpuToTimestamp lhs, sparkFormat, strfFormat, - DType.TIMESTAMP_MICROSECONDS, - daysScalarMicros, - (col, strfFormat) => col.asTimestampMicroseconds(strfFormat)) + DType.TIMESTAMP_MICROSECONDS) } } else { // Timestamp or DateType lhs.getBase.asTimestampMicroseconds() @@ -830,9 +800,7 @@ abstract class GpuToTimestampImproved extends GpuToTimestamp { lhs, sparkFormat, strfFormat, - DType.TIMESTAMP_SECONDS, - daysScalarSeconds, - (col, strfFormat) => col.asTimestampSeconds(strfFormat)) + DType.TIMESTAMP_SECONDS) } } else if (lhs.dataType() == DateType){ lhs.getBase.asTimestampSeconds() From fec39d52a8e17afe6c306d7a65dbf8ca7d6b9c86 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Thu, 16 Sep 2021 20:49:27 +0800 Subject: [PATCH 5/9] fix Signed-off-by: sperlingxx --- .../spark/rapids/shims/v2/Spark30XShims.scala | 11 -- .../spark/rapids/shims/v2/Spark32XShims.scala | 6 - .../com/nvidia/spark/rapids/DateUtils.scala | 29 ++++- .../com/nvidia/spark/rapids/GpuCast.scala | 103 ++++-------------- .../com/nvidia/spark/rapids/SparkShims.scala | 6 - .../sql/rapids/datetimeExpressions.scala | 19 ++-- 6 files changed, 58 insertions(+), 116 deletions(-) diff --git a/sql-plugin/src/main/301until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala b/sql-plugin/src/main/301until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala index e49a5843b1c..670748110f9 100644 --- a/sql-plugin/src/main/301until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala +++ b/sql-plugin/src/main/301until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala @@ -127,15 +127,4 @@ trait Spark30XShims extends SparkShims { ss.sparkContext.defaultParallelism } - override def getSpecialDate(name: String, unit: DType): Scalar = unit match { - case DType.TIMESTAMP_DAYS => - Scalar.timestampDaysFromInt(DateUtils.specialDatesDays(name)) - case DType.TIMESTAMP_SECONDS => - Scalar.timestampFromLong(DType.TIMESTAMP_SECONDS, DateUtils.specialDatesSeconds(name)) - case DType.TIMESTAMP_MICROSECONDS => - Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, DateUtils.specialDatesMicros(name)) - case _ => - throw new IllegalArgumentException(s"unsupported DType: $unit") - } - } diff --git a/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala b/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala index 77e838621ed..ed33b91470d 100644 --- a/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala +++ b/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala @@ -138,12 +138,6 @@ trait Spark32XShims extends SparkShims { Spark32XShimsUtils.leafNodeDefaultParallelism(ss) } - override def getSpecialDate(name: String, unit: DType): Scalar = unit match { - case DType.TIMESTAMP_DAYS => Scalar.fromNull(DType.TIMESTAMP_DAYS) - case DType.TIMESTAMP_SECONDS => Scalar.fromNull(DType.TIMESTAMP_SECONDS) - case DType.TIMESTAMP_MICROSECONDS => Scalar.fromNull(DType.TIMESTAMP_MICROSECONDS) - case _ => throw new IllegalArgumentException(s"unsupported DType: $unit") - } } // TODO dedupe utils inside shims diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DateUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DateUtils.scala index 018f59bbfa8..85a14814a3b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DateUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DateUtils.scala @@ -20,6 +20,8 @@ import java.time.LocalDate import scala.collection.mutable.ListBuffer +import ai.rapids.cudf.{DType, Scalar} + import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils.localDateToDays @@ -57,7 +59,13 @@ object DateUtils { val YESTERDAY = "yesterday" val TOMORROW = "tomorrow" - def specialDatesDays: Map[String, Int] = { + private lazy val isSpark320OrLater: Boolean = { + ShimLoader.getSparkShims.getSparkShimVersion.toString >= "3.2" + } + + def specialDatesDays: Map[String, Int] = if (isSpark320OrLater) { + Map.empty + } else { val today = currentDate() Map( EPOCH -> 0, @@ -68,7 +76,9 @@ object DateUtils { ) } - def specialDatesSeconds: Map[String, Long] = { + def specialDatesSeconds: Map[String, Long] = if (isSpark320OrLater) { + Map.empty + } else { val today = currentDate() val now = DateTimeUtils.currentTimestamp() Map( @@ -80,7 +90,9 @@ object DateUtils { ) } - def specialDatesMicros: Map[String, Long] = { + def specialDatesMicros: Map[String, Long] = if (isSpark320OrLater) { + Map.empty + } else { val today = currentDate() val now = DateTimeUtils.currentTimestamp() Map( @@ -92,6 +104,17 @@ object DateUtils { ) } + def fetchSpecialDates(unit: DType): Map[String, Scalar] = unit match { + case DType.TIMESTAMP_DAYS => + DateUtils.specialDatesDays.map { case (k, v) => k -> Scalar.timestampDaysFromInt(v) } + case DType.TIMESTAMP_SECONDS => + DateUtils.specialDatesSeconds.map { case (k, v) => k -> Scalar.timestampFromLong(unit, v) } + case DType.TIMESTAMP_MICROSECONDS => + DateUtils.specialDatesMicros.map { case (k, v) => k -> Scalar.timestampFromLong(unit, v) } + case _ => + throw new IllegalArgumentException(s"unsupported DType: $unit") + } + def currentDate(): Int = localDateToDays(LocalDate.now()) case class FormatKeywordToReplace(word: String, startIndex: Int, endIndex: Int) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index 59ecc873d98..c3b3028909a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -790,27 +790,6 @@ object GpuCast extends Arm { } } - /** - * Replace special date strings such as "now" with timestampDays. This method does not - * close the `input` ColumnVector. - */ - def specialDateOr( - input: ColumnVector, - special: String, - value: Int, - orColumnVector: ColumnVector): ColumnVector = { - - withResource(orColumnVector) { other => - withResource(Scalar.fromString(special)) { str => - withResource(input.equalTo(str)) { isStr => - withResource(Scalar.timestampDaysFromInt(value)) { date => - isStr.ifElse(date, other) - } - } - } - } - } - /** This method does not close the `input` ColumnVector. */ def convertDateOrNull( input: ColumnVector, @@ -886,40 +865,24 @@ object GpuCast extends Arm { */ private def castStringToDate(sanitizedInput: ColumnVector): ColumnVector = { -/* - withResource(sanitizeStringToDate(input)) { sanitizedInput => - - // convert dates that are in valid formats yyyy, yyyy-mm, yyyy-mm-dd - val converted = convertVarLenDateOr(sanitizedInput, DATE_REGEX_YYYY_MM_DD, "%Y-%m-%d", - convertFixedLenDateOr(sanitizedInput, 7, "%Y-%m", - convertFixedLenDateOrNull(sanitizedInput, 4, "%Y"))) - - // handle special dates like "epoch", "now", etc. - // `converted` will be closed in replaceSpecialDates. We wrap it with closeOnExcept in case - // of exception before replaceSpecialDates. - closeOnExcept(converted) { timeStampVector => - val specialDates = Seq(DateUtils.EPOCH, DateUtils.NOW, DateUtils.TODAY, - DateUtils.YESTERDAY, DateUtils.TOMORROW) - val specialValues = mutable.ListBuffer.empty[Scalar] - withResource(specialValues) { _ => - specialDates.foreach( - specialValues += ShimLoader.getSparkShims.getSpecialDate(_, DType.TIMESTAMP_DAYS)) - replaceSpecialDates(sanitizedInput, timeStampVector, specialDates, specialValues) - } - } - } -*/ - - val specialDates = DateUtils.specialDatesDays - // convert dates that are in valid formats yyyy, yyyy-mm, yyyy-mm-dd val converted = convertDateOr(sanitizedInput, DATE_REGEX_YYYY_MM_DD, "%Y-%m-%d", convertDateOr(sanitizedInput, DATE_REGEX_YYYY_MM, "%Y-%m", convertDateOrNull(sanitizedInput, DATE_REGEX_YYYY, "%Y"))) // handle special dates like "epoch", "now", etc. - specialDates.foldLeft(converted)((prev, specialDate) => - specialDateOr(sanitizedInput, specialDate._1, specialDate._2, prev)) + closeOnExcept(converted) { tsVector => + DateUtils.fetchSpecialDates(DType.TIMESTAMP_DAYS) match { + case dates if dates.nonEmpty => + // `tsVector` will be closed in replaceSpecialDates + val (specialNames, specialValues) = dates.unzip + withResource(specialValues.toList) { scalars => + replaceSpecialDates(sanitizedInput, tsVector, specialNames.toList, scalars) + } + case _ => + tsVector + } + } } private def castStringToDateAnsi(input: ColumnVector, ansiMode: Boolean): ColumnVector = { @@ -934,27 +897,6 @@ object GpuCast extends Arm { } } - /** - * Replace special date strings such as "now" with timestampMicros. This method does not - * close the `input` ColumnVector. - */ - private def specialTimestampOr( - input: ColumnVector, - special: String, - value: Long, - orColumnVector: ColumnVector): ColumnVector = { - - withResource(orColumnVector) { other => - withResource(Scalar.fromString(special)) { str => - withResource(input.equalTo(str)) { isStr => - withResource(Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, value)) { date => - isStr.ifElse(date, other) - } - } - } - } - } - /** This method does not close the `input` ColumnVector. */ private def convertTimestampOrNull( input: ColumnVector, @@ -1052,17 +994,16 @@ object GpuCast extends Arm { convertTimestampOrNull(sanitizedInput, TIMESTAMP_REGEX_YYYY, "%Y")))) // handle special dates like "epoch", "now", etc. - // `converted` will be closed in replaceSpecialDates. We wrap it with closeOnExcept in case - // of exception before replaceSpecialDates. - val finalResult = closeOnExcept(converted) { timeStampVector => - val specialDates = Seq(DateUtils.EPOCH, DateUtils.NOW, DateUtils.TODAY, - DateUtils.YESTERDAY, DateUtils.TOMORROW) - val specialValues = mutable.ListBuffer.empty[Scalar] - withResource(specialValues) { _ => - specialDates.foreach( - specialValues += - ShimLoader.getSparkShims.getSpecialDate(_, DType.TIMESTAMP_MICROSECONDS)) - replaceSpecialDates(sanitizedInput, timeStampVector, specialDates, specialValues) + val finalResult = closeOnExcept(converted) { tsVector => + DateUtils.fetchSpecialDates(DType.TIMESTAMP_MICROSECONDS) match { + case dates if dates.nonEmpty => + // `tsVector` will be closed in replaceSpecialDates. + val (specialNames, specialValues) = dates.unzip + withResource(specialValues.toList) { scalars => + replaceSpecialDates(sanitizedInput, tsVector, specialNames.toList, scalars) + } + case _ => + tsVector } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala index 49e9db92821..2f846b3b3bf 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala @@ -19,11 +19,7 @@ package com.nvidia.spark.rapids import java.net.URI import java.nio.ByteBuffer -<<<<<<< HEAD -import ai.rapids.cudf.{DType, Scalar} -======= import com.esotericsoftware.kryo.Kryo ->>>>>>> origin/branch-21.10 import org.apache.arrow.memory.ReferenceManager import org.apache.arrow.vector.ValueVector import org.apache.hadoop.fs.{FileStatus, Path} @@ -272,8 +268,6 @@ trait SparkShims { def leafNodeDefaultParallelism(ss: SparkSession): Int - def getSpecialDate(name: String, unit: DType): Scalar - def registerKryoClasses(kryo: Kryo): Unit } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala index c6db354d1e5..abdb476afa6 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala @@ -598,15 +598,16 @@ object GpuToTimestamp extends Arm { // in addition to date/timestamp strings, we also need to check for special dates and null // values, since anything else is invalid and should throw an error or be converted to null // depending on the policy - val specialDates = Seq(DateUtils.EPOCH, DateUtils.NOW, DateUtils.TODAY, - DateUtils.YESTERDAY, DateUtils.TOMORROW) - val specialValues = mutable.ListBuffer.empty[Scalar] - - withResource(specialValues) { _ => - closeOnExcept(tsVector) { _ => - specialDates.foreach( - specialValues += ShimLoader.getSparkShims.getSpecialDate(_, dtype)) - replaceSpecialDates(lhs.getBase, tsVector, specialDates, specialValues) + closeOnExcept(tsVector) { tsVector => + DateUtils.fetchSpecialDates(dtype) match { + case dates if dates.nonEmpty => + // `tsVector` will be closed in replaceSpecialDates + val (specialNames, specialValues) = dates.unzip + withResource(specialValues.toList) { scalars => + replaceSpecialDates(lhs.getBase, tsVector, specialNames.toList, scalars) + } + case _ => + tsVector } } } From b2f35d3a062f486034422e9e07e0085fe87b36d5 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Thu, 16 Sep 2021 20:55:13 +0800 Subject: [PATCH 6/9] fix Signed-off-by: sperlingxx --- .../scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala | 3 +-- .../scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/sql-plugin/src/main/301until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala b/sql-plugin/src/main/301until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala index 670748110f9..2ff021f6193 100644 --- a/sql-plugin/src/main/301until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala +++ b/sql-plugin/src/main/301until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala @@ -18,8 +18,7 @@ package com.nvidia.spark.rapids.shims.v2 import scala.collection.mutable.ListBuffer -import ai.rapids.cudf.{DType, Scalar} -import com.nvidia.spark.rapids.{DateUtils, ExecChecks, ExecRule, GpuExec, SparkPlanMeta, SparkShims, TypeSig} +import com.nvidia.spark.rapids.{ExecChecks, ExecRule, GpuExec, SparkPlanMeta, SparkShims, TypeSig} import com.nvidia.spark.rapids.GpuOverrides.exec import org.apache.hadoop.fs.FileStatus diff --git a/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala b/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala index ed33b91470d..3e5d74f5ee0 100644 --- a/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala +++ b/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala @@ -16,7 +16,6 @@ package com.nvidia.spark.rapids.shims.v2 -import ai.rapids.cudf.{DType, Scalar} import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.GpuOverrides.exec import org.apache.hadoop.fs.FileStatus From 790eb44a1b3fd2d16a4832749f4d751233986323 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Fri, 17 Sep 2021 10:30:25 +0800 Subject: [PATCH 7/9] update Signed-off-by: sperlingxx --- .../sql/rapids/datetimeExpressions.scala | 41 ------------------- 1 file changed, 41 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala index 48b97f61c52..ce59042eff6 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala @@ -510,47 +510,6 @@ object GpuToTimestamp extends Arm { val REMOVE_WHITESPACE_FROM_MONTH_DAY: RegexReplace = RegexReplace(raw"(\A\d+)-([ \t]*)(\d+)-([ \t]*)(\d+)", raw"\1-\3-\5") - /** Regex rule to replace "yyyy-m-" with "yyyy-mm-" */ - val FIX_SINGLE_DIGIT_MONTH: RegexReplace = - RegexReplace(raw"(\A\d+)-(\d{1}-)", raw"\1-0\2") - - /** Regex rule to replace "yyyy-mm-d" with "yyyy-mm-dd" */ - val FIX_SINGLE_DIGIT_DAY: RegexReplace = - RegexReplace(raw"(\A\d+-\d{2})-(\d{1})([\D\s]|\Z)", raw"\1-0\2\3") - - /** Regex rule to replace "yyyy-mm-dd[ T]h:" with "yyyy-mm-dd hh:" */ - val FIX_SINGLE_DIGIT_HOUR: RegexReplace = - RegexReplace(raw"(\A\d+-\d{2}-\d{2})[ T](\d{1}:)", raw"\1 0\2") - - /** Regex rule to replace "yyyy-mm-dd[ T]hh:m:" with "yyyy-mm-dd[ T]hh:mm:" */ - val FIX_SINGLE_DIGIT_MINUTE: RegexReplace = - RegexReplace(raw"(\A\d+-\d{2}-\d{2}[ T]\d{2}):(\d{1}:)", raw"\1:0\2") - - /** Regex rule to replace "yyyy-mm-dd[ T]hh:mm:s" with "yyyy-mm-dd[ T]hh:mm:ss" */ - val FIX_SINGLE_DIGIT_SECOND: RegexReplace = - RegexReplace(raw"(\A\d+-\d{2}-\d{2}[ T]\d{2}:\d{2}):(\d{1})([\D\s]|\Z)", raw"\1:0\2\3") - - /** Convert dates to standard format */ - val FIX_DATES = Seq( - REMOVE_WHITESPACE_FROM_MONTH_DAY, - FIX_SINGLE_DIGIT_MONTH, - FIX_SINGLE_DIGIT_DAY) - - /** Convert timestamps to standard format */ - val FIX_TIMESTAMPS = Seq( - FIX_SINGLE_DIGIT_HOUR, - FIX_SINGLE_DIGIT_MINUTE, - FIX_SINGLE_DIGIT_SECOND - ) - - def daysScalarSeconds(name: String): Scalar = { - Scalar.timestampFromLong(DType.TIMESTAMP_SECONDS, DateUtils.specialDatesSeconds(name)) - } - - def daysScalarMicros(name: String): Scalar = { - Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, DateUtils.specialDatesMicros(name)) - } - def daysEqual(col: ColumnVector, name: String): ColumnVector = { withResource(Scalar.fromString(name)) { scalarName => col.equalTo(scalarName) From 76eaa443c70999e0d6db2119afe652094f8999ec Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Fri, 17 Sep 2021 14:17:41 +0800 Subject: [PATCH 8/9] wrap up Signed-off-by: sperlingxx --- .../com/nvidia/spark/rapids/DateUtils.scala | 19 ++++--- .../com/nvidia/spark/rapids/GpuCast.scala | 14 ++--- .../nvidia/spark/rapids/VersionUtils.scala | 55 +++++++++++++++++++ .../sql/rapids/datetimeExpressions.scala | 18 +++--- .../spark/rapids/ScalarSubquerySuite.scala | 2 +- .../rapids/SparkQueryCompareTestSuite.scala | 12 ++-- 6 files changed, 83 insertions(+), 37 deletions(-) create mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/VersionUtils.scala diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DateUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DateUtils.scala index 85a14814a3b..7834bdeaed6 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DateUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/DateUtils.scala @@ -21,6 +21,7 @@ import java.time.LocalDate import scala.collection.mutable.ListBuffer import ai.rapids.cudf.{DType, Scalar} +import com.nvidia.spark.rapids.VersionUtils.isSpark320OrLater import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils.localDateToDays @@ -59,10 +60,6 @@ object DateUtils { val YESTERDAY = "yesterday" val TOMORROW = "tomorrow" - private lazy val isSpark320OrLater: Boolean = { - ShimLoader.getSparkShims.getSparkShimVersion.toString >= "3.2" - } - def specialDatesDays: Map[String, Int] = if (isSpark320OrLater) { Map.empty } else { @@ -104,13 +101,19 @@ object DateUtils { ) } - def fetchSpecialDates(unit: DType): Map[String, Scalar] = unit match { + def fetchSpecialDates(unit: DType): Map[String, () => Scalar] = unit match { case DType.TIMESTAMP_DAYS => - DateUtils.specialDatesDays.map { case (k, v) => k -> Scalar.timestampDaysFromInt(v) } + DateUtils.specialDatesDays.map { case (k, v) => + k -> (() => Scalar.timestampDaysFromInt(v)) + } case DType.TIMESTAMP_SECONDS => - DateUtils.specialDatesSeconds.map { case (k, v) => k -> Scalar.timestampFromLong(unit, v) } + DateUtils.specialDatesSeconds.map { case (k, v) => + k -> (() => Scalar.timestampFromLong(unit, v)) + } case DType.TIMESTAMP_MICROSECONDS => - DateUtils.specialDatesMicros.map { case (k, v) => k -> Scalar.timestampFromLong(unit, v) } + DateUtils.specialDatesMicros.map { case (k, v) => + k -> (() => Scalar.timestampFromLong(unit, v)) + } case _ => throw new IllegalArgumentException(s"unsupported DType: $unit") } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index c3b3028909a..9209125c43f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -873,12 +873,9 @@ object GpuCast extends Arm { // handle special dates like "epoch", "now", etc. closeOnExcept(converted) { tsVector => DateUtils.fetchSpecialDates(DType.TIMESTAMP_DAYS) match { - case dates if dates.nonEmpty => + case specialDates if specialDates.nonEmpty => // `tsVector` will be closed in replaceSpecialDates - val (specialNames, specialValues) = dates.unzip - withResource(specialValues.toList) { scalars => - replaceSpecialDates(sanitizedInput, tsVector, specialNames.toList, scalars) - } + replaceSpecialDates(sanitizedInput, tsVector, specialDates) case _ => tsVector } @@ -996,12 +993,9 @@ object GpuCast extends Arm { // handle special dates like "epoch", "now", etc. val finalResult = closeOnExcept(converted) { tsVector => DateUtils.fetchSpecialDates(DType.TIMESTAMP_MICROSECONDS) match { - case dates if dates.nonEmpty => + case specialDates if specialDates.nonEmpty => // `tsVector` will be closed in replaceSpecialDates. - val (specialNames, specialValues) = dates.unzip - withResource(specialValues.toList) { scalars => - replaceSpecialDates(sanitizedInput, tsVector, specialNames.toList, scalars) - } + replaceSpecialDates(sanitizedInput, tsVector, specialDates) case _ => tsVector } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/VersionUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/VersionUtils.scala new file mode 100644 index 00000000000..ec71c1405a1 --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/VersionUtils.scala @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids + +object VersionUtils { + + lazy val isSpark301OrLater: Boolean = cmpSparkVersion(3, 0, 1) >= 0 + + lazy val isSpark311OrLater: Boolean = cmpSparkVersion(3, 1, 1) >= 0 + + lazy val isSpark320OrLater: Boolean = cmpSparkVersion(3, 2, 0) >= 0 + + lazy val isSpark: Boolean = { + ShimLoader.getSparkShims.getSparkShimVersion.isInstanceOf[SparkShimVersion] + } + + lazy val isDataBricks: Boolean = { + ShimLoader.getSparkShims.getSparkShimVersion.isInstanceOf[DatabricksShimVersion] + } + + lazy val isCloudera: Boolean = { + ShimLoader.getSparkShims.getSparkShimVersion.isInstanceOf[ClouderaShimVersion] + } + + lazy val isEMR: Boolean = { + ShimLoader.getSparkShims.getSparkShimVersion.isInstanceOf[EMRShimVersion] + } + + def cmpSparkVersion(major: Int, minor: Int, bugfix: Int): Int = { + val sparkShimVersion = ShimLoader.getSparkShims.getSparkShimVersion + val (sparkMajor, sparkMinor, sparkBugfix) = sparkShimVersion match { + case SparkShimVersion(a, b, c) => (a, b, c) + case DatabricksShimVersion(a, b, c) => (a, b, c) + case ClouderaShimVersion(a, b, c, _) => (a, b, c) + case EMRShimVersion(a, b, c) => (a, b, c) + } + val fullVersion = ((major.toLong * 1000) + minor) * 1000 + bugfix + val sparkFullVersion = ((sparkMajor.toLong * 1000) + sparkMinor) * 1000 + sparkBugfix + sparkFullVersion.compareTo(fullVersion) + } +} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala index ce59042eff6..bf7b31816cf 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala @@ -518,17 +518,18 @@ object GpuToTimestamp extends Arm { /** * Replace special date strings such as "now" with timestampDays. This method does not - * close the `stringVector` and `specialValues`. + * close the `stringVector`. */ def replaceSpecialDates( stringVector: ColumnVector, chronoVector: ColumnVector, - specialNames: Seq[String], - specialValues: Seq[Scalar]): ColumnVector = { - specialValues.zip(specialNames).foldLeft(chronoVector) { case (buffer, (scalar, name)) => + specialDates: Map[String, () => Scalar]): ColumnVector = { + specialDates.foldLeft(chronoVector) { case (buffer, (name, scalarBuilder)) => withResource(buffer) { bufVector => withResource(daysEqual(stringVector, name)) { isMatch => - isMatch.ifElse(scalar, bufVector) + withResource(scalarBuilder()) { scalar => + isMatch.ifElse(scalar, bufVector) + } } } } @@ -574,12 +575,9 @@ object GpuToTimestamp extends Arm { // depending on the policy closeOnExcept(tsVector) { tsVector => DateUtils.fetchSpecialDates(dtype) match { - case dates if dates.nonEmpty => + case specialDates if specialDates.nonEmpty => // `tsVector` will be closed in replaceSpecialDates - val (specialNames, specialValues) = dates.unzip - withResource(specialValues.toList) { scalars => - replaceSpecialDates(lhs.getBase, tsVector, specialNames.toList, scalars) - } + replaceSpecialDates(lhs.getBase, tsVector, specialDates) case _ => tsVector } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ScalarSubquerySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ScalarSubquerySuite.scala index 48b04ee1e54..0522e359d66 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/ScalarSubquerySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/ScalarSubquerySuite.scala @@ -35,7 +35,7 @@ class ScalarSubquerySuite extends SparkQueryCompareTestSuite { // In Spark 3.2.0+ canonicalization for ScalarSubquery was fixed. Because it is a bug fix // we have fixed it on our end for all versions of Spark, but the canonicalization check // only works if both have the fix or both don't have it. - skipCanonicalizationCheck = isPriorToSpark320) { + skipCanonicalizationCheck = !VersionUtils.isSpark320OrLater) { frame => { frame.createOrReplaceTempView("table") val ret = frame.sparkSession.sql( diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala index f1922389ddc..e94fddca8fa 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala @@ -1837,20 +1837,16 @@ trait SparkQueryCompareTestSuite extends FunSuite with Arm { /** most of the AQE tests requires Spark 3.0.1 or later */ def assumeSpark301orLater: Assertion = - assume(cmpSparkVersion(3, 0, 1) >= 0, "Spark version not 3.0.1+") + assume(VersionUtils.isSpark301OrLater, "Spark version not 3.0.1+") def assumeSpark311orLater: Assertion = - assume(cmpSparkVersion(3, 1, 1) >= 0, "Spark version not 3.1.1+") + assume(VersionUtils.isSpark311OrLater, "Spark version not 3.1.1+") def assumePriorToSpark320: Assertion = - assume(isPriorToSpark320, "Spark version not before 3.2.0") - - def isPriorToSpark320: Boolean = cmpSparkVersion(3, 2, 0) < 0 + assume(!VersionUtils.isSpark320OrLater, "Spark version not before 3.2.0") def assumeSpark320orLater: Assertion = - assume(isSpark320OrLater, "Spark version not 3.2.0+") - - def isSpark320OrLater: Boolean = cmpSparkVersion(3, 2, 0) >= 0 + assume(VersionUtils.isSpark320OrLater, "Spark version not 3.2.0+") def cmpSparkVersion(major: Int, minor: Int, bugfix: Int): Int = { val sparkShimVersion = ShimLoader.getSparkShims.getSparkShimVersion From 50e85eff3becef5f13168c5fd4998960c829a099 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Fri, 17 Sep 2021 14:32:09 +0800 Subject: [PATCH 9/9] update Signed-off-by: sperlingxx --- .../scala/com/nvidia/spark/rapids/ParseDateTimeSuite.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ParseDateTimeSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ParseDateTimeSuite.scala index 64e33b8d1e8..d82b790fdb7 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/ParseDateTimeSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/ParseDateTimeSuite.scala @@ -229,8 +229,14 @@ class ParseDateTimeSuite extends SparkQueryCompareTestSuite with BeforeAndAfterE .repartition(2) .withColumn("c1", unix_timestamp(col("c0"), "yyyy-MM-dd HH:mm:ss")) } + val startTimeSeconds = System.currentTimeMillis() / 1000L val cpuNowSeconds = withCpuSparkSession(now).collect().head.toSeq(1).asInstanceOf[Long] val gpuNowSeconds = withGpuSparkSession(now).collect().head.toSeq(1).asInstanceOf[Long] + // For Spark 3.2+, "now" will NOT be parsed as the current time + if (!VersionUtils.isSpark320OrLater) { + assert(cpuNowSeconds >= startTimeSeconds) + assert(gpuNowSeconds >= startTimeSeconds) + } // CPU ran first so cannot have a greater value than the GPU run (but could be the same second) assert(cpuNowSeconds <= gpuNowSeconds) }