Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,11 @@ object JsonInferSchema {
case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) =>
DoubleType

// This branch is only used by `SchemaOfVariant.mergeSchema` because `JsonInferSchema` never
// produces `FloatType`.
case (FloatType, _: DecimalType) | (_: DecimalType, FloatType) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a standard for this? These two are not compatible, because float is approximate but decimal is not.

Anyway, I'm fine with this as we already did it for double type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not aware of an existing standard, but I think we can follow the type resolution rules in the add operator:
image

Based on this result, I think it may be better to change the result of float x decimal into double.

Changing decimal into float/double may indeed lose precision, but I think it is a reasonable approach in the shcema inference.

DoubleType

case (t1: DecimalType, t2: DecimalType) =>
val scale = math.max(t1.scale, t2.scale)
val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.variant

import java.time.{LocalDateTime, ZoneId, ZoneOffset}

import scala.collection.mutable
import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.{SparkFunSuite, SparkRuntimeException}
Expand Down Expand Up @@ -860,4 +861,50 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
StructType.fromDDL("c ARRAY<STRING>,b MAP<STRING, STRING>,a STRUCT<i: INT>"))
check(struct, """{"a":{"i":0},"b":{"a":"123","b":"true","c":"f"},"c":["123","true","f"]}""")
}

test("schema_of_variant - schema merge") {
val nul = Literal(null, StringType)
val boolean = Literal.default(BooleanType)
val long = Literal.default(LongType)
val string = Literal.default(StringType)
val double = Literal.default(DoubleType)
val date = Literal.default(DateType)
val timestamp = Literal.default(TimestampType)
val timestampNtz = Literal.default(TimestampNTZType)
val float = Literal.default(FloatType)
val binary = Literal.default(BinaryType)
val decimal = Literal(Decimal("123.456"), DecimalType(6, 3))
val array1 = Literal(Array(0L))
val array2 = Literal(Array(0.0))
val struct1 = Literal.default(StructType.fromDDL("a string"))
val struct2 = Literal.default(StructType.fromDDL("a boolean, b bigint"))
val inputs = Seq(nul, boolean, long, string, double, date, timestamp, timestampNtz, float,
binary, decimal, array1, array2, struct1, struct2)

val results = mutable.HashMap.empty[(Literal, Literal), String]
for (i <- inputs) {
val inputType = if (i.value == null) "VOID" else i.dataType.sql
results.put((nul, i), inputType)
results.put((i, i), inputType)
}
results.put((long, double), "DOUBLE")
results.put((long, float), "FLOAT")
results.put((long, decimal), "DECIMAL(23,3)")
results.put((double, float), "DOUBLE")
results.put((double, decimal), "DOUBLE")
results.put((date, timestamp), "TIMESTAMP")
results.put((date, timestampNtz), "TIMESTAMP_NTZ")
results.put((timestamp, timestampNtz), "TIMESTAMP")
results.put((float, decimal), "DOUBLE")
results.put((array1, array2), "ARRAY<DOUBLE>")
results.put((struct1, struct2), "STRUCT<a: VARIANT, b: BIGINT>")

for (i1 <- inputs) {
for (i2 <- inputs) {
val expected = results.getOrElse((i1, i2), results.getOrElse((i2, i1), "VARIANT"))
val array = CreateArray(Seq(Cast(i1, VariantType), Cast(i2, VariantType)))
checkEvaluation(SchemaOfVariant(Cast(array, VariantType)).replacement, s"ARRAY<$expected>")
}
}
}
}