diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index 7ee522226e3e..d982e1f19da0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -372,6 +372,11 @@ object JsonInferSchema { case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) => DoubleType + // This branch is only used by `SchemaOfVariant.mergeSchema` because `JsonInferSchema` never + // produces `FloatType`. + case (FloatType, _: DecimalType) | (_: DecimalType, FloatType) => + DoubleType + case (t1: DecimalType, t2: DecimalType) => val scale = math.max(t1.scale, t2.scale) val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala index 73abf8074e8c..a758fa84f6fc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.variant import java.time.{LocalDateTime, ZoneId, ZoneOffset} +import scala.collection.mutable import scala.reflect.runtime.universe.TypeTag import org.apache.spark.{SparkFunSuite, SparkRuntimeException} @@ -860,4 +861,50 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { StructType.fromDDL("c ARRAY,b MAP,a STRUCT")) check(struct, """{"a":{"i":0},"b":{"a":"123","b":"true","c":"f"},"c":["123","true","f"]}""") } + + test("schema_of_variant - schema merge") { + val nul = Literal(null, StringType) + val boolean = Literal.default(BooleanType) + val long = Literal.default(LongType) + val string = Literal.default(StringType) + val double = Literal.default(DoubleType) + val date = Literal.default(DateType) + val timestamp = Literal.default(TimestampType) + val timestampNtz = Literal.default(TimestampNTZType) + val float = Literal.default(FloatType) + val binary = Literal.default(BinaryType) + val decimal = Literal(Decimal("123.456"), DecimalType(6, 3)) + val array1 = Literal(Array(0L)) + val array2 = Literal(Array(0.0)) + val struct1 = Literal.default(StructType.fromDDL("a string")) + val struct2 = Literal.default(StructType.fromDDL("a boolean, b bigint")) + val inputs = Seq(nul, boolean, long, string, double, date, timestamp, timestampNtz, float, + binary, decimal, array1, array2, struct1, struct2) + + val results = mutable.HashMap.empty[(Literal, Literal), String] + for (i <- inputs) { + val inputType = if (i.value == null) "VOID" else i.dataType.sql + results.put((nul, i), inputType) + results.put((i, i), inputType) + } + results.put((long, double), "DOUBLE") + results.put((long, float), "FLOAT") + results.put((long, decimal), "DECIMAL(23,3)") + results.put((double, float), "DOUBLE") + results.put((double, decimal), "DOUBLE") + results.put((date, timestamp), "TIMESTAMP") + results.put((date, timestampNtz), "TIMESTAMP_NTZ") + results.put((timestamp, timestampNtz), "TIMESTAMP") + results.put((float, decimal), "DOUBLE") + results.put((array1, array2), "ARRAY") + results.put((struct1, struct2), "STRUCT") + + for (i1 <- inputs) { + for (i2 <- inputs) { + val expected = results.getOrElse((i1, i2), results.getOrElse((i2, i1), "VARIANT")) + val array = CreateArray(Seq(Cast(i1, VariantType), Cast(i2, VariantType))) + checkEvaluation(SchemaOfVariant(Cast(array, VariantType)).replacement, s"ARRAY<$expected>") + } + } + } }