diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index b832cd4416a9..65fec22219eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -80,8 +80,12 @@ object CollationTypeCasts extends TypeCoercionRule { case otherExpr @ ( _: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: Greatest | _: Least | - _: Coalesce | _: BinaryExpression | _: ConcatWs | _: Mask | _: StringReplace | - _: StringTranslate | _: StringTrim | _: StringTrimLeft | _: StringTrimRight) => + _: Coalesce | _: ArrayContains | _: ArrayExcept | _: ConcatWs | _: Mask | _: StringReplace | + _: StringTranslate | _: StringTrim | _: StringTrimLeft | _: StringTrimRight | + _: ArrayIntersect | _: ArrayPosition | _: ArrayRemove | _: ArrayUnion | _: ArraysOverlap | + _: Contains | _: EndsWith | _: EqualNullSafe | _: EqualTo | _: FindInSet | _: GreaterThan | + _: GreaterThanOrEqual | _: LessThan | _: LessThanOrEqual | _: StartsWith | _: StringInstr | + _: ToNumber | _: TryToNumber) => val newChildren = collateToSingleType(otherExpr.children) otherExpr.withNewChildren(newChildren) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index fd2e302deb99..673f9397bb03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -212,7 +212,7 @@ object TimeWindow { * that we can use `window` in SQL. */ def parseExpression(expr: Expression): Long = expr match { - case NonNullLiteral(s, StringType) => getIntervalInMicroSeconds(s.toString) + case NonNullLiteral(s, _: StringType) => getIntervalInMicroSeconds(s.toString) case IntegerLiteral(i) => i.toLong case NonNullLiteral(l, LongType) => l.toString.toLong case _ => throw QueryCompilationErrors.invalidLiteralForWindowDurationError() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 167c02c0bafc..1bfa11d67af6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -196,7 +196,7 @@ case class CreateMap(children: Seq[Expression], useStringTypeWhenEmpty: Boolean) private val defaultElementType: DataType = { if (useStringTypeWhenEmpty) { - StringType + SQLConf.get.defaultStringType } else { NullType } @@ -354,7 +354,7 @@ case class MapFromArrays(left: Expression, right: Expression) case object NamePlaceholder extends LeafExpression with Unevaluable { override lazy val resolved: Boolean = false override def nullable: Boolean = false - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def prettyName: String = "NamePlaceholder" override def toString: String = prettyName } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeExpression.scala index 026272a0f2d8..fd942ba60de4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeExpression.scala @@ -30,7 +30,18 @@ private[sql] abstract class DataTypeExpression(val dataType: DataType) { } private[sql] case object BooleanTypeExpression extends DataTypeExpression(BooleanType) -private[sql] case object StringTypeExpression extends DataTypeExpression(StringType) +private[sql] case object StringTypeExpression { + /** + * Enables matching against StringType for expressions: + * {{{ + * case Cast(child @ StringType(collationId), NumericType) => + * ... + * }}} + */ + def unapply(e: Expression): Boolean = { + e.dataType.isInstanceOf[StringType] + } +} private[sql] case object TimestampTypeExpression extends DataTypeExpression(TimestampType) private[sql] case object DateTypeExpression extends DataTypeExpression(DateType) private[sql] case object ByteTypeExpression extends DataTypeExpression(ByteType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index d34fd554e7bd..0bf3cc5689e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -1871,6 +1871,84 @@ class CollationSQLExpressionsSuite }) } + test("DateAdd expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = s"""select date_add(collate('2016-07-30', '${collationName}'), 1)""" + // Result & data type check + val testQuery = sql(query) + val dataType = DateType + val expectedResult = "2016-07-31" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(Date.valueOf(expectedResult))) + }) + } + + test("DateSub expression with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + val query = s"""select date_sub(collate('2016-07-30', '${collationName}'), 1)""" + // Result & data type check + val testQuery = sql(query) + val dataType = DateType + val expectedResult = "2016-07-29" + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, Row(Date.valueOf(expectedResult))) + }) + } + + test("WindowTime and TimeWindow expressions with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collationName) { + val query = + s"""SELECT window_time(window) + | FROM (SELECT a, window, count(*) as cnt FROM VALUES + |('A1', '2021-01-01 00:00:00'), + |('A1', '2021-01-01 00:04:30'), + |('A1', '2021-01-01 00:06:00'), + |('A2', '2021-01-01 00:01:00') AS tab(a, b) + |GROUP by a, window(b, '5 minutes') ORDER BY a, window.start); + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = TimestampType + val expectedResults = + Seq("2021-01-01 00:04:59.999999", + "2021-01-01 00:09:59.999999", + "2021-01-01 00:04:59.999999") + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, expectedResults.map(ts => Row(Timestamp.valueOf(ts)))) + } + }) + } + + test("SessionWindow expressions with collation") { + // Supported collations + testSuppCollations.foreach(collationName => { + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collationName) { + val query = + s"""SELECT count(*) as cnt + | FROM VALUES + |('A1', '2021-01-01 00:00:00'), + |('A1', '2021-01-01 00:04:30'), + |('A1', '2021-01-01 00:10:00'), + |('A2', '2021-01-01 00:01:00'), + |('A2', '2021-01-01 00:04:30') AS tab(a, b) + |GROUP BY a, + |session_window(b, CASE WHEN a = 'A1' THEN '5 minutes' ELSE '1 minutes' END) + |ORDER BY a, session_window.start; + |""".stripMargin + // Result & data type check + val testQuery = sql(query) + val dataType = LongType + val expectedResults = Seq(2, 1, 1, 1) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + checkAnswer(testQuery, expectedResults.map(Row(_))) + } + }) + } + test("ConvertTimezone expression with collation") { // Supported collations testSuppCollations.foreach(collationName => {