Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/DateUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ import java.time.LocalDate

import scala.collection.mutable.ListBuffer

import ai.rapids.cudf.{DType, Scalar}
import com.nvidia.spark.rapids.VersionUtils.isSpark320OrLater

import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.DateTimeUtils.localDateToDays

Expand Down Expand Up @@ -57,7 +60,9 @@ object DateUtils {
val YESTERDAY = "yesterday"
val TOMORROW = "tomorrow"

def specialDatesDays: Map[String, Int] = {
def specialDatesDays: Map[String, Int] = if (isSpark320OrLater) {
Map.empty
} else {
val today = currentDate()
Map(
EPOCH -> 0,
Expand All @@ -68,7 +73,9 @@ object DateUtils {
)
}

def specialDatesSeconds: Map[String, Long] = {
def specialDatesSeconds: Map[String, Long] = if (isSpark320OrLater) {
Map.empty
} else {
val today = currentDate()
val now = DateTimeUtils.currentTimestamp()
Map(
Expand All @@ -80,7 +87,9 @@ object DateUtils {
)
}

def specialDatesMicros: Map[String, Long] = {
def specialDatesMicros: Map[String, Long] = if (isSpark320OrLater) {
Map.empty
} else {
val today = currentDate()
val now = DateTimeUtils.currentTimestamp()
Map(
Expand All @@ -92,6 +101,23 @@ object DateUtils {
)
}

def fetchSpecialDates(unit: DType): Map[String, () => Scalar] = unit match {
case DType.TIMESTAMP_DAYS =>
DateUtils.specialDatesDays.map { case (k, v) =>
k -> (() => Scalar.timestampDaysFromInt(v))
}
case DType.TIMESTAMP_SECONDS =>
DateUtils.specialDatesSeconds.map { case (k, v) =>
k -> (() => Scalar.timestampFromLong(unit, v))
}
case DType.TIMESTAMP_MICROSECONDS =>
DateUtils.specialDatesMicros.map { case (k, v) =>
k -> (() => Scalar.timestampFromLong(unit, v))
}
case _ =>
throw new IllegalArgumentException(s"unsupported DType: $unit")
}

def currentDate(): Int = localDateToDays(LocalDate.now())

case class FormatKeywordToReplace(word: String, startIndex: Int, endIndex: Int)
Expand Down
69 changes: 20 additions & 49 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package com.nvidia.spark.rapids
import java.text.SimpleDateFormat
import java.time.DateTimeException

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, Scalar}
Expand All @@ -27,6 +28,7 @@ import com.nvidia.spark.rapids.RapidsPluginImplicits._
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.{Cast, CastBase, Expression, NullIntolerant, TimeZoneAwareExpression}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.GpuToTimestamp.replaceSpecialDates
import org.apache.spark.sql.rapids.RegexReplace
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -788,27 +790,6 @@ object GpuCast extends Arm {
}
}

/**
* Replace special date strings such as "now" with timestampDays. This method does not
* close the `input` ColumnVector.
*/
def specialDateOr(
input: ColumnVector,
special: String,
value: Int,
orColumnVector: ColumnVector): ColumnVector = {

withResource(orColumnVector) { other =>
withResource(Scalar.fromString(special)) { str =>
withResource(input.equalTo(str)) { isStr =>
withResource(Scalar.timestampDaysFromInt(value)) { date =>
isStr.ifElse(date, other)
}
}
}
}
}

/** This method does not close the `input` ColumnVector. */
def convertDateOrNull(
input: ColumnVector,
Expand Down Expand Up @@ -884,16 +865,21 @@ object GpuCast extends Arm {
*/
private def castStringToDate(sanitizedInput: ColumnVector): ColumnVector = {

val specialDates = DateUtils.specialDatesDays

// convert dates that are in valid formats yyyy, yyyy-mm, yyyy-mm-dd
val converted = convertDateOr(sanitizedInput, DATE_REGEX_YYYY_MM_DD, "%Y-%m-%d",
convertDateOr(sanitizedInput, DATE_REGEX_YYYY_MM, "%Y-%m",
convertDateOrNull(sanitizedInput, DATE_REGEX_YYYY, "%Y")))

// handle special dates like "epoch", "now", etc.
specialDates.foldLeft(converted)((prev, specialDate) =>
specialDateOr(sanitizedInput, specialDate._1, specialDate._2, prev))
closeOnExcept(converted) { tsVector =>
DateUtils.fetchSpecialDates(DType.TIMESTAMP_DAYS) match {
case specialDates if specialDates.nonEmpty =>
// `tsVector` will be closed in replaceSpecialDates
replaceSpecialDates(sanitizedInput, tsVector, specialDates)
case _ =>
tsVector
}
}
}

private def castStringToDateAnsi(input: ColumnVector, ansiMode: Boolean): ColumnVector = {
Expand All @@ -908,27 +894,6 @@ object GpuCast extends Arm {
}
}

/**
* Replace special date strings such as "now" with timestampMicros. This method does not
* close the `input` ColumnVector.
*/
private def specialTimestampOr(
input: ColumnVector,
special: String,
value: Long,
orColumnVector: ColumnVector): ColumnVector = {

withResource(orColumnVector) { other =>
withResource(Scalar.fromString(special)) { str =>
withResource(input.equalTo(str)) { isStr =>
withResource(Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, value)) { date =>
isStr.ifElse(date, other)
}
}
}
}
}

/** This method does not close the `input` ColumnVector. */
private def convertTimestampOrNull(
input: ColumnVector,
Expand Down Expand Up @@ -1009,7 +974,6 @@ object GpuCast extends Arm {
val today = DateUtils.currentDate()
val todayStr = new SimpleDateFormat("yyyy-MM-dd")
.format(today * DateUtils.ONE_DAY_SECONDS * 1000L)
val specialDates = DateUtils.specialDatesMicros

var sanitizedInput = input.incRefCount()

Expand All @@ -1027,8 +991,15 @@ object GpuCast extends Arm {
convertTimestampOrNull(sanitizedInput, TIMESTAMP_REGEX_YYYY, "%Y"))))

// handle special dates like "epoch", "now", etc.
val finalResult = specialDates.foldLeft(converted)((prev, specialDate) =>
specialTimestampOr(sanitizedInput, specialDate._1, specialDate._2, prev))
val finalResult = closeOnExcept(converted) { tsVector =>
DateUtils.fetchSpecialDates(DType.TIMESTAMP_MICROSECONDS) match {
case specialDates if specialDates.nonEmpty =>
// `tsVector` will be closed in replaceSpecialDates.
replaceSpecialDates(sanitizedInput, tsVector, specialDates)
case _ =>
tsVector
}
}

if (ansiMode) {
// When ANSI mode is enabled, we need to throw an exception if any values could not be
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

object VersionUtils {

lazy val isSpark301OrLater: Boolean = cmpSparkVersion(3, 0, 1) >= 0

lazy val isSpark311OrLater: Boolean = cmpSparkVersion(3, 1, 1) >= 0

lazy val isSpark320OrLater: Boolean = cmpSparkVersion(3, 2, 0) >= 0

lazy val isSpark: Boolean = {
ShimLoader.getSparkShims.getSparkShimVersion.isInstanceOf[SparkShimVersion]
}

lazy val isDataBricks: Boolean = {
ShimLoader.getSparkShims.getSparkShimVersion.isInstanceOf[DatabricksShimVersion]
}

lazy val isCloudera: Boolean = {
ShimLoader.getSparkShims.getSparkShimVersion.isInstanceOf[ClouderaShimVersion]
}

lazy val isEMR: Boolean = {
ShimLoader.getSparkShims.getSparkShimVersion.isInstanceOf[EMRShimVersion]
}

def cmpSparkVersion(major: Int, minor: Int, bugfix: Int): Int = {
val sparkShimVersion = ShimLoader.getSparkShims.getSparkShimVersion
val (sparkMajor, sparkMinor, sparkBugfix) = sparkShimVersion match {
case SparkShimVersion(a, b, c) => (a, b, c)
case DatabricksShimVersion(a, b, c) => (a, b, c)
case ClouderaShimVersion(a, b, c, _) => (a, b, c)
case EMRShimVersion(a, b, c) => (a, b, c)
}
val fullVersion = ((major.toLong * 1000) + minor) * 1000 + bugfix
val sparkFullVersion = ((sparkMajor.toLong * 1000) + sparkMinor) * 1000 + sparkBugfix
sparkFullVersion.compareTo(fullVersion)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ package org.apache.spark.sql.rapids

import java.util.concurrent.TimeUnit

import scala.collection.mutable

import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, Scalar}
import com.nvidia.spark.rapids.{Arm, BinaryExprMeta, DataFromReplacementRule, DateUtils, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, GpuScalar, GpuUnaryExpression, RapidsConf, RapidsMeta}
import com.nvidia.spark.rapids.{Arm, BinaryExprMeta, DataFromReplacementRule, DateUtils, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuOverrides, GpuScalar, GpuUnaryExpression, RapidsConf, RapidsMeta, ShimLoader}
import com.nvidia.spark.rapids.DateUtils.TimestampFormatConversionException
import com.nvidia.spark.rapids.GpuOverrides.{extractStringLit, getTimeParserPolicy}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
Expand Down Expand Up @@ -508,20 +510,31 @@ object GpuToTimestamp extends Arm {
val REMOVE_WHITESPACE_FROM_MONTH_DAY: RegexReplace =
RegexReplace(raw"(\A\d+)-([ \t]*)(\d+)-([ \t]*)(\d+)", raw"\1-\3-\5")

def daysScalarSeconds(name: String): Scalar = {
Scalar.timestampFromLong(DType.TIMESTAMP_SECONDS, DateUtils.specialDatesSeconds(name))
}

def daysScalarMicros(name: String): Scalar = {
Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, DateUtils.specialDatesMicros(name))
}

def daysEqual(col: ColumnVector, name: String): ColumnVector = {
withResource(Scalar.fromString(name)) { scalarName =>
col.equalTo(scalarName)
}
}

/**
* Replace special date strings such as "now" with timestampDays. This method does not
* close the `stringVector`.
*/
def replaceSpecialDates(
stringVector: ColumnVector,
chronoVector: ColumnVector,
specialDates: Map[String, () => Scalar]): ColumnVector = {
specialDates.foldLeft(chronoVector) { case (buffer, (name, scalarBuilder)) =>
withResource(buffer) { bufVector =>
withResource(daysEqual(stringVector, name)) { isMatch =>
withResource(scalarBuilder()) { scalar =>
isMatch.ifElse(scalar, bufVector)
}
}
}
}
}

def isTimestamp(col: ColumnVector, sparkFormat: String, strfFormat: String) : ColumnVector = {
CORRECTED_COMPATIBLE_FORMATS.get(sparkFormat) match {
case Some(fmt) =>
Expand All @@ -546,50 +559,27 @@ object GpuToTimestamp extends Arm {
lhs: GpuColumnVector,
sparkFormat: String,
strfFormat: String,
dtype: DType,
daysScalar: String => Scalar,
asTimestamp: (ColumnVector, String) => ColumnVector): ColumnVector = {
dtype: DType): ColumnVector = {

// `tsVector` will be closed in replaceSpecialDates
val tsVector = withResource(isTimestamp(lhs.getBase, sparkFormat, strfFormat)) { isTs =>
withResource(Scalar.fromNull(dtype)) { nullValue =>
withResource(lhs.getBase.asTimestamp(dtype, strfFormat)) { tsVec =>
isTs.ifElse(tsVec, nullValue)
}
}
}

// in addition to date/timestamp strings, we also need to check for special dates and null
// values, since anything else is invalid and should throw an error or be converted to null
// depending on the policy
withResource(isTimestamp(lhs.getBase, sparkFormat, strfFormat)) { isTimestamp =>
withResource(daysEqual(lhs.getBase, DateUtils.EPOCH)) { isEpoch =>
withResource(daysEqual(lhs.getBase, DateUtils.NOW)) { isNow =>
withResource(daysEqual(lhs.getBase, DateUtils.TODAY)) { isToday =>
withResource(daysEqual(lhs.getBase, DateUtils.YESTERDAY)) { isYesterday =>
withResource(daysEqual(lhs.getBase, DateUtils.TOMORROW)) { isTomorrow =>
withResource(lhs.getBase.isNull) { _ =>
withResource(Scalar.fromNull(dtype)) { nullValue =>
withResource(asTimestamp(lhs.getBase, strfFormat)) { converted =>
withResource(daysScalar(DateUtils.EPOCH)) { epoch =>
withResource(daysScalar(DateUtils.NOW)) { now =>
withResource(daysScalar(DateUtils.TODAY)) { today =>
withResource(daysScalar(DateUtils.YESTERDAY)) { yesterday =>
withResource(daysScalar(DateUtils.TOMORROW)) { tomorrow =>
withResource(isTomorrow.ifElse(tomorrow, nullValue)) { a =>
withResource(isYesterday.ifElse(yesterday, a)) { b =>
withResource(isToday.ifElse(today, b)) { c =>
withResource(isNow.ifElse(now, c)) { d =>
withResource(isEpoch.ifElse(epoch, d)) { e =>
isTimestamp.ifElse(converted, e)
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
closeOnExcept(tsVector) { tsVector =>
DateUtils.fetchSpecialDates(dtype) match {
case specialDates if specialDates.nonEmpty =>
// `tsVector` will be closed in replaceSpecialDates
replaceSpecialDates(lhs.getBase, tsVector, specialDates)
case _ =>
tsVector
}
}
}
Expand Down Expand Up @@ -728,9 +718,7 @@ abstract class GpuToTimestamp
lhs,
sparkFormat,
strfFormat,
DType.TIMESTAMP_MICROSECONDS,
daysScalarMicros,
(col, strfFormat) => col.asTimestampMicroseconds(strfFormat))
DType.TIMESTAMP_MICROSECONDS)
}
} else { // Timestamp or DateType
lhs.getBase.asTimestampMicroseconds()
Expand Down Expand Up @@ -779,9 +767,7 @@ abstract class GpuToTimestampImproved extends GpuToTimestamp {
lhs,
sparkFormat,
strfFormat,
DType.TIMESTAMP_SECONDS,
daysScalarSeconds,
(col, strfFormat) => col.asTimestampSeconds(strfFormat))
DType.TIMESTAMP_SECONDS)
}
} else if (lhs.dataType() == DateType){
lhs.getBase.asTimestampSeconds()
Expand Down
Loading