Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
40ffdef
[SPARK-50250][SQL] Assign appropriate error condition for `_LEGACY_ER…
itholic Nov 13, 2024
ede05fa
[SPARK-50248][SQL] Assign appropriate error condition for `_LEGACY_ER…
itholic Nov 13, 2024
6fb1d43
[SPARK-50246][SQL] Assign appropriate error condition for `_LEGACY_ER…
itholic Nov 13, 2024
898bff2
[SPARK-50245][SQL][TESTS] Extended CollationSuite and added tests whe…
vladanvasi-db Nov 13, 2024
bd94419
[SPARK-50226][SQL] Correct MakeDTInterval and MakeYMInterval to catch…
gotocoding-DB Nov 13, 2024
bc9b259
[SPARK-50066][SQL] Codegen Support for `SchemaOfXml` (by Invoke & Run…
panbingkun Nov 13, 2024
558fc89
[SPARK-49611][SQL][FOLLOW-UP] Make collations TVF consistent and retu…
mihailomilosevic2001 Nov 13, 2024
7b1b450
Revert [SPARK-50215][SQL] Refactored StringType pattern matching in j…
vladanvasi-db Nov 13, 2024
87ad4b4
[SPARK-50139][INFRA][SS][PYTHON] Introduce scripts to re-generate and…
LuciferYang Nov 13, 2024
05508cf
[SPARK-42838][SQL] Assign a name to the error class _LEGACY_ERROR_TEM…
mihailomilosevic2001 Nov 13, 2024
5cc60f4
[SPARK-50300][BUILD] Use mirror host instead of `archive.apache.org`
dongjoon-hyun Nov 13, 2024
33378a6
[SPARK-50304][INFRA] Remove `(any|empty).proto` from RAT exclusion
dongjoon-hyun Nov 14, 2024
891f694
[SPARK-50306][PYTHON][CONNECT] Support Python 3.13 in Spark Connect
HyukjinKwon Nov 14, 2024
2fd4702
[SPARK-49913][SQL] Add check for unique label names in nested labeled…
miland-db Nov 14, 2024
6bee268
[SPARK-50299][BUILD] Upgrade jupiter-interface to 0.13.1 and Junit5 t…
LuciferYang Nov 14, 2024
09d6b32
[SPARK-48755][DOCS][PYTHON][FOLLOWUP] Add PySpark doc for `transformW…
itholic Nov 14, 2024
0b1b676
[SPARK-50092][SQL] Fix PostgreSQL connector behaviour for multidimens…
PetarVasiljevic-DB Nov 14, 2024
aea9e87
[SPARK-50291][PYTHON] Standardize verifySchema parameter of createDat…
xinrong-meng Nov 14, 2024
c1968a1
[SPARK-50216][SQL][TESTS] Update `CollationBenchmark` to invoke `coll…
stevomitric Nov 14, 2024
0aee601
[SPARK-50153][SQL] Add `name` to `RuleExecutor` to make printing `Que…
panbingkun Nov 14, 2024
c2343f7
[SPARK-45265][SQL] Support Hive 4.0 metastore
yaooqinn Nov 14, 2024
e0a83f6
[SPARK-50317][BUILD] Upgrade ORC to 2.0.3
dongjoon-hyun Nov 14, 2024
c90efae
[SPARK-50318][SQL] Add IntervalUtils.makeYearMonthInterval to dedupli…
gotocoding-DB Nov 15, 2024
3237885
[SPARK-50312][SQL] SparkThriftServer createServer parameter passing e…
CuiYanxiang Nov 15, 2024
e615e3f
[SPARK-50049][SQL] Support custom driver metrics in writing to v2 table
cloud-fan Nov 15, 2024
3f5e846
[SPARK-50237][SQL] Assign appropriate error condition for `_LEGACY_ER…
itholic Nov 15, 2024
cf90271
[MINOR] Fix code style for if/for/while statements
exmy Nov 15, 2024
cc81ed0
[SPARK-50325][SQL] Factor out alias resolution to be reused in the si…
vladimirg-db Nov 15, 2024
d317002
[SPARK-50322][SQL] Fix parameterized identifier in a sub-query
MaxGekk Nov 15, 2024
77e006f
[SPARK-50327][SQL] Factor out function resolution to be reused in the…
vladimirg-db Nov 15, 2024
11e4706
[SPARK-50320][CORE] Make `--remote` an official option by removing `e…
dongjoon-hyun Nov 15, 2024
007c31d
[SPARK-50236][SQL] Assign appropriate error condition for `_LEGACY_ER…
itholic Nov 15, 2024
281a8e1
[SPARK-50309][DOCS] Document `SQL Pipe` Syntax
dtenedor Nov 15, 2024
b626528
[SPARK-50313][SQL][TESTS] Enable ANSI in SQL *SQLQueryTestSuite by de…
yaooqinn Nov 18, 2024
a01856d
[SPARK-50330][SQL] Add hints to Sort and Window nodes
agubichev Nov 18, 2024
8b2d032
[SPARK-45265][SQL][BUILD][FOLLOWUP] Add `-Xss64m` for Maven testing o…
LuciferYang Nov 18, 2024
05750de
[MINOR][PYTHON][DOCS] Fix the type hint of `histogram_numeric`
zhengruifeng Nov 18, 2024
400a8d3
Revert "[SPARK-49787][SQL] Cast between UDT and other types"
cloud-fan Nov 18, 2024
fa36e8b
[SPARK-50335][PYTHON][DOCS] Refine docstrings for window/aggregation …
zhengruifeng Nov 19, 2024
b61411d
[SPARK-50328][INFRA] Add a separate docker file for SparkR
zhengruifeng Nov 19, 2024
e1477a3
[SPARK-50298][PYTHON][CONNECT] Implement verifySchema parameter of cr…
xinrong-meng Nov 19, 2024
6d47981
[SPARK-50331][INFRA] Add a daily test for PySpark on MacOS
LuciferYang Nov 19, 2024
5a57efd
[SPARK-50313][SQL][TESTS][FOLLOWUP] Restore some tests in *SQLQueryTe…
yaooqinn Nov 19, 2024
b74aa8c
[SPARK-50340][SQL] Unwrap UDT in INSERT input query
cloud-fan Nov 19, 2024
87a5b37
[SPARK-50313][SQL][TESTS][FOLLOWUP] Regenerate golden files for Java 21
LuciferYang Nov 19, 2024
f1b68d8
[SPARK-50315][SQL] Support custom metrics for V1Fallback writes
olaky Nov 19, 2024
19509d0
Revert "[SPARK-49002][SQL] Consistently handle invalid locations in W…
cloud-fan Nov 19, 2024
37497e6
[SPARK-50335][PYTHON][DOCS][FOLLOW-UP] Make percentile doctests more …
zhengruifeng Nov 20, 2024
c149dcb
[SPARK-50352][PYTHON][DOCS] Refine docstrings for window/aggregation …
zhengruifeng Nov 20, 2024
8791767
[SPARK-48344][SQL] Prepare SQL Scripting for addition of Execution Fr…
miland-db Nov 20, 2024
b7cf448
[SPARK-49550][FOLLOWUP][SQL][DOC] Switch Hadoop to 3.4.1 in IsolatedC…
pan3793 Nov 20, 2024
2185f3c
[SPARK-50359][PYTHON] Upgrade PyArrow to 18.0
zhengruifeng Nov 20, 2024
0157778
[SPARK-50358][SQL][TESTS] Update postgres docker image to 17.1
panbingkun Nov 20, 2024
b582dac
[MINOR][DOCS] Fix a HTML/Markdown syntax error in sql-migration-guide.md
yaooqinn Nov 20, 2024
19b8250
[SPARK-50331][INFRA][FOLLOW-UP] Skip Torch/DeepSpeed tests in MacOS P…
zhengruifeng Nov 20, 2024
7a4f3c4
[SPARK-50345][BUILD] Upgrade Kafka to 3.9.0
panbingkun Nov 20, 2024
3151d97
[SPARK-49801][INFRA][FOLLOWUP] Sync pandas version in release environ…
yaooqinn Nov 20, 2024
23f276f
[SPARK-50353][SQL] Refactor ResolveSQLOnFile
mihailoale-db Nov 20, 2024
533b8ca
[SPARK-50363][PYTHON][DOCS] Refine the docstring for datetime functio…
zhengruifeng Nov 20, 2024
81a56df
[SPARK-50362][PYTHON][ML] Skip `CrossValidatorTests` if `torch/torche…
LuciferYang Nov 20, 2024
6ee53da
[SPARK-50258][SQL] Fix output column order changed issue after AQE op…
wangyum Nov 20, 2024
30d0b01
[SPARK-50364][SQL] Implement serialization for LocalDateTime type in …
krm95 Nov 20, 2024
ad46db4
[SPARK-50130][SQL][FOLLOWUP] Make Encoder generation lazy
ueshin Nov 20, 2024
a409199
[SPARK-50376][PYTHON][ML][TESTS] Centralize the dependency check in M…
zhengruifeng Nov 21, 2024
3bc374d
[SPARK-50333][SQL] Codegen Support for `CsvToStructs` (by Invoke & Ru…
panbingkun Nov 21, 2024
95faa02
[SPARK-49490][SQL] Add benchmarks for initCap
mrk-andreev Nov 21, 2024
ee21e6b
[SPARK-50113][CONNECT][PYTHON][TESTS] Add `@remote_only` to check the…
itholic Nov 21, 2024
0f1e410
[SPARK-50016][SQL] Assign appropriate error condition for `_LEGACY_ER…
itholic Nov 21, 2024
b05ef45
[SPARK-50175][SQL] Change collation precedence calculation
stefankandic Nov 21, 2024
fbf255e
[SPARK-50379][SQL] Fix DayTimeIntevalType handling in WindowExecBase
mihailomilosevic2001 Nov 21, 2024
cbb16b9
[MINOR][DOCS] Fix miss semicolon on create table example sql
camilesing Nov 21, 2024
f2de888
[MINOR][DOCS] Remove wrong and ambiguous default statement in datetim…
yaooqinn Nov 21, 2024
229b1b8
[SPARK-50375][BUILD] Upgrade `commons-io` to 2.18.0
panbingkun Nov 21, 2024
136c722
[SPARK-50334][SQL] Extract common logic for reading the descriptor of…
panbingkun Nov 21, 2024
2e1c3dc
[SPARK-50087] Robust handling of boolean expressions in CASE WHEN for…
cloud-fan Nov 21, 2024
2d09ef2
[SPARK-50381][CORE] Support `spark.master.rest.maxThreads`
dongjoon-hyun Nov 21, 2024
69324bd
Merge branch 'master' into pr48820
ueshin Nov 21, 2024
349df78
Fix.
ueshin Nov 21, 2024
1079339
Fix.
ueshin Nov 21, 2024
c6b0651
Fix.
ueshin Nov 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Revert "[SPARK-49787][SQL] Cast between UDT and other types"
This reverts commit b6681fb.
  • Loading branch information
cloud-fan committed Nov 18, 2024
commit 400a8d3797bdcc9183576e66e84163e4dc00a662
16 changes: 5 additions & 11 deletions python/pyspark/sql/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pyspark.sql import Row
from pyspark.sql import functions as F
from pyspark.errors import (
AnalysisException,
ParseException,
PySparkTypeError,
PySparkValueError,
Expand Down Expand Up @@ -1129,17 +1130,10 @@ def test_cast_to_string_with_udt(self):
def test_cast_to_udt_with_udt(self):
row = Row(point=ExamplePoint(1.0, 2.0), python_only_point=PythonOnlyPoint(1.0, 2.0))
df = self.spark.createDataFrame([row])
result = df.select(F.col("point").cast(PythonOnlyUDT())).collect()
self.assertEqual(
result,
[Row(point=PythonOnlyPoint(1.0, 2.0))],
)

result = df.select(F.col("python_only_point").cast(ExamplePointUDT())).collect()
self.assertEqual(
result,
[Row(python_only_point=ExamplePoint(1.0, 2.0))],
)
with self.assertRaises(AnalysisException):
df.select(F.col("point").cast(PythonOnlyUDT())).collect()
with self.assertRaises(AnalysisException):
df.select(F.col("python_only_point").cast(ExamplePointUDT())).collect()

def test_struct_type(self):
struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,6 @@ private[sql] object UpCastRule {

case (from: UserDefinedType[_], to: UserDefinedType[_]) if to.acceptsType(from) => true

case (udt: UserDefinedType[_], toType) => canUpCast(udt.sqlType, toType)

case (fromType, udt: UserDefinedType[_]) => canUpCast(fromType, udt.sqlType)

case _ => false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,6 @@ object Cast extends QueryErrorsBase {

case (udt1: UserDefinedType[_], udt2: UserDefinedType[_]) if udt2.acceptsType(udt1) => true

case (udt: UserDefinedType[_], toType) => canAnsiCast(udt.sqlType, toType)

case (fromType, udt: UserDefinedType[_]) => canAnsiCast(fromType, udt.sqlType)

case _ => false
}

Expand Down Expand Up @@ -271,10 +267,6 @@ object Cast extends QueryErrorsBase {

case (udt1: UserDefinedType[_], udt2: UserDefinedType[_]) if udt2.acceptsType(udt1) => true

case (udt: UserDefinedType[_], toType) => canCast(udt.sqlType, toType)

case (fromType, udt: UserDefinedType[_]) => canCast(fromType, udt.sqlType)

case _ => false
}

Expand Down Expand Up @@ -1131,42 +1123,33 @@ case class Cast(
variant.VariantGet.cast(v, to, evalMode != EvalMode.TRY, timeZoneId, zoneId)
})
} else {
from match {
// `castToString` has special handling for `UserDefinedType`
case udt: UserDefinedType[_] if !to.isInstanceOf[StringType] =>
castInternal(udt.sqlType, to)
case _ =>
to match {
case dt if dt == from => identity[Any]
case VariantType => input =>
variant.VariantExpressionEvalUtils.castToVariant(input, from)
case _: StringType => castToString(from)
case BinaryType => castToBinary(from)
case DateType => castToDate(from)
case decimal: DecimalType => castToDecimal(from, decimal)
case TimestampType => castToTimestamp(from)
case TimestampNTZType => castToTimestampNTZ(from)
case CalendarIntervalType => castToInterval(from)
case it: DayTimeIntervalType => castToDayTimeInterval(from, it)
case it: YearMonthIntervalType => castToYearMonthInterval(from, it)
case BooleanType => castToBoolean(from)
case ByteType => castToByte(from)
case ShortType => castToShort(from)
case IntegerType => castToInt(from)
case FloatType => castToFloat(from)
case LongType => castToLong(from)
case DoubleType => castToDouble(from)
case array: ArrayType =>
castArray(from.asInstanceOf[ArrayType].elementType, array.elementType)
case map: MapType => castMap(from.asInstanceOf[MapType], map)
case struct: StructType => castStruct(from.asInstanceOf[StructType], struct)
case udt: UserDefinedType[_] if udt.acceptsType(from) =>
identity[Any]
case udt: UserDefinedType[_] =>
castInternal(from, udt.sqlType)
case _ =>
throw QueryExecutionErrors.cannotCastError(from, to)
}
to match {
case dt if dt == from => identity[Any]
case VariantType => input => variant.VariantExpressionEvalUtils.castToVariant(input, from)
case _: StringType => castToString(from)
case BinaryType => castToBinary(from)
case DateType => castToDate(from)
case decimal: DecimalType => castToDecimal(from, decimal)
case TimestampType => castToTimestamp(from)
case TimestampNTZType => castToTimestampNTZ(from)
case CalendarIntervalType => castToInterval(from)
case it: DayTimeIntervalType => castToDayTimeInterval(from, it)
case it: YearMonthIntervalType => castToYearMonthInterval(from, it)
case BooleanType => castToBoolean(from)
case ByteType => castToByte(from)
case ShortType => castToShort(from)
case IntegerType => castToInt(from)
case FloatType => castToFloat(from)
case LongType => castToLong(from)
case DoubleType => castToDouble(from)
case array: ArrayType =>
castArray(from.asInstanceOf[ArrayType].elementType, array.elementType)
case map: MapType => castMap(from.asInstanceOf[MapType], map)
case struct: StructType => castStruct(from.asInstanceOf[StructType], struct)
case udt: UserDefinedType[_] if udt.acceptsType(from) =>
identity[Any]
case _: UserDefinedType[_] =>
throw QueryExecutionErrors.cannotCastError(from, to)
}
}
}
Expand Down Expand Up @@ -1228,64 +1211,54 @@ case class Cast(
private[this] def nullSafeCastFunction(
from: DataType,
to: DataType,
ctx: CodegenContext): CastFunction = {
from match {
// `castToStringCode` has special handling for `UserDefinedType`
case udt: UserDefinedType[_] if !to.isInstanceOf[StringType] =>
nullSafeCastFunction(udt.sqlType, to, ctx)
case _ =>
to match {

case _ if from == NullType => (c, evPrim, evNull) => code"$evNull = true;"
case _ if to == from => (c, evPrim, evNull) => code"$evPrim = $c;"
case _ if from.isInstanceOf[VariantType] => (c, evPrim, evNull) =>
val tmp = ctx.freshVariable("tmp", classOf[Object])
val dataTypeArg = ctx.addReferenceObj("dataType", to)
val zoneStrArg = ctx.addReferenceObj("zoneStr", timeZoneId)
val zoneIdArg = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
val failOnError = evalMode != EvalMode.TRY
val cls = classOf[variant.VariantGet].getName
code"""
Object $tmp = $cls.cast($c, $dataTypeArg, $failOnError, $zoneStrArg, $zoneIdArg);
if ($tmp == null) {
$evNull = true;
} else {
$evPrim = (${CodeGenerator.boxedType(to)})$tmp;
}
"""
case VariantType =>
val cls = variant.VariantExpressionEvalUtils.getClass.getName.stripSuffix("$")
val fromArg = ctx.addReferenceObj("from", from)
(c, evPrim, evNull) => code"$evPrim = $cls.castToVariant($c, $fromArg);"
case _: StringType => (c, evPrim, _) => castToStringCode(from, ctx).apply(c, evPrim)
case BinaryType => castToBinaryCode(from)
case DateType => castToDateCode(from, ctx)
case decimal: DecimalType => castToDecimalCode(from, decimal, ctx)
case TimestampType => castToTimestampCode(from, ctx)
case TimestampNTZType => castToTimestampNTZCode(from, ctx)
case CalendarIntervalType => castToIntervalCode(from)
case it: DayTimeIntervalType => castToDayTimeIntervalCode(from, it)
case it: YearMonthIntervalType => castToYearMonthIntervalCode(from, it)
case BooleanType => castToBooleanCode(from, ctx)
case ByteType => castToByteCode(from, ctx)
case ShortType => castToShortCode(from, ctx)
case IntegerType => castToIntCode(from, ctx)
case FloatType => castToFloatCode(from, ctx)
case LongType => castToLongCode(from, ctx)
case DoubleType => castToDoubleCode(from, ctx)

case array: ArrayType =>
castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx)
case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx)
case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx)
case udt: UserDefinedType[_] if udt.acceptsType(from) =>
(c, evPrim, evNull) => code"$evPrim = $c;"
case udt: UserDefinedType[_] =>
nullSafeCastFunction(from, udt.sqlType, ctx)
case _ =>
throw QueryExecutionErrors.cannotCastError(from, to)
ctx: CodegenContext): CastFunction = to match {

case _ if from == NullType => (c, evPrim, evNull) => code"$evNull = true;"
case _ if to == from => (c, evPrim, evNull) => code"$evPrim = $c;"
case _ if from.isInstanceOf[VariantType] => (c, evPrim, evNull) =>
val tmp = ctx.freshVariable("tmp", classOf[Object])
val dataTypeArg = ctx.addReferenceObj("dataType", to)
val zoneStrArg = ctx.addReferenceObj("zoneStr", timeZoneId)
val zoneIdArg = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName)
val failOnError = evalMode != EvalMode.TRY
val cls = classOf[variant.VariantGet].getName
code"""
Object $tmp = $cls.cast($c, $dataTypeArg, $failOnError, $zoneStrArg, $zoneIdArg);
if ($tmp == null) {
$evNull = true;
} else {
$evPrim = (${CodeGenerator.boxedType(to)})$tmp;
}
}
"""
case VariantType =>
val cls = variant.VariantExpressionEvalUtils.getClass.getName.stripSuffix("$")
val fromArg = ctx.addReferenceObj("from", from)
(c, evPrim, evNull) => code"$evPrim = $cls.castToVariant($c, $fromArg);"
case _: StringType => (c, evPrim, _) => castToStringCode(from, ctx).apply(c, evPrim)
case BinaryType => castToBinaryCode(from)
case DateType => castToDateCode(from, ctx)
case decimal: DecimalType => castToDecimalCode(from, decimal, ctx)
case TimestampType => castToTimestampCode(from, ctx)
case TimestampNTZType => castToTimestampNTZCode(from, ctx)
case CalendarIntervalType => castToIntervalCode(from)
case it: DayTimeIntervalType => castToDayTimeIntervalCode(from, it)
case it: YearMonthIntervalType => castToYearMonthIntervalCode(from, it)
case BooleanType => castToBooleanCode(from, ctx)
case ByteType => castToByteCode(from, ctx)
case ShortType => castToShortCode(from, ctx)
case IntegerType => castToIntCode(from, ctx)
case FloatType => castToFloatCode(from, ctx)
case LongType => castToLongCode(from, ctx)
case DoubleType => castToDoubleCode(from, ctx)

case array: ArrayType =>
castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx)
case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx)
case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx)
case udt: UserDefinedType[_] if udt.acceptsType(from) =>
(c, evPrim, evNull) => code"$evPrim = $c;"
case _: UserDefinedType[_] =>
throw QueryExecutionErrors.cannotCastError(from, to)
}

// Since we need to cast input expressions recursively inside ComplexTypes, such as Map's
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -441,53 +441,47 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression {
override def eval(input: InternalRow): Any = value

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
def gen(ctx: CodegenContext, ev: ExprCode, dataType: DataType): ExprCode = {
val javaType = CodeGenerator.javaType(dataType)
if (value == null) {
ExprCode.forNullValue(dataType)
} else {
def toExprCode(code: String): ExprCode = {
ExprCode.forNonNullValue(JavaCode.literal(code, dataType))
}

dataType match {
case BooleanType | IntegerType | DateType | _: YearMonthIntervalType =>
toExprCode(value.toString)
case FloatType =>
value.asInstanceOf[Float] match {
case v if v.isNaN =>
toExprCode("Float.NaN")
case Float.PositiveInfinity =>
toExprCode("Float.POSITIVE_INFINITY")
case Float.NegativeInfinity =>
toExprCode("Float.NEGATIVE_INFINITY")
case _ =>
toExprCode(s"${value}F")
}
case DoubleType =>
value.asInstanceOf[Double] match {
case v if v.isNaN =>
toExprCode("Double.NaN")
case Double.PositiveInfinity =>
toExprCode("Double.POSITIVE_INFINITY")
case Double.NegativeInfinity =>
toExprCode("Double.NEGATIVE_INFINITY")
case _ =>
toExprCode(s"${value}D")
}
case ByteType | ShortType =>
ExprCode.forNonNullValue(JavaCode.expression(s"($javaType)$value", dataType))
case TimestampType | TimestampNTZType | LongType | _: DayTimeIntervalType =>
toExprCode(s"${value}L")
case udt: UserDefinedType[_] =>
gen(ctx, ev, udt.sqlType)
case _ =>
val constRef = ctx.addReferenceObj("literal", value, javaType)
ExprCode.forNonNullValue(JavaCode.global(constRef, dataType))
}
val javaType = CodeGenerator.javaType(dataType)
if (value == null) {
ExprCode.forNullValue(dataType)
} else {
def toExprCode(code: String): ExprCode = {
ExprCode.forNonNullValue(JavaCode.literal(code, dataType))
}
dataType match {
case BooleanType | IntegerType | DateType | _: YearMonthIntervalType =>
toExprCode(value.toString)
case FloatType =>
value.asInstanceOf[Float] match {
case v if v.isNaN =>
toExprCode("Float.NaN")
case Float.PositiveInfinity =>
toExprCode("Float.POSITIVE_INFINITY")
case Float.NegativeInfinity =>
toExprCode("Float.NEGATIVE_INFINITY")
case _ =>
toExprCode(s"${value}F")
}
case DoubleType =>
value.asInstanceOf[Double] match {
case v if v.isNaN =>
toExprCode("Double.NaN")
case Double.PositiveInfinity =>
toExprCode("Double.POSITIVE_INFINITY")
case Double.NegativeInfinity =>
toExprCode("Double.NEGATIVE_INFINITY")
case _ =>
toExprCode(s"${value}D")
}
case ByteType | ShortType =>
ExprCode.forNonNullValue(JavaCode.expression(s"($javaType)$value", dataType))
case TimestampType | TimestampNTZType | LongType | _: DayTimeIntervalType =>
toExprCode(s"${value}L")
case _ =>
val constRef = ctx.addReferenceObj("literal", value, javaType)
ExprCode.forNonNullValue(JavaCode.global(constRef, dataType))
}
}
gen(ctx, ev, dataType)
}

override def sql: String = (value, dataType) match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import java.sql.{Date, Timestamp}
import java.time.{Duration, LocalDate, LocalDateTime, Period, Year => JYear}
import java.time.{Duration, LocalDate, LocalDateTime, Period}
import java.time.temporal.ChronoUnit
import java.util.{Calendar, Locale, TimeZone}

Expand All @@ -37,7 +37,6 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes, yearMonthIntervalTypes}
import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, MINUTE, SECOND}
import org.apache.spark.sql.types.TestUDT._
import org.apache.spark.sql.types.UpCastRule.numericPrecedence
import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR}
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -1410,43 +1409,4 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
assert(!Cast(timestampLiteral, TimestampNTZType).resolved)
assert(!Cast(timestampNTZLiteral, TimestampType).resolved)
}

test("SPARK-49787: Cast between UDT and other types") {
val value = new MyDenseVector(Array(1.0, 2.0, -1.0))
val udtType = new MyDenseVectorUDT()
val targetType = ArrayType(DoubleType, containsNull = false)

val serialized = udtType.serialize(value)

checkEvaluation(Cast(new Literal(serialized, udtType), targetType), serialized)
checkEvaluation(Cast(new Literal(serialized, targetType), udtType), serialized)

val year = JYear.parse("2024")
val yearUDTType = new YearUDT()

val yearSerialized = yearUDTType.serialize(year)

checkEvaluation(Cast(new Literal(yearSerialized, yearUDTType), IntegerType), 2024)
checkEvaluation(Cast(new Literal(2024, IntegerType), yearUDTType), yearSerialized)

val yearString = UTF8String.fromString("2024")
checkEvaluation(Cast(new Literal(yearSerialized, yearUDTType), StringType), yearString)
checkEvaluation(Cast(new Literal(yearString, StringType), yearUDTType), yearSerialized)
}
}

private[sql] class YearUDT extends UserDefinedType[JYear] {
override def sqlType: DataType = IntegerType

override def serialize(obj: JYear): Int = {
obj.getValue
}

def deserialize(datum: Any): JYear = datum match {
case value: Int => JYear.of(value)
}

override def userClass: Class[JYear] = classOf[JYear]

private[spark] override def asNullable: YearUDT = this
}