Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Negate intervals
  • Loading branch information
MaxGekk committed Mar 9, 2021
commit a5fc2c1028061a0286ab69f086757081fc299435
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,19 @@ case class UnaryMinus(
val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
val method = if (failOnError) "negateExact" else "negate"
defineCodeGen(ctx, ev, c => s"$iu.$method($c)")
case DayTimeIntervalType | YearMonthIntervalType =>
nullSafeCodeGen(ctx, ev, eval => {
val mathClass = classOf[Math].getName
s"${ev.value} = $mathClass.negateExact($eval);"
})
}

protected override def nullSafeEval(input: Any): Any = dataType match {
case CalendarIntervalType if failOnError =>
IntervalUtils.negateExact(input.asInstanceOf[CalendarInterval])
case CalendarIntervalType => IntervalUtils.negate(input.asInstanceOf[CalendarInterval])
case DayTimeIntervalType => Math.negateExact(input.asInstanceOf[Long])
case YearMonthIntervalType => Math.negateExact(input.asInstanceOf[Int])
case _ => numeric.negate(input)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,11 @@ private[sql] object TypeCollection {
* Types that include numeric types and interval type. They are only used in unary_minus,
* unary_positive, add and subtract operations.
*/
val NumericAndInterval = TypeCollection(NumericType, CalendarIntervalType)
val NumericAndInterval = TypeCollection(
NumericType,
CalendarIntervalType,
DayTimeIntervalType,
YearMonthIntervalType)

def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import java.sql.{Date, Timestamp}
import java.time.Period

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -102,6 +103,9 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(UnaryMinus(negativeIntLit), - negativeInt)
checkEvaluation(UnaryMinus(positiveLongLit), - positiveLong)
checkEvaluation(UnaryMinus(negativeLongLit), - negativeLong)
checkExceptionInExpression[ArithmeticException](
UnaryMinus(Literal.create(Period.ofMonths(Int.MinValue), YearMonthIntervalType)),
"overflow")

Seq("true", "false").foreach { failOnError =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> failOnError) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ object DataTypeTestUtils {
/**
* Instances of all [[NumericType]]s and [[CalendarIntervalType]]
*/
val numericAndInterval: Set[DataType] = numericTypeWithoutDecimal + CalendarIntervalType
val numericAndInterval: Set[DataType] = numericTypeWithoutDecimal ++ Set(
CalendarIntervalType,
DayTimeIntervalType,
YearMonthIntervalType)

/**
* All the types that support ordering
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql

import java.sql.{Date, Timestamp}
import java.time.{Duration, Period}
import java.util.Locale

import org.apache.hadoop.io.{LongWritable, Text}
Expand Down Expand Up @@ -2375,4 +2376,12 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
assert(e2.getCause.isInstanceOf[RuntimeException])
assert(e2.getCause.getMessage == "hello")
}

test("negate year-month and day-time intervals") {
import testImplicits._
val df = Seq((Period.ofMonths(10), Duration.ofDays(10)))
.toDF("year-month", "day-time")
val negated = df.select(-$"year-month", -$"day-time")
checkAnswer(negated, Row(Period.ofMonths(-10), Duration.ofDays(-10)))
}
}