diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index fc09938a43a8..e1a9191cc5a8 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -52,7 +52,7 @@ private[kafka010] object KafkaWriter extends Logging { s"'$TOPIC_ATTRIBUTE_NAME' attribute is present. Use the " + s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a topic.") } else { - Literal(topic.get, StringType) + Literal.create(topic.get, StringType) } ).dataType match { case StringType => // good diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 2bcbb92f1a46..34d252886ffb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -40,9 +40,10 @@ import org.json4s.JsonAST._ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types._ +import org.apache.spark.util.Utils object Literal { val TrueLiteral: Literal = Literal(true, BooleanType) @@ -196,6 +197,47 @@ object Literal { case other => throw new RuntimeException(s"no default for type $dataType") } + + private[expressions] def validateLiteralValue(value: Any, dataType: DataType): Unit = { + def doValidate(v: Any, dataType: DataType): Boolean = dataType match { + case _ if v == null => true + case BooleanType => v.isInstanceOf[Boolean] + case ByteType => v.isInstanceOf[Byte] + case ShortType => v.isInstanceOf[Short] + case IntegerType | DateType => v.isInstanceOf[Int] + case LongType | TimestampType => v.isInstanceOf[Long] + case FloatType => v.isInstanceOf[Float] + case DoubleType => v.isInstanceOf[Double] + case _: DecimalType => v.isInstanceOf[Decimal] + case CalendarIntervalType => v.isInstanceOf[CalendarInterval] + case BinaryType => v.isInstanceOf[Array[Byte]] + case StringType => v.isInstanceOf[UTF8String] + case st: StructType => + v.isInstanceOf[InternalRow] && { + val row = v.asInstanceOf[InternalRow] + st.fields.map(_.dataType).zipWithIndex.forall { + case (dt, i) => doValidate(row.get(i, dt), dt) + } + } + case at: ArrayType => + v.isInstanceOf[ArrayData] && { + val ar = v.asInstanceOf[ArrayData] + ar.numElements() == 0 || doValidate(ar.get(0, at.elementType), at.elementType) + } + case mt: MapType => + v.isInstanceOf[MapData] && { + val map = v.asInstanceOf[MapData] + doValidate(map.keyArray(), ArrayType(mt.keyType)) && + doValidate(map.valueArray(), ArrayType(mt.valueType)) + } + case ObjectType(cls) => cls.isInstance(v) + case udt: UserDefinedType[_] => doValidate(v, udt.sqlType) + case _ => false + } + require(doValidate(value, dataType), + s"Literal must have a corresponding value to ${dataType.catalogString}, " + + s"but class ${Utils.getSimpleName(value.getClass)} found.") + } } /** @@ -240,6 +282,8 @@ object DecimalLiteral { */ case class Literal (value: Any, dataType: DataType) extends LeafExpression { + Literal.validateLiteralValue(value, dataType) + override def foldable: Boolean = true override def nullable: Boolean = value == null diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 0eba1c537d67..0b168d060ef6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -742,7 +742,7 @@ class TypeCoercionSuite extends AnalysisTest { val nullLit = Literal.create(null, NullType) val floatNullLit = Literal.create(null, FloatType) val floatLit = Literal.create(1.0f, FloatType) - val timestampLit = Literal.create("2017-04-12", TimestampType) + val timestampLit = Literal.create(Timestamp.valueOf("2017-04-12 00:00:00"), TimestampType) val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000")) val tsArrayLit = Literal(Array(new Timestamp(System.currentTimeMillis()))) val strArrayLit = Literal(Array("c")) @@ -793,11 +793,11 @@ class TypeCoercionSuite extends AnalysisTest { ruleTest(TypeCoercion.FunctionArgumentConversion, CreateArray(Literal(1.0) :: Literal(1) - :: Literal.create(1.0, FloatType) + :: Literal.create(1.0f, FloatType) :: Nil), CreateArray(Literal(1.0) :: Cast(Literal(1), DoubleType) - :: Cast(Literal.create(1.0, FloatType), DoubleType) + :: Cast(Literal.create(1.0f, FloatType), DoubleType) :: Nil)) ruleTest(TypeCoercion.FunctionArgumentConversion, @@ -834,23 +834,23 @@ class TypeCoercionSuite extends AnalysisTest { ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) :: Literal("a") - :: Literal.create(2.0, FloatType) + :: Literal.create(2.0f, FloatType) :: Literal("b") :: Nil), CreateMap(Cast(Literal(1), FloatType) :: Literal("a") - :: Literal.create(2.0, FloatType) + :: Literal.create(2.0f, FloatType) :: Literal("b") :: Nil)) ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal.create(null, DecimalType(5, 3)) :: Literal("a") - :: Literal.create(2.0, FloatType) + :: Literal.create(2.0f, FloatType) :: Literal("b") :: Nil), CreateMap(Literal.create(null, DecimalType(5, 3)).cast(DoubleType) :: Literal("a") - :: Literal.create(2.0, FloatType).cast(DoubleType) + :: Literal.create(2.0f, FloatType).cast(DoubleType) :: Literal("b") :: Nil)) // type coercion for map values @@ -895,11 +895,11 @@ class TypeCoercionSuite extends AnalysisTest { ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal(1.0) :: Literal(1) - :: Literal.create(1.0, FloatType) + :: Literal.create(1.0f, FloatType) :: Nil), operator(Literal(1.0) :: Cast(Literal(1), DoubleType) - :: Cast(Literal.create(1.0, FloatType), DoubleType) + :: Cast(Literal.create(1.0f, FloatType), DoubleType) :: Nil)) ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal(1L) @@ -966,7 +966,7 @@ class TypeCoercionSuite extends AnalysisTest { val falseLit = Literal.create(false, BooleanType) val stringLit = Literal.create("c", StringType) val floatLit = Literal.create(1.0f, FloatType) - val timestampLit = Literal.create("2017-04-12", TimestampType) + val timestampLit = Literal.create(Timestamp.valueOf("2017-04-12 00:00:00"), TimestampType) val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000")) ruleTest(rule, @@ -1016,14 +1016,16 @@ class TypeCoercionSuite extends AnalysisTest { CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) ) ruleTest(TypeCoercion.CaseWhenCoercion, - CaseWhen(Seq((Literal(true), Literal(1.2))), Literal.create(1, DecimalType(7, 2))), CaseWhen(Seq((Literal(true), Literal(1.2))), - Cast(Literal.create(1, DecimalType(7, 2)), DoubleType)) + Literal.create(BigDecimal.valueOf(1), DecimalType(7, 2))), + CaseWhen(Seq((Literal(true), Literal(1.2))), + Cast(Literal.create(BigDecimal.valueOf(1), DecimalType(7, 2)), DoubleType)) ) ruleTest(TypeCoercion.CaseWhenCoercion, - CaseWhen(Seq((Literal(true), Literal(100L))), Literal.create(1, DecimalType(7, 2))), + CaseWhen(Seq((Literal(true), Literal(100L))), + Literal.create(BigDecimal.valueOf(1), DecimalType(7, 2))), CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))), - Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2))) + Cast(Literal.create(BigDecimal.valueOf(1), DecimalType(7, 2)), DecimalType(22, 2))) ) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 81ab7d690396..b124bbe17170 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -623,7 +623,8 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("SPARK-21513: to_json support map[string, struct] to json") { val schema = MapType(StringType, StructType(StructField("a", IntegerType) :: Nil)) - val input = Literal.create(ArrayBasedMapData(Map("test" -> InternalRow(1))), schema) + val input = Literal( + ArrayBasedMapData(Map(UTF8String.fromString("test") -> InternalRow(1))), schema) checkEvaluation( StructsToJson(Map.empty, input), """{"test":{"a":1}}""" @@ -633,7 +634,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("SPARK-21513: to_json support map[struct, struct] to json") { val schema = MapType(StructType(StructField("a", IntegerType) :: Nil), StructType(StructField("b", IntegerType) :: Nil)) - val input = Literal.create(ArrayBasedMapData(Map(InternalRow(1) -> InternalRow(2))), schema) + val input = Literal(ArrayBasedMapData(Map(InternalRow(1) -> InternalRow(2))), schema) checkEvaluation( StructsToJson(Map.empty, input), """{"[1]":{"b":2}}""" @@ -642,7 +643,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("SPARK-21513: to_json support map[string, integer] to json") { val schema = MapType(StringType, IntegerType) - val input = Literal.create(ArrayBasedMapData(Map("a" -> 1)), schema) + val input = Literal(ArrayBasedMapData(Map(UTF8String.fromString("a") -> 1)), schema) checkEvaluation( StructsToJson(Map.empty, input), """{"a":1}""" @@ -651,17 +652,18 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("to_json - array with maps") { val inputSchema = ArrayType(MapType(StringType, IntegerType)) - val input = new GenericArrayData(ArrayBasedMapData( - Map("a" -> 1)) :: ArrayBasedMapData(Map("b" -> 2)) :: Nil) + val input = new GenericArrayData( + ArrayBasedMapData(Map(UTF8String.fromString("a") -> 1)) :: + ArrayBasedMapData(Map(UTF8String.fromString("b") -> 2)) :: Nil) val output = """[{"a":1},{"b":2}]""" checkEvaluation( - StructsToJson(Map.empty, Literal.create(input, inputSchema), gmtId), + StructsToJson(Map.empty, Literal(input, inputSchema), gmtId), output) } test("to_json - array with single map") { val inputSchema = ArrayType(MapType(StringType, IntegerType)) - val input = new GenericArrayData(ArrayBasedMapData(Map("a" -> 1)) :: Nil) + val input = new GenericArrayData(ArrayBasedMapData(Map(UTF8String.fromString("a") -> 1)) :: Nil) val output = """[{"a":1}]""" checkEvaluation( StructsToJson(Map.empty, Literal.create(input, inputSchema), gmtId), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index 6e07f7a59b73..8818d0135b29 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.Timestamp + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext @@ -107,8 +109,8 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val nullLit = Literal.create(null, NullType) val floatNullLit = Literal.create(null, FloatType) val floatLit = Literal.create(1.01f, FloatType) - val timestampLit = Literal.create("2017-04-12", TimestampType) - val decimalLit = Literal.create(10.2, DecimalType(20, 2)) + val timestampLit = Literal.create(Timestamp.valueOf("2017-04-12 00:00:00"), TimestampType) + val decimalLit = Literal.create(BigDecimal.valueOf(10.2), DecimalType(20, 2)) assert(analyze(new Nvl(decimalLit, stringLit)).dataType == StringType) assert(analyze(new Nvl(doubleLit, decimalLit)).dataType == DoubleType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala index cc2e2a993d62..f2696849d775 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SortOrderExpressionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import java.util.TimeZone import org.apache.spark.SparkFunSuite @@ -32,9 +32,9 @@ class SortOrderExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val b2 = Literal.create(true, BooleanType) val i1 = Literal.create(20132983, IntegerType) val i2 = Literal.create(-20132983, IntegerType) - val l1 = Literal.create(20132983, LongType) - val l2 = Literal.create(-20132983, LongType) - val millis = 1524954911000L; + val l1 = Literal.create(20132983L, LongType) + val l2 = Literal.create(-20132983L, LongType) + val millis = 1524954911000L // Explicitly choose a time zone, since Date objects can create different values depending on // local time zone of the machine on which the test is running val oldDefaultTZ = TimeZone.getDefault @@ -57,7 +57,7 @@ class SortOrderExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val dec1 = Literal(Decimal(20132983L, 10, 2)) val dec2 = Literal(Decimal(20132983L, 19, 2)) val dec3 = Literal(Decimal(20132983L, 21, 2)) - val list1 = Literal(List(1, 2), ArrayType(IntegerType)) + val list1 = Literal.create(Seq(1, 2), ArrayType(IntegerType)) val nullVal = Literal.create(null, IntegerType) checkEvaluation(SortPrefix(SortOrder(b1, Ascending)), 0L) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala index d46135c02bc0..d202c2f271d9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala @@ -105,9 +105,9 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva } test("parse sql expression for duration in microseconds - long") { - val dur = TimeWindow.invokePrivate(parseExpression(Literal.create(2 << 52, LongType))) + val dur = TimeWindow.invokePrivate(parseExpression(Literal.create(2L << 52, LongType))) assert(dur.isInstanceOf[Long]) - assert(dur === (2 << 52)) + assert(dur === (2L << 52)) } test("parse sql expression for duration in microseconds - invalid interval") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala index 2420ba513f28..294fce8e9a10 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala @@ -232,11 +232,14 @@ class PercentileSuite extends SparkFunSuite { BooleanType, StringType, DateType, TimestampType, CalendarIntervalType, NullType) invalidDataTypes.foreach { dataType => - val percentage = Literal(0.5, dataType) + val percentage = Literal.default(dataType) val percentile4 = new Percentile(child, percentage) - assertEqual(percentile4.checkInputDataTypes(), - TypeCheckFailure(s"argument 2 requires double type, however, " + - s"'0.5' is of ${dataType.simpleString} type.")) + val checkResult = percentile4.checkInputDataTypes() + assert(checkResult.isFailure) + Seq("argument 2 requires double type, however, ", + s"is of ${dataType.simpleString} type.").foreach { errMsg => + assert(checkResult.asInstanceOf[TypeCheckFailure].message.contains(errMsg)) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 93447a52097c..04a944242435 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -210,13 +210,13 @@ case class AnalyzeColumnCommand( def struct(exprs: Expression*): CreateNamedStruct = CreateStruct(exprs.map { expr => expr.transformUp { case af: AggregateFunction => af.toAggregateExpression() } }) - val one = Literal(1, LongType) + val one = Literal(1L, LongType) // the approximate ndv (num distinct value) should never be larger than the number of rows val numNonNulls = if (col.nullable) Count(col) else Count(one) val ndv = Least(Seq(HyperLogLogPlusPlus(col, conf.ndvMaxError), numNonNulls)) val numNulls = Subtract(Count(one), numNonNulls) - val defaultSize = Literal(col.dataType.defaultSize, LongType) + val defaultSize = Literal(col.dataType.defaultSize.toLong, LongType) val nullArray = Literal(null, ArrayType(LongType)) def fixedLenTypeStruct: CreateNamedStruct = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index b4ad1db20a9e..42dd0024b258 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -228,7 +228,7 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { test("join key rewritten") { val l = Literal(1L) val i = Literal(2) - val s = Literal.create(3, ShortType) + val s = Literal.create(3.toShort, ShortType) val ss = Literal("hello") assert(HashJoin.rewriteKeyExpr(l :: Nil) === l :: Nil)