From a5fc2c1028061a0286ab69f086757081fc299435 Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Tue, 9 Mar 2021 20:44:52 +0300 Subject: [PATCH 1/7] Negate intervals --- .../spark/sql/catalyst/expressions/arithmetic.scala | 7 +++++++ .../org/apache/spark/sql/types/AbstractDataType.scala | 6 +++++- .../catalyst/expressions/ArithmeticExpressionSuite.scala | 4 ++++ .../org/apache/spark/sql/types/DataTypeTestUtils.scala | 5 ++++- .../org/apache/spark/sql/ColumnExpressionSuite.scala | 9 +++++++++ 5 files changed, 29 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 2ee68e62abd5..2c111eb8dca8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -83,12 +83,19 @@ case class UnaryMinus( val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$") val method = if (failOnError) "negateExact" else "negate" defineCodeGen(ctx, ev, c => s"$iu.$method($c)") + case DayTimeIntervalType | YearMonthIntervalType => + nullSafeCodeGen(ctx, ev, eval => { + val mathClass = classOf[Math].getName + s"${ev.value} = $mathClass.negateExact($eval);" + }) } protected override def nullSafeEval(input: Any): Any = dataType match { case CalendarIntervalType if failOnError => IntervalUtils.negateExact(input.asInstanceOf[CalendarInterval]) case CalendarIntervalType => IntervalUtils.negate(input.asInstanceOf[CalendarInterval]) + case DayTimeIntervalType => Math.negateExact(input.asInstanceOf[Long]) + case YearMonthIntervalType => Math.negateExact(input.asInstanceOf[Int]) case _ => numeric.negate(input) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 21ac32adca6e..02c95b286a21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -82,7 +82,11 @@ private[sql] object TypeCollection { * Types that include numeric types and interval type. They are only used in unary_minus, * unary_positive, add and subtract operations. */ - val NumericAndInterval = TypeCollection(NumericType, CalendarIntervalType) + val NumericAndInterval = TypeCollection( + NumericType, + CalendarIntervalType, + DayTimeIntervalType, + YearMonthIntervalType) def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 14dd04afebe2..a91b5e3f889b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} +import java.time.Period import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow @@ -102,6 +103,9 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(UnaryMinus(negativeIntLit), - negativeInt) checkEvaluation(UnaryMinus(positiveLongLit), - positiveLong) checkEvaluation(UnaryMinus(negativeLongLit), - negativeLong) + checkExceptionInExpression[ArithmeticException]( + UnaryMinus(Literal.create(Period.ofMonths(Int.MinValue), YearMonthIntervalType)), + "overflow") Seq("true", "false").foreach { failOnError => withSQLConf(SQLConf.ANSI_ENABLED.key -> failOnError) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala index 07552a510b90..769de3352889 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala @@ -52,7 +52,10 @@ object DataTypeTestUtils { /** * Instances of all [[NumericType]]s and [[CalendarIntervalType]] */ - val numericAndInterval: Set[DataType] = numericTypeWithoutDecimal + CalendarIntervalType + val numericAndInterval: Set[DataType] = numericTypeWithoutDecimal ++ Set( + CalendarIntervalType, + DayTimeIntervalType, + YearMonthIntervalType) /** * All the types that support ordering diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 4f64de4ae875..4a74bfdac9c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} +import java.time.{Duration, Period} import java.util.Locale import org.apache.hadoop.io.{LongWritable, Text} @@ -2375,4 +2376,12 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { assert(e2.getCause.isInstanceOf[RuntimeException]) assert(e2.getCause.getMessage == "hello") } + + test("negate year-month and day-time intervals") { + import testImplicits._ + val df = Seq((Period.ofMonths(10), Duration.ofDays(10))) + .toDF("year-month", "day-time") + val negated = df.select(-$"year-month", -$"day-time") + checkAnswer(negated, Row(Period.ofMonths(-10), Duration.ofDays(-10))) + } } From 5bb4b7c1d9642d921617ddceae0f3eddd873d9b0 Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Tue, 9 Mar 2021 21:40:42 +0300 Subject: [PATCH 2/7] Support +/- --- .../sql/catalyst/expressions/arithmetic.scala | 18 ++++++++ .../ArithmeticExpressionSuite.scala | 42 +++++++++++++++++-- .../spark/sql/ColumnExpressionSuite.scala | 14 ++++--- 3 files changed, 65 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 2c111eb8dca8..942f397bc93e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -180,6 +180,11 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { def calendarIntervalMethod: String = sys.error("BinaryArithmetics must override either calendarIntervalMethod or genCode") + /** Name of the function for this expression on [[DayTimeIntervalType]] and + * [[YearMonthIntervalType]] types. */ + def intervalMethod: String = + sys.error("BinaryArithmetics must override either intervalMethod or genCode") + // Name of the function for the exact version of this expression in [[Math]]. // If the option "spark.sql.ansi.enabled" is enabled and there is corresponding // function in [[Math]], the exact function will be called instead of evaluation with [[symbol]]. @@ -192,6 +197,9 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { case CalendarIntervalType => val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$") defineCodeGen(ctx, ev, (eval1, eval2) => s"$iu.$calendarIntervalMethod($eval1, $eval2)") + case DayTimeIntervalType | YearMonthIntervalType => + val mathClass = classOf[Math].getName + defineCodeGen(ctx, ev, (eval1, eval2) => s"$mathClass.${intervalMethod}($eval1, $eval2)") // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => nullSafeCodeGen(ctx, ev, (eval1, eval2) => { @@ -264,6 +272,7 @@ case class Add( override def decimalMethod: String = "$plus" override def calendarIntervalMethod: String = if (failOnError) "addExact" else "add" + override def intervalMethod: String = "addExact" private lazy val numeric = TypeUtils.getNumeric(dataType, failOnError) @@ -274,6 +283,10 @@ case class Add( case CalendarIntervalType => IntervalUtils.add( input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval]) + case DayTimeIntervalType => + Math.addExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long]) + case YearMonthIntervalType => + Math.addExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int]) case _ => numeric.plus(input1, input2) } @@ -303,6 +316,7 @@ case class Subtract( override def decimalMethod: String = "$minus" override def calendarIntervalMethod: String = if (failOnError) "subtractExact" else "subtract" + override def intervalMethod: String = "subtractExact" private lazy val numeric = TypeUtils.getNumeric(dataType, failOnError) @@ -313,6 +327,10 @@ case class Subtract( case CalendarIntervalType => IntervalUtils.subtract( input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval]) + case DayTimeIntervalType => + Math.subtractExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long]) + case YearMonthIntervalType => + Math.subtractExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int]) case _ => numeric.minus(input1, input2) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index a91b5e3f889b..9341ab3267f2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} -import java.time.Period +import java.time.{Duration, Period} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow @@ -103,9 +103,6 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(UnaryMinus(negativeIntLit), - negativeInt) checkEvaluation(UnaryMinus(positiveLongLit), - positiveLong) checkEvaluation(UnaryMinus(negativeLongLit), - negativeLong) - checkExceptionInExpression[ArithmeticException]( - UnaryMinus(Literal.create(Period.ofMonths(Int.MinValue), YearMonthIntervalType)), - "overflow") Seq("true", "false").foreach { failOnError => withSQLConf(SQLConf.ANSI_ENABLED.key -> failOnError) { @@ -580,4 +577,41 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } } } + + test("SPARK-34677: exact add and subtract of day-time and year-month intervals") { + checkExceptionInExpression[ArithmeticException]( + UnaryMinus(Literal.create(Period.ofMonths(Int.MinValue), YearMonthIntervalType)), + "overflow") + Seq(true, false).foreach { failOnError => + checkExceptionInExpression[ArithmeticException]( + Subtract( + Literal.create(Period.ofMonths(Int.MinValue), YearMonthIntervalType), + Literal.create(Period.ofMonths(10), YearMonthIntervalType), + failOnError + ), + "overflow") + checkExceptionInExpression[ArithmeticException]( + Add( + Literal.create(Period.ofMonths(Int.MaxValue), YearMonthIntervalType), + Literal.create(Period.ofMonths(10), YearMonthIntervalType), + failOnError + ), + "overflow") + + checkExceptionInExpression[ArithmeticException]( + Subtract( + Literal.create(Duration.ofDays(-106751991), DayTimeIntervalType), + Literal.create(Duration.ofDays(10), DayTimeIntervalType), + failOnError + ), + "overflow") + checkExceptionInExpression[ArithmeticException]( + Add( + Literal.create(Duration.ofDays(106751991), DayTimeIntervalType), + Literal.create(Duration.ofDays(10), DayTimeIntervalType), + failOnError + ), + "overflow") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 4a74bfdac9c4..24c2b732e753 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -2377,11 +2377,15 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { assert(e2.getCause.getMessage == "hello") } - test("negate year-month and day-time intervals") { + test("SPARK-34677: negate/add/subtract year-month and day-time intervals") { import testImplicits._ - val df = Seq((Period.ofMonths(10), Duration.ofDays(10))) - .toDF("year-month", "day-time") - val negated = df.select(-$"year-month", -$"day-time") - checkAnswer(negated, Row(Period.ofMonths(-10), Duration.ofDays(-10))) + val df = Seq((Period.ofMonths(10), Duration.ofDays(10), Period.ofMonths(1), Duration.ofDays(1))) + .toDF("year-month-A", "day-time-A", "year-month-B", "day-time-B") + val negatedDF = df.select(-$"year-month-A", -$"day-time-A") + checkAnswer(negatedDF, Row(Period.ofMonths(-10), Duration.ofDays(-10))) + val sumDF = df.select($"year-month-A" + $"year-month-B", $"day-time-A" + $"day-time-B") + checkAnswer(sumDF, Row(Period.ofMonths(11), Duration.ofDays(11))) + val subDF = df.select($"year-month-A" - $"year-month-B", $"day-time-A" - $"day-time-B") + checkAnswer(subDF, Row(Period.ofMonths(9), Duration.ofDays(9))) } } From 18b1db30ed4681752827b3f72cda8083aa50cf48 Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Wed, 10 Mar 2021 00:12:55 +0300 Subject: [PATCH 3/7] Regen expected error message --- .../analysis/ExpressionTypeCheckingSuite.scala | 4 ++-- .../sql-tests/results/ansi/literals.sql.out | 18 +++++++++--------- .../sql-tests/results/literals.sql.out | 18 +++++++++--------- .../native/windowFrameCoercion.sql.out | 6 +++--- 4 files changed, 23 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 46634c93148b..ee560ea4bea8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -78,9 +78,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(BitwiseXor(Symbol("intField"), Symbol("booleanField"))) assertError(Add(Symbol("booleanField"), Symbol("booleanField")), - "requires (numeric or interval) type") + "requires (numeric or interval or daytimeinterval or yearmonthinterval) type") assertError(Subtract(Symbol("booleanField"), Symbol("booleanField")), - "requires (numeric or interval) type") + "requires (numeric or interval or daytimeinterval or yearmonthinterval) type") assertError(Multiply(Symbol("booleanField"), Symbol("booleanField")), "requires numeric type") assertError(Divide(Symbol("booleanField"), Symbol("booleanField")), "requires (double or decimal) type") diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/literals.sql.out index ea74bb7175e9..1c290a0f3d88 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/literals.sql.out @@ -436,7 +436,7 @@ select +date '1999-01-01' struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(+ DATE '1999-01-01')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'DATE '1999-01-01'' is of date type.; line 1 pos 7 +cannot resolve '(+ DATE '1999-01-01')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'DATE '1999-01-01'' is of date type.; line 1 pos 7 -- !query @@ -445,7 +445,7 @@ select +timestamp '1999-01-01' struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(+ TIMESTAMP '1999-01-01 00:00:00')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'TIMESTAMP '1999-01-01 00:00:00'' is of timestamp type.; line 1 pos 7 +cannot resolve '(+ TIMESTAMP '1999-01-01 00:00:00')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'TIMESTAMP '1999-01-01 00:00:00'' is of timestamp type.; line 1 pos 7 -- !query @@ -462,7 +462,7 @@ select +map(1, 2) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(+ map(1, 2))' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'map(1, 2)' is of map type.; line 1 pos 7 +cannot resolve '(+ map(1, 2))' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'map(1, 2)' is of map type.; line 1 pos 7 -- !query @@ -471,7 +471,7 @@ select +array(1,2) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(+ array(1, 2))' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'array(1, 2)' is of array type.; line 1 pos 7 +cannot resolve '(+ array(1, 2))' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'array(1, 2)' is of array type.; line 1 pos 7 -- !query @@ -480,7 +480,7 @@ select +named_struct('a', 1, 'b', 'spark') struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(+ named_struct('a', 1, 'b', 'spark'))' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'named_struct('a', 1, 'b', 'spark')' is of struct type.; line 1 pos 7 +cannot resolve '(+ named_struct('a', 1, 'b', 'spark'))' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'named_struct('a', 1, 'b', 'spark')' is of struct type.; line 1 pos 7 -- !query @@ -489,7 +489,7 @@ select +X'1' struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(+ X'01')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'X'01'' is of binary type.; line 1 pos 7 +cannot resolve '(+ X'01')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'X'01'' is of binary type.; line 1 pos 7 -- !query @@ -498,7 +498,7 @@ select -date '1999-01-01' struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(- DATE '1999-01-01')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'DATE '1999-01-01'' is of date type.; line 1 pos 7 +cannot resolve '(- DATE '1999-01-01')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'DATE '1999-01-01'' is of date type.; line 1 pos 7 -- !query @@ -507,7 +507,7 @@ select -timestamp '1999-01-01' struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(- TIMESTAMP '1999-01-01 00:00:00')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'TIMESTAMP '1999-01-01 00:00:00'' is of timestamp type.; line 1 pos 7 +cannot resolve '(- TIMESTAMP '1999-01-01 00:00:00')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'TIMESTAMP '1999-01-01 00:00:00'' is of timestamp type.; line 1 pos 7 -- !query @@ -516,4 +516,4 @@ select -x'2379ACFe' struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(- X'2379ACFE')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'X'2379ACFE'' is of binary type.; line 1 pos 7 +cannot resolve '(- X'2379ACFE')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'X'2379ACFE'' is of binary type.; line 1 pos 7 diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out index ea74bb7175e9..1c290a0f3d88 100644 --- a/sql/core/src/test/resources/sql-tests/results/literals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out @@ -436,7 +436,7 @@ select +date '1999-01-01' struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(+ DATE '1999-01-01')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'DATE '1999-01-01'' is of date type.; line 1 pos 7 +cannot resolve '(+ DATE '1999-01-01')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'DATE '1999-01-01'' is of date type.; line 1 pos 7 -- !query @@ -445,7 +445,7 @@ select +timestamp '1999-01-01' struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(+ TIMESTAMP '1999-01-01 00:00:00')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'TIMESTAMP '1999-01-01 00:00:00'' is of timestamp type.; line 1 pos 7 +cannot resolve '(+ TIMESTAMP '1999-01-01 00:00:00')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'TIMESTAMP '1999-01-01 00:00:00'' is of timestamp type.; line 1 pos 7 -- !query @@ -462,7 +462,7 @@ select +map(1, 2) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(+ map(1, 2))' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'map(1, 2)' is of map type.; line 1 pos 7 +cannot resolve '(+ map(1, 2))' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'map(1, 2)' is of map type.; line 1 pos 7 -- !query @@ -471,7 +471,7 @@ select +array(1,2) struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(+ array(1, 2))' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'array(1, 2)' is of array type.; line 1 pos 7 +cannot resolve '(+ array(1, 2))' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'array(1, 2)' is of array type.; line 1 pos 7 -- !query @@ -480,7 +480,7 @@ select +named_struct('a', 1, 'b', 'spark') struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(+ named_struct('a', 1, 'b', 'spark'))' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'named_struct('a', 1, 'b', 'spark')' is of struct type.; line 1 pos 7 +cannot resolve '(+ named_struct('a', 1, 'b', 'spark'))' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'named_struct('a', 1, 'b', 'spark')' is of struct type.; line 1 pos 7 -- !query @@ -489,7 +489,7 @@ select +X'1' struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(+ X'01')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'X'01'' is of binary type.; line 1 pos 7 +cannot resolve '(+ X'01')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'X'01'' is of binary type.; line 1 pos 7 -- !query @@ -498,7 +498,7 @@ select -date '1999-01-01' struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(- DATE '1999-01-01')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'DATE '1999-01-01'' is of date type.; line 1 pos 7 +cannot resolve '(- DATE '1999-01-01')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'DATE '1999-01-01'' is of date type.; line 1 pos 7 -- !query @@ -507,7 +507,7 @@ select -timestamp '1999-01-01' struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(- TIMESTAMP '1999-01-01 00:00:00')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'TIMESTAMP '1999-01-01 00:00:00'' is of timestamp type.; line 1 pos 7 +cannot resolve '(- TIMESTAMP '1999-01-01 00:00:00')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'TIMESTAMP '1999-01-01 00:00:00'' is of timestamp type.; line 1 pos 7 -- !query @@ -516,4 +516,4 @@ select -x'2379ACFe' struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(- X'2379ACFE')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'X'2379ACFE'' is of binary type.; line 1 pos 7 +cannot resolve '(- X'2379ACFE')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'X'2379ACFE'' is of binary type.; line 1 pos 7 diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/windowFrameCoercion.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/windowFrameCoercion.sql.out index 71ef82d48bdd..1520d807a150 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/windowFrameCoercion.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/windowFrameCoercion.sql.out @@ -168,7 +168,7 @@ SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as string) DESC RANGE BETWE struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS STRING) FOLLOWING' due to data type mismatch: The data type of the upper bound 'string' does not match the expected data type '(numeric or interval)'.; line 1 pos 21 +cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS STRING) FOLLOWING' due to data type mismatch: The data type of the upper bound 'string' does not match the expected data type '(numeric or interval or daytimeinterval or yearmonthinterval)'.; line 1 pos 21 -- !query @@ -177,7 +177,7 @@ SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast('1' as binary) DESC RANGE BET struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS BINARY) FOLLOWING' due to data type mismatch: The data type of the upper bound 'binary' does not match the expected data type '(numeric or interval)'.; line 1 pos 21 +cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS BINARY) FOLLOWING' due to data type mismatch: The data type of the upper bound 'binary' does not match the expected data type '(numeric or interval or daytimeinterval or yearmonthinterval)'.; line 1 pos 21 -- !query @@ -186,7 +186,7 @@ SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as boolean) DESC RANGE BETW struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS BOOLEAN) FOLLOWING' due to data type mismatch: The data type of the upper bound 'boolean' does not match the expected data type '(numeric or interval)'.; line 1 pos 21 +cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS BOOLEAN) FOLLOWING' due to data type mismatch: The data type of the upper bound 'boolean' does not match the expected data type '(numeric or interval or daytimeinterval or yearmonthinterval)'.; line 1 pos 21 -- !query From 0a202f3490bbb84e19dc041d0c73ef248304498a Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Wed, 10 Mar 2021 10:05:04 +0300 Subject: [PATCH 4/7] Reuse exactMathMethod() --- .../spark/sql/catalyst/expressions/arithmetic.scala | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 942f397bc93e..ec0d45e9221c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -180,11 +180,6 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { def calendarIntervalMethod: String = sys.error("BinaryArithmetics must override either calendarIntervalMethod or genCode") - /** Name of the function for this expression on [[DayTimeIntervalType]] and - * [[YearMonthIntervalType]] types. */ - def intervalMethod: String = - sys.error("BinaryArithmetics must override either intervalMethod or genCode") - // Name of the function for the exact version of this expression in [[Math]]. // If the option "spark.sql.ansi.enabled" is enabled and there is corresponding // function in [[Math]], the exact function will be called instead of evaluation with [[symbol]]. @@ -198,8 +193,11 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$") defineCodeGen(ctx, ev, (eval1, eval2) => s"$iu.$calendarIntervalMethod($eval1, $eval2)") case DayTimeIntervalType | YearMonthIntervalType => + assert(exactMathMethod.isDefined, + s"The expression '$nodeName' must override the exactMathMethod() method " + + s"if it is supposed to operate over interval types.") val mathClass = classOf[Math].getName - defineCodeGen(ctx, ev, (eval1, eval2) => s"$mathClass.${intervalMethod}($eval1, $eval2)") + defineCodeGen(ctx, ev, (eval1, eval2) => s"$mathClass.${exactMathMethod.get}($eval1, $eval2)") // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => nullSafeCodeGen(ctx, ev, (eval1, eval2) => { @@ -272,7 +270,6 @@ case class Add( override def decimalMethod: String = "$plus" override def calendarIntervalMethod: String = if (failOnError) "addExact" else "add" - override def intervalMethod: String = "addExact" private lazy val numeric = TypeUtils.getNumeric(dataType, failOnError) @@ -316,7 +313,6 @@ case class Subtract( override def decimalMethod: String = "$minus" override def calendarIntervalMethod: String = if (failOnError) "subtractExact" else "subtract" - override def intervalMethod: String = "subtractExact" private lazy val numeric = TypeUtils.getNumeric(dataType, failOnError) From fc4a76db4d3f4f2db20a2111ee31e9f0e3b50427 Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Wed, 10 Mar 2021 10:15:14 +0300 Subject: [PATCH 5/7] Remove 's' --- .../org/apache/spark/sql/catalyst/expressions/arithmetic.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index ec0d45e9221c..59831dae21e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -195,7 +195,7 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { case DayTimeIntervalType | YearMonthIntervalType => assert(exactMathMethod.isDefined, s"The expression '$nodeName' must override the exactMathMethod() method " + - s"if it is supposed to operate over interval types.") + "if it is supposed to operate over interval types.") val mathClass = classOf[Math].getName defineCodeGen(ctx, ev, (eval1, eval2) => s"$mathClass.${exactMathMethod.get}($eval1, $eval2)") // byte and short are casted into int when add, minus, times or divide From 9a086f332f2582754538768831a8fec374aae426 Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Thu, 11 Mar 2021 10:04:26 +0300 Subject: [PATCH 6/7] Move UnaryMinus inside of foreach --- .../catalyst/expressions/ArithmeticExpressionSuite.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 9341ab3267f2..ca97418e0d87 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -579,10 +579,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } test("SPARK-34677: exact add and subtract of day-time and year-month intervals") { - checkExceptionInExpression[ArithmeticException]( - UnaryMinus(Literal.create(Period.ofMonths(Int.MinValue), YearMonthIntervalType)), - "overflow") Seq(true, false).foreach { failOnError => + checkExceptionInExpression[ArithmeticException]( + UnaryMinus( + Literal.create(Period.ofMonths(Int.MinValue), YearMonthIntervalType), + failOnError), + "overflow") checkExceptionInExpression[ArithmeticException]( Subtract( Literal.create(Period.ofMonths(Int.MinValue), YearMonthIntervalType), From 3994b6976fdf5e97bf47a88d8d176392319e946e Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Thu, 11 Mar 2021 10:05:47 +0300 Subject: [PATCH 7/7] sumDF -> addDF --- .../scala/org/apache/spark/sql/ColumnExpressionSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 24c2b732e753..fac510502c0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -2383,8 +2383,8 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { .toDF("year-month-A", "day-time-A", "year-month-B", "day-time-B") val negatedDF = df.select(-$"year-month-A", -$"day-time-A") checkAnswer(negatedDF, Row(Period.ofMonths(-10), Duration.ofDays(-10))) - val sumDF = df.select($"year-month-A" + $"year-month-B", $"day-time-A" + $"day-time-B") - checkAnswer(sumDF, Row(Period.ofMonths(11), Duration.ofDays(11))) + val addDF = df.select($"year-month-A" + $"year-month-B", $"day-time-A" + $"day-time-B") + checkAnswer(addDF, Row(Period.ofMonths(11), Duration.ofDays(11))) val subDF = df.select($"year-month-A" - $"year-month-B", $"day-time-A" - $"day-time-B") checkAnswer(subDF, Row(Period.ofMonths(9), Duration.ofDays(9))) }